Skip to content

Commit

Permalink
Add private func.
Browse files Browse the repository at this point in the history
  • Loading branch information
JackyTown committed Aug 28, 2020
1 parent 0c3757e commit 3cdfd5c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 118 deletions.
220 changes: 108 additions & 112 deletions mmaction/models/heads/ssn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,35 @@ def __init__(self, stpp_stage=(1, (1, 2), 1), num_segments_list=(2, 5, 2)):

self.num_segments_list = num_segments_list

def _extract_stage_feature(self, stage_feat, stage_parts, num_multipliers,
scale_factors, num_samples):
"""Extract stage feature based on structured temporal pyramid pooling.
Args:
stage_feat (torch.Tensor): Stage features to be STPP.
stage_parts (tuple): Config of STPP.
num_multipliers (int): Total number of parts in the stage.
scale_factors (list): Ratios of the effective sampling lengths
to augmented lengths.
num_samples (int): Number of samples.
Returns:
torch.Tensor: Features of the stage.
"""
stage_stpp_feat = []
stage_len = stage_feat.size(1)
for stage_part in stage_parts:
ticks = torch.arange(0, stage_len + 1e-5,
stage_len / stage_part).int()
for i in range(stage_part):
part_feat = stage_feat[:, ticks[i]:ticks[i + 1], :].mean(
dim=1) / num_multipliers
if scale_factors is not None:
part_feat = (
part_feat * scale_factors.view(num_samples, 1))
stage_stpp_feat.append(part_feat)
return stage_stpp_feat

def forward(self, x, scale_factors):
"""Defines the computation performed at every call.
Expand All @@ -73,47 +102,19 @@ def forward(self, x, scale_factors):

scale_factors = scale_factors.view(-1, 2)

def extract_stage_feature(stage_feat, stage_parts, num_multipliers,
scale_factors):
"""Extract stage feature based on structured temporal pyramid
pooling.
Args:
stage_feat (torch.Tensor): Stage features to be STPP.
stage_parts (tuple): Config of STPP.
num_multipliers (int): Total number of parts in the stage.
scale_factors (list): Ratios of the effective sampling lengths
to augmented lengths.
Returns:
torch.Tensor: Features of the stage.
"""
stage_stpp_feat = []
stage_len = stage_feat.size(1)
for stage_part in stage_parts:
ticks = torch.arange(0, stage_len + 1e-5,
stage_len / stage_part).int()
for i in range(stage_part):
part_feat = stage_feat[:, ticks[i]:ticks[i + 1], :].mean(
dim=1) / num_multipliers
if scale_factors is not None:
part_feat = (
part_feat * scale_factors.view(num_samples, 1))
stage_stpp_feat.append(part_feat)
return stage_stpp_feat

stage_stpp_feats = []
stage_stpp_feats.extend(
extract_stage_feature(x[:, :x0, :], self.stpp_stages[0],
self.multiplier_list[0], scale_factors[:,
0]))
self._extract_stage_feature(x[:, :x0, :], self.stpp_stages[0],
self.multiplier_list[0],
scale_factors[:, 0], num_samples))
stage_stpp_feats.extend(
extract_stage_feature(x[:, x0:x1, :], self.stpp_stages[1],
self.multiplier_list[1], None))
self._extract_stage_feature(x[:, x0:x1, :], self.stpp_stages[1],
self.multiplier_list[1], None,
num_samples))
stage_stpp_feats.extend(
extract_stage_feature(x[:, x1:, :], self.stpp_stages[2],
self.multiplier_list[2], scale_factors[:,
1]))
self._extract_stage_feature(x[:, x1:, :], self.stpp_stages[2],
self.multiplier_list[2],
scale_factors[:, 1], num_samples))
stpp_feat = torch.cat(stage_stpp_feats, dim=1)

course_feat = x[:, x0:x1, :].mean(dim=1)
Expand All @@ -125,30 +126,30 @@ class STPPTest(nn.Module):
Args:
num_classes (int): Number of classes to be classified.
with_regression (bool): Whether to perform regression or not.
use_regression (bool): Whether to perform regression or not.
Default: True.
stpp_stage (tuple): Config of structured temporal pyramid pooling.
Default: (1, (1, 2), 1).
"""

