Skip to content

Commit

Permalink
Minor fix unittest.
Browse files Browse the repository at this point in the history
Fix docstring.
  • Loading branch information
xusu committed Dec 15, 2020
1 parent eab20ff commit 7f6e9ea
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmaction/models/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
loss_cls (dict): Config for building loss.
Default: dict(type='CrossEntropyLoss').
Default: dict(type='CrossEntropyLoss', loss_weight=1.0).
multi_class (bool): Determines whether it is a multi-class
recognition task. Default: False.
label_smooth_eps (float): Epsilon used in label smooth.
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ExampleHead(BaseHead):
# use a ExampleHead to success BaseHead
# use a ExampleHead to succeed BaseHead
def init_weights(self):
pass

Expand All @@ -24,6 +24,15 @@ def test_base_head():
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'

head = ExampleHead(3, 400, dict(type='CrossEntropyLoss', loss_weight=2.0))

cls_scores = torch.rand((3, 4))
# When truth is non-empty then cls loss should be nonzero for random inputs
gt_labels = torch.LongTensor([2] * 3).squeeze()
losses = head.loss(cls_scores, gt_labels)
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'


def test_i3d_head():
"""Test loss method, layer construction, attributes and forward function in
Expand Down

0 comments on commit 7f6e9ea

Please sign in to comment.