Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] HVU Training #235

Merged
merged 38 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7a3d259
add hvu video dataset
Oct 1, 2020
5bb55ae
add aux_info interface
Oct 2, 2020
b9c9039
add HVU Loss
Oct 3, 2020
abf64c5
update arg name
Oct 3, 2020
24d3f9a
fix bug
Oct 3, 2020
e00b5c6
fix bug
Oct 3, 2020
c28a001
put HVU label loading in loading.py
Oct 4, 2020
b1127d8
update HVU Dataset, so that it accept both video and rawframe annotat…
Oct 4, 2020
4a72546
fix bug
Oct 4, 2020
815e5f1
update
Oct 4, 2020
828e1c1
fix bug
Oct 5, 2020
65c19b3
fix bug
Oct 6, 2020
31c7df5
fix bug
Oct 6, 2020
48c7ac4
fix bug
Oct 6, 2020
bff20c0
fix bug in HVU Loss
Oct 7, 2020
ee87824
resolve comments
Oct 8, 2020
e4d8948
resolve comments
Oct 10, 2020
82c3115
add option: `reduction = "sum"`
Oct 10, 2020
f7a3a20
bug fix
Oct 10, 2020
cae8439
deal with multi class label in a lazy style
Oct 10, 2020
d36ba78
bug fix
Oct 10, 2020
65dbc19
Merge branch 'master' into hvu_training
kennymckormick Oct 10, 2020
0003de8
fix evaluation bug
Oct 10, 2020
41d44b7
fix testing
Oct 10, 2020
bab5921
add testing for HVULoss and LoadHVULabel
Oct 11, 2020
61d1775
fix testing bug
Oct 11, 2020
4e64dd0
bug fix
Oct 11, 2020
b1bbffb
add testing for HVU Dataset
Oct 13, 2020
aecceeb
add mean_average_precision, rename the old one as mmit_mean_average_p…
Oct 13, 2020
037c0c6
fix hvu dataset testing
Oct 13, 2020
a72a2c7
fix testing bug
Oct 13, 2020
c5f00a8
resolve comments
Oct 16, 2020
05575c1
update changelog
Oct 16, 2020
eb09070
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 19, 2020
332bad7
Merge branch 'master' into hvu_training
Oct 19, 2020
0d67e97
update changelog
Oct 19, 2020
86ac35d
Merge branch 'master' into hvu_training
kennymckormick Oct 21, 2020
6c6eda2
fix linting
Oct 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmaction/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .base import BaseDataset
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .hvu_dataset import HVUDataset
from .rawframe_dataset import RawframeDataset
from .ssn_dataset import SSNDataset
from .video_dataset import VideoDataset

__all__ = [
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset'
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset',
'HVUDataset'
]
4 changes: 2 additions & 2 deletions mmaction/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def load_json_annotations(self):
num_videos = len(video_infos)
path_key = 'frame_dir' if 'frame_dir' in video_infos[0] else 'filename'
for i in range(num_videos):
path_value = video_infos[i][path_key]
if self.data_prefix is not None:
path_value = video_infos[i][path_key]
path_value = osp.join(self.data_prefix, path_value)
video_infos[i][path_key] = path_value
video_infos[i][path_key] = path_value
if self.multi_class:
assert self.num_classes is not None
onehot = torch.zeros(self.num_classes)
Expand Down
186 changes: 186 additions & 0 deletions mmaction/datasets/hvu_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import os.path as osp

import mmcv
import numpy as np
from mmcv.utils import print_log

from ..core import mean_average_precision
from .base import BaseDataset
from .registry import DATASETS


@DATASETS.register_module()
class HVUDataset(BaseDataset):
"""HVU dataset, which support the recognition tags of multiple categories.
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
Accept both video annotation files or rawframe annotation files.

The dataset loads videos or raw frames and apply specified transforms to
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
return a dict containing the frame tensors and other information.

The ann_file is a json file with multiple dictionaries, and each dictionary
indicates a sample video with the filename and tags, the tags are organized
as different categories. Example of a video dictionary:

.. code-block:: txt

{
'filename': 'gD_G1b0wV5I_001015_001035.mp4',
'label': {
'concept': [250, 131, 42, 51, 57, 155, 122],
'object': [1570, 508],
'event': [16],
'action': [180],
'scene': [206]
}
}

Example of a rawframe dictionary:

.. code-block:: txt

{
'frame_dir': 'gD_G1b0wV5I_001015_001035',
'total_frames': 61
'label': {
'concept': [250, 131, 42, 51, 57, 155, 122],
'object': [1570, 508],
'event': [16],
'action': [180],
'scene': [206]
}
}


Args:
ann_file (str): Path to the annotation file, should be a json file.
pipeline (list[dict | callable]): A sequence of data transforms.
tag_categories (list[str]): List of category names of tags.
tag_category_nums (list[int]): List of number of tags in each category.
**kwargs: Keyword arguments for ``BaseDataset``.
"""
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
ann_file,
pipeline,
tag_categories,
tag_category_nums,
filename_tmpl=None,
**kwargs):
assert len(tag_categories) == len(tag_category_nums)
self.tag_categories = tag_categories
self.tag_category_nums = tag_category_nums
self.filename_tmpl = filename_tmpl
self.num_categories = len(self.tag_categories)
self.num_tags = sum(self.tag_category_nums)
self.category2num = {
k: v
for k, v in zip(tag_categories, tag_category_nums)
}
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
self.start_idx = [0]
for i in range(self.num_categories - 1):
self.start_idx.append(self.start_idx[-1] +
self.tag_category_nums[i])
self.category2startidx = {
k: v
for k, v in zip(tag_categories, self.start_idx)
}
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
self.start_index = 0
if 'start_index' in kwargs:
self.start_index = kwargs.pop('start_index')
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
self.dataset_type = None
super().__init__(
ann_file, pipeline, start_index=self.start_index, **kwargs)

