Skip to content

Commit

Permalink
Add SSN model. (#55)
Browse files Browse the repository at this point in the history
Co-authored-by: lizz <innerlee@users.noreply.github.com>
  • Loading branch information
JackyTown and innerlee committed Aug 31, 2020
1 parent f183c1c commit 1ae0122
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 12 deletions.
5 changes: 4 additions & 1 deletion mmaction/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .base import BaseHead
from .i3d_head import I3DHead
from .slowfast_head import SlowFastHead
from .ssn_head import SSNHead
from .tsm_head import TSMHead
from .tsn_head import TSNHead

__all__ = ['TSNHead', 'I3DHead', 'BaseHead', 'TSMHead', 'SlowFastHead']
__all__ = [
'TSNHead', 'I3DHead', 'BaseHead', 'TSMHead', 'SlowFastHead', 'SSNHead'
]
413 changes: 413 additions & 0 deletions mmaction/models/heads/ssn_head.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mmaction/models/localizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseLocalizer
from .bmn import BMN
from .bsn import PEM, TEM
from .ssn import SSN

__all__ = ['PEM', 'TEM', 'BMN', 'BaseLocalizer']
__all__ = ['PEM', 'TEM', 'BMN', 'SSN', 'BaseLocalizer']
32 changes: 25 additions & 7 deletions mmaction/models/localizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch.distributed as dist
import torch.nn as nn

from .. import builder


class BaseLocalizer(nn.Module, metaclass=ABCMeta):
"""Base class for localizers.
Expand All @@ -14,12 +16,30 @@ class BaseLocalizer(nn.Module, metaclass=ABCMeta):
Methods:``forward_test``, supporting to forward when testing.
"""

def __init__(self):
def __init__(self, backbone, cls_head, train_cfg=None, test_cfg=None):
super().__init__()
self.backbone = builder.build_backbone(backbone)
self.cls_head = builder.build_head(cls_head)

self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights()

def init_weights(self):
"""Weight initialization for model."""
pass
self.backbone.init_weights()
self.cls_head.init_weights()

def extract_feat(self, imgs):
"""Extract features through a backbone.
Args:
imgs (torch.Tensor): The input images.
Returns:
torch.tensor: The extracted features.
"""
x = self.backbone(imgs)
return x

@abstractmethod
def forward_train(self, imgs, labels):
Expand All @@ -31,14 +51,12 @@ def forward_test(self, imgs):
"""Defines the computation performed at testing."""
pass

def forward(self, imgs, label=None, return_loss=True):
def forward(self, imgs, return_loss=True, **kwargs):
"""Define the computation performed at every call."""
if return_loss:
if label is None:
raise ValueError('Label should not be None.')
return self.forward_train(imgs, label)
return self.forward_train(imgs, **kwargs)
else:
return self.forward_test(imgs)
return self.forward_test(imgs, **kwargs)

@staticmethod
def _parse_losses(losses):
Expand Down
2 changes: 1 addition & 1 deletion mmaction/models/localizers/bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
hidden_dim_1d=256,
hidden_dim_2d=128,
hidden_dim_3d=512):
super().__init__()
super(BaseLocalizer, self).__init__()

self.tscale = temporal_dim
self.boundary_ratio = boundary_ratio
Expand Down
5 changes: 3 additions & 2 deletions mmaction/models/localizers/bsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(self,
conv1_ratio=1,
conv2_ratio=1,
conv3_ratio=0.01):
super().__init__()
super(BaseLocalizer, self).__init__()

self.temporal_dim = temporal_dim
self.boundary_ratio = boundary_ratio
self.feat_dim = tem_feat_dim
Expand Down Expand Up @@ -264,7 +265,7 @@ def __init__(self,
fc1_ratio=0.1,
fc2_ratio=0.1,
output_dim=1):
super().__init__()
super(BaseLocalizer, self).__init__()

self.feat_dim = pem_feat_dim
self.hidden_dim = pem_hidden_dim
Expand Down
129 changes: 129 additions & 0 deletions mmaction/models/localizers/ssn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn

from .. import builder
from ..registry import LOCALIZERS
from .base import BaseLocalizer


