In [8]:
import torch.nn as nn
import numpy as np

In [17]:
def triplet_loss(anchor, label, embeddings, database, criterion=None, mining=True):
    """
    Given an anchor, finds P and N, calculates Triplet Loss and backward gradient
    
    Parameters:
        anchor: feature embedding for the anchor
        label:  label for anchor
        embeddings: np array of embeddings - assume same order as dataset - ensure normalized
        database: label to index
        criterion: assume to be passed to the function or then instantiate everytime fn is called
        mining: assuming turning this off will be much faster, default is True
    
    Returns:
        loss: calculated loss (gradients are propagated within the fn itself)
    """
    
    classes = database.keys()
    pairwise_dist = np.linalg.norm(embeddings - anchor , axis =1)
    
    p_idx = database[label]
    positive = embeddings[p_idx]
    
    n_idx = np.argmin(pairwise_dist[:p_idx])
    if p_idx + 1 < len(embeddings):
        n2_idx = p_idx + 1 + np.argmin(pairwise_dist[p_idx+1:])
        if pairwise_dist[n2_idx] < pairwise_dist[n_idx]:
            n_idx = n2_idx
    
    negative = embeddings[n_idx]
    
    if criterion == None:
        criterion = nn.TripletMarginLoss(margin=1.0)
    
    loss = criterion(anchor, positive, negative)
    loss.backward()
    print("Loss: ", loss)

In [16]:
a = np.array([[1,2,3], [2, 3, 4]])
b = np.array([1,2,3])
print(a-b)
pairwise_dist = np.linalg.norm(a - b , axis =1)
print(pairwise_dist)

[[0 0 0]
 [1 1 1]]
[0.         1.73205081]


In [25]:
np.argmin(b[2:])

0

In [26]:
b[2:]

array([3])