# Metric learning with triplet loss

We again use a convolutional neural network with the fully-connected classification layer removed. As such, the model returns an embedding for each image. This time, we will **train the network to produce better embeddings**. More specifically, the network will learn to produce

* embeddings that are *close* to each other for images that belong to the *same identity* and
* embeddings that are *far* from each other for images that belong to a *different identity*

## The ingredients of a training script

* Dataset(s) and data loader(s)
* Training loop

During the training phase, we pass training data through the model and optimize the model's parameters to better fit the data. During the validation phase, we pass the validation data through the model and check how well the trained model is working.

Data is always passed through the model in *batches*, i.e. multiple images at once.

## Create datasets

In [1]:
from torchvision.datasets import ImageFolder

from lib.cnn_classifiers import train_transform, val_transform


train_ds = ImageFolder(
    'data/GroZiVitroBasic/',
    transform=train_transform
)

val_gallery_ds = ImageFolder(
    'data/sodas/gallery',
    transform=val_transform
)

val_query_ds = ImageFolder(
    'data/sodas/queries',
    transform=val_transform
)

## Create data loaders

In [2]:
batch_size = 60
num_workers = 8

In [3]:
from lib.triplet_sampler import TripletSampler
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_ds,
    batch_sampler=TripletSampler(train_ds, batch_size=batch_size),
    num_workers=num_workers,
)

val_gallery_loader = DataLoader(
    dataset=val_gallery_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

val_query_loader = DataLoader(
    dataset=val_query_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

## Define what should happen in a single training step

In [4]:
from torch import nn
import torch.nn.functional as F

from lib.triplet_sampler import split_triplet_tensor


def run_train_step(model, batch, batch_idx):
    imgs, labels = batch

    (a_labels,
     p_labels,
     n_labels) = split_triplet_tensor(labels)
    assert torch.all(a_labels == p_labels)
    assert torch.all(a_labels != n_labels)

    embeddings = model_forward(model, imgs)

    (a_embs,
     p_embs,
     n_embs) = split_triplet_tensor(embeddings)

    return F.triplet_margin_loss(a_embs, p_embs, n_embs,
                                 margin=1.0)


def model_forward(model, imgs):
    """
    Pass the images through the model and return the (normalized) embeddings.
    """
    embeddings = model(imgs)
    embeddings = torch.squeeze(embeddings)
    return embeddings / embeddings.norm(dim=1, keepdim=True)

## Call the training step on each batch in a training epoch

In [5]:
def run_train_epoch(model, train_loader, optimizer,
                    epoch_idx, writer=None, device='cpu'):
    """
    Run a training epoch.
    """
    # Put model in train mode
    model.train()

    for batch_idx, train_batch in tqdm(enumerate(train_loader),
                                       total=len(train_loader),
                                       leave=False, desc='Train batch'):
        train_batch = batch_to_device(train_batch, device)
        loss = run_train_step(model, train_batch, batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log the training loss
        if writer is not None:
            writer.add_scalar("Loss/train", loss, epoch_idx)


def batch_to_device(batch, device):
    batch[0] = batch[0].to(device)
    batch[1] = batch[1].to(device)

    return batch

## Define what should happen during a *validation epoch*

In [6]:
def run_val_epoch(model, val_gallery_loader, val_query_loader,
                  epoch_idx, writer, device='cpu'):
    # Put model in eval mode
    model.eval()

    # Compute gallery embeddings
    g_embeddings, g_labels = compute_embs_from_dataloader(model,
                                                          val_gallery_loader,
                                                          device)

    # Compute query embeddings
    q_embeddings, q_labels = compute_embs_from_dataloader(model,
                                                          val_query_loader,
                                                          device)

    # Compute similarity matrix
    sim_mat = match_embeddings(g_embeddings, q_embeddings).numpy()

    # Log average precision
    idx_to_class = {
        idx: label
        for label, idx in val_gallery_loader.dataset.class_to_idx.items()
    }
    for label in g_labels:
        ap = calc_ap(label, sim_mat, g_labels, q_labels)
        writer.add_scalar(f"AP_val/{idx_to_class[label]}", ap, epoch_idx)


def compute_embs_from_dataloader(model, data_loader, device='cpu'):
    all_embeddings = []
    all_labels = []

    for batch_idx, batch in tqdm(enumerate(data_loader),
                                 total=len(data_loader),
                                 leave=False):
        batch = batch_to_device(batch, device)
        imgs, labels = batch

        with torch.no_grad():
            embeddings = model_forward(model, imgs)

        all_embeddings.append(embeddings)
        all_labels.append(labels)

    all_embeddings = torch.cat(all_embeddings).cpu()
    all_labels = torch.cat(all_labels).cpu().numpy()

    return all_embeddings, all_labels

## Put everything together in a training loop

In [7]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch

from lib.metric_learning import match_embeddings
from lib.evaluation_metrics import calc_ap


def run_training(model, optimizer, train_loader, val_gallery_loader,
                 val_query_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    writer = SummaryWriter()

    for epoch_idx in tqdm(range(num_epochs), desc='Epoch'):
        run_train_epoch(model, train_loader, optimizer,
                        epoch_idx, writer, device)
        run_val_epoch(model, val_gallery_loader, val_query_loader,
                      epoch_idx, writer, device)

    writer.flush()
    writer.close()

## Start TensorBoard for logging

In [8]:
import tensorboard

%load_ext tensorboard
%tensorboard --logdir runs

## Define the model

In [9]:
from lib.metric_learning import get_cut_off_cnn

model = get_cut_off_cnn('resnet50')

for param in model.parameters():
    param.requires_grad = False

for param in model.layer4.parameters():
    param.requires_grad = True

## Run the training loop

In [10]:
from torch.optim import SGD

optimizer = SGD(model.layer4.parameters(), lr=0.01)

run_training(model, optimizer, train_loader,
             val_gallery_loader, val_query_loader,
             num_epochs=50)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

## Save the trained model

In [12]:
state_dict = model.state_dict()
torch.save(state_dict, 'finetuned_embedding.pth')