def __init__(self,
num_classes,
with_regression=True,
use_regression=True,
stpp_stage=(1, (1, 2), 1)):
super().__init__()

self.activity_score_len = num_classes + 1
self.complete_score_len = num_classes
self.reg_score_len = num_classes * 2
self.with_regression = with_regression
self.use_regression = use_regression

starting_parts, starting_multiplier = parse_stage_config(stpp_stage[0])
course_parts, course_multiplier = parse_stage_config(stpp_stage[1])
ending_parts, ending_multiplier = parse_stage_config(stpp_stage[2])

self.num_multipliers = (
starting_multiplier + course_multiplier + ending_multiplier)
if self.with_regression:
if self.use_regression:
self.feat_dim = (
self.activity_score_len + self.num_multipliers *
(self.complete_score_len + self.reg_score_len))
Expand All @@ -166,6 +167,56 @@ def __init__(self,
self.complete_slice.stop, self.complete_slice.stop +
self.reg_score_len * self.num_multipliers)

def _pyramids_pooling(self, out_scores, index, raw_scores, ticks,
scale_factors, score_len, stpp_stage):
"""Perform pyramids pooling.
Args:
out_scores (torch.Tensor): Scores to be returned.
index (int): Index of output scores.
raw_scores (torch.Tensor): Raw scores before STPP.
ticks (list): Ticks of raw scores.
scale_factors (list): Ratios of the effective sampling lengths
to augmented lengths.
score_len (int): Length of the score.
stpp_stage (tuple): Config of STPP.
"""
offset = 0
for stage_idx, stage_cfg in enumerate(stpp_stage):
if stage_idx == 0:
scale_factor = scale_factors[0]
elif stage_idx == len(stpp_stage) - 1:
scale_factor = scale_factors[1]
else:
scale_factor = 1.0

sum_parts = sum(stage_cfg)
tick_left = ticks[stage_idx]
tick_right = max(ticks[stage_idx] + 1,
ticks[stage_idx + 1]).float()

if tick_right <= 0 or tick_left >= raw_scores.size(0):
offset += sum_parts
continue
for num_parts in stage_cfg:
part_ticks = torch.arange(tick_left, tick_right + 1e-5,
(tick_right - tick_left) /
num_parts).int()

for i in range(num_parts):
part_tick_left = part_ticks[i]
part_tick_right = part_ticks[i + 1]
if part_tick_right - part_tick_left >= 1:
raw_score = raw_scores[part_tick_left:part_tick_right,
offset *
score_len:(offset + 1) *
score_len]
raw_scale_score = raw_score.mean(dim=0) * scale_factor
out_scores[index, :] += raw_scale_score.detach().cpu()
offset += 1

return out_scores

