diff --git a/docs/changelog.md b/docs/changelog.md index d9ed84a8d1..7445e93b34 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -11,6 +11,7 @@ **Bug and Typo Fixes** +- Fix typo in default argument of BaseHead. ([#446](https://github.com/open-mmlab/mmaction2/pull/446)) **ModelZoo** diff --git a/mmaction/models/heads/base.py b/mmaction/models/heads/base.py index 92e9b0427f..91abacd124 100644 --- a/mmaction/models/heads/base.py +++ b/mmaction/models/heads/base.py @@ -36,7 +36,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta): num_classes (int): Number of classes to be classified. in_channels (int): Number of channels in input feature. loss_cls (dict): Config for building loss. - Default: dict(type='CrossEntropyLoss'). + Default: dict(type='CrossEntropyLoss', loss_weight=1.0). multi_class (bool): Determines whether it is a multi-class recognition task. Default: False. label_smooth_eps (float): Epsilon used in label smooth. @@ -46,7 +46,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta): def __init__(self, num_classes, in_channels, - loss_cls=dict(type='CrossEntropyLoss', loss_factor=1.0), + loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0), multi_class=False, label_smooth_eps=0.0): super().__init__() diff --git a/tests/test_models/test_head.py b/tests/test_models/test_head.py index b05d85f05b..272fdeadaf 100644 --- a/tests/test_models/test_head.py +++ b/tests/test_models/test_head.py @@ -6,7 +6,7 @@ class ExampleHead(BaseHead): - # use a ExampleHead to success BaseHead + # use an ExampleHead to test BaseHead def init_weights(self): pass @@ -24,6 +24,15 @@ def test_base_head(): assert 'loss_cls' in losses.keys() assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + head = ExampleHead(3, 400, dict(type='CrossEntropyLoss', loss_weight=2.0)) + + cls_scores = torch.rand((3, 4)) + # When truth is non-empty then cls loss should be nonzero for random inputs + gt_labels = torch.LongTensor([2] * 3).squeeze() + losses = head.loss(cls_scores, gt_labels) + assert 'loss_cls' in losses.keys() + assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + def test_i3d_head(): """Test loss method, layer construction, attributes and forward function in