[<img src="https://github.com/KevinMusgrave/pytorch-metric-learning/raw/master/docs/imgs/Logo2.png">]()

## Introduction

This notebook makes use of the fantastic library `pytorch-metric-learning` developed and maintained by Kevin Musgrave. You can find the github at the following link:

- https://github.com/KevinMusgrave/pytorch-metric-learning

You can find a ton of useful metric learning modules there, along with a super friendly API for rapid training and evaluation. I recommend reading through the example notebooks because they are very well put together (below borrows from them heavily).

Here we use the library to train a basic whale detector using an efficient net backbone (https://arxiv.org/abs/1905.11946) with ArcFace loss (https://arxiv.org/abs/1801.07698). This is a very straightforward example and there are many ways to improve. Here are some suggestions:

- Change the train/validation split to better resemble the public LB.
- Change the model trunk.
- Pre-process the images by e.g. applying bounding boxes.
- Experiment with the training proceedure.

I will continue to develop this notebook over time and hopefully improve the results.

All feedback appreciated.

**Change Log**

- Version 9: switched to 384x384 dataset, added training augmentation, and switched from Adam to SGD with cosine schedule.
- Version 8 (LB: 0.245): fixed bug where same individual predicted multiple times for single image and increased the KNN search range.
- Version 6 (LB: 0.229): switched to cropped YOLO5 input, switched to b3 model, reduced epochs, and updated logging.
- Version 4 (LB: 0.190): initial notebook completed.

## Dependencies

In [1]:
!pip install timm
!pip install pytorch-metric-learning[with-hooks]

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
     |████████████████████████████████| 431 kB 893 kB/s            
Installing collected packages: timm
Successfully installed timm-0.5.4
Collecting pytorch-metric-learning[with-hooks]
  Downloading pytorch_metric_learning-1.1.2-py3-none-any.whl (106 kB)
     |████████████████████████████████| 106 kB 907 kB/s            
Collecting record-keeper>=0.9.31
  Downloading record_keeper-0.9.31-py3-none-any.whl (8.2 kB)
Collecting faiss-gpu>=1.6.3
  Downloading faiss_gpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
     |████████████████████████████████| 85.5 MB 144 kB/s             
Installing collected packages: record-keeper, pytorch-metric-learning, faiss-gpu
Successfully installed faiss-gpu-1.7.2 pytorch-metric-learning-1.1.2 record-keeper-0.9.31


## Imports

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import logging
import timm
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import Compose, Lambda, Normalize, AutoAugment, AutoAugmentPolicy

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as LP
from pytorch_metric_learning.utils import common_functions
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import InferenceModel

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)

INFO:root:VERSION 1.1.2


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Parameters

There is no logic behind these, really. Go wild.

In [4]:
MODEL_NAME='tf_efficientnet_b3_ns'
N_CLASSES=15587
OUTPUT_SIZE = 1536
EMBEDDING_SIZE = 512
N_EPOCH=15
BATCH_SIZE=24
MODEL_LR = 1e-3
PCT_START=0.3
PATIENCE=5
N_WORKER=2
N_NEIGHBOURS = 1000

## Directories

We have now switched to using cropped images provided by Awsaf in the following notebook: https://www.kaggle.com/awsaf49/happywhale-cropped-dataset-yolov5. Please go give him an upvote if you like this notebook.

In [5]:
TRAIN_DIR = '../input/jpeg-happywhale-384x384/train_images-384-384/train_images-384-384'
TEST_DIR = '../input/jpeg-happywhale-384x384/test_images-384-384/test_images-384-384'
LOG_DIR = "../logs/{}".format(MODEL_NAME)
MODEL_DIR = "../models/{}".format(MODEL_NAME)

## Dataset

Create a basic dataset for loading images. 

Since we're planning to use pre-trained imagenet weights we need to normalize appropriately.

In [6]:

class HappyWhaleDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        return_labels=True,
    ):
        self.df = df
        self.images = self.df["image"]
        self.image_dir = image_dir
        self.image_transform = Compose(
            [
                AutoAugment(AutoAugmentPolicy.IMAGENET),
                Lambda(lambda x: x / 255),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                
            ]
        )
        self.return_labels = return_labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.images.iloc[idx])
        image = read_image(path=image_path)
        image = self.image_transform(image)
        
        if self.return_labels:
            label = self.df['label'].iloc[idx]
            return image, label
        else:
            return image


# Data Split

Load in the csv:

In [7]:
df = pd.read_csv('../input/happy-whale-and-dolphin/train.csv')
df.head()

Unnamed: 0,image,species,individual_id
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9
1,000562241d384d.jpg,humpback_whale,1a71fbb72250
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392


Add a label for the classes:

In [8]:
df['label'] = df.groupby('individual_id').ngroup()
df['label'].describe()

count    51033.000000
mean      7651.356240
std       4465.552697
min          0.000000
25%       3748.000000
50%       7605.000000
75%      11443.000000
max      15586.000000
Name: label, dtype: float64

Split into training and validation:

In [9]:
valid_proportion = 0.1

valid_df = df.sample(frac=valid_proportion, replace=False, random_state=1).copy()
train_df = df[~df['image'].isin(valid_df['image'])].copy()

print(train_df.shape)
print(valid_df.shape)

(45930, 4)
(5103, 4)


Reset index on both since we want to use it for KNN lookups later:

In [10]:
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)

Create our dataset objects:

In [11]:
train_dataset = HappyWhaleDataset(df=train_df, image_dir=TRAIN_DIR, return_labels=True)
len(train_dataset)

45930

In [12]:
valid_dataset = HappyWhaleDataset(df=valid_df, image_dir=TRAIN_DIR, return_labels=True)
len(valid_dataset)

5103

In [13]:
dataset_dict = {"train": train_dataset, "val": valid_dataset}

## Model Setup

We need to specify three components to build our model:

- Trunk
- Embedder
- Loss

Setup the trunk using a pre-trained model from timm:

In [14]:
trunk = timm.create_model(MODEL_NAME, pretrained=True)
trunk.classifier = common_functions.Identity()
trunk = trunk.to(device)
trunk_optimizer = optim.SGD(trunk.parameters(), lr=MODEL_LR, momentum=0.9)
trunk_schedule = optim.lr_scheduler.OneCycleLR(
    trunk_optimizer,
    max_lr=MODEL_LR,
    total_steps = N_EPOCH * int(len(train_dataset)/BATCH_SIZE),
    pct_start = PCT_START
)

INFO:timm.models.helpers:Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b3_ns-9d44bf68.pth


Add our embedder. This is just a linear layer that will create the embeddings for KNN:

In [15]:
embedder = nn.Linear(OUTPUT_SIZE, EMBEDDING_SIZE).to(device)
embedder_optimizer = optim.SGD(trunk.parameters(), lr=MODEL_LR, momentum=0.9)
embedder_schedule = optim.lr_scheduler.OneCycleLR(
    embedder_optimizer,
    max_lr=MODEL_LR,
    total_steps = N_EPOCH * int(len(train_dataset)/BATCH_SIZE),
    pct_start = PCT_START
)

And add the loss function:

In [16]:
loss_func = losses.ArcFaceLoss(num_classes=N_CLASSES, embedding_size=EMBEDDING_SIZE).to(device)
loss_optimizer = optim.SGD(trunk.parameters(), lr=MODEL_LR, momentum=0.9)
loss_schedule = optim.lr_scheduler.OneCycleLR(
    loss_optimizer,
    max_lr=MODEL_LR,
    total_steps = N_EPOCH * int(len(train_dataset)/BATCH_SIZE),
    pct_start = PCT_START
)

Setup some hooks for validation, logging and model saving at the end of the epoch:

In [17]:
record_keeper, _, _ = LP.get_record_keeper(LOG_DIR)
hooks = LP.get_hook_container(record_keeper, primary_metric='mean_average_precision')

