#  Hello World of Deep Metric Learning: Siamese Contrastive loss

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/j0rd1smit/Hello_world_of_metric_learning/blob/master/Hello%20World%20of%20Deep%20Metric%20Learning%20Siamese%20Contrastive%20loss.ipynb)


In [1]:
! pip install pip install pytorch-lightning

Collecting install
  Downloading https://files.pythonhosted.org/packages/41/cf/e3e6b4d494051c07261cae8c403f0f0d0cedad43d980e5255f2c88fd5edf/install-1.3.3-py3-none-any.whl
Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/ed/af/2f10c8ee22d7a05fe8c9be58ad5c55b71ab4dd895b44f0156bfd5535a708/pytorch_lightning-0.9.0-py3-none-any.whl (408kB)
[K     |████████████████████████████████| 409kB 5.8MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 14.2MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 14.6MB/s 
[?25hCollecting tensorboard==2.2.0
[?25l  Downloading https://files.pythonhosted

In [2]:
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torch.utils.data import Dataset
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import cv2
import io

### Loss function

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin, *, eps:float=1e-9):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = eps

    def forward(self, output1, output2, target):
        positve_loss = target.float() * (output2 - output1).pow(2).sum(1)
        negative_loss = (1 + -1 * target).float() * F.relu(self.margin - (output2 - output1).pow(2).sum(1) + self.eps).pow(2)

        loss = positve_loss + negative_loss

        return loss.mean()

### Data

In [4]:
def get_mnist_dataset(
        *,
        train: bool,
        exclude_labels = None,
):
    exclude_labels = exclude_labels if exclude_labels is not None else []
    mean, std = 0.1307, 0.3081
    transforms = T.Compose([
                    T.ToTensor(),
                    T.Normalize((mean,), (std,))
                ])

    dataset = MNIST('../data/MNIST',
      train=train,
      download=True,
      transform=transforms,
    )

    for label in exclude_labels:
        dataset.data = dataset.data[dataset.targets != label]
        dataset.targets = dataset.targets[dataset.targets != label]

    return dataset

In [5]:
class SiameseMNIST(Dataset):  
    def __init__(self, mnist_dataset):  
        self.mnist_dataset = mnist_dataset  
  
        self.train = self.mnist_dataset.train  
        self.transform = self.mnist_dataset.transform  
  
        self.labels = self.mnist_dataset.targets  
        self.data = self.mnist_dataset.data  
        self.labels_set = set(self.labels.numpy())  
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] for label in self.labels_set}  
  
        if not self.train:  
	        # During validation, always pick the same tuple.
            np.random.seed(42)  
            self.val_data = [self._draw(i) for i, _ in enumerate(self.data)]  
  
  
    def __getitem__(self, index):  
        if self.train:
	         # Randomly pick during training  
            image1, image2, target, label1 = self._draw(index)  
        else:  
	        # During validation always pick the same tuple.
            image1, image2, target, label1 = self.val_data[index]  
  
        if self.transform is not None:  
            image1 = self.transform(image1)  
            image2 = self.transform(image2)  
  
        return (image1, image2), (target, label1)  
  
    def _draw(self, index):  
        image1 = self.data[index]  
        label1 = self.labels[index]  
  
        target = np.random.choice([0, 1])  
        if target == 1:  
            # Pick a random image with the same label as image1.
            siamese_index = np.random.choice(self.label_to_indices[int(label1)])  
        else:  
	        # Pick a random label that is not the same as image1.
            siamese_label = np.random.choice(list(self.labels_set - {label1}))  
            # Pick a random image with the randomly chosen label.
            siamese_index = np.random.choice(self.label_to_indices[siamese_label])  
  
        image2 = self.data[siamese_index]  
		# Load the images
        image1 = Image.fromarray(image1.numpy(), mode='L')  
        image2 = Image.fromarray(image2.numpy(), mode='L')  
  
        return image1, image2, target, label1  
  
    def __len__(self):  
        return len(self.mnist_dataset)

### Metrics