def load_annotations(self):
"""Load annotation file to get video information."""
assert self.ann_file.endswith('.json')
return self.load_json_annotations()

def load_json_annotations(self):
video_infos = mmcv.load(self.ann_file)
num_videos = len(video_infos)

video_info0 = video_infos[0]
assert ('filename' in video_info0) != ('frame_dir' in video_info0)
path_key = 'filename' if 'filename' in video_info0 else 'frame_dir'
self.dataset_type = 'video' if path_key == 'filename' else 'rawframe'
if self.dataset_type == 'rawframe':
assert self.filename_tmpl is not None

for i in range(num_videos):
path_value = video_infos[i][path_key]
if self.data_prefix is not None:
path_value = osp.join(self.data_prefix, path_value)
video_infos[i][path_key] = path_value

# We will convert label to torch tensors in the pipeline
video_infos[i]['categories'] = self.tag_categories
video_infos[i]['category_nums'] = self.tag_category_nums
if self.dataset_type == 'rawframe':
video_infos[i]['filename_tmpl'] = self.filename_tmpl
video_infos[i]['start_index'] = self.start_index
video_infos[i]['modality'] = self.modality

return video_infos

def evaluate(self, results, metrics='mean_average_precision', logger=None):
"""Evaluation in HVU Video Dataset. We only support evaluating mAP for
each tag categories. Since some tag categories are missing for some
videos, we can not evaluate mAP for all tags.

Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'mean_average_precision'.
logger (logging.Logger | None): Logger for recording.
Default: None.

Return:
dict: Evaluation results dict.
"""
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')

metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]

# There should be only one metric in the metrics list:
# 'mean_average_precision'
assert len(metrics) == 1
metric = metrics[0]
assert metric == 'mean_average_precision'

gt_labels = [ann['label'] for ann in self.video_infos]

eval_results = {}
for i, category in enumerate(self.tag_categories):

start_idx = self.category2startidx[category]
num = self.category2num[category]
preds = [
result[start_idx:start_idx + num]
for video_idx, result in enumerate(results)
if category in gt_labels[video_idx]
]
gts = [
gt_label[category]
for video_idx, gt_label in enumerate(gt_labels)
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
if category in gt_label
]

# convert label list to ndarray
def label2array(label):
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
arr = np.zeros(num, dtype=np.float32)
arr[label] = 1.
return arr

gts = [label2array(item) for item in gts]

mAP = mean_average_precision(preds, gts)
eval_results[f'{category}_mAP'] = mAP
log_msg = f'\n{category}_mAP\t{mAP:.4f}'
print_log(log_msg, logger=logger)

return eval_results
58 changes: 58 additions & 0 deletions mmaction/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,71 @@

import mmcv
import numpy as np
import torch
from mmcv.fileio import FileClient
from torch.nn.modules.utils import _pair

from ...utils import get_random_string, get_shm_dir, get_thread_id
from ..registry import PIPELINES


@PIPELINES.register_module()
class LoadHVULabel(object):
"""Convert the HVU label from dictionaries to torch tensors.

Required keys are "label", "categories", "category_nums", added or modified
keys are "label", "mask" and "category_mask".
"""

def __init__(self, **kwargs):
self.hvu_initialized = False

def init_hvu_info(self, categories, category_nums):
assert len(categories) == len(category_nums)
self.categories = categories
self.category_nums = category_nums
self.num_categories = len(self.categories)
self.num_tags = sum(self.category_nums)
self.category2num = {k: v for k, v in zip(categories, category_nums)}
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
self.start_idx = [0]
for i in range(self.num_categories - 1):
self.start_idx.append(self.start_idx[-1] + self.category_nums[i])
self.category2startidx = {
k: v
for k, v in zip(categories, self.start_idx)
}
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
self.hvu_initialized = True