In [18]:
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook,
    accuracy_calculator=AccuracyCalculator(
        include=['mean_average_precision'],
        device=torch.device("cpu"),
        k=5),
    dataloader_num_workers=N_WORKER,
    batch_size=BATCH_SIZE
)

By adding the tester as an end of epoch hook in this way, it will automatically use the embedder model to generate train and validation embeddings, then for each validation embedding find the k nearest neighbours and evaluate MAP@5. This won't take into account the `new_individual` problem, but it should give us an idea of model performance on the task regardless.

In [19]:
end_of_epoch_hook = hooks.end_of_epoch_hook(
    tester, 
    dataset_dict,
    MODEL_DIR,
    test_interval=1, 
    patience=PATIENCE, 
    splits_to_eval = [('val', ['train'])]
)

Finally, setup our trainer object:

In [20]:
trainer = trainers.MetricLossOnly(
    models={"trunk": trunk, "embedder": embedder},
    optimizers={"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "metric_loss_optimizer": loss_optimizer},
    batch_size=BATCH_SIZE,
    loss_funcs={"metric_loss": loss_func},
    mining_funcs={},
    dataset=train_dataset,
    dataloader_num_workers=N_WORKER,
    end_of_epoch_hook=end_of_epoch_hook,
    lr_schedulers={
        'trunk_scheduler_by_iteration': trunk_schedule,
        'embedder_scheduler_by_iteration': embedder_schedule,
        'metric_loss_scheduler_by_iteration': loss_schedule,
    }
)

## Model Training

Train the model:

In [21]:
trainer.train(num_epochs=N_EPOCH)

INFO:PML:Initializing dataloader
INFO:PML:Initializing dataloader iterator
INFO:PML:Done creating dataloader iterator
INFO:PML:TRAINING EPOCH 1
total_loss=41.89675: 100%|██████████| 1913/1913 [17:28<00:00,  1.82it/s]
INFO:PML:Evaluating epoch 1
INFO:PML:Getting embeddings for the val split
100%|██████████| 213/213 [01:06<00:00,  3.23it/s]
INFO:PML:Getting embeddings for the train split
100%|██████████| 1914/1914 [08:28<00:00,  3.76it/s]
INFO:PML:Computing accuracy for the val split w.r.t ['train']
INFO:PML:running k-nn with k=5
INFO:PML:embedding dimensionality is 512
INFO:PML:New best accuracy! 0.06186890128847916
INFO:PML:TRAINING EPOCH 2
total_loss=40.11728: 100%|██████████| 1913/1913 [17:07<00:00,  1.86it/s]
INFO:PML:Evaluating epoch 2
INFO:PML:Getting embeddings for the val split
100%|██████████| 213/213 [00:57<00:00,  3.70it/s]
INFO:PML:Getting embeddings for the train split
100%|██████████| 1914/1914 [08:27<00:00,  3.77it/s]
INFO:PML:Computing accuracy for the val split w.r.t ['

## Inference (validation set)

Here we want to use the validation set to help us choose the appropriate distance threshold between our query and reference images after which we classify the former as a `new_individual`. To do so, we loop through the validation set for a number of thresholds and find that which maximises our MAP@5.

Load in the best weights:

In [22]:
logging.getLogger().setLevel(logging.WARNING)

In [23]:
best_trunk_weights = glob.glob('../models/{}/trunk_best*.pth'.format(MODEL_NAME))[0]
trunk.load_state_dict(torch.load(best_trunk_weights))

<All keys matched successfully>

In [24]:
best_embedder_weights = glob.glob('../models/{}/embedder_best*.pth'.format(MODEL_NAME))[0]
embedder.load_state_dict(torch.load(best_embedder_weights))

<All keys matched successfully>

Setup the inference model object to easily generate embeddings and find nearest neighbours:

In [25]:
inference_model = InferenceModel(
    trunk=trunk,
    embedder=embedder,
    normalize_embeddings=True,
)

Train this on the training data:

In [26]:
inference_model.train_knn(train_dataset)

Loop through the validation data and loop through to find k nearest neighbours:

In [27]:
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_WORKER, pin_memory=True)

In [28]:
valid_labels_list = []
valid_distance_list = []
valid_indices_list = []

for images, labels in tqdm(valid_dataloader):

    distances, indices = inference_model.get_nearest_neighbors(images, k=N_NEIGHBOURS)
    valid_labels_list.append(labels)
    valid_distance_list.append(distances)
    valid_indices_list.append(indices)

valid_labels = torch.cat(valid_labels_list, dim=0).cpu().numpy()
valid_distances = torch.cat(valid_distance_list, dim=0).cpu().numpy()
valid_indices = torch.cat(valid_indices_list, dim=0).cpu().numpy()

  0%|          | 0/213 [00:00<?, ?it/s]

We have the indices of the nearest neighbours in our training set, so setup the lookups to return the `individual_id`:

In [29]:
new_whale_idx = -1

train_labels = train_df['individual_id'].unique()
train_idx_lookup = train_df['individual_id'].copy().to_dict()
train_idx_lookup[-1] = 'new_individual'

valid_class_lookup = valid_df.set_index('label')['individual_id'].copy().to_dict()

Loop through a range of thresholds and find which maximises our MAP@5:

In [30]:
thresholds = [np.quantile(valid_distances, q=q) for q in np.arange(0, 1.0, 0.01)]

In [31]:
results = []

for threshold in tqdm(thresholds):

    prediction_list = []
    running_map=0

    for i in range(len(valid_distances)):

        pred_knn_idx = valid_indices[i, :].copy()  
        insert_idx = np.where(valid_distances[i, :] > threshold) 

        if insert_idx[0].size != 0:  
            pred_knn_idx = np.insert(pred_knn_idx, np.min(insert_idx[0]), new_whale_idx) 

        predicted_label_list = []
        
        for predicted_idx in pred_knn_idx:
            predicted_label = train_idx_lookup[predicted_idx]
            if len(predicted_label_list) == 5:
                break
            if (predicted_label == 'new_individual') | (predicted_label not in predicted_label_list):
                predicted_label_list.append(predicted_label)

        gt = valid_class_lookup[valid_labels[i]]

        if gt not in train_labels:
            gt = "new_individual"

        precision_vals = []

        for j in range(5):
            if predicted_label_list[j] == gt:
                precision_vals.append(1/(j+1))
            else:
                precision_vals.append(0)

        running_map += np.max(precision_vals)

    results.append([threshold, running_map / len(valid_distances)])

results_df = pd.DataFrame(results, columns=['threshold','map5'])

  0%|          | 0/100 [00:00<?, ?it/s]

In [32]:
results_df = results_df.sort_values(by='map5', ascending=False).reset_index(drop=True)
results_df.head(5)

Unnamed: 0,threshold,map5
0,0.8035,0.418257
1,0.837327,0.417042
2,0.765159,0.416654
3,0.867816,0.414766
4,0.718935,0.414407


Grab the best result:

In [33]:
threshold = results_df.loc[0, 'threshold']
threshold

0.8034998118877411

## Inference (test set)

We want to make sure we use both our training and validation images for comparison. Combine the two dataframes and add a new dataset: 

In [34]:
combined_df = pd.concat([train_df, valid_df], axis=0).reset_index(drop=True)
combined_dataset = HappyWhaleDataset(df=combined_df, image_dir=TRAIN_DIR, return_labels=True)
len(combined_dataset)

51033

Re-train the KNN model on this:

In [35]:
inference_model.train_knn(combined_dataset)

Grab the submission file:

In [36]:
test_df = pd.read_csv('../input/happy-whale-and-dolphin/sample_submission.csv')

Create our dataset and dataloader objects for the test set:

In [37]:
test_dataset = HappyWhaleDataset(df=test_df, image_dir=TEST_DIR, return_labels=False)
len(test_dataset)

27956

In [38]:
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_WORKER, pin_memory=True)

