In [None]:
# 1. Imports


import tensorflow as tf
from tensorflow.keras import layers
import numpy as np


In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

def build_embedding_model(input_shape):

    inputs = tf.keras.Input(shape=input_shape)

    x = layers.Conv2D(32, 3, activation="relu")(inputs)
    x = layers.MaxPooling2D()(x)

    x = layers.Conv2D(64, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)

    x = layers.Flatten()(x)
    x = layers.Dense(128)(x)

    # L2 Normalization using a Lambda layer
    outputs = layers.Lambda(lambda y: tf.nn.l2_normalize(y, axis=1))(x)

    model = tf.keras.Model(inputs, outputs)
    return model

In [None]:
# 3. Triplet Loss (Hard Mining)


def triplet_loss(labels, embeddings, margin=0.2):

    # Pairwise distance
    dot_product = tf.matmul(embeddings, embeddings, transpose_b=True)
    square_norm = tf.linalg.diag_part(dot_product)

    distances = (
        tf.expand_dims(square_norm, 1)
        - 2.0 * dot_product
        + tf.expand_dims(square_norm, 0)
    )

    distances = tf.maximum(distances, 0.0)

    # Create masks
    labels = tf.reshape(labels, (-1, 1))
    positive_mask = tf.equal(labels, tf.transpose(labels))

    # Remove self-comparisons
    positive_mask = tf.logical_and(
        positive_mask,
        tf.logical_not(tf.eye(tf.shape(labels)[0], dtype=tf.bool))
    )

    negative_mask = tf.logical_not(positive_mask)

    # Hardest positive
    hardest_positive = tf.reduce_max(
        tf.where(positive_mask, distances, tf.zeros_like(distances)),
        axis=1
    )

    # Hardest negative
    max_dist = tf.reduce_max(distances)
    hardest_negative = tf.reduce_min(
        tf.where(negative_mask, distances, max_dist),
        axis=1
    )

    # Final loss
    loss = tf.maximum(hardest_positive - hardest_negative + margin, 0.0)

    return tf.reduce_mean(loss)


In [None]:
# 4. Custom Training Model


class TripletModel(tf.keras.Model):

    def __init__(self, embedding_model):
        super().__init__()
        self.embedding_model = embedding_model

    def train_step(self, data):

        images, labels = data

        with tf.GradientTape() as tape:
            embeddings = self.embedding_model(images, training=True)
            loss = triplet_loss(labels, embeddings)

        gradients = tape.gradient(loss, self.embedding_model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.embedding_model.trainable_variables)
        )

        return {"loss": loss}


# Images → Embedding Model → Triplet Loss → Gradient → Update Embedding Model


In [None]:
# 5. Load MNIST


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

y_train = y_train.astype("int32")
y_test = y_test.astype("int32")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
# 6. Create Dataset


train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(64)


In [None]:
# 7. Train Model


embedding_model = build_embedding_model((28, 28, 1))

model = TripletModel(embedding_model)

model.compile(optimizer=tf.keras.optimizers.Adam(0.001))

model.fit(train_dataset, epochs=20)


Epoch 1/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - loss: 0.2009
Epoch 2/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 3/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 0.1998
Epoch 4/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 5/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 6/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 7/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 8/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 9/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.1998
Epoch 10/20
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - lo

<keras.src.callbacks.history.History at 0x7c96296ce750>

In [None]:
# 8. Test Similarity


img1 = x_test[0:1]
img2 = x_test[1:2]

emb1 = embedding_model(img1)
emb2 = embedding_model(img2)

distance = tf.norm(emb1 - emb2)

print("Distance:", distance.numpy())


threshold = 0.8

if distance < threshold:
    print("Images are SIMILAR")
else:
    print("Images are DIFFERENT")


Distance: 0.00069402286
Images are SIMILAR
