In [1]:
!pip install pytorch-metric-learning
!pip install faiss-gpu
!pip install umap-learn

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-2.5.0-py3-none-any.whl (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.6.0->pytorch-metric-learning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.6.0->pytorch-metric-learning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.6.0->pytorch-metric-learning)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.6.0->pytorch-metric-learning)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.6.0->pytorch-met

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import umap.umap_ as umap
from cycler import cycler
import logging

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
from torchvision import datasets, transforms

from pytorch_metric_learning import distances, losses, miners, reducers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

import pytorch_metric_learning.utils.logging_presets as logging_presets

from pytorch_metric_learning.samplers import MPerClassSampler

In [3]:

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [4]:


### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print(
                "Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(
                    epoch, batch_idx, loss, mining_func.num_triplets
                )
            )

In [5]:


### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)


In [6]:


### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, test_labels, train_embeddings, train_labels, False
    )
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))

In [8]:
device = torch.device("cuda")

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

batch_size = 256

dataset1 = datasets.MNIST(".", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST(".", train=False, transform=transform)
train_sampler = MPerClassSampler(dataset1.targets, m=4) #sampler for 4 datapoints per class

train_loader = torch.utils.data.DataLoader(
    dataset1, batch_size=batch_size, shuffle=False, sampler= train_sampler
)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size)

In [9]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 2


### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)
accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs + 1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator)

Epoch 1 Iteration 0: Loss = 0.10672411322593689, Number of mined triplets = 831023
Epoch 1 Iteration 20: Loss = 0.09169085323810577, Number of mined triplets = 91612
Epoch 1 Iteration 40: Loss = 0.08725952357053757, Number of mined triplets = 62340
Epoch 1 Iteration 60: Loss = 0.08429042994976044, Number of mined triplets = 45409
Epoch 1 Iteration 80: Loss = 0.08533218502998352, Number of mined triplets = 48165
Epoch 1 Iteration 100: Loss = 0.0827733650803566, Number of mined triplets = 35447
Epoch 1 Iteration 120: Loss = 0.08189624547958374, Number of mined triplets = 30935
Epoch 1 Iteration 140: Loss = 0.08429250121116638, Number of mined triplets = 29141
Epoch 1 Iteration 160: Loss = 0.0801670178771019, Number of mined triplets = 23155
Epoch 1 Iteration 180: Loss = 0.08094586431980133, Number of mined triplets = 18762
Epoch 1 Iteration 200: Loss = 0.08416017889976501, Number of mined triplets = 19353
Epoch 1 Iteration 220: Loss = 0.08524766564369202, Number of mined triplets = 22677

