In [1]:
import tensorflow as tf
from tensorflow import keras

# Triplet Loss

For an image $x$, its encoding is denoted as $f(x)$, where $f$ is the function computed by the neural network.

<div style="text-align: center;">
    <img src="images/f_x.png" style="width:400px;height:200px;">
</div>

### Triplet Training:
- Training will use **triplets of images** $(A, P, N)$, where:
    - **A** is the "Anchor" image — a picture of a person.
    - **P** is the "Positive" image — a picture of the same person as the Anchor.
    - **N** is the "Negative" image — a picture of a different person than the Anchor.

- These triplets are selected from the training dataset. Let $(A^{(i)}, P^{(i)}, N^{(i)})$ denote the $i$-th training example.

- You aim to ensure that an image $A^{(i)}$ (Anchor) is **closer** to the Positive $P^{(i)}$ than to the Negative $N^{(i)}$, by at least a margin $\alpha$:

$$
|| f\left(A^{(i)}\right)-f\left(P^{(i)}\right)||_{2}^{2}+\alpha<|| f\left(A^{(i)}\right)-f\left(N^{(i)}\right)||_{2}^{2}
$$

### Triplet Loss:
- We define the following **triplet loss** function:

$$
\mathcal{J} = \sum^{m}_{i=1} \left[ \underbrace{\mid \mid f(A^{(i)}) - f(P^{(i)}) \mid \mid_2^2}_\text{(1)} - \underbrace{\mid \mid f(A^{(i)}) - f(N^{(i)}) \mid \mid_2^2}_\text{(2)} + \alpha \right]_+
\tag{3}
$$

> **Note**: The notation "$[z]_+$" denotes $\max(z, 0)$.

### Key Points:
- Term (1): Squared distance between the **anchor** (A) and the **positive** (P) for a given triplet; you want this to be small.
- Term (2): Squared distance between the **anchor** (A) and the **negative** (N) for a given triplet; you want this to be large
- **$\alpha$**: This is the margin—a manually chosen hyperparameter. Here, $\alpha = 0.2$.

> **Note**: In most implementations, the encoding vectors are rescaled to have an L2 norm equal to 1 (i.e., $\mid \mid f(img) \mid \mid_2 = 1$). You won’t need to handle this here.


In [2]:
def triplet_loss(Y_true, Y_pred, alpha=0.2):
    anchor, positive, negative = Y_pred[0], Y_pred[1], Y_pred[2]

    pos_dist = tf.subtract(anchor, positive)  # (None, n_features)
    pos_dist = tf.square(pos_dist)  # (None, n_features)
    pos_dist = tf.reduce_sum(pos_dist, axis=-1)  # (None, )

    neg_dist = tf.subtract(anchor, negative)  # (None, n_features)
    neg_dist = tf.square(neg_dist)  # (None, n_features)
    neg_dist = tf.reduce_sum(neg_dist, axis=-1)  #  (None, )

    loss = tf.add(tf.subtract(pos_dist, neg_dist), alpha)  # (None, )
    loss = tf.maximum(loss, 0)  # (None, )
    loss = tf.reduce_sum(loss)
    return loss