Skip to content

Commit

Permalink
Support start_index in GenerateSegmentIndices (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckkelvinchan committed May 31, 2021
1 parent 6238c69 commit 8473c2b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
8 changes: 7 additions & 1 deletion demo/restoration_video_demo.py
Expand Up @@ -13,6 +13,11 @@ def parse_args():
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('input_dir', help='directory of the input video')
parser.add_argument('output_dir', help='directory of the output video')
parser.add_argument(
'--start_idx',
type=int,
default=0,
help='index corresponds to the first frame of the sequence')
parser.add_argument(
'--filename_tmpl',
default='{:08d}.png',
Expand All @@ -34,7 +39,8 @@ def main():
args.config, args.checkpoint, device=torch.device('cuda', args.device))

output = restoration_video_inference(model, args.input_dir,
args.window_size, args.filename_tmpl)
args.window_size, args.start_idx,
args.filename_tmpl)
for i in range(0, output.size(1)):
output_i = output[:, i, :, :, :]
output_i = tensor2img(output_i)
Expand Down
7 changes: 6 additions & 1 deletion mmedit/apis/restoration_video_inference.py
Expand Up @@ -18,7 +18,8 @@ def pad_sequence(data, window_size):
return data


def restoration_video_inference(model, img_dir, window_size, filename_tmpl):
def restoration_video_inference(model, img_dir, window_size, start_idx,
filename_tmpl):
"""Inference image with the model.
Args:
Expand All @@ -27,6 +28,9 @@ def restoration_video_inference(model, img_dir, window_size, filename_tmpl):
window_size (int): The window size used in sliding-window framework.
This value should be set according to the settings of the network.
A value smaller than 0 means using recurrent framework.
start_idx (int): The index corresponds to the first frame in the
sequence.
filename_tmpl (str): Template for file name.
Returns:
Tensor: The predicted restoration result.
Expand All @@ -39,6 +43,7 @@ def restoration_video_inference(model, img_dir, window_size, filename_tmpl):
dict(
type='GenerateSegmentIndices',
interval_list=[1],
start_idx=start_idx,
filename_tmpl=filename_tmpl),
dict(
type='LoadImageFromFileList',
Expand Down
6 changes: 5 additions & 1 deletion mmedit/datasets/pipelines/augmentation.py
Expand Up @@ -927,12 +927,15 @@ class GenerateSegmentIndices:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
start_idx (int): The index corresponds to the first frame in the
sequence. Default: 0.
filename_tmpl (str): Template for file name. Default: '{:08d}.png'.
"""

def __init__(self, interval_list, filename_tmpl='{:08d}.png'):
def __init__(self, interval_list, start_idx=0, filename_tmpl='{:08d}.png'):
self.interval_list = interval_list
self.filename_tmpl = filename_tmpl
self.start_idx = start_idx

def __call__(self, results):
"""Call function.
Expand Down Expand Up @@ -961,6 +964,7 @@ def __call__(self, results):
0, self.sequence_length - num_input_frames * interval + 1)
end_frame_idx = start_frame_idx + num_input_frames * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
neighbor_list = [v + self.start_idx for v in neighbor_list]

# add the corresponding file paths
lq_path_root = results['lq_path']
Expand Down

0 comments on commit 8473c2b

Please sign in to comment.