Find the k nearest neighbours in our combined dataset:

In [39]:
test_distance_list = []
test_indices_list = []

for images in tqdm(test_dataloader):

    distances, indices = inference_model.get_nearest_neighbors(images, k=N_NEIGHBOURS)
    test_distance_list.append(distances)
    test_indices_list.append(indices)

test_distances = torch.cat(test_distance_list, dim=0).cpu().numpy()
test_indices = torch.cat(test_indices_list, dim=0).cpu().numpy()

  0%|          | 0/1165 [00:00<?, ?it/s]

Prepare the labels for lookup based on index:

In [40]:
combined_idx_lookup = combined_df['individual_id'].copy().to_dict()
combined_idx_lookup[-1] = 'new_individual'

Loop through applying the threshold we found earlier to insert `new_individual`:

In [41]:
results = []

prediction_list = []

for i in range(len(test_distances)):

    pred_knn_idx = test_indices[i, :].copy() 
    insert_idx = np.where(test_distances[i, :] > threshold)  

    if insert_idx[0].size != 0:  
        pred_knn_idx = np.insert(pred_knn_idx, np.min(insert_idx[0]), new_whale_idx)  

    predicted_label_list = []

    for predicted_idx in pred_knn_idx:
        predicted_label = combined_idx_lookup[predicted_idx]
        if len(predicted_label_list) == 5:
            break
        if (predicted_label == 'new_individual') | (predicted_label not in predicted_label_list):
            predicted_label_list.append(predicted_label)

    prediction_list.append(predicted_label_list)

