Skip to content

Commit

Permalink
[pytorch] Add triplet margin loss with custom distance
Browse files Browse the repository at this point in the history
Summary: As discussed [here](pytorch/pytorch#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 672055f026a5627c7883c43625ca85bc05f5af5e
Pull Request resolved: pytorch/pytorch#43680
  • Loading branch information
ethch18 committed Sep 2, 2020
1 parent 14ebb2c commit 8e1f3aa
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 11 deletions.
7 changes: 5 additions & 2 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ Loss functions

.. autofunction:: triplet_margin_loss

:hidden:`triplet_margin_loss_with_distance`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: triplet_margin_loss_with_distance

Vision functions
----------------

Expand Down Expand Up @@ -533,5 +538,3 @@ DataParallel functions (multi-GPU, distributed)
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.parallel.data_parallel


3 changes: 2 additions & 1 deletion docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These are the basic building block for graphs
:depth: 2
:local:
:backlinks: top


.. currentmodule:: torch.nn

Expand Down Expand Up @@ -269,6 +269,7 @@ Loss Functions
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
nn.TripletMarginLossWithDistance

Vision Layers
----------------
Expand Down
70 changes: 70 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12067,6 +12067,76 @@ def test_threshold_inplace_overlap(self, device):
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)

def test_triplet_margin_loss_with_distance_default_parity(self, device):
# Test for `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (``nn.TripletMarginLoss`
# and `F.triplet_margin_loss`) under *default args*.

anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# functional grad and parity check
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative),
F.triplet_margin_loss(anchor, positive, negative))

# module grad and parity check
loss_base = nn.TripletMarginLoss()
loss_test = nn.TripletMarginLossWithDistance()
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance(self, device):
# Test for `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (`nn.TripletMarginLoss`
# and `F.triplet_margin_loss`).

def pairwise_similarity(x, y):
return 1.0 - F.pairwise_distance(x, y)
pairwise_distance = nn.PairwiseDistance()
distance_functions = ((pairwise_similarity, True), (pairwise_distance, False))

reductions = ('mean', 'none')
margins = (1.0, 1.5)
swaps = (True, False)

for (distance_fn, is_similarity_fn), reduction, margin, swap \
in itertools.product(distance_functions, reductions, margins, swaps):
anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# functional: standard gradient check
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn), (anchor, positive, negative)))
# functional: parity check
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap),
F.triplet_margin_loss(anchor, positive, negative,
reduction=reduction, margin=margin, swap=swap))

loss_base = nn.TripletMarginLoss(reduction=reduction, margin=margin, swap=swap)
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap)
# module: standard gradient check
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
# module: parity check
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))


class TestModuleGlobalHooks(TestCase):

