In [1]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

tf.enable_eager_execution()

%matplotlib inline

# https://omoindrot.github.io/triplet-loss#a-better-implementation-with-online-triplet-mining

  from ._conv import register_converters as _register_converters


In [14]:
def _pairwise_distances(embeddings, squared=False):
    """Compute the 2D matrix of distances between all the embeddings.

    Args:
        embeddings: tensor of shape (batch_size, embed_dim)
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        pairwise_distances: tensor of shape (batch_size, batch_size)
    """
    # Get the dot product between all embeddings
    # shape (batch_size, batch_size)
    dot_product = tf.matmul(embeddings, tf.transpose(embeddings))

    # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
    # This also provides more numerical stability (the diagonal of the result will be exactly 0).
    # shape (batch_size,)
    square_norm = tf.diag_part(dot_product)

    # Compute the pairwise distance matrix as we have:
    # ||a - b||^2 = ||a||^2  - 2 <a, b> + ||b||^2
    # shape (batch_size, batch_size)
    distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)

    # Because of computation errors, some distances might be negative so we put everything >= 0.0
    distances = tf.maximum(distances, 0.0)

    if not squared:
        # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
        # we need to add a small epsilon where distances == 0.0
        mask = tf.to_float(tf.equal(distances, 0.0))
        distances = distances + mask * 1e-16

        distances = tf.sqrt(distances)

        # Correct the epsilon added: set the distances on the mask to be exactly 0.0
        distances = distances * (1.0 - mask)

    return distances



def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
    """Build the triplet loss over a batch of embeddings.

    We generate all the valid triplets and average the loss over the positive ones.

    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    """
    # Get the pairwise distance matrix
    pairwise_dist = _pairwise_distances(embeddings, squared=squared)

    anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
    anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)

    # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
    # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
    # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
    # and the 2nd (batch_size, 1, batch_size)
    triplet_loss = anchor_positive_dist - anchor_negative_dist + margin

    # Put to zero the invalid triplets
    # (where label(a) != label(p) or label(n) == label(a) or a == p)
    mask = _get_triplet_mask(labels)
    mask = tf.to_float(mask)
    triplet_loss = tf.multiply(mask, triplet_loss)

    # Remove negative losses (i.e. the easy triplets)
    triplet_loss = tf.maximum(triplet_loss, 0.0)

    # Count number of positive triplets (where triplet_loss > 0)
    valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16))
    num_positive_triplets = tf.reduce_sum(valid_triplets)
    num_valid_triplets = tf.reduce_sum(mask)
    fraction_positive_triplets = num_positive_triplets / (num_valid_triplets + 1e-16)

    # Get final mean triplet loss over the positive valid triplets
    triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_triplets + 1e-16)

    return triplet_loss, fraction_positive_triplets




In [3]:
embeddings = tf.random_normal([7, 10])

In [4]:
dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
square_norm = tf.diag_part(dot_product)
distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)


In [5]:
print(embeddings.shape, embeddings[0].shape)
print(embeddings[0])
print(tf.tensordot(embeddings[0], embeddings[0], axes=1))
print('-'*10)
print(tf.tensordot(embeddings[1], embeddings[2], axes=1))
print('-'*10)
print(dot_product.numpy())

(7, 10) (10,)
tf.Tensor(
[-1.0018165   0.33871594  0.9615877   2.672127   -0.8645353   0.7446462
  0.2539306   1.199053   -0.89967215 -0.09572788], shape=(10,), dtype=float32)
tf.Tensor(12.80598, shape=(), dtype=float32)
----------
tf.Tensor(2.3669274, shape=(), dtype=float32)
----------
[[12.805981    6.4123864   0.6581936   4.5202904  -1.3989016  -1.0348586
   2.466448  ]
 [ 6.4123864  14.190258    2.3669274  -1.3510969  -0.37516654 -2.522407
   7.6524663 ]
 [ 0.6581936   2.3669274   6.287444   -4.3116736  -1.537779   -2.247109
  -0.79516757]
 [ 4.5202904  -1.3510969  -4.3116736  19.387615    2.36764     4.153995
  -1.3458203 ]
 [-1.3989017  -0.37516657 -1.5377787   2.3676403   8.851185    2.3401558
  -4.468193  ]
 [-1.0348585  -2.522407   -2.247109    4.1539955   2.3401558   5.2345133
  -2.0617304 ]
 [ 2.4664478   7.6524663  -0.79516757 -1.3458201  -4.468193   -2.0617304
  10.947656  ]]


