# Modifications of triplet loss for deep metric learning tasks

This notebook contains code for experiments with CUB-200-2011 dataset. 

Requirements: pandas, numpy, torch, wandb, timm, open-metric-learning

In [None]:
!pip uninstall -y numpy torch torchvision torchaudio
!pip install git+https://github.com/OML-Team/open-metric-learning.git
!pip uninstall -y nvidia_cublas_cu11
!pip install wandb

In [1]:
import torch
torch.cuda.is_available()

True

In [None]:
# !pip install git+https://github.com/OML-Team/open-metric-learning.git
# !pip install wandb
!wandb login "2a3ffdce0110826a26805443c7575053621bc696"

!wget "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz"
!tar -zxvf CUB_200_2011.tgz

!wget "https://raw.githubusercontent.com/OML-Team/open-metric-learning/main/pipelines/datasets_converters/convert_cub.py"
!python convert_cub.py --dataset_root=/kaggle/working/CUB_200_2011 --no_bboxes

In [2]:
import datetime as dt
from pathlib import Path

import numpy as np
import pandas as pd
import timm
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from oml.datasets.base import DatasetWithLabels
from oml.inference.flat import inference_on_images
from oml.losses.triplet import TripletLossWithMiner
from oml.miners.cross_batch import TripletMinerWithMemory
from oml.miners.inbatch_hard_cluster import HardClusterMiner
from oml.miners.inbatch_hard_tri import HardTripletsMiner
from oml.miners.inbatch_nhard_tri import NHardTripletsMiner
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models.vit.vit import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.transforms.images.albumentations import (
    get_augs_albu,
    get_normalisation_resize_albu
)
from oml.transforms.images.torchvision import (
    get_augs_hypvit,
    get_normalisation_resize_hypvit
)

import wandb



In [3]:
from oml.functional.metrics import (
    calc_gt_mask,
    calc_mask_to_ignore,
    calc_retrieval_metrics,
)


def compute_metrics(dist_mat, labels, is_query, is_gallery, **metrics):
    mask_gt = calc_gt_mask(labels=labels, is_query=is_query, is_gallery=is_gallery)
    mask_to_ignore = calc_mask_to_ignore(is_query=is_query, is_gallery=is_gallery)
    return calc_retrieval_metrics(dist_mat, mask_gt, mask_to_ignore, **metrics)


def transform_metrics_for_wandb_logging(metrics_value):
    res = {}
    for metric_name in metrics_value:
        for k in metrics_value[metric_name]:
            res[metric_name + '/' + str(k)] = metrics_value[metric_name][k].item()
    return res


def save_model(path, num_epochs, model, optimizer, scheduler=None):
    '''Save on GPU'''
    data = {
        'num_epochs': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None
    }
    torch.save(data, path)


