In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Thu Jul  1 16:07:54 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             Off  | 00000000:3B:00.0 Off |                  N/A |
| 28%   39C    P2    38W / 250W |   2979MiB / 12066MiB |     44%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN V             Off  | 00000000:5E:00.0 Off |                  N/A |
| 28%   41C    P2    40W / 250W |   3077MiB / 12066MiB |     17%      Default |
|       

In [2]:
import os
os.chdir('/home/l/liny/ruofan/pytorch-metric-learning/src')
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1"

In [3]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

from pytorch_metric_learning.models import bninception
from pytorch_metric_learning import samplers
from pytorch_metric_learning.datasets.Cub2011 import *

In [4]:
device = torch.device("cuda")

train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

batch_size = 64
num_epochs = 200
num_classes = 200
hidden_dim = 512
centers_per_class = 5
result_dir = './log'
exp_name = 'Cub2011'
os.makedirs(result_dir, exist_ok=True)

In [5]:
train_data = Cub2011(root='pytorch_metric_learning/datasets/Cub2011/', 
                     train=True, 
                     transform=train_transform)

In [6]:
test_data = Cub2011(root='pytorch_metric_learning/datasets/Cub2011/', 
                     train=False, 
                     transform=test_transform)

In [7]:
sampler = samplers.MPerClassSampler(train_data.labels, 
                                    m=5, 
                                    length_before_new_iter=100000)

In [8]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=sampler)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [10]:
for data, target in train_loader:
    print(target)
    break

tensor([167, 167, 167, 167, 167, 133, 133, 133, 133, 133,  71,  71,  71,  71,
         71, 144, 144, 144, 144, 144, 194, 194, 194, 194, 194,  60,  60,  60,
         60,  60,  28,  28,  28,  28,  28, 151, 151, 151, 151, 151, 170, 170,
        170, 170, 170,  20,  20,  20,  20,  20, 123, 123, 123, 123, 123, 115,
        115, 115, 115, 115, 132, 132, 132, 132])


In [11]:
model = bninception(dim=hidden_dim, pretrained=None)
model = torch.nn.DataParallel(model).to(device)

In [12]:
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(test_embeddings, 
                                                  train_embeddings,
                                                  test_labels,
                                                  train_labels,
                                                  False)
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]

In [13]:
### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.MeanReducer()
loss_func = losses.SoftTripleLoss(num_classes=num_classes, 
                                  embedding_size=hidden_dim, 
                                  centers_per_class=centers_per_class, 
                                  la=20, 
                                  gamma=0.1, 
                                  margin=0.01).to(device)

loss_optimizer = torch.optim.Adam([{"params": model.parameters(), "lr": 1e-4},
                                   {"params": loss_func.parameters(), "lr": 1e-2}])

mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semi-hard")
accuracy_calculator = AccuracyCalculator(include = ("precision_at_1",), k = 1)

In [None]:
### pytorch-metric-learning stuff ###

for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, loss_optimizer, epoch)
    knn_acc = test(train_data, test_data, model, accuracy_calculator)
    if epoch % 50 == 0 or epoch == 1:
        torch.save(model.state_dict(), os.path.join(result_dir, '{}_epoch{}_knnAcc{:.4f}.pt'.format(exp_name, epoch, knn_acc)))


Epoch 1 Iteration 0: Loss = 5.0204339027404785, Number of mined triplets = 13557
Epoch 1 Iteration 100: Loss = 4.598340034484863, Number of mined triplets = 10376
Epoch 1 Iteration 200: Loss = 4.502689361572266, Number of mined triplets = 10363
Epoch 1 Iteration 300: Loss = 3.1398725509643555, Number of mined triplets = 9010