def forward(self, x, proposal_ticks, scale_factors):
"""Defines the computation performed at every call.
Expand All @@ -192,83 +243,28 @@ def forward(self, x, proposal_ticks, scale_factors):
dtype=x.dtype)
raw_complete_scores = x[:, self.complete_slice]

if self.with_regression:
if self.use_regression:
out_reg_scores = torch.zeros((num_ticks, self.reg_score_len),
dtype=x.dtype)
raw_reg_scores = x[:, self.reg_slice]
else:
out_reg_scores = None
raw_reg_scores = None

def pyramids_pooling_(out_scores, index, raw_scores, ticks,
scale_factors, score_len, stpp_stage):
"""Perform pyramids pooling.
Args:
out_scores (torch.Tensor): Scores to be returned.
index (int): Index of output scores.
raw_scores (torch.Tensor): Raw scores before STPP.
ticks (list): Ticks of raw scores.
scale_factors (list): Ratios of the effective sampling lengths
to augmented lengths.
score_len (int): Length of the score.
stpp_stage (tuple): Config of STPP.
"""
offset = 0
for stage_idx, stage_cfg in enumerate(stpp_stage):
if stage_idx == 0:
scale_factor = scale_factors[0]
elif stage_idx == len(stpp_stage) - 1:
scale_factor = scale_factors[1]
else:
scale_factor = 1.0

sum_parts = sum(stage_cfg)
tick_left = ticks[stage_idx]
tick_right = max(ticks[stage_idx] + 1,
ticks[stage_idx + 1]).float()

if tick_right <= 0 or tick_left >= raw_scores.size(0):
offset += sum_parts
continue
for num_parts in stage_cfg:
part_ticks = torch.arange(tick_left, tick_right + 1e-5,
(tick_right - tick_left) /
num_parts).int()

for i in range(num_parts):
part_tick_left = part_ticks[i]
part_tick_right = part_ticks[i + 1]
if part_tick_right - part_tick_left >= 1:
raw_score = raw_scores[
part_tick_left:part_tick_right,
offset * score_len:(offset + 1) * score_len]
raw_scale_score = raw_score.mean(
dim=0) * scale_factor
out_scores[
index, :] += raw_scale_score.detach().cpu()
offset += 1

return out_scores

for i in range(num_ticks):
ticks = proposal_ticks[i]

out_activity_scores[i, :] = raw_activity_scores[
ticks[1]:max(ticks[1] + 1, ticks[2]), :].mean(dim=0)

out_complete_scores = pyramids_pooling_(out_complete_scores, i,
raw_complete_scores, ticks,
scale_factors[i],
self.complete_score_len,
self.stpp_stage)
out_complete_scores = self._pyramids_pooling(
out_complete_scores, i, raw_complete_scores, ticks,
scale_factors[i], self.complete_score_len, self.stpp_stage)

if self.with_regression:
out_reg_scores = pyramids_pooling_(out_reg_scores, i,
raw_reg_scores, ticks,
scale_factors[i],
self.reg_score_len,
self.stpp_stage)
if self.use_regression:
out_reg_scores = self._pyramids_pooling(
out_reg_scores, i, raw_reg_scores, ticks, scale_factors[i],
self.reg_score_len, self.stpp_stage)

return out_activity_scores, out_complete_scores, out_reg_scores

Expand All @@ -282,7 +278,7 @@ class SSNHead(nn.Module):
in_channels (int): Number of channels for input data. Default: 1024.
num_classes (int): Number of classes to be classified. Default: 20.
consensus (dict): Config of segmental consensus.
with_regression (bool): Whether to perform regression or not.
use_regression (bool): Whether to perform regression or not.
Default: True.
init_std (float): Std value for Initiation. Default: 0.001.
"""
Expand All @@ -296,14 +292,14 @@ def __init__(self,
standalong_classifier=True,
stpp_cfg=(1, 1, 1),
num_seg=(2, 5, 2)),
with_regression=True,
use_regression=True,
init_std=0.001):

super().__init__()

self.dropout_ratio = dropout_ratio
self.num_classes = num_classes
self.with_regression = with_regression
self.use_regression = use_regression
self.init_std = init_std

