Skip to content

Commit

Permalink
Additive margin softmax (#1131)
Browse files Browse the repository at this point in the history
* Added Additive-margin softmax

* update changelog.md
  • Loading branch information
Atharva-Phatak committed Mar 27, 2021
1 parent 77599e5 commit 2e3ef50
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 26 deletions.
47 changes: 27 additions & 20 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Additive Margin SoftMax(AMSoftmax)([#1125](https://github.com/catalyst-team/catalyst/issues/1125))

### Added

- Generalized Mean Pooling(GeM)([#1084](https://github.com/catalyst-team/catalyst/issues/1084))

- Generalized Mean Pooling(GeM) ([#1084](https://github.com/catalyst-team/catalyst/issues/1084))
- Key-value support for CriterionCallback ([#1130](https://github.com/catalyst-team/catalyst/issues/1130))
- Engine configuration through cmd ([#1134](https://github.com/catalyst-team/catalyst/issues/1134))
Expand All @@ -22,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-


### Fixed

Expand Down Expand Up @@ -97,8 +104,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- ([#1002](https://github.com/catalyst-team/catalyst/pull/1002))
- a few docs
- ([#998](https://github.com/catalyst-team/catalyst/pull/998))
- ``reciprocal_rank`` metric
- unified recsys metrics preprocessing
- ``reciprocal_rank`` metric
- unified recsys metrics preprocessing
- ([#1018](https://github.com/catalyst-team/catalyst/pull/1018))
- readme examples for all supported metrics under ``catalyst.metrics``
- ``wrap_metric_fn_with_activation`` for model outputs wrapping with activation
Expand Down Expand Up @@ -146,7 +153,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- ([#1018](https://github.com/catalyst-team/catalyst/pull/1014))
- ClasswiseIouCallback/ClasswiseJaccardCallback as deprecated on (should be refactored in future releases)



### Fixed

Expand All @@ -163,7 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added
- DCG, nDCG metrics ([#881](https://github.com/catalyst-team/catalyst/pull/881))
- MAP calculations [#968](https://github.com/catalyst-team/catalyst/pull/968)
- MAP calculations [#968](https://github.com/catalyst-team/catalyst/pull/968)
- hitrate calculations [#975] (https://github.com/catalyst-team/catalyst/pull/975)
- extra functions for classification metrics ([#966](https://github.com/catalyst-team/catalyst/pull/966))
- `OneOf` and `OneOfV2` batch transforms ([#951](https://github.com/catalyst-team/catalyst/pull/951))
Expand Down Expand Up @@ -197,7 +204,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- MRR metrics calculation ([#886](https://github.com/catalyst-team/catalyst/pull/886))
- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- ArcMargin layer to contrib ([#957](https://github.com/catalyst-team/catalyst/pull/957))
- AdaCos to contrib ([#958](https://github.com/catalyst-team/catalyst/pull/958))
Expand All @@ -218,7 +225,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand All @@ -243,7 +250,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -271,13 +278,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

- autoresume option for Config API ([#907](https://github.com/catalyst-team/catalyst/pull/907))
- a few issues with TF projector ([#917](https://github.com/catalyst-team/catalyst/pull/917))
- batch sampler speed issue ([#921](https://github.com/catalyst-team/catalyst/pull/921))
- batch sampler speed issue ([#921](https://github.com/catalyst-team/catalyst/pull/921))
- add apex key-value optimizer support ([#924](https://github.com/catalyst-team/catalyst/pull/924))
- runtime warning for PyTorch 1.6 ([920](https://github.com/catalyst-team/catalyst/pull/920))
- Apex synbn usage ([920](https://github.com/catalyst-team/catalyst/pull/920))
Expand Down Expand Up @@ -377,7 +384,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -407,7 +414,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -461,35 +468,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [20.04] - 2020-04-06

### Added
-


### Changed

-
-

### Removed

-
-

### Fixed

-
-


## [YY.MM.R] - YYYY-MM-DD

### Added

-
-

### Changed

-
-

### Removed

-
-

### Fixed

-
-
1 change: 1 addition & 0 deletions catalyst/contrib/nn/modules/__init__.py
@@ -1,6 +1,7 @@
# flake8: noqa
from torch.nn.modules import *

from catalyst.contrib.nn.modules.amsoftmax import AMSoftmax
from catalyst.contrib.nn.modules.arcface import ArcFace, SubCenterArcFace
from catalyst.contrib.nn.modules.arcmargin import ArcMarginProduct
from catalyst.contrib.nn.modules.common import (
Expand Down
106 changes: 106 additions & 0 deletions catalyst/contrib/nn/modules/amsoftmax.py
@@ -0,0 +1,106 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class AMSoftmax(nn.Module):
"""Implementation of
`AMSoftmax: Additive Margin Softmax for Face Verification`_.
.. _AMSoftmax\: Additive Margin Softmax for Face Verification:
https://arxiv.org/pdf/1801.05599.pdf
Args:
in_features: size of each input sample.
out_features: size of each output sample.
s: norm of input feature.
Default: ``64.0``.
m: margin.
Default: ``0.5``.
eps: operation accuracy.
Default: ``1e-6``.
Shape:
- Input: :math:`(batch, H_{in})` where
:math:`H_{in} = in\_features`.
- Output: :math:`(batch, H_{out})` where
:math:`H_{out} = out\_features`.
Example:
>>> layer = AMSoftmax(5, 10, s=1.31, m=0.5)
>>> loss_fn = nn.CrossEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()
"""

def __init__( # noqa: D107
self,
in_features: int,
out_features: int,
s: float = 64.0,
m: float = 0.5,
eps: float = 1e-6,
):
super(AMSoftmax, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.eps = eps

self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
"""Object representation."""
rep = (
"ArcFace("
f"in_features={self.in_features},"
f"out_features={self.out_features},"
f"s={self.s},"
f"m={self.m},"
f"eps={self.eps}"
")"
)
return rep

def forward(self, input: torch.Tensor, target: torch.LongTensor = None) -> torch.Tensor:
"""
Args:
input: input features,
expected shapes ``BxF`` where ``B``
is batch dimension and ``F`` is an
input feature dimension.
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.
Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes
(out_features).
"""
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))

if target is None:
return cos_theta

cos_theta = torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)

one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

logits = torch.where(one_hot.bool(), cos_theta - self.m, cos_theta)
logits *= self.s

return logits


__all__ = ["AMSoftmax"]
2 changes: 1 addition & 1 deletion catalyst/contrib/nn/modules/arcface.py
Expand Up @@ -30,7 +30,7 @@ class ArcFace(nn.Module):
Example:
>>> layer = ArcFace(5, 10, s=1.31, m=0.5)
>>> loss_fn = nn.CrosEntropyLoss()
>>> loss_fn = nn.CrossEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding, target)
Expand Down

0 comments on commit 2e3ef50

Please sign in to comment.