def tearDown(self):
Expand Down
44 changes: 44 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3729,6 +3729,50 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=None, is_similarity_function=False,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], bool, float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginLossWithDistance` for details
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT scripting: "
"Callables cannot be scripted unless they are properties of "
"a module. Please use nn.TripletMarginLossWithDistance instead.")

tens_ops = (anchor, positive, negative)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
triplet_margin_loss_with_distance, tens_ops, anchor, positive, negative,
distance_function=distance_function, is_similarity_function=is_similarity_function,
margin=margin, swap=swap, reduction=reduction)

distance_function = distance_function if distance_function is not None else pairwise_distance

positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)

if swap:
swap_dist = distance_function(positive, negative)
if is_similarity_function:
negative_dist = torch.max(negative_dist, swap_dist)
else:
negative_dist = torch.min(negative_dist, swap_dist)

if is_similarity_function:
output = torch.clamp(negative_dist - positive_dist + margin, min=0.0)
else:
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
reduction_enum = _Reduction.get_enum(reduction)

if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output


def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Expand Down
12 changes: 9 additions & 3 deletions torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ GRID_SAMPLE_PADDING_MODES = Dict[str, int]
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
#
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
def fractional_max_pool2d_with_indices(input: Tensor, kernel_size: _size, output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ..., return_indices: bool = ...,
Expand Down Expand Up @@ -311,6 +311,12 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...


def triplet_margin_loss_with_distance(anchor: Tensor, positive: Tensor, negative: Tensor,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
is_similarity_function: bool=..., margin: float=...,
swap: bool=..., reduction: str=...) -> Tensor: ...


def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ...,
out: Optional[Tensor] = ...) -> Tensor: ...

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Hardsigmoid, Hardswish, SiLU
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginLossWithDistance, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginLossWithDistance'
]
149 changes: 147 additions & 2 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import warnings

import torch

from .distance import PairwiseDistance
from .module import Module
from .. import functional as F
from .. import _reduction as _Reduction

from torch import Tensor
from typing import Optional
from typing import Callable, Optional


class _Loss(Module):
Expand Down Expand Up @@ -1191,6 +1194,9 @@ class TripletMarginLoss(_Loss):
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
See also :class:`~torch.nn.TripletMarginLossWithDistance`, which computes the
triplet margin loss for input tensors using a custom distance function.
Args:
margin (float, optional): Default: :math:`1`.
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
Expand All @@ -1215,7 +1221,8 @@ class TripletMarginLoss(_Loss):
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
Expand Down Expand Up @@ -1246,6 +1253,144 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
eps=self.eps, swap=self.swap, reduction=self.reduction)


class TripletMarginLossWithDistance(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a real-valued function
between them.
The unreduced loss (i.e., with `reduction` set to `'none'`)
can be described as:
.. math::
\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a real-valued function quantifying
the separation between two tensors, referred to as `distance_function`;
and :math:`margin` is a non-negative margin enforced between the positive and
negative distances. The input tensors have :math:`N` elements each and can be of
any shape that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Args:
distance_function (callable, optional): A distance function between two Tensors which,
if specified, will be used instead of the pairwise distance. If not specified,
`nn.PairwiseDistance` will be used. Default: ``None``
is_similarity_function (bool, optional): Whether `distance_function` represents a
similarity metric, i.e., larger is closer. If True, computes the difference of
distances as :math:`d(a_i, n_i) - d(a_i, p_i)` so that larger loss values occur
when the negative example is more similar to the anchor than the positive example
is. Default: ``False``
margin (float, optional): A non-negative margin enforced between the positive and
negative distances. Larger margins penalize cases where the negative examples
are not distant enough from the anchors, relative to the positives. Default: :math:`1`.
swap (bool, optional): Whether to use the distance swap described in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
negative example than the anchor is, swaps the positive example and the anchor in
the loss computation. Default: ``False``.
reduction (string, optional): Specifies the (optional) reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
as supported by the distance function.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
Example::
>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
>>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> anchor = embedding(anchor_ids)
>>> positive = embedding(positive_ids)
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Built-in Similarity Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.CosineSimilarity(), is_similarity_function=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # User-defined Similarity Function
>>> def l_infinity(x1, x2):
>>> return torch.max(torch.abs(x1 - x2), dim=1).values
>>>
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
Reference:
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
http://www.bmva.org/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ['is_similarity_function', 'margin', 'swap', 'reduction']
is_similarity_function: bool
margin: float
swap: bool

def __init__(self, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, is_similarity_function: bool = False,
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
super(TripletMarginLossWithDistance, self).__init__(size_average=None, reduce=None, reduction=reduction)
self.distance_function = distance_function if distance_function is not None else PairwiseDistance()
self.is_similarity_function = is_similarity_function
self.margin = margin
self.swap = swap

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
if not torch.jit.is_scripting():
return F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=self.distance_function,
is_similarity_function=self.is_similarity_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)
else:
positive_dist = self.distance_function(anchor, positive)
negative_dist = self.distance_function(anchor, negative)

if self.swap:
swap_dist = self.distance_function(positive, negative)
if self.is_similarity_function:
negative_dist = torch.max(negative_dist, swap_dist)
else:
negative_dist = torch.min(negative_dist, swap_dist)

if self.is_similarity_function:
output = torch.clamp(negative_dist - positive_dist + self.margin, min=0.0)
else:
output = torch.clamp(positive_dist - negative_dist + self.margin, min=0.0)
reduction_enum = _Reduction.get_enum(self.reduction)

if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output


class CTCLoss(_Loss):
r"""The Connectionist Temporal Classification loss.
Expand Down
3 changes: 3 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_loss_with_distance: (lambda anchor, positive, negative, distance_function=None,
is_similarity_function=False, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
Expand Down

0 comments on commit 8e1f3aa

Please sign in to comment.