**Purpose of this notebook**

This notebook presents how we can use the machine learning the similarity between images.
Particularly we would like to distinguish several types of relationship between images:
1. exact duplicate
1. near-duplicate
1. similar
1. different

![alt text](categories_similarity_openfoodfact.jpg "Title")

**Proposal**  
Use the machine learning to represent the image in a new space where the distance is correlated with similarity.

**Hypothesis**  
The deep learning with neural network (NN) is supposed to be able to catch/learn some patterns from its training dataset that helps itself to discriminate instance of this dataset. By using the trained neural networks, it will be possible to represent the picture in some embeddings that `would be easier to discriminate`, or allow us to build `a metric of similarity`.

**Protocol**  
1. load images
1. download a already trained NN for images
2. use the backbone of the model to generate the embeddings of images (more exactly to transform the pixel of images into another representation called embeddings). Thus, we considered the following hypothesis: `the euclidean distance in the embedding space` is correlated with the `similarity`.
1. by products, look at the distance between images, tag some of them that are `exact_duplicate`, `near_duplicate`, `very_similar` and `different`
1. build a small model that determine the optimal threshold.

\*: if you do not understand something, be curious :)

# Protocol

## Load images

In [None]:
from pathlib import Path
from datasets import Dataset, Image, load_dataset

In [None]:
# to clean data if necessary
# import os
# path = Path('../data/images').resolve()
# for dir in os.listdir(path):
#     for file in os.listdir(path / dir):
#         if 'front.' in file:
#             os.remove(path / f'{dir}/{file}')
#         if 'ingredients.' in file:
#             os.remove(path / f'{dir}/{file}')
#         if 'nutrition.' in file:
#             os.remove(path / f'{dir}/{file}')
#         if 'packaging' in file:
#             os.remove(path / f'{dir}/{file}')
#         if 'other' in file:
#             os.remove(path / f'{dir}/{file}')

In [None]:
images = load_dataset("imagefolder", data_dir="../../data/images")
images = images['train'].cast_column('image', Image(decode=True)) # all images are in train

## Load models and produce embeddings

In [None]:
from transformers import AutoFeatureExtractor, AutoModel
import torch
import torchvision.transforms as T

In [None]:
model_ckpt = "nateraw/vit-base-beans"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

In [None]:
candidate_subset = images.filter(lambda x: x['label'] == 1)
# candidate_subset

In [None]:
# Data transformation chain.

class TransformationChain():
    
    def __init__(self):
        self.transformation_chain = T.Compose(
        [
            # We first resize the input image to 256x256 and then we take center crop.
            T.Resize(int((256 / 224) * extractor.size["height"])),
            T.CenterCrop(extractor.size["height"]),
            T.ToTensor(),
        ]
    )
        
    def __call__(self, image):
        tensor = self.transformation_chain(image)
        if tensor.shape[0] == 1:
            tensor = tensor.expand(3, tensor.shape[1], tensor.shape[2])
        if tensor.shape[0] > 3:
            tensor = tensor[:3]
        tensor = T.Normalize(mean=extractor.image_mean, std=extractor.image_std)(tensor)
        return tensor

transformation_chain = TransformationChain()

def extract_embeddings(model: torch.nn.Module):
    """Utility to compute embeddings."""
    device = model.device

    def pp(batch):
        images = batch["image"]
        image_batch_transformed = torch.stack(
            [transformation_chain(image) for image in images]
        )
        new_batch = {"pixel_values": image_batch_transformed.to(device)}
        with torch.no_grad():
            embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
        return {"embeddings": embeddings}
    return pp

# to check 
# for i, image in enumerate(images):
#     try:
#         transformation_chain(image['image'])
#     except:
#         print('error:', i)

In [None]:
# Here, we map embedding extraction utility on our subset of candidate images.
batch_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')
extract_fn = extract_embeddings(model.to(device))

In [None]:
embeddings = images.map(extract_fn, batched=True, batch_size=24)

In [None]:
from pacmap import PaCMAP
from trimap import TRIMAP
import numpy as np

## Distance / similarity to determine threshold

In [None]:
from sklearn.metrics.pairwise import euclidean_distances


projector = TRIMAP() # PaCMAP()
emb_2d = projector.fit_transform(
    np.array(embeddings.filter(lambda x: x['label'] == 1)['embeddings'])
    )

import numpy as np
import matplotlib.pyplot as plt

plt.scatter(*emb_2d.T, s=7)
plt.show()

threshold = 0.2
dist = euclidean_distances(emb_2d)

mask = (dist > 0).astype(int) * (dist < threshold).astype(int)
np.where(mask ==1)

In [None]:
a

In [None]:
a, b = np.where(mask ==1)
a = a[len(a)//2]
b = b[len(b)//2]
if not isinstance(a, list):
    a = [a]
if not isinstance(b, list):
    b = [b]
for i, j in zip(a, b):
    print(i, j)
    embeddings[int(i)]['image'].show()
    embeddings[int(j)]['image'].show()

## Evaluate