In [6]:
square_norm

<tf.Tensor: id=16, shape=(7,), dtype=float32, numpy=
array([12.805981 , 14.190258 ,  6.287444 , 19.387615 ,  8.851185 ,
        5.2345133, 10.947656 ], dtype=float32)>

In [7]:
tf.expand_dims(square_norm, 0)

<tf.Tensor: id=77, shape=(1, 7), dtype=float32, numpy=
array([[12.805981 , 14.190258 ,  6.287444 , 19.387615 ,  8.851185 ,
         5.2345133, 10.947656 ]], dtype=float32)>

In [8]:
tf.expand_dims(square_norm, 1)

<tf.Tensor: id=80, shape=(7, 1), dtype=float32, numpy=
array([[12.805981 ],
       [14.190258 ],
       [ 6.287444 ],
       [19.387615 ],
       [ 8.851185 ],
       [ 5.2345133],
       [10.947656 ]], dtype=float32)>

In [9]:
distances

<tf.Tensor: id=24, shape=(7, 7), dtype=float32, numpy=
array([[ 0.      , 14.171466, 17.777039, 23.153015, 24.454967, 20.11021 ,
        18.82074 ],
       [14.171466,  0.      , 15.743847, 36.280067, 23.791775, 24.469585,
         9.832981],
       [17.777039, 15.743848,  0.      , 34.29841 , 18.214188, 16.016174,
        18.825436],
       [23.153015, 36.280067, 34.29841 ,  0.      , 23.503521, 16.314138,
        33.026913],
       [24.45497 , 23.791775, 18.214188, 23.50352 ,  0.      ,  9.405386,
        28.735226],
       [20.11021 , 24.469585, 16.016174, 16.314137,  9.405386,  0.      ,
        20.30563 ],
       [18.82074 ,  9.832981, 18.825436, 33.026913, 28.735226, 20.30563 ,
         0.      ]], dtype=float32)>

In [10]:
pairwise_dist = _pairwise_distances(embeddings)
pairwise_dist

<tf.Tensor: id=109, shape=(7, 7), dtype=float32, numpy=
array([[0.       , 3.7645009, 4.2162824, 4.811758 , 4.945196 , 4.4844403,
        4.338288 ],
       [3.7645009, 0.       , 3.9678516, 6.023294 , 4.8776813, 4.946674 ,
        3.1357584],
       [4.2162824, 3.9678519, 0.       , 5.856484 , 4.267808 , 4.002021 ,
        4.338829 ],
       [4.811758 , 6.023294 , 5.856484 , 0.       , 4.848043 , 4.0390763,
        5.746904 ],
       [4.945196 , 4.8776813, 4.267808 , 4.8480425, 0.       , 3.0668201,
        5.3605247],
       [4.4844403, 4.946674 , 4.002021 , 4.0390763, 3.0668201, 0.       ,
        4.506177 ],
       [4.338288 , 3.1357584, 4.338829 , 5.746904 , 5.3605247, 4.506177 ,
        0.       ]], dtype=float32)>

In [11]:
anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)
print(anchor_positive_dist.shape, anchor_negative_dist.shape)

(7, 7, 1) (7, 1, 7)


In [12]:
anchor_positive_dist

