Skip to content

Commit

Permalink
Polish data preparation codes (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Jul 31, 2020
1 parent 22196b2 commit 9a07fc2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 69 deletions.
14 changes: 2 additions & 12 deletions tools/data/build_file_list.py
Expand Up @@ -150,19 +150,9 @@ def build_list(split):
def main():
args = parse_args()

if args.level == 2:
# search for two-level directory
def key_func(x):
return osp.join(osp.basename(osp.dirname(x)), osp.basename(x))
else:
# Only search for one-level directory
def key_func(x):
return osp.basename(x)

if args.format == 'rawframes':
frame_info = parse_directory(
args.src_folder,
key_func=key_func,
rgb_prefix=args.rgb_prefix,
flow_x_prefix=args.flow_x_prefix,
flow_y_prefix=args.flow_y_prefix,
Expand Down Expand Up @@ -191,9 +181,9 @@ def key_func(x):
elif args.dataset == 'sthv2':
splits = parse_sthv2_splits(args.level)
elif args.dataset == 'mit':
splits = parse_mit_splits(args.level)
splits = parse_mit_splits()
elif args.dataset == 'mmit':
splits = parse_mmit_splits(args.level)
splits = parse_mmit_splits()
elif args.dataset == 'kinetics400':
splits = parse_kinetics_splits(args.level)
elif args.dataset == 'hmdb51':
Expand Down
110 changes: 53 additions & 57 deletions tools/data/parse_file_list.py
Expand Up @@ -7,17 +7,14 @@


def parse_directory(path,
key_func=lambda x: osp.basename(x),
rgb_prefix='img_',
flow_x_prefix='flow_x_',
flow_y_prefix='flow_y_',
level=1):
"""Parse directories holding extracted frames from standard benchmarks.
Args:
path (str): Folder path to parse frames.
key_func (callable): Function to do key mapping.
default: lambda x: osp.basename(x).
path (str): Directory path to parse frames.
rgb_prefix (str): Prefix of generated rgb frames name.
default: 'img_'.
flow_x_prefix (str): Prefix of generated flow x name.
Expand All @@ -31,11 +28,21 @@ def parse_directory(path,
dict: frame info dict with video id as key and tuple(path(str),
rgb_num(int), flow_x_num(int)) as value.
"""
print(f'parse frames under folder {path}')
print(f'parse frames under directory {path}')
if level == 1:
frame_folders = glob.glob(osp.join(path, '*'))
# Only search for one-level directory
def locate_directory(x):
return osp.basename(x)

frame_dirs = glob.glob(osp.join(path, '*'))

elif level == 2:
frame_folders = glob.glob(osp.join(path, '*', '*'))
# search for two-level directory
def locate_directory(x):
return osp.join(osp.basename(osp.dirname(x)), osp.basename(x))

frame_dirs = glob.glob(osp.join(path, '*', '*'))

else:
raise ValueError('level can be only 1 or 2')

Expand All @@ -55,22 +62,22 @@ def count_files(directory, prefix_list):

# check RGB
frame_dict = {}
for i, frame_folder in enumerate(frame_folders):
total_num = count_files(frame_folder,
for i, frame_dir in enumerate(frame_dirs):
total_num = count_files(frame_dir,
(rgb_prefix, flow_x_prefix, flow_y_prefix))
k = key_func(frame_folder)
dir_name = locate_directory(frame_dir)

num_x = total_num[1]
num_y = total_num[2]
if num_x != num_y:
raise ValueError(f'x and y direction have different number '
f'of flow images in video folder: {frame_folder}')
f'of flow images in video directory: {frame_dir}')
if i % 200 == 0:
print('{} videos parsed'.format(i))
print(f'{i} videos parsed')

frame_dict[k] = (frame_folder, total_num[0], num_x)
frame_dict[dir_name] = (frame_dir, total_num[0], num_x)

print('frame folder analysis done')
print('frame directory analysis done')
return frame_dict


Expand Down Expand Up @@ -126,6 +133,7 @@ def line_to_map(line):

def parse_sthv1_splits(level):
"""Parse Something-Something dataset V1 into "train", "val" splits.
Args:
level (int): Directory level of data. 1 for the single-level directory,
2 for the two-level directory.
Expand Down Expand Up @@ -222,12 +230,9 @@ def line_to_map(item, test_mode=False):
return splits


def parse_mmit_splits(level):
def parse_mmit_splits():
"""Parse Multi-Moments in Time dataset into "train", "val" splits.
Args:
level: directory level of data.
Returns:
list: "train", "val", "test" splits of Multi-Moments in Time.
"""
Expand Down Expand Up @@ -331,12 +336,9 @@ def line_to_map(x, test=False):
return splits


def parse_mit_splits(level):
def parse_mit_splits():
"""Parse Moments in Time dataset into "train", "val" splits.
Args:
level: directory level of data.
Returns:
list: "train", "val", "test" splits of Moments in Time.
"""
Expand All @@ -347,7 +349,7 @@ def parse_mit_splits(level):
cat, digit = line.rstrip().split(',')
class_mapping[cat] = int(digit)

def line_to_map(x, test=False):
def line_to_map(x):
vid = osp.splitext(x[0])[0]
label = class_mapping[osp.dirname(x[0])]
return vid, label
Expand All @@ -370,6 +372,8 @@ def parse_hmdb51_split(level):
class_index_file = 'data/hmdb51/annotations/classInd.txt'

def generate_class_index_file():
"""This function will generate a `ClassInd.txt` for HMDB51 in a format
like UCF101, where class id starts with 1."""
frame_path = 'data/hmdb51/rawframes'
annotation_dir = 'data/hmdb51/annotations'

Expand All @@ -378,49 +382,41 @@ def generate_class_index_file():
with open(class_index_file, 'w') as f:
content = []
for class_id, class_name in enumerate(class_list):
# like `ClassInd.txt` in UCF-101, the class_id begins with 1
class_dict[class_name] = class_id + 1
cur_line = ' '.join([str(class_id + 1), class_name])
content.append(cur_line)
content = '\n'.join(content)
f.write(content)

for i in range(1, 4):
train_content = []
test_content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '1':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
train_content.append(target_line)
elif video_info[1] == '2':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
test_content.append(target_line)
train_content = '\n'.join(train_content)
test_content = '\n'.join(test_content)
with open(train_file_template.format(i), 'w') as fout:
content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '1':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
content.append(target_line)
content = '\n'.join(content)
fout.write(content)

for i in range(1, 4):
fout.write(train_content)
with open(test_file_template.format(i), 'w') as fout:
content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '2':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
content.append(target_line)
content = '\n'.join(content)
fout.write(content)
fout.write(test_content)

if not osp.exists(class_index_file):
generate_class_index_file()
Expand Down

0 comments on commit 9a07fc2

Please sign in to comment.