diff --git a/tools/data/ava_kinetics/README.md b/tools/data/ava_kinetics/README.md new file mode 100644 index 0000000000..2d28771320 --- /dev/null +++ b/tools/data/ava_kinetics/README.md @@ -0,0 +1,173 @@ +# Preparing AVA-Kinetics + +## Introduction + + + +```BibTeX +@article{li2020ava, + title={The ava-kinetics localized human actions video dataset}, + author={Li, Ang and Thotakuri, Meghana and Ross, David A and Carreira, Jo{\~a}o and Vostrikov, Alexander and Zisserman, Andrew}, + journal={arXiv preprint arXiv:2005.00214}, + year={2020} +} +``` + +For basic dataset information, please refer to the official [website](https://research.google.com/ava/index.html). +AVA-Kinetics dataset is a crossover between the AVA Actions and Kinetics datasets. You may want to first prepare the AVA datasets. In this file, we provide commands to prepare the Kinetics part and merge the two parts together. + +For model training, we will keep reading from raw frames for the AVA part, but read from videos using `decord` for the Kinetics part to accelerate training. + +Before we start, please make sure that the directory is located at `$MMACTION2/tools/data/ava_kinetics/`. + +## Step 1. Prepare the Kinetics700 dataset + +The Kinetics part of the AVA-Kinetics dataset are sampled from the Kinetics-700 dataset. + +It is best if you have prepared the Kinetics-700 dataset (only videos required) following +[Preparing Kinetics](https://github.com/open-mmlab/mmaction2/tree/master/tools/data/kinetics). We will also have alternative method to prepare these videos if you do not have enough storage (coming soon). + +We will need the videos of this dataset (`$MMACTION2/data/kinetics700/videos_train`) and the videos file list (`$MMACTION2/data/kinetics700/kinetics700_train_list_videos.txt`), which is generated by [Step 4 in Preparing Kinetics](https://github.com/open-mmlab/mmaction2/tree/master/tools/data/kinetics#step-4-generate-file-list) + +The format of the file list should be: + +``` +Path_to_video_1 label_1\n +Path_to_video_2 label_2\n +... +Path_to_video_n label_n\n +``` + +The timestamp (start and end of the video) must be contained. For example: + +``` +class602/o3lCwWyyc_s_000012_000022.mp4 602\n +``` + +It means that this video clip is the 12th to 22nd seconds of the original video. It is okay if some videos are missing, and we will ignore them in the next steps. + +## Step 2. Download Annotations + +Download the annotation tar file (recall that the directory should be located at `$MMACTION2/tools/data/ava_kinetics/`). + +```shell +wget https://storage.googleapis.com/deepmind-media/Datasets/ava_kinetics_v1_0.tar.gz +tar xf ava_kinetics_v1_0.tar.gz && rm ava_kinetics_v1_0.tar.gz +``` + +You should have the `ava_kinetics_v1_0` folder at `$MMACTION2/tools/data/ava_kinetics/`. + +## Step 3. Cut Videos + +Use `cut_kinetics.py` to find the desired videos from the Kinetics-700 dataset and trim them to contain only annotated clips. Currently we only use the train set of the Kinetics part to improve training. Validation on the Kinetics part will come soon. + +Here is the script: + +```shell +python3 cut_kinetics.py --avakinetics_anotation=$AVAKINETICS_ANOTATION \ + --kinetics_list=$KINETICS_LIST \ + --avakinetics_root=$AVAKINETICS_ROOT \ + [--num_workers=$NUM_WORKERS ] +``` + +Arguments: + +- `avakinetics_anotation`: the directory to ava-kinetics anotations. Defaults to `./ava_kinetics_v1_0`. +- `kinetics_list`: the path to the videos file list as mentioned in Step 1. If you have prepared the Kinetics700 dataset following `mmaction2`, it should be `$MMACTION2/data/kinetics700/kinetics700_train_list_videos.txt`. +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. +- `num_workers`: number of workers used to cut videos. Defaults to -1 and use all available cpus. + +There should be about 100k videos. It is OK if some videos are missing and we will ignore them in the next steps. + +## Step 4. Extract RGB Frames + +This step is similar to Step 4 in [Preparing AVA](https://github.com/open-mmlab/mmaction2/tree/dev-1.x/tools/data/ava#step-4-extract-rgb-and-flow). + +Here we provide a script to extract RGB frames using ffmpeg: + +```shell +python3 extract_rgb_frames.py --avakinetics_root=$AVAKINETICS_ROOT \ + [--num_workers=$NUM_WORKERS ] +``` + +Arguments: + +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. +- `num_workers`: number of workers used to extract frames. Defaults to -1 and use all available cpus. + +If you have installed denseflow, you can also use `build_rawframes.py` to extract RGB frames: + +```shell +python3 ../build_rawframes.py ../../../data/ava_kinetics/videos/ ../../../data/ava_kinetics/rawframes/ --task rgb --level 1 --mixed-ext +``` + +## Step 5. Prepare Annotations + +Use `prepare_annotation.py` to prepare the training annotations. It will generate a `kinetics_train.csv` file containning the spatial-temporal annotations for the Kinetics part, localting at `$AVAKINETICS_ROOT`. + +Here is the script: + +```shell +python3 prepare_annotation.py --avakinetics_anotation=$AVAKINETICS_ANOTATION \ + --avakinetics_root=$AVAKINETICS_ROOT \ + [--num_workers=$NUM_WORKERS] +``` + +Arguments: + +- `avakinetics_anotation`: the directory to ava-kinetics anotations. Defaults to `./ava_kinetics_v1_0`. +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. +- `num_workers`: number of workers used to prepare annotations. Defaults to -1 and use all available cpus. + +## Step 6. Fetch Proposal Files + +The pre-computed proposals for AVA dataset are provided by FAIR's [Long-Term Feature Banks](https://github.com/facebookresearch/video-long-term-feature-banks). For the Kinetics part, we use `Cascade R-CNN X-101-64x4d-FPN` from [mmdetection](https://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco/cascade_rcnn_x101_64x4d_fpn_1x_coco_20200515_075702-43ce6a30.pth) to fetch the proposals. Here is the script: + +```shell +python3 fetch_proposal.py --avakinetics_root=$AVAKINETICS_ROOT \ + --datalist=$DATALIST \ + --picklepath=$PICKLEPATH \ + [--config=$CONFIG ] \ + [--checkpoint=$CHECKPOINT ] + +``` + +It will generate a `kinetics_proposal.pkl` file at `$MMACTION2/data/ava_kinetics/`. + +Arguments: + +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. +- `datalist`: path to the `kinetics_train.csv` file generated at Step 3. +- `picklepath`: path to save the extracted proposal pickle file. +- `config`: the config file for the human detection model. Defaults to `X-101-64x4d-FPN.py`. +- `checkpoint`: the checkpoint for the human detection model. Defaults to the `mmdetection` pretraining checkpoint. + +## Step 7. Merge AVA to AVA-Kinetics + +Now we are done with the preparations for the Kinetics part. We need to merge the AVA part into the `ava_kinetics` folder (assuming you have AVA dataset ready at `$MMACTION2/data/ava`). First we make a copy of the AVA anotation to the `ava_kinetics` folder (recall that you are at `$MMACTION2/tools/data/ava_kinetics/`): + +```shell +cp -r ../../../data/ava/annotations/ ../../../data/ava_kinetics/ +``` + +Next we merge the generated anotation files of the Kinetics part to AVA. Please check: you should have two files `kinetics_train.csv` and `kinetics_proposal.pkl` at `$MMACTION2/data/ava_kinetics/` generated from Step 5 and Step 6. Run the following script to merge these two files into `$MMACTION2/data/ava_kinetics/annotations/ava_train_v2.2.csv` and `$MMACTION2/data/ava_kinetics/annotations/ava_dense_proposals_train.FAIR.recall_93.9.pkl` respectively. + +```shell +python3 merge_annotations.py --avakinetics_root=$AVAKINETICS_ROOT +``` + +Arguments: + +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. + +Finally, we need to merge the rawframes of AVA part. You can either copy/move them or generate soft links. The following script is an example to use soft links: + +```shell +python3 softlink_ava.py --avakinetics_root=$AVAKINETICS_ROOT \ + --ava_root=$AVA_ROOT +``` + +Arguments: + +- `avakinetics_root`: the directory to save the ava-kinetics dataset. Defaults to `$MMACTION2/data/ava_kinetics`. +- `ava_root`: the directory to save the ava dataset. Defaults to `$MMACTION2/data/ava`. diff --git a/tools/data/ava_kinetics/X-101-64x4d-FPN.py b/tools/data/ava_kinetics/X-101-64x4d-FPN.py new file mode 100644 index 0000000000..114f80a5e2 --- /dev/null +++ b/tools/data/ava_kinetics/X-101-64x4d-FPN.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +model = dict( + type='CascadeRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNeXt', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d'), + groups=64, + base_width=4), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='SmoothL1Loss', beta=0.1111111111111111, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100))) + +test_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='CocoDataset', + data_root='data/coco/', + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=[ + dict( + type='LoadImageFromFile', + file_client_args=dict(backend='disk')), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) + ])) + +test_evaluator = dict( + type='CocoMetric', + ann_file='data/coco/annotations/instances_val2017.json', + metric='bbox', + format_only=False) + +test_cfg = dict(type='TestLoop') diff --git a/tools/data/ava_kinetics/cut_kinetics.py b/tools/data/ava_kinetics/cut_kinetics.py new file mode 100644 index 0000000000..3582035b10 --- /dev/null +++ b/tools/data/ava_kinetics/cut_kinetics.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import multiprocessing +import os +from collections import defaultdict +from typing import List + +import decord + + +def get_kinetics_frames(kinetics_anotation_file: str) -> dict: + """Given the AVA-kinetics anotation file, return a lookup to map the video + id and the the set of timestamps involved of this video id. + + Args: + kinetics_anotation_file (str): Path to the AVA-like anotation file for + the kinetics subset. + Returns: + dict: the dict keys are the kinetics videos' video id. The values are + the set of timestamps involved. + """ + with open(kinetics_anotation_file) as f: + anotated_frames = [i.split(',') for i in f.readlines()] + anotated_frames = [i for i in anotated_frames if len(i) == 7] + anotated_frames = [(i[0], int(float(i[1]))) for i in anotated_frames] + + frame_lookup = defaultdict(set) + for video_id, timestamp in anotated_frames: + frame_lookup[video_id].add(timestamp) + return frame_lookup + + +def filter_missing_videos(kinetics_list: str, frame_lookup: dict) -> dict: + """Given the kinetics700 dataset list, remove the video ids from the lookup + that are missing videos or frames. + + Args: + kinetics_list (str): Path to the kinetics700 dataset list. + The content of the list should be: + ``` + Path_to_video1 label_1\n + Path_to_video2 label_2\n + ... + Path_to_videon label_n\n + ``` + The start and end of the video must be contained in the filename. + For example: + ``` + class602/o3lCwWyyc_s_000012_000022.mp4\n + ``` + frame_lookup (dict): the dict from `get_kinetics_frames`. + Returns: + dict: the dict keys are the kinetics videos' video id. The values are + the a list of tuples: + (start_of_the_video, end_of_the_video, video_path) + """ + video_lookup = defaultdict(set) + with open(kinetics_list) as f: + for line in f.readlines(): + video_path = line.split(' ')[0] # remove label information + video_name = video_path.split('/')[-1] # get the file name + video_name = video_name.split('.')[0] # remove file extensions + video_name = video_name.split('_') + video_id = '_'.join(video_name[:-2]) + if video_id not in frame_lookup: + continue + + start, end = int(video_name[-2]), int(video_name[-1]) + frames = frame_lookup[video_id] + frames = [frame for frame in frames if start < frame < end] + if len(frames) == 0: + continue + + start, end = max(start, min(frames) - 2), min(end, max(frames) + 2) + video_lookup[video_id].add((start, end, video_path)) + + # Some frame ids exist in multiple videos in the Kinetics dataset. + # The reason is the part of one video may fall into different categories. + # Remove the duplicated records + for video in video_lookup: + if len(video_lookup[video]) == 1: + continue + info_list = list(video_lookup[video]) + removed_list = [] + for i, info_i in enumerate(info_list): + start_i, end_i, _ = info_i + for j in range(i + 1, len(info_list)): + start_j, end_j, _ = info_list[j] + if start_i <= start_j and end_j <= end_i: + removed_list.append(j) + elif start_j <= start_i and end_i <= end_j: + removed_list.append(i) + new_list = [] + for i, info in enumerate(info_list): + if i not in removed_list: + new_list.append(info) + video_lookup[video] = set(new_list) + return video_lookup + + +template = ('ffmpeg -ss %d -t %d -accurate_seek -i' + ' %s -r 30 -avoid_negative_ts 1 %s') + + +def generate_cut_cmds(video_lookup: dict, data_root: str) -> List[str]: + cmds = [] + for video_id in video_lookup: + for start, end, video_path in video_lookup[video_id]: + start0 = int(video_path.split('_')[-2]) + new_path = '%s/%s_%06d_%06d.mp4' % (data_root, video_id, start, + end) + cmd = template % (start - start0, end - start, video_path, + new_path) + cmds.append(cmd) + return cmds + + +def run_cmd(cmd): + os.system(cmd) + return + + +def remove_failed_video(video_path: str) -> None: + """Given the path to the video, delete the video if it cannot be read or if + the actual length of the video is 0.75 seconds shorter than expected.""" + try: + v = decord.VideoReader(video_path) + fps = v.get_avg_fps() + num_frames = len(v) + x = video_path.split('.')[0].split('_') + time = int(x[-1]) - int(x[-2]) + if num_frames < (time - 3 / 4) * fps: + os.remove(video_path) + except: # noqa: E722 + os.remove(video_path) + return + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument( + '--avakinetics_anotation', + type=str, + default='./ava_kinetics_v1_0', + help='the directory to ava-kinetics anotations') + p.add_argument( + '--kinetics_list', + type=str, + help='the datalist of the kinetics700 training videos') + p.add_argument( + '--num_workers', + type=int, + default=-1, + help='number of workers used for multiprocessing') + p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics dataset') + args = p.parse_args() + + if args.num_workers > 0: + num_workers = args.num_workers + else: + num_workers = max(multiprocessing.cpu_count() - 1, 1) + + # Find videos from the Kinetics700 dataset required for AVA-Kinetics + kinetics_train = args.avakinetics_anotation + '/kinetics_train_v1.0.csv' + frame_lookup = get_kinetics_frames(kinetics_train) + video_lookup = filter_missing_videos(args.kinetics_list, frame_lookup) + + root = args.avakinetics_root + os.makedirs(root, exist_ok=True) + video_path = root + '/videos/' + os.makedirs(video_path, exist_ok=True) + all_cmds = generate_cut_cmds(video_lookup, video_path) + + # Cut and save the videos for AVA-Kinetics + pool = multiprocessing.Pool(num_workers) + _ = pool.map(run_cmd, all_cmds) + + # Remove failed videos + videos = os.listdir(video_path) + videos = ['%s/%s' % (video_path, video) for video in videos] + _ = pool.map(remove_failed_video, videos) diff --git a/tools/data/ava_kinetics/extract_rgb_frames.py b/tools/data/ava_kinetics/extract_rgb_frames.py new file mode 100644 index 0000000000..25a3251b56 --- /dev/null +++ b/tools/data/ava_kinetics/extract_rgb_frames.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import multiprocessing +import os + + +def extract_rgb(video_name, frame_path, video_path): + video_id = video_name.split('.')[0] + os.makedirs('%s/%s' % (frame_path, video_id), exist_ok=True) + cmd = 'ffmpeg -i %s/%s -r 30 -q:v 1 %s/%s' % (video_path, video_name, + frame_path, video_id) + cmd += '/img_%05d.jpg' + return cmd + + +def run_cmd(cmd): + os.system(cmd) + return + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics dataset') + p.add_argument( + '--num_workers', + type=int, + default=-1, + help='number of workers used for multiprocessing') + args = p.parse_args() + + if args.num_workers > 0: + num_workers = args.num_workers + else: + num_workers = max(multiprocessing.cpu_count() - 1, 1) + + root = args.avakinetics_root + video_path = root + '/videos/' + frame_path = root + '/rawframes/' + os.makedirs(frame_path, exist_ok=True) + + all_cmds = [ + extract_rgb(video_name, frame_path, video_path) + for video_name in os.listdir(video_path) + ] + + pool = multiprocessing.Pool(num_workers) + out = pool.map(run_cmd, all_cmds) diff --git a/tools/data/ava_kinetics/fetch_proposal.py b/tools/data/ava_kinetics/fetch_proposal.py new file mode 100644 index 0000000000..6e5279d4b4 --- /dev/null +++ b/tools/data/ava_kinetics/fetch_proposal.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import multiprocessing as mp +import os +import pickle + +import numpy as np +from mmdet.apis import inference_detector, init_detector +from mmdet.utils import register_all_modules +from PIL import Image + + +def get_vid_from_path(path): + video_id = path.split('/')[-1] + video_id = video_id.split('_')[:-2] + return '_'.join(video_id) + + +def prepare_det_lookup(datalist, frame_root): + with open(datalist) as f: + records = f.readlines() + det_lookup = {} + for record in records: + record = record.split(',') + folder_path = record[0] + video_id = get_vid_from_path(folder_path) + frame_id = int(record[1]) + for idx in range(frame_id - 1, frame_id + 2): + proposal_id = '%s,%04d' % (video_id, idx) + det_lookup[proposal_id] = '%s/%s' % (frame_root, folder_path) + return det_lookup + + +def single_worker(rank, det_lookup, args): + detect_list = list(det_lookup) + detect_sublist = [ + detect_list[i] for i in range(len(detect_list)) + if i % args.num_gpus == rank + ] + + # register all modules in mmdet into the registries + register_all_modules() + model = init_detector( + args.config, args.checkpoint, device='cuda:%d' % rank) + + lookup = {} + for count, key in enumerate(detect_sublist): + try: + folder_path = det_lookup[key] + start = int(folder_path.split('/')[-1].split('_')[-2]) + time = int(key.split(',')[1]) + frame_id = (time - start) * 30 + 1 + frame_path = '%s/img_%05d.jpg' % (folder_path, frame_id) + img = Image.open(frame_path) + result = inference_detector(model, frame_path) + bboxes = result._pred_instances.bboxes.cpu() + scores = result._pred_instances.scores.cpu() + labels = result._pred_instances.labels.cpu() + + bboxes = bboxes[labels == 0] + scores = scores[labels == 0] + + bboxes = bboxes[scores > 0.7].numpy() + scores = scores[scores > 0.7] + if scores.numel() > 0: + result_ = [] + for idx, (h1, w1, h2, w2) in enumerate(bboxes): + h1 /= img.size[0] + h2 /= img.size[0] + w1 /= img.size[1] + w2 /= img.size[1] + score = scores[idx].item() + result_.append((h1, w1, h2, w2, score)) + lookup[key] = np.array(result_) + except: # noqa: E722 + pass + + with open('tmp_person_%d.pkl' % rank, 'wb') as f: + pickle.dump(lookup, f) + return + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics dataset') + p.add_argument( + '--datalist', + type=str, + default='../../../data/ava_kinetics/kinetics_train.csv', + help='the list for kinetics videos') + p.add_argument( + '--config', + type=str, + default='X-101-64x4d-FPN.py', + help='the human detector') + p.add_argument( + '--checkpoint', + type=str, + default='https://download.openmmlab.com/mmdetection/v2.0/' + 'cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco/' + 'cascade_rcnn_x101_64x4d_fpn_1x_coco_20200515_' + '075702-43ce6a30.pth', + help='the human detector checkpoint') + p.add_argument( + '--picklepath', + type=str, + default='../../../data/ava_kinetics/kinetics_proposal.pkl') + p.add_argument('--num_gpus', type=int, default=8) + + args = p.parse_args() + + frame_root = args.avakinetics_root + '/rawframes/' + det_lookup = prepare_det_lookup(args.datalist, frame_root) + + processes = [] + for rank in range(args.num_gpus): + ctx = mp.get_context('spawn') + p = ctx.Process(target=single_worker, args=(rank, det_lookup, args)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + lookup = {} + for k in range(args.num_gpus): + one_lookup = pickle.load(open('tmp_person_%d.pkl' % k, 'rb')) + os.remove('tmp_person_%d.pkl' % k) + for key in one_lookup: + lookup[key] = one_lookup[key] + + with open(args.picklepath, 'wb') as f: + pickle.dump(lookup, f) diff --git a/tools/data/ava_kinetics/merge_annotations.py b/tools/data/ava_kinetics/merge_annotations.py new file mode 100644 index 0000000000..9b5060a8fb --- /dev/null +++ b/tools/data/ava_kinetics/merge_annotations.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import pickle + + +def check_file(path): + if os.path.isfile(path): + return + else: + path = path.split('/') + folder = '/'.join(path[:-1]) + filename = path[-1] + info = '%s not found at %s' % (filename, folder) + raise FileNotFoundError(info) + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics dataset') + root = p.parse_args().avakinetics_root + + kinetics_annot = root + '/kinetics_train.csv' + ava_annot = root + '/annotations/ava_train_v2.2.csv' + + check_file(kinetics_annot) + check_file(ava_annot) + + with open(kinetics_annot) as f: + record = f.readlines() + + with open(ava_annot) as f: + record += f.readlines() + + with open(ava_annot, 'w') as f: + for line in record: + f.write(line) + + kinetics_proposal = root + '/kinetics_proposal.pkl' + ava_proposal = root + '/annotations/' \ + 'ava_dense_proposals_train.FAIR.recall_93.9.pkl' + + check_file(kinetics_proposal) + check_file(ava_proposal) + + lookup = pickle.load(open(kinetics_proposal, 'rb')) + lookup.update(pickle.load(open(ava_proposal, 'rb'))) + + with open(ava_proposal, 'wb') as f: + pickle.dump(lookup, f) diff --git a/tools/data/ava_kinetics/prepare_annotation.py b/tools/data/ava_kinetics/prepare_annotation.py new file mode 100644 index 0000000000..00b7669d49 --- /dev/null +++ b/tools/data/ava_kinetics/prepare_annotation.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import multiprocessing +import os +from collections import defaultdict + +FPS = 30 + + +def get_video_info(frame_folder): + folder_name = frame_folder.split('/')[-1] + filename = folder_name.split('_') + video_id = '_'.join(filename[:-2]) + start = int(filename[-2]) + length = len(os.listdir(frame_folder)) // FPS + return (video_id, start, start + length, folder_name) + + +def get_avaialble_clips(frame_root, num_cpus): + folders = os.listdir(frame_root) + folders = ['%s/%s' % (frame_root, folder) for folder in folders] + pool = multiprocessing.Pool(num_cpus) + outputs = pool.map(get_video_info, folders) + lookup = defaultdict(list) + for record in outputs: + lookup[record[0]].append(record[1:]) + return lookup + + +def filter_train_list(kinetics_anotation_file, lookup): + with open(kinetics_anotation_file) as f: + anotated_frames = [i.split(',') for i in f.readlines()] + anotated_frames = [i for i in anotated_frames if len(i) == 7] + + filtered = [] + for line in anotated_frames: + if line[0] not in lookup: + continue + flag = False + for start, end, video_path in lookup[line[0]]: + if start < float(line[1]) < end: + flag = True + break + if flag is False: + continue + + frame_idx, x1, y1, x2, y2, label = list(map(float, line[1:7])) + frame_idx, label = int(frame_idx), int(label) + + string = (f'{video_path},{frame_idx},' + f'{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f},{label},-1\n') + + filtered.append(string) + return filtered + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument( + '--avakinetics_anotation', + type=str, + default='./ava_kinetics_v1_0', + help='the directory to ava-kinetics anotations') + p.add_argument( + '--num_workers', + type=int, + default=-1, + help='number of workers used for multiprocessing') + p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics videos') + args = p.parse_args() + + if args.num_workers > 0: + num_workers = args.num_workers + else: + num_workers = max(multiprocessing.cpu_count() - 1, 1) + + frame_root = args.avakinetics_root + '/rawframes/' + frame_root = os.path.abspath(frame_root) + lookup = get_avaialble_clips(frame_root, num_workers) + + kinetics_train = args.avakinetics_anotation + '/kinetics_train_v1.0.csv' + filtered_list = filter_train_list(kinetics_train, lookup) + + with open('%s/kinetics_train.csv' % args.avakinetics_root, 'w') as f: + for line in filtered_list: + f.write(line) diff --git a/tools/data/ava_kinetics/softlink_ava.py b/tools/data/ava_kinetics/softlink_ava.py new file mode 100644 index 0000000000..18ca0688c3 --- /dev/null +++ b/tools/data/ava_kinetics/softlink_ava.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +p = argparse.ArgumentParser() +p.add_argument( + '--ava_root', + type=str, + default='../../../data/ava', + help='the path to save ava dataset') +p.add_argument( + '--avakinetics_root', + type=str, + default='../../../data/ava_kinetics', + help='the path to save ava-kinetics dataset') +args = p.parse_args() + +ava_frames = os.path.abspath(args.ava_root + '/rawframes/') +kinetics_frames = os.path.abspath(args.avakinetics_root + '/rawframes/') + +ava_folders = os.listdir(ava_frames) +for folder in ava_folders: + cmd = 'ln -s %s/%s %s/%s' % (ava_frames, folder, kinetics_frames, folder) + os.system(cmd)