diff --git a/demo/restoration_video_demo.py b/demo/restoration_video_demo.py index b44e8efb14..62235919a7 100644 --- a/demo/restoration_video_demo.py +++ b/demo/restoration_video_demo.py @@ -13,6 +13,11 @@ def parse_args(): parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('input_dir', help='directory of the input video') parser.add_argument('output_dir', help='directory of the output video') + parser.add_argument( + '--start_idx', + type=int, + default=0, + help='index corresponds to the first frame of the sequence') parser.add_argument( '--filename_tmpl', default='{:08d}.png', @@ -34,7 +39,8 @@ def main(): args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = restoration_video_inference(model, args.input_dir, - args.window_size, args.filename_tmpl) + args.window_size, args.start_idx, + args.filename_tmpl) for i in range(0, output.size(1)): output_i = output[:, i, :, :, :] output_i = tensor2img(output_i) diff --git a/mmedit/apis/restoration_video_inference.py b/mmedit/apis/restoration_video_inference.py index 63c1c647e5..698008835f 100644 --- a/mmedit/apis/restoration_video_inference.py +++ b/mmedit/apis/restoration_video_inference.py @@ -18,7 +18,8 @@ def pad_sequence(data, window_size): return data -def restoration_video_inference(model, img_dir, window_size, filename_tmpl): +def restoration_video_inference(model, img_dir, window_size, start_idx, + filename_tmpl): """Inference image with the model. Args: @@ -27,6 +28,9 @@ def restoration_video_inference(model, img_dir, window_size, filename_tmpl): window_size (int): The window size used in sliding-window framework. This value should be set according to the settings of the network. A value smaller than 0 means using recurrent framework. + start_idx (int): The index corresponds to the first frame in the + sequence. + filename_tmpl (str): Template for file name. Returns: Tensor: The predicted restoration result. @@ -39,6 +43,7 @@ def restoration_video_inference(model, img_dir, window_size, filename_tmpl): dict( type='GenerateSegmentIndices', interval_list=[1], + start_idx=start_idx, filename_tmpl=filename_tmpl), dict( type='LoadImageFromFileList', diff --git a/mmedit/datasets/pipelines/augmentation.py b/mmedit/datasets/pipelines/augmentation.py index 1d2b97a14e..4cba61b2f1 100644 --- a/mmedit/datasets/pipelines/augmentation.py +++ b/mmedit/datasets/pipelines/augmentation.py @@ -927,12 +927,15 @@ class GenerateSegmentIndices: interval_list (list[int]): Interval list for temporal augmentation. It will randomly pick an interval from interval_list and sample frame index with the interval. + start_idx (int): The index corresponds to the first frame in the + sequence. Default: 0. filename_tmpl (str): Template for file name. Default: '{:08d}.png'. """ - def __init__(self, interval_list, filename_tmpl='{:08d}.png'): + def __init__(self, interval_list, start_idx=0, filename_tmpl='{:08d}.png'): self.interval_list = interval_list self.filename_tmpl = filename_tmpl + self.start_idx = start_idx def __call__(self, results): """Call function. @@ -961,6 +964,7 @@ def __call__(self, results): 0, self.sequence_length - num_input_frames * interval + 1) end_frame_idx = start_frame_idx + num_input_frames * interval neighbor_list = list(range(start_frame_idx, end_frame_idx, interval)) + neighbor_list = [v + self.start_idx for v in neighbor_list] # add the corresponding file paths lq_path_root = results['lq_path']