# Tutorial: TripletLoss for re-identification using FastAI v2

This tutorial demonstrates the use of the `TripleLoss` loss function in FastAI v2 (running over Pytorch). TripletLoss is one of the leading ways to match an image (or sound or other signal) with a large database of images, even if there are very few matches for that particular input. Joint-embedding networks (aka Twin or Siamese networks) are a common alternative for this "low k-shot" challenge. According to [that article I read](tk), triplet loss outperformed twin networks in all of the tk studies where both approaches were tried. Another option for re-identification is Scale Invariant Feature Transform (SIFT) which is implemented in OpenCV. I haven't seen any papers comparing SIFT and statistical ML techniques in real-world tasks. 

The tutorial walks you through:

- Creating the necessary conda environment
- Downloading and expanding the MNIST dataset of handwritten digits
- Creating and training a model using triplet loss
- Applying this model to a few images of a new glyph (a hand-drawn star)
- Finding the closest neighbors to the new images
- Visualizing that the model tightly clusters the new glyph
- Demonstrating that the model can be used for re-identification of the new glyph

MNIST is an easy problem for modern statistical ML, with relatively small data and a small input size. This makes it a good candidate for a tutorial. I will try to highlight the places where you'll need to alter parameters when faced with a more realistic task, like animal re-identification. 

The code assumes that you have a CUDA-accelerated GPU for training and inferencing. The tutorial runs quickly on a 2080 (about 30 seconds per epoch during training, about 20ms to inference) but runs _much_ slower on a mobile GPU like a GeForce GTX 1650 (30 minutes per epoch on my Surface Book 3). I assume that porting it to CPU wouldn't be difficult, but I think it would be quite slow. If you don't have a local GPU, I'd suggest using a GPU-powered cloud compute resource, such as Azure ML. (Let me know if you'd like to see a tutorial on running this in Azure ML.) 

## The problem of re-identification

Our photo libraries are filled with images of our friends and family. A machine learning model that focuses on "classification" can successfully tell us that we have lots of photos of "smiling man" or "smiling woman." Such models will not, generally, tell us which photos are of Uncle Al and Aunt Betty. _That_ task -- re-identifying individuals -- is a different problem and requires different approaches.

Beyond identifying Uncle Al and Aunt Betty in a photo library, re-identification is a common problem for wildlife biologists. Many species have photographic catalogs taken over years and decades, and many species have some distinctive features that can be used to identify individuals. 

### Some animals and how they may be re-identified

| Species | &nbsp | &nbsp |
| --- | --- | --- |
| Humpback Whales | tk | tk |
| Tigers | tk | tk | 
| Dolphins | tk | tk | 
| Manta Rays | tk | tk | 


## Create the Python environment 

I developed this tutorial on Ubuntu 20.04, Cuda 10.2, PyTorch 1.7, FastAI 2.3. I have to admit that the rigmarole of exactly recreating a GPU-enabled virtual environment is a little beyond me, but I _think_ you have to install CUDA manually and then when you run the `conda create` command below, I _think_ it will download properly-configured versions of the various libraries. 

Prequisites: 
1. [Install conda](tk) 
1. [Install CUDA](tk)

1. Clone this repo and change into the directory.
1. Create the environment:
    ```bash
    conda create -f 'conda_environment.yml'
    ```
1. Activate the environment for use with:
   ```bash
   conda activate fastai
   ```
1. Confirm the environment by running:
   ```bash
   python .\versions.py
   ```
Which should result in something similar to:
```bash
Pytorch : 1.7.0, FastAI : 2.3.1, CUDA? : True
```

The environment also installs Jupyter. You may need to restart your terminal session to have the `jupyter` in your path. Run:

```bash
jupyter notebook TripletLossTutorial.ipynb
```

and continue from there.

# Tutorial: TripletLoss for re-identification using FastAI v2

**If you have not done so, please see [README.md](readme.md) for instructions on creating the Python environment for this tutorial.**

## Import packages

Nothing surprising here, I think. To visualize the output, we're going to use scikit-learn's TSNE implementation. I've heard this isn't the fastest TSNE, but for 10K datapoints and a limited number of dimensions, it's fine. 