def load_model(path, device, model, optimizer=None, scheduler=None):
    '''Load on GPU'''
    data = torch.load(path)
    model.load_state_dict(data['model_state_dict'])
    model.to(device)
    if optimizer is not None:
        optimizer.load_state_dict(data['optimizer_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(data['scheduler_state_dict'])
    return data['num_epochs']


@torch.no_grad()
def inference(model, valid_loader, device):
    embeds, labels = [], []
    for batch in valid_loader:
        # embeds += [F.normalize(model.body(batch['input_tensors'].to(device)), p=2, dim=1)]
        embeds += [model(batch['input_tensors'].to(device))]
        labels += [batch['labels']]
    return torch.cat(embeds, dim=0).cpu(), torch.cat(labels, dim=0).cpu()


@torch.no_grad()
def track_additional_valid_metrics(embeds, labels, dist_mat):
    additionel_metrics = {}
    
    class_centers, class_sizes = [], []
    for label, count in zip(*torch.unique(labels, return_counts=True)):
        class_embeds = embeds[labels == label]
        class_center = torch.mean(class_embeds, dim=0)
        class_centers += [class_center]
        class_variance = torch.sum((class_embeds - class_center.unsqueeze(0)) ** 2) / count
        class_sizes += [torch.sqrt(class_variance)]
    
    class_sizes = torch.tensor(class_sizes)
    class_centers = torch.stack(class_centers, dim=0)
    class_centers_dist_mat = torch.cdist(class_centers, class_centers, p=2)
    n_classes = class_centers.shape[0]
    
    additionel_metrics['additional/class_sizes/min'] = torch.min(class_sizes)
    additionel_metrics['additional/class_sizes/max'] = torch.max(class_sizes)
    additionel_metrics['additional/class_sizes/mean'] = torch.mean(class_sizes)
    
    additionel_metrics['additional/inter_class_dist/min'] = \
        torch.min(class_centers_dist_mat[class_centers_dist_mat > 0])
    additionel_metrics['additional/inter_class_dist/max'] = torch.max(class_centers_dist_mat)
    additionel_metrics['additional/inter_class_dist/mean'] = \
        torch.sum(class_centers_dist_mat) / float(n_classes) / (n_classes - 1)
    
    return additionel_metrics


@torch.no_grad()
def validation(model, valid_loader, metrics, device):
    model.eval()
    embeds, labels = inference(model, valid_loader, device)
    print(f'Inference finished: {dt.datetime.now()}')

    dist_mat = torch.cdist(embeds, embeds, p=2)
    mask = torch.ones(len(embeds))
    metrics_value = compute_metrics(dist_mat, labels, mask, mask, **metrics)
    wandb_metrics_value = transform_metrics_for_wandb_logging(metrics_value)
    wandb_metrics_value.update(track_additional_valid_metrics(embeds, labels, dist_mat))
    print(wandb_metrics_value, end='\n\n')
    
    return wandb_metrics_value


@torch.no_grad()
def add_train_metrics(wandb_metrics_value, model, train_loader, metrics, device):
    train_metrics_value = validation(model, train_loader, metrics, device)
    for metric, value in train_metrics_value.items():
        wandb_metrics_value['train/' + metric] = value


def curve_relu(x, gamma):
    gamma = float(gamma)
    assert gamma >= 1.0
    
    res = torch.clone(x)
    mask = (res < 0)
    res[mask] = 0.0
    res[torch.logical_not(mask)] **= gamma
    return res


def relu_threshold(x, t):
    t = float(t)
    assert t > 0
    
    res = torch.clone(x)
    res[res < 0] = 0.0
    res[res > t] = t
    return res

In [4]:
def freeze(model):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False
    fr(model.patch_embed)
    fr(model.pos_drop)


def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())


class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)


class Extractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = timm.create_model('vit_small_patch16_224', pretrained=True)
        # self.body = torch.hub.load("facebookresearch/dino:main", 'dino_vits16')
        # self.body = ViTExtractor('vits8_dino', arch='vits8', normalise_features=False)
        freeze(self.body)   # freeze MLPs for patch embeds
        
#         self.head = nn.Sequential(nn.Linear(384, 384), NormLayer())
#         nn.init.constant_(self.head[0].bias.data, 0)
#         nn.init.orthogonal_(self.head[0].weight.data)
        rm_head(self.body)
        # self.head = nn.Identity()
        self.head = NormLayer()
    
    def forward(self, x):
        return self.head(self.body(x))

# Triplet loss implementation from OML library

In [5]:
from typing import Optional

from oml.functional.losses import get_reduced
from oml.interfaces.miners import labels2list
from oml.utils.misc_torch import elementwise_dist

from torch import Tensor