def __call__(self, results):
"""Convert the label dictionary to 3 tensors: "label", "mask" and
"category_mask".

Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""

if not self.hvu_initialized:
self.init_hvu_info(results['categories'], results['category_nums'])

onehot = torch.zeros(self.num_tags)
onehot_mask = torch.zeros(self.num_tags)
category_mask = torch.zeros(self.num_categories)

for category, tags in results['label'].items():
category_mask[self.categories.index(category)] = 1.
start_idx = self.category2startidx[category]
category_num = self.category2num[category]
tags = [idx + start_idx for idx in tags]
onehot[tags] = 1.
onehot_mask[start_idx:category_num] = 1.

results['label'] = onehot
results['mask'] = onehot_mask
results['category_mask'] = category_mask
return results


@PIPELINES.register_module()
class SampleFrames(object):
"""Sample frames from the video.
Expand Down
14 changes: 8 additions & 6 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .heads import BaseHead, I3DHead, SlowFastHead, TPNHead, TSMHead, TSNHead
from .localizers import BMN, PEM, TEM
from .losses import (BCELossWithLogits, BinaryLogisticRegressionLoss, BMNLoss,
CrossEntropyLoss, NLLLoss, OHEMHingeLoss, SSNLoss)
CrossEntropyLoss, HVULoss, NLLLoss, OHEMHingeLoss,
SSNLoss)
from .necks import TPN
from .recognizers import BaseRecognizer, recognizer2d, recognizer3d
from .registry import BACKBONES, HEADS, LOCALIZERS, LOSSES, RECOGNIZERS
Expand All @@ -16,9 +17,10 @@
'BACKBONES', 'HEADS', 'RECOGNIZERS', 'build_recognizer', 'build_head',
'build_backbone', 'recognizer2d', 'recognizer3d', 'ResNet', 'ResNet3d',
'ResNet2Plus1d', 'I3DHead', 'TSNHead', 'TSMHead', 'BaseHead',
'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', 'NLLLoss', 'ResNetTSM',
'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly',
'BCELossWithLogits', 'LOCALIZERS', 'build_localizer', 'PEM', 'TEM',
'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model',
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN', 'TPN', 'TPNHead'
'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', 'NLLLoss', 'HVULoss',
'ResNetTSM', 'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d',
'ResNet3dSlowOnly', 'BCELossWithLogits', 'LOCALIZERS', 'build_localizer',
'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN',
'TPN', 'TPNHead'
]
15 changes: 12 additions & 3 deletions mmaction/models/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ def forward(self, x):
"""Defines the computation performed at every call."""
pass

def loss(self, cls_score, labels):
"""Calculate the loss given output ``cls_score`` and target ``labels``.
def loss(self, cls_score, labels, aux_info):
"""Calculate the loss given output ``cls_score``, target ``labels`` and
auxiliary tensors ``aux_info``.

Args:
cls_score (torch.Tensor): The output of the model.
labels (torch.Tensor): The target output of the model.
aux_info (dict[torch.Tensor]): The auxiliary tensors used to
calculate loss.

Returns:
dict: A dict containing field 'loss_cls'(mandatory)
Expand All @@ -94,5 +97,11 @@ def loss(self, cls_score, labels):
labels = ((1 - self.label_smooth_eps) * labels +
self.label_smooth_eps / self.num_classes)

losses['loss_cls'] = self.loss_cls(cls_score, labels)
loss_cls = self.loss_cls(cls_score, labels, aux_info)
# loss_cls may be dictionary or single tensor
if type(loss_cls) is dict:
losses.update(loss_cls)
else:
losses['loss_cls'] = loss_cls

return losses
4 changes: 3 additions & 1 deletion mmaction/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .binary_logistic_regression_loss import BinaryLogisticRegressionLoss
from .bmn_loss import BMNLoss
from .cross_entropy_loss import BCELossWithLogits, CrossEntropyLoss
from .hvu_loss import HVULoss
from .nll_loss import NLLLoss
from .ohem_hinge_loss import OHEMHingeLoss
from .ssn_loss import SSNLoss

__all__ = [
'BaseWeightedLoss', 'CrossEntropyLoss', 'NLLLoss', 'BCELossWithLogits',
'BinaryLogisticRegressionLoss', 'BMNLoss', 'OHEMHingeLoss', 'SSNLoss'
'BinaryLogisticRegressionLoss', 'BMNLoss', 'OHEMHingeLoss', 'SSNLoss',
'HVULoss'
]
9 changes: 8 additions & 1 deletion mmaction/models/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@ def forward(self, *args, **kwargs):
Returns:
torch.Tensor: The calculated loss.
"""
return self._forward(*args, **kwargs) * self.loss_weight
ret = self._forward(*args, **kwargs)
if type(ret) is dict:
for k in ret:
if 'loss' in k:
ret[k] *= self.loss_weight
else:
ret *= self.loss_weight
return ret
Loading