In [1]:
from fastai.vision import *
from fastai.basics import *
from fastai.vision.all import *

from loss_functions.triplet_loss import TripletLoss
from fastai.vision.augment import *
from fastai.vision.learner import cnn_learner
from fastai.vision.models import resnet34
from fastai.metrics import accuracy
from fastai.data.core import DataLoaders

In [None]:
import torch
import torchvision
import torch.nn as nn

In [None]:
from scipy import spatial
import logging

In [None]:
%matplotlob inline 

import matplotlib as plt
import matplotlib.patches as patches
from sklearn.manifold import TSNE
from matplotlib import cm
from mpl_toolkits.axes_grid1 import ImageGrid

## Confirm versions and CUDA

This is just to confirm you're running GPU-accelerated. If you don't have a GPU, I think training will be very slow. (Let me know if I'm wrong!)

In [2]:
import fastai
fastai.__version__

'2.3.2'

In [3]:
torch.cuda.is_available()

True

In [5]:
device = 0
torch.cuda.set_device(device)

## Setup the training data

The training data for this tutorial is the MNIST dataset of hand-written digits 0-9. The dataset contains 10K PNG images of size 28x28. The PNGs are 3-channel, but all the pixels are grayscale. Dear ol' MNIST.

![A few MNIST digits](./media/mnist_sample.png)

Now, you're probably used to MNIST as a dataset for _image classification_: "To which of the trained-on categories ('0'...'9') does this query image most likely belong?" But for this tutorial, we're using it to learn "glyph re-identification." Our acid test will be seeing if images of a glyph not seen during training (a star) can be re-identified with very few samples (aka "low shot").

`URLs.MNIST`, `untar_data()`, and `get_image_files()` are from the `fastai.basics` module. The following cell will download, if necessary, the MNIST dataset and unpack it. By default, the data will be stored in `~/.fastai/data`. (It's a good idea to keep track of the size of that directory when fooling around with FastAI! Some of the datasets are _big_!).

In [8]:
mnist = untar_data(URLs.MNIST)
fnames = get_image_files(mnist)

Now, `~/.fastai/data/mnist_ping/` contains `testing/` and `training/` directories, which in turn have `0/`, `1/`, `2/`, etc. subdirectories. FastAI v2's `ImageDataLoaders` class can process this typical directory structure and take care of the boilerplate of testing vs. training sets, loading and transforming the images appropriately for processing, etc.

In [9]:
# The proper label for an image is the name of the directory it's in. e.g., "1" is the proper label for `1/1.png`
def label_func(x): return x.parent.name

dls = ImageDataLoaders.from_path_func(mnist, fnames, label_func)

The output of a TripletLoss-based model is an embedding (a vector representing a point in a high-dimensional space). The goal is that each of the dimensions in this embedding space represents a feature in the "how to tell individuals of this class apart" solution space. The distance between any two embeddings 

After training, the model generates embedding values for a query image. 

In a real-world animal re-identification model, this might be on the order of 128. In the case of MNIST, we can get fine results with just a few output features.

In [1]:
# output_embedding_length = 128
output_embedding_length = 4

In [10]:
# dls.valid_dl.new(shuffle=True)
# Original is a fastai data (valid_dl etc.)

To create a TripletLoss-based embedding, we'll start with a standard Convolutional Neural Network using our recently created `dls` `ImageDataLoaders` object, the ResNet34 architecture, and FastAI's (PyTorch's?) TripletLoss function. 

In [12]:
learn = cnn_learner(dls, resnet34, metrics=accuracy, loss_func=TripletLoss(device))

In [13]:
learn.model[1]

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=10, bias=False)
)

We want to replace the classification head of the ResNet34 architecture with an embedding output.  The `L2_norm` class normalizes the values to the range $ (0.0,1.0) $. 

In [4]:
class L2_norm(nn.Module):
    def __init__(self):
        super(L2_norm, self).__init__()

    def forward(self, x):
        return F.normalize(x, p=2, dim=-1)

In [237]:
layers = learn.model[1]
learn.model[1] = nn.Sequential(layers[0], layers[1], layers[2], layers[3], nn.Linear(in_features=1024, out_features=output_embedding_length, bias=False), L2_norm()).to(device)
    

In [None]:
learn.model[1]

We now have a model configured for TripletLoss. What happens behind the scene when this model is trained? 

Basically, the training loop will grab 3 images: a `target` image, a `positive` image, and a `negative` image. The `target` and `positive` are in the same class ('1', '2', '3', etc.) while the `negative` image is in a different class. The basic goal of training is to find weights such that `target` and `positive` have embedding values that are near other while `target` and `negative` have embedding values that are far from each other. 

If you think about that, _most_ triplets are going to end up being 'easy'; any two random '9's are likely to share several features, while a randomly selected not-9 is unlikely to be "oh, yeah, that could be a 9." So training with just-random triplets is less efficient than training with 'hard' triplets. In a hard triplet, the `positive` and `negative` are as close as possible to the dividing line in "tell glyphs apart" embedding space. In an ideal 'hard' triplet with a `target` in class '9', the 'positive' is a '9' image that's close in appearance to a '7' image that could be mistaken for a '9'. Finding hard triplets and using them for training is somewhere between "valuable" and "essential" in real-world re-identification tasks. "Triplet mining" is outside the scope of this tutorial, but you can read more about it [here](tk). 

## Training

FastAI's `lr_find()` function does a quick sweep of learning rates. Typically, a good learning rate is one where the slope of the `lr_find()` graph is steepest.

In [2]:
learn.lr_find()

NameError: name 'learn' is not defined

For tutorial purposes and because MNIST is pretty easy, we're going to use a fast learning rate and few epochs. Real-world data sets won't be so kind!

In [None]:
lr = 0.0002
epochs = 20
learn.fit_one_cycle(epochs, slice(lr))

epoch,train_loss,valid_loss,accuracy,time
0,0.746077,0.293546,0.0005,00:29
1,0.514373,0.240476,0.000429,00:29
2,0.356138,0.159343,0.000643,00:29
3,0.197945,0.068402,0.000571,00:29
4,0.100318,0.049595,0.001,00:29
5,0.058414,0.043318,0.009286,00:29
6,0.027147,0.042105,0.025429,00:29
7,0.017346,0.041576,0.021143,00:29
8,0.011594,0.040571,0.019857,00:29
9,0.011474,0.039593,0.0045,00:29


In [None]:
model_name = f'{epochs}_epochs_{out_features_count}_features'

In [None]:
learn.save(model_name)

## Inferencing

To get the output embedding of a `query` image, we just use `learn.predict()` on our trained `Learner: 

In [None]:
query_img = PILImage.create('Star1.png')
result = learn.predict(query_img)
query_fingerprint = result[1].numpy()
print(query_fingerprint)

### Fingerprinting the exemplars

Create the database to which we will compare `query_fingerprint`:    

In [4]:
fnames = list(Path(mnist/'testing').rglob('*.png'))
len(fnames)

NameError: name 'Path' is not defined

In [None]:
fnames[0]

In [None]:
# Outputs a blank line for each for some reason. Use %%capture --no-std-err when calling
def fingerprint_all(fnames):
    fingerprints = {}
    for f in fnames:
        category = label_func(f)
        img = PILImage.create(f)
        result = learn.predict(img)
        fingerprint = result[1].numpy()
        fingerprints[(category,f)] = fingerprint
    return fingerprints


In [None]:
%%capture --no-std-err
# Suppresses output
# Takes about 3 minutes on a 2080
fingerprint_db = fingerprint_all(fnames)

## Nearest k-neighbors

Because our embedding values are normalized, finding the nearest k-neighbors is trivial:

In [None]:
# Find k nearest neighbour using cosine similarity. Normalized vectors, so easy...
def find_k_nearest_neighbors(vectors,vec,k):
    dist_arr = np.matmul(vectors, vec.T)
    return np.argsort(-dist_arr.flatten())[:k]


In [None]:
fps = list(fingerprint_db.values())
closest = find_k_nearest_neighbors(fps, query_fingerprint, 10)
closest

In [None]:
list(fingerprint_db.items())[closest[0]]

In [None]:
## Put it together

def best_match_mnist(fingerprint_db, fingerprint, k):
    fps = list(fingerprint_db.values())
    keys = list(fingerprint_db.keys())
    match_indices = find_k_nearest_neighbors(fps,fingerprint,k)
    for i in match_indices:
        match_fp = fps[i]
        distance = spatial.distance.cosine(match_fp, fingerprint)
        yield (keys[i], distance)
        
list(best_match_mnist(fingerprint_db, fingerprint, 10))

In [None]:
def fingerprint_file(path) :
    img = PILImage.create(path)
    result = learn.predict(img)
    fingerprint = result[1].numpy()
    return fingerprint

In [None]:
fp4 = fingerprint_file('4.png')
list(best_match_mnist(fingerprint_db, fp4, 5)

In [None]:
fps1 = fingerprint_file('Star1.png')
fps2 = fingerprint_file('Star2.png')
list(best_match_mnist(fingerprint_db, fps1, 1))

In [None]:
spatial.distance.cosine(fps2, fps1)

In [None]:
list(best_match_mnist(fingerprint_db, fps2, 1))

In [None]:
fingerprint_db[('*','/Star1.png')] = fps1

In [None]:
(list(m for m in best_match_mnist(fingerprint_db, fps2, 10)))

In [None]:
(list(m for m in best_match_mnist(fingerprint_db, fps2, 10000) if m[0][0] != "8"))[0:10]

In [None]:
query_img = PILImage.create("Star2.png")
query_img

In [None]:
%matplotlib inline

In [None]:
PILImage.create("Star1.png")

I think this net is overfitted to numbers. But does that make sense? Shouldn't just... features to features create the appropriate cosine distance between star and star?

Or, maybe it's that the best diagnostic is: "OK, return the top 3 returns in the top 3 categories." So you'd get "8"s and their <0.04 distance, and you'd get FPS1 at 0.04, and then you'd get the top (whatever) 

In [None]:
fps = np.array(list(fingerprint_db.values()))
fps.shape

In [None]:
def train_and_fingerprint(epochs, learning_rate, files_to_fingerprint) :
    learn.fit_one_cycle(epochs,slice(learning_rate))
    fingerprints = fingerprint_all(files_to_fingerprint)
    
    return (learn, fingerprints)

In [None]:


def visualize(fingerprints_db, fingerprints_to_highlight = [], indices_to_highlight = []):
    l = list(fingerprint_db.values())
    fps = np.array(l + fingerprints_to_highlight)
    tsne = TSNE(2, random_state = 42, verbose = 1)    
    tsne_proj = tsne.fit_transform(fps)
    
    cmap = cm.get_cmap('tab20')
    fig, ax = plt.subplots(figsize=(8,8))
    num_categories = 11

    for i in range(num_categories):
        matching_indices = [key for key, (val,_) in enumerate(fingerprint_db.keys()) if val == str(i)]
        plt.scatter(tsne_proj[matching_indices,0], tsne_proj[matching_indices,1], c = np.array(cmap(i)).reshape(1,4), label = i)
    ax.legend(fontsize='large', markerscale=2)
    
    # Highlight
    plt.scatter(tsne_proj[indices_to_highlight,0], tsne_proj[indices_to_highlight,1], marker = '+', c = 'k', label = 'highlight') 
    ixs_to_highlight = range(len(fps) - len(fingerprints_to_highlight),len(fps))
    (xs, ys) = (tsne_proj[ixs_to_highlight,0], tsne_proj[ixs_to_highlight,1])
    plt.scatter(xs, ys, marker = '+', c = 'k', label = 'highlight') 
    (min_x, min_y, max_x, max_y) = (min(xs), min(ys), max(xs), max(ys))
    min_x = min_x - 2
    min_y = min_y - 2
    width = max_x - min_x + 2
    height = max_y - min_y + 2
    rect = patches.Rectangle((min_x, min_y), width, height, linewidth=1, edgecolor='k', facecolor='none')
    ax.add_patch(rect)
    
    plt.savefig(model_name)
    plt.show()
    return tsne

In [None]:
%%capture 
# Use %%capture to suppress fingerprint_all output, which produces a blank line for each file for some reason
(_, fps) = train_and_fingerprint(0, .0005, fnames)

In [None]:
visualize(fps, fingerprints_to_highlight = [fps1, fps2])