if self.dropout_ratio != 0:
Expand All @@ -328,15 +324,15 @@ def __init__(self,
self.activity_fc = nn.Linear(in_channels, num_classes + 1)
self.completeness_fc = nn.Linear(self.in_channels_complete,
num_classes)
if self.with_regression:
if self.use_regression:
self.regressor_fc = nn.Linear(self.in_channels_complete,
num_classes * 2)

def init_weights(self):
"""Initiate the parameters from scratch."""
normal_init(self.activity_fc, std=self.init_std)
normal_init(self.completeness_fc, std=self.init_std)
if self.with_regression:
if self.use_regression:
normal_init(self.regressor_fc, std=self.init_std)

def prepare_test_fc(self, stpp_feat_multiplier):
Expand All @@ -354,7 +350,7 @@ def prepare_test_fc(self, stpp_feat_multiplier):
out_features = (
self.activity_fc.out_features +
self.completeness_fc.out_features * stpp_feat_multiplier)
if self.with_regression:
if self.use_regression:
out_features += (
self.regressor_fc.out_features * stpp_feat_multiplier)
self.test_fc = nn.Linear(in_features, out_features)
Expand All @@ -370,7 +366,7 @@ def prepare_test_fc(self, stpp_feat_multiplier):
weight = torch.cat((self.activity_fc.weight.data, complete_weight))
bias = torch.cat((self.activity_fc.bias.data, complete_bias))

if self.with_regression:
if self.use_regression:
reg_weight = self.regressor_fc.weight.data.view(
self.regressor_fc.out_features, stpp_feat_multiplier,
in_features).transpose(0,
Expand Down Expand Up @@ -398,7 +394,7 @@ def forward(self, x, test_mode=False):

activity_scores = self.activity_fc(activity_feat)
complete_scores = self.completeness_fc(completeness_feat)
if self.with_regression:
if self.use_regression:
bbox_preds = self.regressor_fc(completeness_feat)
bbox_preds = bbox_preds.view(-1,
self.completeness_fc.out_features,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_models/test_localizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def test_ssn_train():
type='STPPTrain',
stpp_stage=(1, 1, 1),
num_segments_list=(2, 5, 2)),
with_regression=True),
use_regression=True),
train_cfg=train_cfg)
dropout_cfg = copy.deepcopy(base_model_cfg)
dropout_cfg['dropout_ratio'] = 0
dropout_cfg['cls_head']['dropout_ratio'] = 0.5
non_regression_cfg = copy.deepcopy(base_model_cfg)
non_regression_cfg['cls_head']['with_regression'] = False
non_regression_cfg['cls_head']['use_regression'] = False

imgs = torch.rand(1, 8, 9, 3, 224, 224)
proposal_scale_factor = torch.Tensor([[[1.0345, 1.0345], [1.0028, 0.0028],
Expand Down Expand Up @@ -291,13 +291,13 @@ def test_ssn_test():
in_channels=512,
num_classes=20,
consensus=dict(type='STPPTest', stpp_stage=(1, 1, 1)),
with_regression=True),
use_regression=True),
test_cfg=test_cfg)
maxpool_model_cfg = copy.deepcopy(base_model_cfg)
maxpool_model_cfg['spatial_type'] = 'max'
non_regression_cfg = copy.deepcopy(base_model_cfg)
non_regression_cfg['cls_head']['with_regression'] = False
non_regression_cfg['cls_head']['consensus']['with_regression'] = False
non_regression_cfg['cls_head']['use_regression'] = False
non_regression_cfg['cls_head']['consensus']['use_regression'] = False
tuple_stage_cfg = copy.deepcopy(base_model_cfg)
tuple_stage_cfg['cls_head']['consensus']['stpp_stage'] = (1, (1, 2), 1)
str_stage_cfg = copy.deepcopy(base_model_cfg)
Expand All @@ -307,7 +307,7 @@ def test_ssn_test():
relative_proposal_list = torch.Tensor([[[0.2500, 0.6250], [0.3750,
0.7500]]])
scale_factor_list = torch.Tensor([[[1.0000, 1.0000], [1.0000, 0.2661]]])
proposal_tick_list = torch.LongTensor([[[1, 2, 5, 7], [2, 3, 6, 8]]])
proposal_tick_list = torch.LongTensor([[[1, 2, 5, 7], [20, 30, 60, 80]]])
reg_norm_consts = torch.Tensor([[[-0.0603, 0.0325], [0.0752, 0.1596]]])

localizer_ssn = build_localizer(base_model_cfg)
Expand Down

0 comments on commit 3cdfd5c

Please sign in to comment.