Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] HVU Training #235

Merged
merged 38 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7a3d259
add hvu video dataset
Oct 1, 2020
5bb55ae
add aux_info interface
Oct 2, 2020
b9c9039
add HVU Loss
Oct 3, 2020
abf64c5
update arg name
Oct 3, 2020
24d3f9a
fix bug
Oct 3, 2020
e00b5c6
fix bug
Oct 3, 2020
c28a001
put HVU label loading in loading.py
Oct 4, 2020
b1127d8
update HVU Dataset, so that it accept both video and rawframe annotat…
Oct 4, 2020
4a72546
fix bug
Oct 4, 2020
815e5f1
update
Oct 4, 2020
828e1c1
fix bug
Oct 5, 2020
65c19b3
fix bug
Oct 6, 2020
31c7df5
fix bug
Oct 6, 2020
48c7ac4
fix bug
Oct 6, 2020
bff20c0
fix bug in HVU Loss
Oct 7, 2020
ee87824
resolve comments
Oct 8, 2020
e4d8948
resolve comments
Oct 10, 2020
82c3115
add option: `reduction = "sum"`
Oct 10, 2020
f7a3a20
bug fix
Oct 10, 2020
cae8439
deal with multi class label in a lazy style
Oct 10, 2020
d36ba78
bug fix
Oct 10, 2020
65dbc19
Merge branch 'master' into hvu_training
kennymckormick Oct 10, 2020
0003de8
fix evaluation bug
Oct 10, 2020
41d44b7
fix testing
Oct 10, 2020
bab5921
add testing for HVULoss and LoadHVULabel
Oct 11, 2020
61d1775
fix testing bug
Oct 11, 2020
4e64dd0
bug fix
Oct 11, 2020
b1bbffb
add testing for HVU Dataset
Oct 13, 2020
aecceeb
add mean_average_precision, rename the old one as mmit_mean_average_p…
Oct 13, 2020
037c0c6
fix hvu dataset testing
Oct 13, 2020
a72a2c7
fix testing bug
Oct 13, 2020
c5f00a8
resolve comments
Oct 16, 2020
05575c1
update changelog
Oct 16, 2020
eb09070
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 19, 2020
332bad7
Merge branch 'master' into hvu_training
Oct 19, 2020
0d67e97
update changelog
Oct 19, 2020
86ac35d
Merge branch 'master' into hvu_training
kennymckormick Oct 21, 2020
6c6eda2
fix linting
Oct 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/recognition/tsn/tsn_r101_1x1x5_50e_mmit_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
checkpoint_config = dict(interval=5)
evaluation = dict(interval=5, metrics=['mean_average_precision'])
evaluation = dict(interval=5, metrics=['mmit_mean_average_precision'])
# yapf:disable
log_config = dict(
interval=20,
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

**Improvements**
- Set default values of 'average_clips' in each config file so that there is no need to set it explicitly during testing in most cases ([#232](https://github.com/open-mmlab/mmaction2/pull/232))
- Extend HVU datatools to generate individual file list for each tag category ([#258](https://github.com/open-mmlab/mmaction2/pull/258))
- Support data preparation for Kinetics-600 and Kinetics-700 ([#254](https://github.com/open-mmlab/mmaction2/pull/254))
- Add `cfg-options` in arguments to override some settings in the used config for convenience ([#212](https://github.com/open-mmlab/mmaction2/pull/212))

Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-

Optional arguments:
- `RESULT_FILE`: Filename of the output results. If not specified, the results will not be saved to a file.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `top_k_accuracy`, `mean_class_accuracy` are available for all datasets in recognition, `mean_average_precision` for Multi-Moments in Time, `AR@AN` for ActivityNet, etc.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `top_k_accuracy`, `mean_class_accuracy` are available for all datasets in recognition, `mmit_mean_average_precision` for Multi-Moments in Time, `mean_average_precision` for Multi-Moments in Time and HVU single category. `AR@AN` for ActivityNet, etc.
- `--gpu-collect`: If specified, recognition results will be collected using gpu communication. Otherwise, it will save the results on different gpus to `TMPDIR` and collect them by the rank 0 worker.
- `TMPDIR`: Temporary directory used for collecting results from multiple workers, available when `--gpu-collect` is not specified.
- `OPTIONS`: Custom options used for evaluation. Allowed values depend on the arguments of the `evaluate` function in dataset.
Expand Down
5 changes: 3 additions & 2 deletions mmaction/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
average_recall_at_avg_proposals, confusion_matrix,
get_weighted_score, interpolated_precision_recall,
mean_average_precision, mean_class_accuracy,
pairwise_temporal_iou, softmax, top_k_accuracy)
mmit_mean_average_precision, pairwise_temporal_iou,
softmax, top_k_accuracy)
from .eval_detection import ActivityNetDetection
from .eval_hooks import DistEvalHook, EvalHook

Expand All @@ -11,5 +12,5 @@
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou',
'average_precision_at_temporal_iou', 'ActivityNetDetection', 'softmax',
'interpolated_precision_recall'
'interpolated_precision_recall', 'mmit_mean_average_precision'
]
33 changes: 31 additions & 2 deletions mmaction/core/evaluation/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,46 @@ def top_k_accuracy(scores, labels, topk=(1, )):
return res


def mmit_mean_average_precision(scores, labels):
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
"""Mean average precision for multi-label recognition. Used for reporting
MMIT style mAP on Multi-Moments in Times. The difference is that this
method calculates average-precision for each sample and averages them among
samples.

Args:
scores (list[np.ndarray]): Prediction scores of different classes for
each sample.
labels (list[np.ndarray]): Ground truth many-hot vector for each
sample.

Returns:
np.float: The MMIT style mean average precision.
"""
results = []
for i in range(len(scores)):
precision, recall, _ = binary_precision_recall_curve(
scores[i], labels[i])
ap = -np.sum(np.diff(recall) * np.array(precision)[:-1])
results.append(ap)
return np.mean(results)


def mean_average_precision(scores, labels):
"""Mean average precision for multi-label recognition.

Args:
scores (list[np.ndarray]): Prediction scores for each class.
labels (list[np.ndarray]): Ground truth many-hot vector.
scores (list[np.ndarray]): Prediction scores of different classes for
each sample.
labels (list[np.ndarray]): Ground truth many-hot vector for each
sample.

Returns:
np.float: The mean average precision.
"""
results = []
scores = np.stack(scores).T
labels = np.stack(labels).T

for i in range(len(scores)):
precision, recall, _ = binary_precision_recall_curve(
scores[i], labels[i])
Expand Down
16 changes: 9 additions & 7 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class EvalHook(Hook):
checkpoint during evaluation when ``save_best`` is set to True.
Options are the evaluation metrics to the test dataset. e.g.,
``top1_acc``, ``top5_acc``, ``mean_class_accuracy``,
``mean_average_precision`` for action recognition dataset
(RawframeDataset and VideoDataset). ``AR@AN``, ``auc`` for action
localization dataset (ActivityNetDataset). Default: `top1_acc`.
``mean_average_precision``, ``mmit_mean_average_precision``
for action recognition dataset (RawframeDataset and VideoDataset).
``AR@AN``, ``auc`` for action localization dataset
(ActivityNetDataset). Default: `top1_acc`.
rule (str | None): Comparison rule for best score. If set to None,
it will infer a reasonable rule. Default: 'None'.
eval_kwargs (dict, optional): Arguments for evaluation.
Expand Down Expand Up @@ -144,10 +145,11 @@ class DistEvalHook(EvalHook):
key_indicator (str | None): Key indicator to measure the best
checkpoint during evaluation when ``save_best`` is set to True.
Options are the evaluation metrics to the test dataset. e.g.,
``top1_acc``, ``top5_acc``, ``mean_class_accuracy``,
``mean_average_precision`` for action recognition dataset
(RawframeDataset and VideoDataset). ``AR@AN``, ``auc`` for action
localization dataset (ActivityNetDataset). Default: `top1_acc`.
``top1_acc``, ``top5_acc``, ``mean_class_accuracy``,
``mean_average_precision``, ``mmit_mean_average_precision``
for action recognition dataset (RawframeDataset and VideoDataset).
``AR@AN``, ``auc`` for action localization dataset
(ActivityNetDataset). Default: `top1_acc`.
rule (str | None): Comparison rule for best score. If set to None,
it will infer a reasonable rule. Default: 'None'.
eval_kwargs (dict, optional): Arguments for evaluation.
Expand Down
4 changes: 3 additions & 1 deletion mmaction/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .base import BaseDataset
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .hvu_dataset import HVUDataset
from .rawframe_dataset import RawframeDataset
from .ssn_dataset import SSNDataset
from .video_dataset import VideoDataset

__all__ = [
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset'
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset',
'HVUDataset'
]
23 changes: 18 additions & 5 deletions mmaction/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,12 @@ def load_json_annotations(self):
num_videos = len(video_infos)
path_key = 'frame_dir' if 'frame_dir' in video_infos[0] else 'filename'
for i in range(num_videos):
path_value = video_infos[i][path_key]
if self.data_prefix is not None:
path_value = video_infos[i][path_key]
path_value = osp.join(self.data_prefix, path_value)
video_infos[i][path_key] = path_value
video_infos[i][path_key] = path_value
if self.multi_class:
assert self.num_classes is not None
onehot = torch.zeros(self.num_classes)
onehot[video_infos[i]['label']] = 1.
video_infos[i]['label'] = onehot
else:
assert len(video_infos[i]['label']) == 1
video_infos[i]['label'] = video_infos[i]['label'][0]
Expand Down Expand Up @@ -111,13 +108,29 @@ def prepare_train_frames(self, idx):
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index

# prepare tensor in getitem
# If HVU, type(results['label']) is dict
if self.multi_class and type(results['label']) is list:
onehot = torch.zeros(self.num_classes)
onehot[results['label']] = 1.
results['label'] = onehot

return self.pipeline(results)

def prepare_test_frames(self, idx):
"""Prepare the frames for testing given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index

# prepare tensor in getitem
# If HVU, type(results['label']) is dict
if self.multi_class and type(results['label']) is list:
onehot = torch.zeros(self.num_classes)
onehot[results['label']] = 1.
results['label'] = onehot

return self.pipeline(results)

def __len__(self):
Expand Down
179 changes: 179 additions & 0 deletions mmaction/datasets/hvu_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os.path as osp

import mmcv
import numpy as np
from mmcv.utils import print_log

from ..core import mean_average_precision
from .base import BaseDataset
from .registry import DATASETS


@DATASETS.register_module()
class HVUDataset(BaseDataset):
"""HVU dataset, which supports the recognition tags of multiple categories.
Accept both video annotation files or rawframe annotation files.

The dataset loads videos or raw frames and applies specified transforms to
return a dict containing the frame tensors and other information.

The ann_file is a json file with multiple dictionaries, and each dictionary
indicates a sample video with the filename and tags, the tags are organized
as different categories. Example of a video dictionary:

.. code-block:: txt

{
'filename': 'gD_G1b0wV5I_001015_001035.mp4',
'label': {
'concept': [250, 131, 42, 51, 57, 155, 122],
'object': [1570, 508],
'event': [16],
'action': [180],
'scene': [206]
}
}

Example of a rawframe dictionary:

.. code-block:: txt

{
'frame_dir': 'gD_G1b0wV5I_001015_001035',
'total_frames': 61
'label': {
'concept': [250, 131, 42, 51, 57, 155, 122],
'object': [1570, 508],
'event': [16],
'action': [180],
'scene': [206]
}
}


Args:
ann_file (str): Path to the annotation file, should be a json file.
pipeline (list[dict | callable]): A sequence of data transforms.
tag_categories (list[str]): List of category names of tags.
tag_category_nums (list[int]): List of number of tags in each category.
filename_tmpl: Template for each filename. `filename_tmpl is None`
indicates video dataset is used. Default: None.
**kwargs: Keyword arguments for ``BaseDataset``.
"""
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
ann_file,
pipeline,
tag_categories,
tag_category_nums,
filename_tmpl=None,
**kwargs):
assert len(tag_categories) == len(tag_category_nums)
self.tag_categories = tag_categories
self.tag_category_nums = tag_category_nums
self.filename_tmpl = filename_tmpl
self.num_categories = len(self.tag_categories)
self.num_tags = sum(self.tag_category_nums)
self.category2num = dict(zip(tag_categories, tag_category_nums))
self.start_idx = [0]
for i in range(self.num_categories - 1):
self.start_idx.append(self.start_idx[-1] +
self.tag_category_nums[i])
self.category2startidx = dict(zip(tag_categories, self.start_idx))
self.start_index = kwargs.pop('start_index', 0)
self.dataset_type = None
super().__init__(
ann_file, pipeline, start_index=self.start_index, **kwargs)

def load_annotations(self):
"""Load annotation file to get video information."""
assert self.ann_file.endswith('.json')
return self.load_json_annotations()

def load_json_annotations(self):
video_infos = mmcv.load(self.ann_file)
num_videos = len(video_infos)

video_info0 = video_infos[0]
assert ('filename' in video_info0) != ('frame_dir' in video_info0)
path_key = 'filename' if 'filename' in video_info0 else 'frame_dir'
self.dataset_type = 'video' if path_key == 'filename' else 'rawframe'
if self.dataset_type == 'rawframe':
assert self.filename_tmpl is not None

for i in range(num_videos):
path_value = video_infos[i][path_key]
if self.data_prefix is not None:
path_value = osp.join(self.data_prefix, path_value)
video_infos[i][path_key] = path_value

# We will convert label to torch tensors in the pipeline
video_infos[i]['categories'] = self.tag_categories
video_infos[i]['category_nums'] = self.tag_category_nums
if self.dataset_type == 'rawframe':
video_infos[i]['filename_tmpl'] = self.filename_tmpl
video_infos[i]['start_index'] = self.start_index
video_infos[i]['modality'] = self.modality

return video_infos

@staticmethod
def label2array(num, label):
arr = np.zeros(num, dtype=np.float32)
arr[label] = 1.
return arr

def evaluate(self, results, metrics='mean_average_precision', logger=None):
"""Evaluation in HVU Video Dataset. We only support evaluating mAP for
each tag categories. Since some tag categories are missing for some
videos, we can not evaluate mAP for all tags.

Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'mean_average_precision'.
logger (logging.Logger | None): Logger for recording.
Default: None.

Return:
dict: Evaluation results dict.
"""
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')

metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]

# There should be only one metric in the metrics list:
# 'mean_average_precision'
assert len(metrics) == 1
metric = metrics[0]
assert metric == 'mean_average_precision'

gt_labels = [ann['label'] for ann in self.video_infos]

eval_results = {}
for i, category in enumerate(self.tag_categories):

start_idx = self.category2startidx[category]
num = self.category2num[category]
preds = [
result[start_idx:start_idx + num]
for video_idx, result in enumerate(results)
if category in gt_labels[video_idx]
]
gts = [
gt_label[category] for gt_label in gt_labels
if category in gt_label
]

gts = [self.label2array(num, item) for item in gts]

mAP = mean_average_precision(preds, gts)
eval_results[f'{category}_mAP'] = mAP
log_msg = f'\n{category}_mAP\t{mAP:.4f}'
print_log(log_msg, logger=logger)

return eval_results
4 changes: 2 additions & 2 deletions mmaction/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .formating import (Collect, FormatShape, ImageToTensor, ToDataContainer,
ToTensor, Transpose)
from .loading import (DecordDecode, DecordInit, DenseSampleFrames,
FrameSelector, GenerateLocalizationLabels,
FrameSelector, GenerateLocalizationLabels, LoadHVULabel,
LoadLocalizationFeature, LoadProposals, OpenCVDecode,
OpenCVInit, PyAVDecode, PyAVInit, RawFrameDecode,
SampleFrames, SampleProposalFrames,
Expand All @@ -21,5 +21,5 @@
'GenerateLocalizationLabels', 'LoadLocalizationFeature', 'LoadProposals',
'DecordInit', 'OpenCVInit', 'PyAVInit', 'SampleProposalFrames',
'UntrimmedSampleFrames', 'RawFrameDecode', 'DecordInit', 'OpenCVInit',
'PyAVInit', 'SampleProposalFrames', 'ColorJitter'
'PyAVInit', 'SampleProposalFrames', 'ColorJitter', 'LoadHVULabel'
]
Loading