<tf.Tensor: id=112, shape=(7, 7, 1), dtype=float32, numpy=
array([[[0.       ],
        [3.7645009],
        [4.2162824],
        [4.811758 ],
        [4.945196 ],
        [4.4844403],
        [4.338288 ]],

       [[3.7645009],
        [0.       ],
        [3.9678516],
        [6.023294 ],
        [4.8776813],
        [4.946674 ],
        [3.1357584]],

       [[4.2162824],
        [3.9678519],
        [0.       ],
        [5.856484 ],
        [4.267808 ],
        [4.002021 ],
        [4.338829 ]],

       [[4.811758 ],
        [6.023294 ],
        [5.856484 ],
        [0.       ],
        [4.848043 ],
        [4.0390763],
        [5.746904 ]],

       [[4.945196 ],
        [4.8776813],
        [4.267808 ],
        [4.8480425],
        [0.       ],
        [3.0668201],
        [5.3605247]],

       [[4.4844403],
        [4.946674 ],
        [4.002021 ],
        [4.0390763],
        [3.0668201],
        [0.       ],
        [4.506177 ]],

       [[4.338288 ],
        [3.1357584],
     

In [13]:
anchor_negative_dist

<tf.Tensor: id=114, shape=(7, 1, 7), dtype=float32, numpy=
array([[[0.       , 3.7645009, 4.2162824, 4.811758 , 4.945196 ,
         4.4844403, 4.338288 ]],

       [[3.7645009, 0.       , 3.9678516, 6.023294 , 4.8776813,
         4.946674 , 3.1357584]],

       [[4.2162824, 3.9678519, 0.       , 5.856484 , 4.267808 ,
         4.002021 , 4.338829 ]],

       [[4.811758 , 6.023294 , 5.856484 , 0.       , 4.848043 ,
         4.0390763, 5.746904 ]],

       [[4.945196 , 4.8776813, 4.267808 , 4.8480425, 0.       ,
         3.0668201, 5.3605247]],

       [[4.4844403, 4.946674 , 4.002021 , 4.0390763, 3.0668201,
         0.       , 4.506177 ]],

       [[4.338288 , 3.1357584, 4.338829 , 5.746904 , 5.3605247,
         4.506177 , 0.       ]]], dtype=float32)>

In [18]:
embeddings = tf.random_normal([2, 10])

pairwise_dist = _pairwise_distances(embeddings)

In [19]:
pairwise_dist

<tf.Tensor: id=181, shape=(2, 2), dtype=float32, numpy=
array([[0.      , 5.063244],
       [5.063244, 0.      ]], dtype=float32)>

In [20]:
anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)

In [23]:
margin = .1
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
triplet_loss

<tf.Tensor: id=192, shape=(2, 2, 2), dtype=float32, numpy=
array([[[ 0.1     , -4.963244],
        [ 5.163244,  0.1     ]],

       [[ 0.1     ,  5.163244],
        [-4.963244,  0.1     ]]], dtype=float32)>

In [24]:
print(anchor_positive_dist.shape, anchor_negative_dist.shape)

(2, 2, 1) (2, 1, 2)


In [31]:
print(anchor_positive_dist)
print(anchor_negative_dist)
print('-'*4)
print(anchor_positive_dist[0])
print(anchor_negative_dist[0])
print(anchor_positive_dist[0] + anchor_negative_dist[0])
print(anchor_positive_dist[1] + anchor_negative_dist[1])
print(anchor_positive_dist + anchor_negative_dist)

tf.Tensor(
[[[0.      ]
  [5.063244]]

 [[5.063244]
  [0.      ]]], shape=(2, 2, 1), dtype=float32)
tf.Tensor(
[[[0.       5.063244]]

 [[5.063244 0.      ]]], shape=(2, 1, 2), dtype=float32)
----
tf.Tensor(
[[0.      ]
 [5.063244]], shape=(2, 1), dtype=float32)
tf.Tensor([[0.       5.063244]], shape=(1, 2), dtype=float32)
tf.Tensor(
[[ 0.        5.063244]
 [ 5.063244 10.126488]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[10.126488  5.063244]
 [ 5.063244  0.      ]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[[ 0.        5.063244]
  [ 5.063244 10.126488]]

 [[10.126488  5.063244]
  [ 5.063244  0.      ]]], shape=(2, 2, 2), dtype=float32)


In [28]:
np.array([[0],
          [5]]) + np.array([[0, 5]])

array([[ 0,  5],
       [ 5, 10]])