In [6]:
def knn_accuracy(embeddings, labels):
    embeddings = embeddings.detach().cpu()
    labels = labels.detach().cpu()

    return KNeighborsClassifier().fit(embeddings, labels).score(embeddings, labels)

### Visualizations

In [7]:
def create_embeddings_plot_image(embeddings, labels):
    colours = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan"]

    embeddings = embeddings.detach().cpu()
    labels = labels.detach().cpu()
    for label in torch.unique(labels):
        color = colours[int(label) % len(colours)]
        idx_slice = labels == label
        plt.scatter(embeddings[idx_slice, 0], embeddings[idx_slice, 1], label=str(int(label)), c=color)

    plt.legend(loc='upper right')
    plt.grid()
    
    buf = io.BytesIO()
    plt.savefig(buf, format='jpg')
    buf.seek(0)
    image = cv2.imdecode(np.frombuffer(buf.getvalue(), np.uint8), -1)
    image = image.transpose(2, 0, 1)
    
    plt.close()

    return image

### The model

In [8]:
class ConvBackbone(nn.Module):
    def __init__(self):
        super(ConvBackbone, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

In [9]:
class SiamaseNet(pl.LightningModule):
    def __init__(
        self, 
        batch_size,
        n_workers
    ):
        super(SiamaseNet, self).__init__()

        self.batch_size = batch_size
        self.n_workers = n_workers
        
        
        self.backbone = ConvBackbone()
        self.loss_func = ContrastiveLoss(margin=1.0)


    def forward(self, x):
        return self.backbone(x)

    def training_step(self, batch, batch_idx):
        (input1, input2), (targets, labels) = batch
        embedding1 = self(input1)
        embedding2 = self(input2)
        loss = self.loss_func(embedding1, embedding2, targets)

        log = {"train_loss": loss}
        return {"loss": loss, "log": log, "embeddings": embedding1, "labels": labels}

    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        embeddings = torch.cat([x['embeddings'] for x in outputs])
        labels = torch.cat([x['labels'] for x in outputs])

        plot = create_embeddings_plot_image(embeddings, labels)
        self.logger.experiment.add_image('embedding_space/train', plot, self.current_epoch)

        accuracy = knn_accuracy(embeddings, labels)

        log = {'avg_train_loss': loss, "knn_accuracy/train": accuracy}
        return {'log': log, 'train_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def train_dataloader(self):
        return self._create_dataloader(
            SiameseMNIST(get_mnist_dataset(
                train=True,
            )),
            shuffle=True
        )

    def _create_dataloader(self, dataset, shuffle):
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.n_workers,
            pin_memory=True
        )

    def validation_step(self, batch, batch_idx):
        (input1, input2), (targets, labels) = batch
        embedding1 = self(input1)
        embedding2 = self(input2)
        loss = self.loss_func(embedding1, embedding2, labels)

        return {"val_loss": loss, "embeddings": embedding1, "labels": labels}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        embeddings = torch.cat([x['embeddings'] for x in outputs])
        labels = torch.cat([x['labels'] for x in outputs])

        plot = create_embeddings_plot_image(embeddings, labels)
        self.logger.experiment.add_image('embedding_space/val', plot, self.current_epoch)

        accuracy = knn_accuracy(embeddings, labels)


        log = {'avg_val_loss': val_loss, "knn_accuracy/val": accuracy}
        return {'log': log, 'val_loss': val_loss}

    def val_dataloader(self):
        return self._create_dataloader(
            SiameseMNIST(get_mnist_dataset(train=False)),
            shuffle=False
        )

In [14]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


<IPython.core.display.Javascript object>

In [10]:
batch_size = 512
n_workers = 8
epochs = 20
gpus = 1

model = SiamaseNet(batch_size=batch_size, n_workers=n_workers)

trainer = pl.Trainer(
    gpus=gpus,
    max_epochs=epochs,
)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type            | Params
----------------------------------------------
0 | backbone  | ConvBackbone    | 380 K 
1 | loss_func | ContrastiveLoss | 0     


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1