class MyTripletLoss(nn.Module):
    """
    Class, which combines classical `TripletMarginLoss` and `SoftTripletLoss`.
    The idea of `SoftTripletLoss` is the following:
    instead of using the classical formula
    ``loss = relu(margin + positive_distance - negative_distance)``
    we use
    ``loss = log1p(exp(positive_distance - negative_distance))``.
    It may help to solve the often problem when `TripletMarginLoss` converges to it's
    margin value (also known as `dimension collapse`).
    """

    criterion_name = "triplet"  # for better logging

    def __init__(self, margin: Optional[float], reduction: str = "mean", need_logs: bool = False):
        """
        Args:
            margin: Margin value, set ``None`` to use `SoftTripletLoss`
            reduction: ``mean``, ``sum`` or ``none``
            need_logs: Set ``True`` if you want to store logs
        """
        assert reduction in ("mean", "sum", "none")
        # assert (margin is None) or (margin > 0)

        super().__init__()

        self.margin = margin
        self.reduction = reduction
        self.need_logs = need_logs
        self.last_logs: Dict[str, float] = {}
        
        self.log_dap, self.log_dan, self.log_dpn = [], [], []

    def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
        """
        Args:
            anchor: Anchor features with the shape of ``(batch_size, feat)``
            positive: Positive features with the shape of ``(batch_size, feat)``
            negative: Negative features with the shape of ``(batch_size, feat)``
        Returns:
            Loss value
        """
        assert anchor.shape == positive.shape == negative.shape

        positive_dist = elementwise_dist(x1=anchor, x2=positive, p=2)
        negative_dist = elementwise_dist(x1=anchor, x2=negative, p=2)
        pos_neg_dist = elementwise_dist(x1=positive, x2=negative, p=2)
        
        self.log_dap.append(positive_dist)
        self.log_dan.append(negative_dist)
        self.log_dpn.append(pos_neg_dist)

        if self.margin is None:
            # here is the soft version of TripletLoss without margin
            loss = torch.log1p(torch.exp(positive_dist - negative_dist))
        else:
            # loss = torch.relu(self.margin + positive_dist - negative_dist)
            loss = curve_relu(self.margin + positive_dist - negative_dist, gamma=2.0)
            # loss = relu_threshold(self.margin + positive_dist - negative_dist, t=self.margin + 0.45)

        if self.need_logs:
            self.last_logs = {
                "active_tri": float((loss.clone().detach() > 0).float().mean()),
                "pos_dist": float(positive_dist.clone().detach().mean().item()),
                "neg_dist": float(negative_dist.clone().detach().mean().item()),
            }

        loss = get_reduced(loss, reduction=self.reduction)

        return loss
    
    def summary(self):
        res = {}
        dap, dan, dpn = map(torch.cat, [self.log_dap, self.log_dan, self.log_dpn])
        diff = dan - dap
        for data, name in zip([dap, dan, dpn, diff], 
                              ['d(a,p)', 'd(a,n)', 'd(p,n)', 'd(a,n)-d(a,p)']):
            res[name + '/min'] = torch.min(data).item()
            res[name + '/mean'] = torch.mean(data).item()
            res[name + '/std'] = torch.std(data).item()
            res[name + '/max'] = torch.max(data).item()
        self.log_dap, self.log_dan, self.log_dpn = [], [], []
        return res


class MyTripletLossWithMiner(nn.Module):
    def __init__(self, triplet_loss, miner):
        super().__init__()
        self.tri_loss = triplet_loss
        self.miner = miner
    
    def forward(self, features, labels):
        labels_list = labels2list(labels)
        anchor, positive, negative = self.miner.sample(features=features, labels=labels_list)
        loss = self.tri_loss(anchor=anchor, positive=positive, negative=negative)
        return loss

In [6]:
dataset_root = '/kaggle/working/CUB_200_2011'
num_workers = 2
valid_batch_size = 128
n_labels = 24
n_instances = 4

lr = 1e-5
wd = 0.01

df = pd.read_csv(dataset_root + '/df.csv')
# use trainval split as in DML articles
df[['is_query', 'is_gallery']] = np.nan
df.loc[df['label'] <= 100, 'split'] = 'train'
df.loc[df['label'] > 100, 'split'] = 'validation'
df.loc[df['label'] > 100, ['is_query', 'is_gallery']] = True

df_train = df[df['split'] == 'train']
df_valid = df[df['split'] == 'validation']

# train_transforms = get_augs_albu(224)
# valid_transforms = get_normalisation_resize_albu(224)
mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
train_transforms = get_augs_hypvit(224, mean=mean, std=std)
valid_transforms = get_normalisation_resize_hypvit(256, mean=mean, std=std)

train_dataset = DatasetWithLabels(df_train, transform=train_transforms, dataset_root=dataset_root)
train_dataset_metrics = DatasetWithLabels(df_train, transform=valid_transforms, dataset_root=dataset_root)
valid_dataset = DatasetWithLabels(df_valid, transform=valid_transforms, dataset_root=dataset_root)