@LOCALIZERS.register_module()
class SSN(BaseLocalizer):
"""Temporal Action Detection with Structured Segment Networks.
Args:
backbone (dict): Config for building backbone.
cls_head (dict): Config for building classification head.
in_channels (int): Number of channels for input data.
Default: 3.
spatial_type (str): Type of spatial pooling.
Default: 'avg'.
dropout_ratio (float): Ratio of dropout.
Default: 0.5.
loss_cls (dict): Config for building loss.
Default: ``dict(type='SSNLoss')``.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
"""

def __init__(self,
backbone,
cls_head,
in_channels=3,
spatial_type='avg',
dropout_ratio=0.5,
loss_cls=dict(type='SSNLoss'),
train_cfg=None,
test_cfg=None):

super().__init__(backbone, cls_head, train_cfg, test_cfg)

self.is_test_prepared = False
self.in_channels = in_channels

self.spatial_type = spatial_type
if self.spatial_type == 'avg':
self.pool = nn.AvgPool2d((7, 7), stride=1, padding=0)
elif self.spatial_type == 'max':
self.pool = nn.MaxPool2d((7, 7), stride=1, padding=0)
else:
self.pool = None

self.dropout_ratio = dropout_ratio
if self.dropout_ratio != 0:
self.dropout = nn.Dropout(p=self.dropout_ratio)
else:
self.dropout = None
self.loss_cls = builder.build_loss(loss_cls)

def forward_train(self, imgs, proposal_scale_factor, proposal_type,
proposal_labels, reg_targets, **kwargs):
"""Define the computation performed at every call when training."""
imgs = imgs.reshape((-1, self.in_channels) + imgs.shape[4:])

x = self.extract_feat(imgs)

if self.pool:
x = self.pool(x)
if self.dropout is not None:
x = self.dropout(x)

activity_scores, completeness_scores, bbox_preds = self.cls_head(
(x, proposal_scale_factor))

loss = self.loss_cls(activity_scores, completeness_scores, bbox_preds,
proposal_type, proposal_labels, reg_targets,
self.train_cfg)
loss_dict = dict(**loss)

return loss_dict

def forward_test(self, imgs, relative_proposal_list, scale_factor_list,
proposal_tick_list, reg_norm_consts, **kwargs):
"""Define the computation performed at every call when testing."""
num_crops = imgs.shape[0]
imgs = imgs.reshape((num_crops, -1, self.in_channels) + imgs.shape[3:])
num_ticks = imgs.shape[1]

output = []
minibatch_size = self.test_cfg.ssn.sampler.batch_size
for idx in range(0, num_ticks, minibatch_size):
chunk = imgs[:, idx:idx +
minibatch_size, :, :, :].view((-1, ) + imgs.shape[2:])
x = self.extract_feat(chunk)
if self.pool:
x = self.pool(x)
# Merge crop to save memory.
x = x.reshape((num_crops, x.size(0) // num_crops, -1)).mean(dim=0)
output.append(x)
output = torch.cat(output, dim=0)

relative_proposal_list = relative_proposal_list.squeeze(0)
proposal_tick_list = proposal_tick_list.squeeze(0)
scale_factor_list = scale_factor_list.squeeze(0)
reg_norm_consts = reg_norm_consts.squeeze(0)

if not self.is_test_prepared:
self.is_test_prepared = self.cls_head.prepare_test_fc(
self.cls_head.consensus.num_multipliers)

(output, activity_scores, completeness_scores,
bbox_preds) = self.cls_head(
(output, proposal_tick_list, scale_factor_list), test_mode=True)

if bbox_preds is not None:
bbox_preds = bbox_preds.view(-1, self.cls_head.num_classes, 2)
bbox_preds[:, :, 0] = (
bbox_preds[:, :, 0] * reg_norm_consts[1, 0] +
reg_norm_consts[0, 0])
bbox_preds[:, :, 1] = (
bbox_preds[:, :, 1] * reg_norm_consts[1, 1] +
reg_norm_consts[0, 1])

return (relative_proposal_list.cpu().numpy(),
activity_scores.cpu().numpy(),
completeness_scores.cpu().numpy(),
bbox_preds.cpu().numpy())
else:
return (relative_proposal_list.cpu().numpy(),
activity_scores.cpu().numpy(),
completeness_scores.cpu().numpy(), None)
Loading

0 comments on commit 1ae0122

Please sign in to comment.