Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class weight #509

Merged
merged 25 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c5f00a8
resolve comments
Oct 16, 2020
05575c1
update changelog
Oct 16, 2020
eb09070
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 19, 2020
81302a4
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 22, 2020
43a649a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 27, 2020
755809e
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 29, 2020
d478c9d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 2, 2020
08bbc06
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 7, 2020
ff958e6
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 8, 2020
d0e192d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
a52c536
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
81a2029
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 16, 2020
e03d2a9
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 17, 2020
2a9b57f
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 27, 2020
28001ff
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 30, 2020
46cc5dd
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 1, 2020
667818a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
34398a8
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
a5be690
add class_weight in loss arguments
Jan 4, 2021
9667211
switch to mmcv 1.2.4
Jan 4, 2021
95ebd9b
Merge branch 'master' into add_class_weight
kennymckormick Jan 4, 2021
6d61310
use v1.1.1 as mmcv version lower bound
Jan 4, 2021
9b601e0
Merge branch 'add_class_weight' of github.com:kennymckormick/mmaction…
Jan 4, 2021
a841cc9
reorganize code
Jan 4, 2021
8a28890
resolve comments
Jan 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMCV
run: pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
run: pip install mmcv-full==1.2.4 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
- name: Install MMDet
run: pip install -q git+https://github.com/open-mmlab/mmdetection/
- name: Install unittest dependencies
Expand Down Expand Up @@ -138,7 +138,7 @@ jobs:
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmaction dependencies
run: |
pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/${{matrix.mmcv}}/index.html
pip install mmcv-full==1.2.4 -f https://download.openmmlab.com/mmcv/dist/${{matrix.mmcv}}/index.html
pip install -q git+https://github.com/open-mmlab/mmdetection/
pip install -r requirements.txt
- name: Build and install
Expand Down
41 changes: 39 additions & 2 deletions mmaction/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from ..registry import LOSSES
Expand All @@ -6,7 +7,22 @@

@LOSSES.register_module()
class CrossEntropyLoss(BaseWeightedLoss):
"""Cross Entropy Loss."""
"""Cross Entropy Loss.

Args:
loss_weight (float): Factor scalar multiplied on the loss.
Default: 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Default: None.
"""

def __init__(self, loss_weight=1.0, class_weight=None):
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)

def _forward(self, cls_score, label, **kwargs):
"""Forward function.
Expand All @@ -20,13 +36,31 @@ def _forward(self, cls_score, label, **kwargs):
Returns:
torch.Tensor: The returned CrossEntropy loss.
"""
if self.class_weight is not None:
assert 'weight' not in kwargs, "The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
loss_cls = F.cross_entropy(cls_score, label, **kwargs)
return loss_cls


@LOSSES.register_module()
class BCELossWithLogits(BaseWeightedLoss):
"""Binary Cross Entropy Loss with logits."""
"""Binary Cross Entropy Loss with logits.

Args:
loss_weight (float): Factor scalar multiplied on the loss.
Default: 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Default: None.
"""

def __init__(self, loss_weight=1.0, class_weight=None):
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)

def _forward(self, cls_score, label, **kwargs):
"""Forward function.
Expand All @@ -40,6 +74,9 @@ def _forward(self, cls_score, label, **kwargs):
Returns:
torch.Tensor: The returned bce loss with logits.
"""
if self.class_weight is not None:
assert 'weight' not in kwargs, "The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
loss_cls = F.binary_cross_entropy_with_logits(cls_score, label,
**kwargs)
return loss_cls
19 changes: 18 additions & 1 deletion tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,38 @@ def test_hvu_loss():

def test_cross_entropy_loss():
cls_scores = torch.rand((3, 4))
gt_labels = torch.LongTensor([2] * 3).squeeze()
gt_labels = torch.LongTensor([0, 1, 2]).squeeze()

cross_entropy_loss = CrossEntropyLoss()
output_loss = cross_entropy_loss(cls_scores, gt_labels)
assert torch.equal(output_loss, F.cross_entropy(cls_scores, gt_labels))

weight = torch.rand(4)
class_weight = weight.numpy().tolist()
cross_entropy_loss = CrossEntropyLoss(class_weight=class_weight)
output_loss = cross_entropy_loss(cls_scores, gt_labels)
assert torch.equal(output_loss,
F.cross_entropy(cls_scores, gt_labels, weight=weight))


def test_bce_loss_with_logits():
cls_scores = torch.rand((3, 4))
gt_labels = torch.rand((3, 4))

bce_loss_with_logits = BCELossWithLogits()
output_loss = bce_loss_with_logits(cls_scores, gt_labels)
assert torch.equal(
output_loss, F.binary_cross_entropy_with_logits(cls_scores, gt_labels))

weight = torch.rand(4)
class_weight = weight.numpy().tolist()
bce_loss_with_logits = BCELossWithLogits(class_weight=class_weight)
output_loss = bce_loss_with_logits(cls_scores, gt_labels)
assert torch.equal(
output_loss,
F.binary_cross_entropy_with_logits(
cls_scores, gt_labels, weight=weight))


def test_nll_loss():
cls_scores = torch.randn(3, 3)
Expand Down