sampler = BalanceSampler(train_dataset.get_labels(), n_labels=n_labels, n_instances=n_instances)
train_loader = DataLoader(train_dataset, batch_sampler=sampler, num_workers=num_workers)
train_loader_metrics = DataLoader(train_dataset_metrics, batch_size=valid_batch_size, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, num_workers=num_workers)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Extractor().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scheduler = None
# criterion = TripletLossWithMiner(margin=0.15, miner=HardTripletsMiner())
criterion = MyTripletLossWithMiner(MyTripletLoss(margin=0.2), 
                                   NHardTripletsMiner(n_positive=(2, 2), n_negative=(1, 1)))

# Training

In [7]:
# torch.manual_seed(42)

n_epochs = 10000
valid_period = 5

metrics = {
    'cmc_top_k': [1],  # to calculate cmc@1
    'map_top_k': [5],  # to calculate map@5
    'precision_top_k': [],
    'fmr_vals': []
}

wandb_init_data = {
    'project': 'TP3',
    'name': 'run',
    'save_code': True,
    'config': {
        'model': 'ViT',
        'optimizer': optimizer,
        'scheduler': scheduler,
        'sampler': {
            'name': 'balanced',
            'n_labels': n_labels,
            'n_instances': n_instances
        },
        
        'valid_period': valid_period,

        'dataset': 'CUB_200_2011',
        'num_epochs': n_epochs,
        'dataloader_num_workers': num_workers,
        'script': _ih[-1]
    }
}

