Skip to content

Commit

Permalink
Minor fix unittest.
Browse files Browse the repository at this point in the history
  • Loading branch information
xusu committed Dec 15, 2020
1 parent eab20ff commit c35fb4f
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tests/test_models/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ExampleHead(BaseHead):
# use a ExampleHead to success BaseHead
# use a ExampleHead to succeed BaseHead
def init_weights(self):
pass

Expand All @@ -24,6 +24,15 @@ def test_base_head():
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'

head = ExampleHead(3, 400, dict(type='CrossEntropyLoss', loss_weight=2.0))

cls_scores = torch.rand((3, 4))
# When truth is non-empty then cls loss should be nonzero for random inputs
gt_labels = torch.LongTensor([2] * 3).squeeze()
losses = head.loss(cls_scores, gt_labels)
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'


def test_i3d_head():
"""Test loss method, layer construction, attributes and forward function in
Expand Down

0 comments on commit c35fb4f

Please sign in to comment.