-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[20230424 v0.4.0] Add strong data augmentations; Support the argument…
… 'in_chans' for all models; Support imagenet21k (22k)
- Loading branch information
1 parent
85a5143
commit 5b17f6a
Showing
45 changed files
with
323 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,24 @@ | ||
# Copyright (c) QIU, Tian. All rights reserved. | ||
|
||
from .default import DefaultCriterion | ||
from .cross_entropy import CrossEntropy, LabelSmoothingCrossEntropy, SoftTargetCrossEntropy | ||
|
||
|
||
def build_criterion(args): | ||
criterion_name = args.criterion.lower() | ||
|
||
if criterion_name in ['ce', 'default']: | ||
if criterion_name == 'ce': | ||
losses = ['labels'] | ||
weight_dict = {'loss_ce': 1} | ||
return DefaultCriterion(losses=losses, weight_dict=weight_dict) | ||
return CrossEntropy(losses=losses, weight_dict=weight_dict) | ||
|
||
if criterion_name == 'label_smoothing_ce': | ||
losses = ['labels'] | ||
weight_dict = {'loss_ce': 1} | ||
return LabelSmoothingCrossEntropy(losses=losses, weight_dict=weight_dict, smoothing=args.label_smoothing) | ||
|
||
if criterion_name in ['soft_target_ce', 'default']: | ||
losses = ['labels'] | ||
weight_dict = {'loss_ce': 1} | ||
return SoftTargetCrossEntropy(losses=losses, weight_dict=weight_dict) | ||
|
||
raise ValueError(f"Criterion '{criterion_name}' is not found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) QIU, Tian. All rights reserved. | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ._base_ import BaseCriterion | ||
from ..utils.misc import accuracy | ||
|
||
__all__ = ['CrossEntropy', 'LabelSmoothingCrossEntropy', 'SoftTargetCrossEntropy'] | ||
|
||
|
||
class CrossEntropy(BaseCriterion): | ||
def __init__(self, losses: list, weight_dict: dict): | ||
super().__init__(losses, weight_dict) | ||
|
||
def loss_labels(self, outputs, targets, **kwargs): | ||
if isinstance(outputs, dict): | ||
assert 'logits' in outputs.keys(), \ | ||
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \ | ||
f"if 'outputs' is a dict, 'logits' MUST be the key." | ||
outputs = outputs["logits"] | ||
|
||
loss_ce = F.cross_entropy(outputs, targets, reduction='mean') | ||
losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]} | ||
|
||
return losses | ||
|
||
|
||
class LabelSmoothingCrossEntropy(BaseCriterion): | ||
def __init__(self, losses: list, weight_dict: dict, smoothing: float = 0.1): | ||
super().__init__(losses, weight_dict) | ||
self.smoothing = smoothing | ||
self.confidence = 1. - smoothing | ||
|
||
def loss_labels(self, outputs, targets, training, **kwargs): | ||
if isinstance(outputs, dict): | ||
assert 'logits' in outputs.keys(), \ | ||
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \ | ||
f"if 'outputs' is a dict, 'logits' MUST be the key." | ||
outputs = outputs["logits"] | ||
|
||
if training: | ||
logprobs = F.log_softmax(outputs, dim=-1) | ||
nll_loss = -logprobs.gather(dim=-1, index=targets.unsqueeze(1)) | ||
nll_loss = nll_loss.squeeze(1) | ||
smooth_loss = -logprobs.mean(dim=-1) | ||
loss_ce = (self.confidence * nll_loss + self.smoothing * smooth_loss).mean() | ||
else: | ||
loss_ce = F.cross_entropy(outputs, targets, reduction='mean') | ||
|
||
losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]} | ||
|
||
return losses | ||
|
||
|
||
class SoftTargetCrossEntropy(BaseCriterion): # Compatible with 'CrossEntropy' | ||
def __init__(self, losses: list, weight_dict: dict): | ||
super().__init__(losses, weight_dict) | ||
|
||
def loss_labels(self, outputs, targets, **kwargs): | ||
if isinstance(outputs, dict): | ||
assert 'logits' in outputs.keys(), \ | ||
f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \ | ||
f"if 'outputs' is a dict, 'logits' MUST be the key." | ||
outputs = outputs["logits"] | ||
|
||
if targets.dim() == 1: | ||
loss_ce = F.cross_entropy(outputs, targets, reduction='mean') | ||
losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]} | ||
else: | ||
loss_ce = torch.sum(-targets * F.log_softmax(outputs, dim=-1), dim=-1).mean() | ||
losses = {'loss_ce': loss_ce} | ||
|
||
return losses |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.