## Triplet Loss on Totally Looks Like dataset

This notebook is inspired from [this Keras tutorial](https://keras.io/examples/vision/siamese_network/) by Hazem Essam and Santiago L. Valdarrama.

The goal is to showcase the use of siamese networks and triplet loss to do representation learning using a CNN. It will also showcase data generators and data augmentation techniques.

### Dataset

The dataset considered is the [Totally Looks Like](https://sites.google.com/view/totally-looks-like-dataset) dataset, consisting of pairs of web curated similar looking images:

Image pair 1               |  Image pair 2
:-------------------------:|:-------------------------:
![](https://lh3.googleusercontent.com/fzSVA9pEAjAtRicUNFywUoO4qSR6r7P7YVrO6zIVVQFmAG1ZqYF2ORNnUJlng56qwsPts6gcv5GKdMl0Lm8cYP04PGvrqJMCxaehwWM2TDWU7iRb=w1280)  |  ![](https://lh3.googleusercontent.com/1jGi3A6JP6OJKNzHm_gfPnb79WQV2HYQ7Xe2FnZMj3kjKM4VThsZfGS_IRohzOYRZ1tswWHvQKjuCnF90tP4jdATGRfZ6eN6RgPs8v4Lvf_BspEE=w1280)

The goal is to extract generic human perceptual representation through a CNN. The next cell downloads the dataset and unzips it (run it asap, it will download a few hundead megabytes).

In [None]:
import os
import os.path as op
from urllib.request import urlretrieve

# TODO add correct URL
URL = "https://github.com/m2dsupsdlclass/lectures-labs/releases/download/totallylookslike/dataset_totally.zip"
FILENAME = "dataset_totally.zip"

if not op.exists(FILENAME):
    print('Downloading %s to %s...' % (URL, FILENAME))
    urlretrieve(URL, FILENAME)

import zipfile
if not op.exists("anchors"):
    print('Extracting image files...')
    with zipfile.ZipFile(FILENAME, 'r') as zip_ref:
        zip_ref.extractall('.')

home_dir = Path(Path.home())
anchor_images_path = home_dir / "anchor"
positive_images_path = home_dir / "positive"

In [None]:
from pathlib import Path

anchor_images_path = Path("./anchors")
positive_images_path = Path("./positives")

In [None]:
def open_image(filename, target_shape = (256, 256)):
    """ Load the specified file as a JPEG image, preprocess it and
    resize it to the target shape.
    """
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, target_shape)
    return image

In [None]:
import tensorflow as tf

anchor_images = sorted([str(anchor_images_path / f) for f in os.listdir(anchor_images_path)])
positive_images = sorted([str(positive_images_path / f) for f in os.listdir(positive_images_path)])

anchor_count = len(anchor_images)
positive_count = len(positive_images)

print(anchor_count, positive_count)

anchor_dataset_files = tf.data.Dataset.from_tensor_slices(anchor_images)
anchor_dataset = anchor_dataset_files.map(open_image)
positive_dataset_files = tf.data.Dataset.from_tensor_slices(positive_images)
positive_dataset = positive_dataset_files.map(open_image)

In [None]:
import matplotlib.pyplot as plt 

def visualize(anchor, positive, negative=None):
    """Visualize a triplet or pair"""

    def show(ax, image):
        ax.imshow(image)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    fig = plt.figure(figsize=(9, 9))
    
    num_slots = 3 if negative is not None else 2
    
    axs = fig.subplots(1, num_slots)
    show(axs[0], anchor)
    show(axs[1], positive)
    if negative is not None:
        show(axs[2], negative)

# display the first element of our dataset
anc = next(iter(anchor_dataset))
pos = next(iter(positive_dataset))
visualize(anc, pos)

In [None]:
# data augmentations
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.15),
    layers.RandomCrop(224, 224)
])

In [None]:
import numpy as np 

# To generate the list of negative images, let's randomize the list of
# available images and concatenate them together.
rng = np.random.RandomState(seed=42)
rng.shuffle(anchor_images)
rng.shuffle(positive_images)

negative_images = anchor_images + positive_images
np.random.RandomState(seed=32).shuffle(negative_images)

negative_dataset_files = tf.data.Dataset.from_tensor_slices(negative_images)
negative_dataset_files = negative_dataset_files.shuffle(buffer_size=4096)