prediction_df = pd.DataFrame(prediction_list)
prediction_df.head()

Unnamed: 0,0,1,2,3,4
0,6f4abc5666cd,6d1cc6d00dca,new_individual,d354abb7798b,5dc511af5c99
1,new_individual,64cbd1f56354,ca2c3f068c7a,81c3d57462fb,6a28c14fc117
2,c80653b9edd0,afbfd0d6b694,31f748b822f4,20c4589b5b16,4f87f855216a
3,new_individual,ad962a1131d3,6ad3713dda3a,1492507238d8,91ed5caeb0d3
4,new_individual,d4d2cd407a48,b897fc70bb2e,e35763314b89,a18ef9f290bb


Create the prediction label:

In [42]:
prediction_df['predictions'] = prediction_df[0].astype(str) + ' ' + prediction_df[1].astype(str) + ' ' + prediction_df[2 ].astype(str) + ' ' + prediction_df[3].astype(str) + ' ' + prediction_df[4].astype(str)
prediction_df.head()

Unnamed: 0,0,1,2,3,4,predictions
0,6f4abc5666cd,6d1cc6d00dca,new_individual,d354abb7798b,5dc511af5c99,6f4abc5666cd 6d1cc6d00dca new_individual d354a...
1,new_individual,64cbd1f56354,ca2c3f068c7a,81c3d57462fb,6a28c14fc117,new_individual 64cbd1f56354 ca2c3f068c7a 81c3d...
2,c80653b9edd0,afbfd0d6b694,31f748b822f4,20c4589b5b16,4f87f855216a,c80653b9edd0 afbfd0d6b694 31f748b822f4 20c4589...
3,new_individual,ad962a1131d3,6ad3713dda3a,1492507238d8,91ed5caeb0d3,new_individual ad962a1131d3 6ad3713dda3a 14925...
4,new_individual,d4d2cd407a48,b897fc70bb2e,e35763314b89,a18ef9f290bb,new_individual d4d2cd407a48 b897fc70bb2e e3576...


Attach this to the submission:

In [43]:
submission = pd.read_csv('../input/happy-whale-and-dolphin/sample_submission.csv')
submission['predictions'] = prediction_df['predictions']
submission.head(1)

Unnamed: 0,image,predictions
0,000110707af0ba.jpg,6f4abc5666cd 6d1cc6d00dca new_individual d354a...


Save our submission:

In [44]:
submission.to_csv('submission.csv', index=False)