## Triplet loss numpy

In [24]:
import numpy as np

logits = np.random.randn(4, 3)
labels = np.concatenate([np.zeros([2, 1]), np.ones([2, 1])], axis=0)
logits_norm = logits / np.linalg.norm(logits)

In [41]:
print(logits)
print(logits_norm, 2)
print(labels)

[[ 1.12180787 -1.11408089 -1.35190712]
 [ 0.64929267  0.53293354 -0.70270401]
 [-0.08280154 -0.23974176  0.30269162]
 [ 0.32925159  0.60438217  1.43597645]]
[[ 0.39131491 -0.38861954 -0.47157934]
 [ 0.22648968  0.18590067 -0.2451209 ]
 [-0.02888327 -0.08362798  0.10558648]
 [ 0.11485127  0.21082376  0.50090485]] 2
[[ 0.]
 [ 0.]
 [ 1.]
 [ 1.]]


In [116]:
margin=0.5

### tensorflow

In [83]:
def pairwise_distance(logits):
    pairwise_distances_squared = np.add(
        np.sum(np.square(logits), axis=1, keepdims=True),
        np.sum(np.square(logits.T), axis=0, keepdims=True),
    ) - 2. * np.dot(logits, logits.T)

    error_mask = np.less_equal(pairwise_distances_squared, 0.0)
    pairwise_distances = np.multiply(pairwise_distances_squared, np.logical_not(error_mask).astype(float))
    
    num_data = logits.shape[0]
    mask_offdiagonals = np.ones_like(pairwise_distances) - np.diag(np.ones([num_data]))
    pairwise_distances = np.multiply(pairwise_distances, mask_offdiagonals)
    return pairwise_distances

def masked_minimum(data, mask):
    axis_maximums = np.max(data, axis=1, keepdims=True)
    return np.min(np.multiply(data - axis_maximums, mask), axis=1, keepdims=True) + axis_maximums

def masked_maximum(data, mask):
    axis_minimums = np.min(data, axis=1, keepdims=True)
    return np.max(data - axis_minimums, axis=1, keepdims=True)

In [121]:
pdist_matrix = pairwise_distance(logits)
adjacency = np.equal(labels, labels.T)
adjacency_not = np.logical_not(adjacency)

batch_size = logits.shape[0]

pdist_matrix_tile = np.tile(pdist_matrix, [batch_size, 1])

mask = np.logical_and(
    np.tile(adjacency_not, [batch_size, 1]),
    np.greater(pdist_matrix_tile, pdist_matrix.T.reshape([-1,1])),
)

mask_final = np.greater(np.sum(mask.astype(float), axis=1, keepdims=True), 0.0).reshape([batch_size, batch_size])
mask_final = mask_final.T

mask = mask.astype(float)
adjacency_not = adjacency_not.astype(float)

negatives_outside = masked_minimum(pdist_matrix_tile, mask).reshape([batch_size, batch_size]).T
negatives_inside = np.tile(masked_maximum(pdist_matrix, adjacency_not), [1, batch_size])

semi_hard_negatives = np.where(mask_final, negatives_outside, negatives_inside)

loss_mat = np.add(margin, pdist_matrix - semi_hard_negatives)

mask_positives = adjacency.astype(float) - np.diag(np.ones([batch_size]))

num_positives = np.sum(mask_positives)

In [122]:
num_positives

4.0

In [123]:
loss_mat

array([[-4.45324973, -1.09585791, -5.90030585,  0.5       ],
       [-0.82409353, -1.64380941, -2.03767593,  0.5       ],
       [ 0.5       , -2.30944032, -1.64380941, -2.28658223],
       [ 0.5       , -6.17207023, -2.01481784, -4.18148534]])

In [124]:
pdist_matrix

array([[  0.        ,   3.35739181,   4.95324973,  11.35355557],
       [  3.35739181,   0.        ,   2.14380941,   4.68148534],
       [  4.95324973,   2.14380941,   0.        ,   2.1666675 ],
       [ 11.35355557,   4.68148534,   2.1666675 ,   0.        ]])

In [125]:
semi_hard_negatives

array([[  4.95324973,   4.95324973,  11.35355557,  11.35355557],
       [  4.68148534,   2.14380941,   4.68148534,   4.68148534],
       [  4.95324973,   4.95324973,   2.14380941,   4.95324973],
       [ 11.35355557,  11.35355557,   4.68148534,   4.68148534]])

### custom

In [None]:
def _pairwise_distances(logits):
    dot_product = np.dot(logits, logits.T)
    square_norm = np.diagonal(dot_product)
    
    distances = np.add(
        np.expand_dims(square_norm, 0),
        np.expand_dims(square_norm, 1)
    ) - 2. * dot_product
    distances = np.maximum(distances, 0.)
    
    return distances

def _get_anchor_positive_triplet_mask(labels):
    indices_equal = np.eye(labels.shape[0]).astype(bool)
    indices_not_equal = np.logical_not(indices_equal)
    labels_equal = np.equal(labels, labels.T)
    
    return np.logical_and(labels_equal, indices_not_equal)

def _get_anchor_negative_triplet_mask(labels):
    return np.logical_not(
        np.equal(labels, labels.T)
    )

def _get_triplet_mask(labels):
    indices_equal = np.eye(labels.shape[0]).astype(bool)
    indices_not_equal = np.logical_not(indices_equal)
    
    i_not_equal_j = np.expand_dims(indices_not_equal, 2)
    i_not_equal_k = np.expand_dims(indices_not_equal, 1)
    j_not_equal_k = np.expand_dims(indices_not_equal, 0)
    distinct_indices = np.logical_and(np.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
    
    labels_equal = np.equal(labels, labels.T)
    i_equal_j = np.expand_dims(labels_equal, 2)
    i_equal_k = np.expand_dims(labels_equal, 1)
    valid_labels = np.logical_and(i_equal_j, np.logical_not(i_equal_k))
    
    return np.logical_and(distinct_indices, valid_labels)

In [None]:
def batch_hard_triplet_loss(logits_norm, labels):
    pairwise_dist = _pairwise_distances(logits_norm)

    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).astype(float)

    anchor_positive_dist = np.multiply(mask_anchor_positive, pairwise_dist)

    hardest_positive_dist = np.max(anchor_positive_dist, axis=1, keepdims=True)

    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).astype(float)

    max_anchor_negative_dist = np.max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1. - mask_anchor_negative)

    hardest_negative_dist = np.min(anchor_negative_dist, axis=1, keepdims=True)
    
    return hardest_positive_dist, hardest_negative_dist

In [None]:
batch_hard_triplet_loss(logits_norm, labels)

In [None]:
def batch_all_triplet_loss(logits_norm, labels, margin=0.5):
    pairwise_dist = _pairwise_distances(logits_norm)
    
    anchor_positive_dist = np.expand_dims(pairwise_dist, 2)
    anchor_negative_dist = np.expand_dims(pairwise_dist, 1)
    
    triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
    
    mask = _get_triplet_mask(labels).astype(float)
    
    return triplet_loss

In [None]:
batch_all_triplet_loss(logits_norm, labels)