with wandb.init(**wandb_init_data) as run:
    print('Evaluating pre-trained model before training')
    wandb_metrics_value = validation(model, valid_loader, metrics, device)
    add_train_metrics(wandb_metrics_value, model, train_loader_metrics, metrics, device)
    wandb.log(wandb_metrics_value)
    best_cmc1 = wandb_metrics_value['cmc/1']
    
    for epoch in range(n_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            embeddings = model(batch['input_tensors'].to(device))
            loss = criterion(embeddings, batch['labels'].to(device))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()

        if (epoch + 1) % valid_period == 0:
            print(f'{epoch + 1} training epochs finished\nValidation started: {dt.datetime.now()}')
            with torch.inference_mode():
                wandb_metrics_value = validation(model, valid_loader, metrics, device)
                add_train_metrics(wandb_metrics_value, model, train_loader_metrics, metrics, device)
                wandb_metrics_value.update(criterion.tri_loss.summary())
                wandb.log(wandb_metrics_value)
                
                if wandb_metrics_value['cmc/1'] > best_cmc1:
                    best_cmc1 = wandb_metrics_value['cmc/1']
                    save_model('best.pt', epoch + 1, model, optimizer, scheduler)
                    wandb.save('best.pt')
                    print(f'\nNew best CMC@1 {best_cmc1} at {epoch + 1} epoch\n')

[34m[1mwandb[0m: Currently logged in as: [33mnik-fedorov[0m. Use [1m`wandb login --relogin`[0m to force relogin


Evaluating pre-trained model before training
Inference finished: 2023-05-20 09:38:17.754253
{'cmc/1': 0.830857515335083, 'map/5': 0.8621483445167542, 'additional/class_sizes/min': tensor(0.4694), 'additional/class_sizes/max': tensor(0.6891), 'additional/class_sizes/mean': tensor(0.5814), 'additional/inter_class_dist/min': tensor(0.0002), 'additional/inter_class_dist/max': tensor(1.1435), 'additional/inter_class_dist/mean': tensor(0.7877)}

Inference finished: 2023-05-20 09:38:56.015809
{'cmc/1': 0.844474732875824, 'map/5': 0.8746084570884705, 'additional/class_sizes/min': tensor(0.4522), 'additional/class_sizes/max': tensor(0.6944), 'additional/class_sizes/mean': tensor(0.5822), 'additional/inter_class_dist/min': tensor(0.0002), 'additional/inter_class_dist/max': tensor(1.1752), 'additional/inter_class_dist/mean': tensor(0.8619)}

5 training epochs finished
Validation started: 2023-05-20 09:39:13.663012
Inference finished: 2023-05-20 09:39:48.167185
{'cmc/1': 0.8408170342445374, 'map/5

0,1
additional/class_sizes/max,█▇▆▄▃▃▃▃▂▂▁▁▁▁▁▁▂▂▂▂▂▂▂▂▁▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄
additional/class_sizes/mean,█▇▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃
additional/class_sizes/min,█▆▅▅▄▃▃▃▃▃▃▃▂▁▁▁▂▁▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▃▃▂▂▂
additional/inter_class_dist/max,▇██▇▆▆▆▆▅▄▄▄▃▅▄▃▅▃▂▂▁▁▁▂▁▂▂▂▃▃▃▂▃▂▂▄▄▂▂▁
additional/inter_class_dist/mean,▆▇█▇▇▅▅▅▅▃▂▃▂▄▃▂▃▃▃▂▁▂▁▃▃▂▃▃▄▄▄▅▅▆▅▆▅▄▅▄
additional/inter_class_dist/min,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
cmc/1,▄▇████▇▇▆▇▇▇▆▇▆▆▆▆▅▅▅▄▅▅▄▅▄▃▃▃▄▃▂▂▂▂▁▂▁▁
"d(a,n)-d(a,p)/max",▃▅▆▂▄▂▂▆▁▄▁▄▄▃▄▄▄▆▂▃▄▃▃▆▄▃▄▆▅▅▅▇▅▆▄▆▄██▆
"d(a,n)-d(a,p)/mean",▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█▇███
"d(a,n)-d(a,p)/min",▂▃▅▆▅▄▂▆▂▅▃▄▆▄▅▄▃▆▅▇▆▄▄▇▄▆▁▅▅▆▇▇▄█▆▅▃▆▁▃

0,1
additional/class_sizes/max,0.63589
additional/class_sizes/mean,0.49344
additional/class_sizes/min,0.36078
additional/inter_class_dist/max,1.11313
additional/inter_class_dist/mean,0.7982
additional/inter_class_dist/min,0.00024
cmc/1,0.82968
"d(a,n)-d(a,p)/max",0.69835
"d(a,n)-d(a,p)/mean",0.26022
"d(a,n)-d(a,p)/min",-0.39955


KeyboardInterrupt: 

# Loading model from wandb and resume training

In [None]:
import wandb
best_model = wandb.restore('best.pt', run_path="nik-fedorov/TP3/f0ian1ey")

# model = ViTExtractor('vits16_dino', arch='vits16', normalise_features=False).to(device)
# model = timm.create_model('vit_small_patch16_224', pretrained=True).to(device)
load_model(best_model.name, device, model, optimizer)

In [None]:
torch.manual_seed(42)

n_epochs = 10000
valid_period = 10

metrics = {
    'cmc_top_k': [1],  # to calculate cmc@1
    'map_top_k': [5],  # to calculate map@5
    'precision_top_k': [],
    'fmr_vals': []
}

wandb_init_data = {
    'project': 'TP3',
    'name': 'run',
    'save_code': True,
    'config': {
        'model': 'ViT',
        'optimizer': optimizer,
        'scheduler': scheduler,
        'sampler': {
            'name': 'balanced',
            'n_labels': n_labels,
            'n_instances': n_instances
        },
        
        'valid_period': valid_period,

        'dataset': 'CUB_200_2011',
        'num_epochs': n_epochs,
        'dataloader_num_workers': num_workers,
        'script': _ih[-1]
    }
}

with wandb.init(**wandb_init_data) as run:
    print('Evaluating pre-trained model before training')
    wandb_metrics_value = validation(model, valid_loader, metrics, device)
    wandb.log(wandb_metrics_value)
    best_cmc1 = wandb_metrics_value['cmc/1']
    
    for epoch in range(n_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            embeddings = model(batch['input_tensors'].to(device))
            loss = criterion(embeddings, batch['labels'].to(device))
            loss.backward()
            optimizer.step()

        if (epoch + 1) % valid_period == 0:
            print(f'{epoch + 1} training epochs finished\nValidation started: {dt.datetime.now()}')
            with torch.inference_mode():
                wandb_metrics_value = validation(model, valid_loader, metrics, device)
                wandb.log(wandb_metrics_value)
                
                if wandb_metrics_value['cmc/1'] > best_cmc1:
                    best_cmc1 = wandb_metrics_value['cmc/1']
                    save_model('best.pt', epoch + 1, model, optimizer, scheduler)
                    wandb.save('best.pt')
                    print(f'\nNew best CMC@1 {best_cmc1} at {epoch + 1} epoch\n')