Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish data preparation codes #70

Merged
merged 1 commit into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions tools/data/build_file_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,9 @@ def build_list(split):
def main():
args = parse_args()

if args.level == 2:
# search for two-level directory
def key_func(x):
return osp.join(osp.basename(osp.dirname(x)), osp.basename(x))
else:
# Only search for one-level directory
def key_func(x):
return osp.basename(x)

if args.format == 'rawframes':
frame_info = parse_directory(
args.src_folder,
key_func=key_func,
rgb_prefix=args.rgb_prefix,
flow_x_prefix=args.flow_x_prefix,
flow_y_prefix=args.flow_y_prefix,
Expand Down Expand Up @@ -191,9 +181,9 @@ def key_func(x):
elif args.dataset == 'sthv2':
splits = parse_sthv2_splits(args.level)
elif args.dataset == 'mit':
splits = parse_mit_splits(args.level)
splits = parse_mit_splits()
elif args.dataset == 'mmit':
splits = parse_mmit_splits(args.level)
splits = parse_mmit_splits()
elif args.dataset == 'kinetics400':
splits = parse_kinetics_splits(args.level)
elif args.dataset == 'hmdb51':
Expand Down
110 changes: 53 additions & 57 deletions tools/data/parse_file_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@


def parse_directory(path,
key_func=lambda x: osp.basename(x),
rgb_prefix='img_',
flow_x_prefix='flow_x_',
flow_y_prefix='flow_y_',
level=1):
"""Parse directories holding extracted frames from standard benchmarks.

Args:
path (str): Folder path to parse frames.
key_func (callable): Function to do key mapping.
default: lambda x: osp.basename(x).
path (str): Directory path to parse frames.
rgb_prefix (str): Prefix of generated rgb frames name.
default: 'img_'.
flow_x_prefix (str): Prefix of generated flow x name.
Expand All @@ -31,11 +28,21 @@ def parse_directory(path,
dict: frame info dict with video id as key and tuple(path(str),
rgb_num(int), flow_x_num(int)) as value.
"""
print(f'parse frames under folder {path}')
print(f'parse frames under directory {path}')
if level == 1:
frame_folders = glob.glob(osp.join(path, '*'))
# Only search for one-level directory
def locate_directory(x):
return osp.basename(x)

frame_dirs = glob.glob(osp.join(path, '*'))

elif level == 2:
frame_folders = glob.glob(osp.join(path, '*', '*'))
# search for two-level directory
def locate_directory(x):
return osp.join(osp.basename(osp.dirname(x)), osp.basename(x))

frame_dirs = glob.glob(osp.join(path, '*', '*'))

else:
raise ValueError('level can be only 1 or 2')

Expand All @@ -55,22 +62,22 @@ def count_files(directory, prefix_list):

# check RGB
frame_dict = {}
for i, frame_folder in enumerate(frame_folders):
total_num = count_files(frame_folder,
for i, frame_dir in enumerate(frame_dirs):
total_num = count_files(frame_dir,
(rgb_prefix, flow_x_prefix, flow_y_prefix))
k = key_func(frame_folder)
dir_name = locate_directory(frame_dir)

num_x = total_num[1]
num_y = total_num[2]
if num_x != num_y:
raise ValueError(f'x and y direction have different number '
f'of flow images in video folder: {frame_folder}')
f'of flow images in video directory: {frame_dir}')
if i % 200 == 0:
print('{} videos parsed'.format(i))
print(f'{i} videos parsed')

frame_dict[k] = (frame_folder, total_num[0], num_x)
frame_dict[dir_name] = (frame_dir, total_num[0], num_x)

print('frame folder analysis done')
print('frame directory analysis done')
return frame_dict


Expand Down Expand Up @@ -126,6 +133,7 @@ def line_to_map(line):

def parse_sthv1_splits(level):
"""Parse Something-Something dataset V1 into "train", "val" splits.

Args:
level (int): Directory level of data. 1 for the single-level directory,
2 for the two-level directory.
Expand Down Expand Up @@ -222,12 +230,9 @@ def line_to_map(item, test_mode=False):
return splits


def parse_mmit_splits(level):
def parse_mmit_splits():
"""Parse Multi-Moments in Time dataset into "train", "val" splits.

Args:
level: directory level of data.

Returns:
list: "train", "val", "test" splits of Multi-Moments in Time.
"""
Expand Down Expand Up @@ -331,12 +336,9 @@ def line_to_map(x, test=False):
return splits


def parse_mit_splits(level):
def parse_mit_splits():
"""Parse Moments in Time dataset into "train", "val" splits.

Args:
level: directory level of data.

Returns:
list: "train", "val", "test" splits of Moments in Time.
"""
Expand All @@ -347,7 +349,7 @@ def parse_mit_splits(level):
cat, digit = line.rstrip().split(',')
class_mapping[cat] = int(digit)

def line_to_map(x, test=False):
def line_to_map(x):
vid = osp.splitext(x[0])[0]
label = class_mapping[osp.dirname(x[0])]
return vid, label
Expand All @@ -370,6 +372,8 @@ def parse_hmdb51_split(level):
class_index_file = 'data/hmdb51/annotations/classInd.txt'

def generate_class_index_file():
"""This function will generate a `ClassInd.txt` for HMDB51 in a format
like UCF101, where class id starts with 1."""
frame_path = 'data/hmdb51/rawframes'
annotation_dir = 'data/hmdb51/annotations'

Expand All @@ -378,49 +382,41 @@ def generate_class_index_file():
with open(class_index_file, 'w') as f:
content = []
for class_id, class_name in enumerate(class_list):
# like `ClassInd.txt` in UCF-101, the class_id begins with 1
innerlee marked this conversation as resolved.
Show resolved Hide resolved
class_dict[class_name] = class_id + 1
cur_line = ' '.join([str(class_id + 1), class_name])
content.append(cur_line)
content = '\n'.join(content)
f.write(content)

for i in range(1, 4):
train_content = []
test_content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '1':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
train_content.append(target_line)
elif video_info[1] == '2':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
test_content.append(target_line)
train_content = '\n'.join(train_content)
test_content = '\n'.join(test_content)
with open(train_file_template.format(i), 'w') as fout:
content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '1':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
content.append(target_line)
content = '\n'.join(content)
fout.write(content)

for i in range(1, 4):
fout.write(train_content)
with open(test_file_template.format(i), 'w') as fout:
content = []
for class_name in class_dict:
filename = class_name + f'_test_split{i}.txt'
filename_path = osp.join(annotation_dir, filename)
with open(filename_path, 'r') as fin:
for line in fin:
video_info = line.strip().split()
video_name = video_info[0]
if video_info[1] == '2':
target_line = ' '.join([
osp.join(class_name, video_name),
str(class_dict[class_name])
])
content.append(target_line)
content = '\n'.join(content)
fout.write(content)
fout.write(test_content)

if not osp.exists(class_index_file):
generate_class_index_file()
Expand Down