# Build final triplet dataset
dataset = tf.data.Dataset.zip((anchor_dataset_files, positive_dataset_files, negative_dataset_files))
dataset = dataset.shuffle(buffer_size=1024)

# preprocess function
def preprocess_triplets(anchor, positive, negative):
    return (
        data_augmentation(open_image(anchor)),
        data_augmentation(open_image(positive)),
        data_augmentation(open_image(negative)),
    )

dataset = dataset.map(preprocess_triplets)

# Let's now split our dataset in train and validation.
train_dataset = dataset.take(round(anchor_count * 0.8))
val_dataset = dataset.skip(round(anchor_count * 0.8))

train_dataset = train_dataset.batch(32, drop_remainder=False)
train_dataset = train_dataset.prefetch(8)

val_dataset = val_dataset.batch(32, drop_remainder=False)
val_dataset = val_dataset.prefetch(8)

In [None]:
anc_batch, pos_batch, neg_batch = next(train_dataset.take(1).as_numpy_iterator())
print(anc_batch.shape, pos_batch.shape, neg_batch.shape)

In [None]:
idx = np.random.randint(0, 32)
visualize(anc_batch[idx], pos_batch[idx], neg_batch[idx])

In [None]:
from tensorflow.keras.applications import resnet
from tensorflow.keras import applications
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras import Model
from tensorflow.keras.applications.resnet50 import preprocess_input


base_cnn = resnet.ResNet50(weights="imagenet", input_shape=(224,224,3), include_top=False)
input_img = layers.Input((224,224,3))
resnet_output = base_cnn(input_preprocessed)

flatten = layers.Flatten()(resnet_output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(input_img, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

In [None]:
class DistanceLayer(layers.Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)


anchor_input = layers.Input(name="anchor", shape=(224, 224, 3))
positive_input = layers.Input(name="positive", shape=(224, 224, 3))
negative_input = layers.Input(name="negative", shape=(224, 224, 3))

distances = DistanceLayer()(
    embedding(resnet.preprocess_input(anchor_input)),
    embedding(resnet.preprocess_input(positive_input)),
    embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

In [None]:
class SiameseModel(Model):
    """The Siamese Network model with a custom training and testing loops.

    Computes the triplet loss using the three embeddings produced by the
    Siamese Network.

    The triplet loss is defined as:
       L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0)
    """

    def __init__(self, siamese_network, margin=0.5):
        super(SiameseModel, self).__init__()
        self.siamese_network = siamese_network
        self.margin = margin
        self.loss_tracker = metrics.Mean(name="loss")

    def call(self, inputs):
        return self.siamese_network(inputs)

    def train_step(self, data):
        # GradientTape is a context manager that records every operation that
        # you do inside. We are using it here to compute the loss so we can get
        # the gradients and apply them using the optimizer specified in
        # `compile()`.
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)

        # Storing the gradients of the loss function with respect to the
        # weights/parameters.
        gradients = tape.gradient(loss, self.siamese_network.trainable_weights)

        # Applying the gradients on the model using the specified optimizer
        self.optimizer.apply_gradients(
            zip(gradients, self.siamese_network.trainable_weights)
        )

        # Let's update and return the training loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def test_step(self, data):
        loss = self._compute_loss(data)

        # Let's update and return the loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        # The output of the network is a tuple containing the distances
        # between the anchor and the positive example, and the anchor and
        # the negative example.
        ap_distance, an_distance = self.siamese_network(data)

        # Computing the Triplet Loss by subtracting both distances and
        # making sure we don't get a negative value.
        loss = ap_distance - an_distance
        loss = tf.maximum(loss + self.margin, 0.0)
        return loss

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker]


In [None]:
siamese_model = SiameseModel(siamese_network)
siamese_model.compile(optimizer=optimizers.Adam(0.0001))
siamese_model.fit(train_dataset, epochs=2, validation_data=val_dataset)

## Find most similar images in test dataset

In [None]:
emb = shared_conv.predict(all_imgs)
emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True)
pixelwise = np.reshape(all_imgs, (all_imgs.shape[0], 60*60*3))

In [None]:
embedding.compile()

In [None]:
from functools import partial

open_img = partial(open_image, target_shape=(224,224))
d = negative_dataset_files.map(open_img).map(resnet.preprocess_input).batch(32, drop_remainder=False).prefetch(8)
out = embedding.predict(d)


In [None]:
out.shape