100%|██████████| 1875/1875 [00:16<00:00, 114.66it/s]
100%|██████████| 313/313 [00:02<00:00, 122.88it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.9828
Epoch 2 Iteration 0: Loss = 0.08110182732343674, Number of mined triplets = 19178
Epoch 2 Iteration 20: Loss = 0.08309600502252579, Number of mined triplets = 17785
Epoch 2 Iteration 40: Loss = 0.08153527975082397, Number of mined triplets = 20730
Epoch 2 Iteration 60: Loss = 0.07751911878585815, Number of mined triplets = 16077
Epoch 2 Iteration 80: Loss = 0.08110570162534714, Number of mined triplets = 22225
Epoch 2 Iteration 100: Loss = 0.08376365154981613, Number of mined triplets = 20176
Epoch 2 Iteration 120: Loss = 0.08577261865139008, Number of mined triplets = 26528
Epoch 2 Iteration 140: Loss = 0.08315284550189972, Number of mined triplets = 18558
Epoch 2 Iteration 160: Loss = 0.08728939294815063, Number of mined triplets = 22372
Epoch 2 Iteration 180: Loss = 0.08220931142568588, Number of mined triplets = 15424
Epoch 2 Iteration 200: Loss = 0.08564545214176178, Number of mined triplets = 19011
Epoch 2 Iteration 220:

100%|██████████| 1875/1875 [00:16<00:00, 114.63it/s]
100%|██████████| 313/313 [00:04<00:00, 78.02it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.9846


### inference

In [None]:
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder

In [None]:
def print_decision(is_match):
    if is_match:
        print("Same class")
    else:
        print("Different class")


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
)


def imshow(img, figsize=(8, 4)):
    img = inv_normalize(img)
    npimg = img.numpy()
    plt.figure(figsize=figsize)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
match_finder = MatchFinder(distance=distances.CosineSimilarity(), threshold=0.7)
inference_model = InferenceModel(model, match_finder=match_finder)

# cars and frogs
classA, classB = labels_to_indices[1], labels_to_indices[6]

In [None]:
inference_model.train_knn(dataset1)

In [None]:
for img_type in [classA, classB]:
    img = dataset1[img_type[0]][0].unsqueeze(0)
    print("query image")
    imshow(torchvision.utils.make_grid(img))
    distances, indices = inference_model.get_nearest_neighbors(img, k=10)
    nearest_imgs = [dataset1[i][0] for i in indices.cpu()[0]]
    print("nearest images")
    imshow(torchvision.utils.make_grid(nearest_imgs))

In [None]:
(x, _), (y, _) = dataset1[classA[0]], dataset1[classA[1]]
imshow(torchvision.utils.make_grid(torch.stack([x, y], dim=0)))
decision = inference_model.is_match(x.unsqueeze(0), y.unsqueeze(0))
print_decision(decision)

In [None]:
(x, _), (y, _) = dataset1[classA[0]], dataset1[classB[0]]
imshow(torchvision.utils.make_grid(torch.stack([x, y], dim=0)))
decision = inference_model.is_match(x.unsqueeze(0), y.unsqueeze(0))
print_decision(decision)

In [None]:
x = torch.zeros(20, 1, 28, 28)
y = torch.zeros(20, 1, 28, 28)
for i in range(0, 20, 2):
    x[i] = dataset1[classA[i]][0]
    x[i + 1] = dataset1[classB[i]][0]
    y[i] = dataset1[classA[i + 20]][0]
    y[i + 1] = dataset1[classB[i + 20]][0]
imshow(torchvision.utils.make_grid(torch.cat((x, y), dim=0), nrow=20), figsize=(30, 3))
decision = inference_model.is_match(x, y)
for d in decision:
    print_decision(d)
print("accuracy = {}".format(np.sum(decision) / len(x)))

In [None]:
#using trainer of pytorch_metric_learn, we also built umap in this

In [None]:
models = {"trunk": model}
optimizers = {
    "trunk_optimizer": optimizer
}
loss_funcs = {"metric_loss": loss_func}
mining_funcs = {"tuple_miner": miner}

dataset_dict = {"val": dataset2}


In [None]:
record_keeper, _, _ = logging_presets.get_record_keeper(
    "example_logs", "example_tensorboard"
)
hooks = logging_presets.get_hook_container(record_keeper)
model_folder = "example_saved_models"


def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):
    logging.info(
        "UMAP plot for the {} split and label set {}".format(split_name, keyname)
    )
    label_set = np.unique(labels)
    num_classes = len(label_set)
    plt.figure(figsize=(20, 15))
    plt.gca().set_prop_cycle(
        cycler(
            "color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]
        )
    )
    for i in range(num_classes):
        idx = labels == label_set[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=1)
    plt.show()


# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook,
    visualizer=umap.UMAP(),
    visualizer_hook=visualizer_hook,
    dataloader_num_workers=2,
    accuracy_calculator=AccuracyCalculator(k="max_bin_count"),
)

end_of_epoch_hook = hooks.end_of_epoch_hook(
    tester, dataset_dict, model_folder#, test_interval=1, patience=1
)

In [None]:
trainer = trainers.MetricLossOnly(
    models,
    optimizers,
    batch_size,
    loss_funcs,
    dataset1,
    mining_funcs=mining_funcs,
    sampler=train_sampler,
    dataloader_num_workers=2,
    end_of_iteration_hook=hooks.end_of_iteration_hook,
    end_of_epoch_hook=end_of_epoch_hook,
)

In [None]:
trainer.train(num_epochs=num_epochs)