In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Thu Jul  1 10:23:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             Off  | 00000000:3B:00.0 Off |                  N/A |
| 28%   41C    P2    97W / 250W |   8730MiB / 12066MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN V             Off  | 00000000:5E:00.0 Off |                  N/A |
| 28%   43C    P2    91W / 250W |   5428MiB / 12066MiB |      0%      Default |
|       

In [2]:
import os
os.chdir('/home/l/liny/ruofan/pytorch-metric-learning/src')
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1"

In [3]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

from pytorch_metric_learning.models import bninception
from pytorch_metric_learning import samplers

In [4]:
device = torch.device("cuda")

train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

batch_size = 64
num_epochs = 200
result_dir = './log'
exp_name = 'logo2k'
os.makedirs(result_dir, exist_ok=True)

In [5]:
import pickle
from PIL import Image, ImageOps
from typing import List, Union, Callable

class GetLoader(torch.utils.data.Dataset):
    '''Define customized dataset
    Args:
        data_root:
            Path to directory holding the images to load.
        data_list:
            Path to txt file which map images to labels.
        label_dict:
            Dict which converts label in plain text to label in int.
        transform:
            Transformations to apply.
        grayscale:
            Grayscale model/RGB model, default is RGB
    
    '''
    def __init__(self, data_root, data_list, label_dict, transform=None, grayscale=False):
        
        self.transform = transform
        self.data_root = data_root
        self.grayscale = grayscale
        data_list = [x.strip('\n') for x in open(data_list).readlines()]

        with open(label_dict, 'rb') as handle:
            self.label_dict = pickle.load(handle)

        self.classes = list(self.label_dict.keys())

        self.n_data = len(data_list)

        self.img_paths = []
        self.labels = []
        self.targets = []

        for data in data_list:
            image_path = data
            label = image_path.split('/')[0]
            self.img_paths.append(image_path)
            self.labels.append(label)
            self.targets.append(self.label_dict[label])

    def __getitem__(self, index):

        img_path, label= self.img_paths[index], self.labels[index]
        img_path_full = os.path.join(self.data_root, img_path)
        if self.grayscale:
            img = Image.open(img_path_full).convert('L').convert('RGB')
        else:
            img = Image.open(img_path_full).convert('RGB')

        img = ImageOps.expand(img, (
            (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2,
            (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=(255, 255, 255))

        label = self.label_dict[label]
        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return self.n_data

In [6]:
train_data = GetLoader(data_root='/home/l/liny/ruofan/lightly/datasets/logo2k/train/', 
                           data_list='/home/l/liny/ruofan/PhishIntention/src/siamese_retrain/logo2k/train.txt', 
                           label_dict='/home/l/liny/ruofan/PhishIntention/src/siamese_retrain/logo2k/logo2k_labeldict.pkl',
                           transform=train_transform)

test_data = GetLoader(data_root='/home/l/liny/ruofan/lightly/datasets/logo2k/test/', 
                      data_list='/home/l/liny/ruofan/PhishIntention/src/siamese_retrain/logo2k/test.txt', 
                      label_dict='/home/l/liny/ruofan/PhishIntention/src/siamese_retrain/logo2k/logo2k_labeldict.pkl',
                      transform=test_transform)

In [7]:
sampler = samplers.MPerClassSampler(train_data.labels, 
                                    m=5, 
                                    length_before_new_iter=100000)

In [8]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=sampler)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [15]:
model = bninception(dim=512, pretrained=None)
model = torch.nn.DataParallel(model).to(device)

In [16]:
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(test_embeddings, 
                                                  train_embeddings,
                                                  test_labels,
                                                  train_labels,
                                                  False)
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]

In [17]:
### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.MeanReducer()
loss_func = losses.SoftTripleLoss(num_classes=2340, 
                                  embedding_size=512, 
                                  centers_per_class=5, 
                                  la=20, 
                                  gamma=0.1, 
                                  margin=0.01).to(device)

loss_optimizer = torch.optim.Adam([{"params": model.parameters(), "lr": 1e-4},
                                   {"params": loss_func.parameters(), "lr": 1e-2}])

mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semi-hard")
accuracy_calculator = AccuracyCalculator(include = ("precision_at_1",), k = 1)

In [None]:
### pytorch-metric-learning stuff ###

for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, loss_optimizer, epoch)
    knn_acc = test(train_data, test_data, model, accuracy_calculator)
    if epoch % 50 == 0 or epoch == 1:
        torch.save(model.state_dict(), os.path.join(result_dir, '{}_epoch{}_knnAcc{:.4f}.pt'.format(exp_name, epoch, knn_acc)))


Epoch 1 Iteration 0: Loss = 7.890983581542969, Number of mined triplets = 14702
Epoch 1 Iteration 100: Loss = 6.437102794647217, Number of mined triplets = 10033
Epoch 1 Iteration 200: Loss = 6.644566535949707, Number of mined triplets = 9564
Epoch 1 Iteration 300: Loss = 5.138212203979492, Number of mined triplets = 7776
Epoch 1 Iteration 400: Loss = 4.899921894073486, Number of mined triplets = 7424
Epoch 1 Iteration 500: Loss = 6.552186012268066, Number of mined triplets = 10032
Epoch 1 Iteration 600: Loss = 5.8222832679748535, Number of mined triplets = 9376
Epoch 1 Iteration 700: Loss = 5.751879692077637, Number of mined triplets = 8978
Epoch 1 Iteration 800: Loss = 5.512117385864258, Number of mined triplets = 9762
Epoch 1 Iteration 900: Loss = 5.344035625457764, Number of mined triplets = 8410
Epoch 1 Iteration 1000: Loss = 5.829801559448242, Number of mined triplets = 9519
Epoch 1 Iteration 1100: Loss = 5.217018127441406, Number of mined triplets = 9072
Epoch 1 Iteration 1200: 

100%|██████████| 3655/3655 [04:13<00:00, 14.41it/s]
100%|██████████| 1569/1569 [01:51<00:00, 14.07it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.06954684815354994
Epoch 2 Iteration 0: Loss = 4.698849201202393, Number of mined triplets = 8346
Epoch 2 Iteration 100: Loss = 5.577515602111816, Number of mined triplets = 8875
Epoch 2 Iteration 200: Loss = 5.384181022644043, Number of mined triplets = 9619
Epoch 2 Iteration 300: Loss = 5.572055339813232, Number of mined triplets = 9091
Epoch 2 Iteration 400: Loss = 5.139688491821289, Number of mined triplets = 8790
Epoch 2 Iteration 500: Loss = 5.322958946228027, Number of mined triplets = 9211
Epoch 2 Iteration 600: Loss = 5.1449055671691895, Number of mined triplets = 8986
Epoch 2 Iteration 700: Loss = 5.100333213806152, Number of mined triplets = 8374
Epoch 2 Iteration 800: Loss = 4.989513397216797, Number of mined triplets = 8344
Epoch 2 Iteration 900: Loss = 4.7517805099487305, Number of mined triplets = 8099
Epoch 2 Iteration 1000: Loss = 5.105037689208984, Number of mined triplets = 8115
Epoch 2 Iteration 1100: Loss = 4.63

100%|██████████| 3655/3655 [04:20<00:00, 14.01it/s]
100%|██████████| 1569/1569 [01:51<00:00, 14.01it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.0971663127784268
Epoch 3 Iteration 0: Loss = 5.184262275695801, Number of mined triplets = 8535
Epoch 3 Iteration 100: Loss = 4.109067916870117, Number of mined triplets = 6978
Epoch 3 Iteration 200: Loss = 4.816425323486328, Number of mined triplets = 7332
Epoch 3 Iteration 300: Loss = 5.017761707305908, Number of mined triplets = 7988
Epoch 3 Iteration 400: Loss = 4.506860733032227, Number of mined triplets = 7442
Epoch 3 Iteration 500: Loss = 4.686699867248535, Number of mined triplets = 7735
Epoch 3 Iteration 600: Loss = 3.8762736320495605, Number of mined triplets = 6695
Epoch 3 Iteration 700: Loss = 4.485928535461426, Number of mined triplets = 7946
Epoch 3 Iteration 800: Loss = 5.441232681274414, Number of mined triplets = 8958
Epoch 3 Iteration 900: Loss = 4.671576023101807, Number of mined triplets = 8057
Epoch 3 Iteration 1000: Loss = 4.1937456130981445, Number of mined triplets = 7352
Epoch 3 Iteration 1100: Loss = 5.602

100%|██████████| 3655/3655 [04:20<00:00, 14.05it/s]
100%|██████████| 1569/1569 [01:42<00:00, 15.29it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.12084013959974982
Epoch 4 Iteration 0: Loss = 4.6623921394348145, Number of mined triplets = 7893
Epoch 4 Iteration 100: Loss = 4.755434036254883, Number of mined triplets = 7540
Epoch 4 Iteration 200: Loss = 4.282244682312012, Number of mined triplets = 7147
Epoch 4 Iteration 300: Loss = 3.827988624572754, Number of mined triplets = 6538
Epoch 4 Iteration 400: Loss = 3.767906665802002, Number of mined triplets = 6733
Epoch 4 Iteration 500: Loss = 4.208225250244141, Number of mined triplets = 7634
Epoch 4 Iteration 600: Loss = 5.026324272155762, Number of mined triplets = 8204
Epoch 4 Iteration 700: Loss = 4.075446605682373, Number of mined triplets = 6745
Epoch 4 Iteration 800: Loss = 4.452949523925781, Number of mined triplets = 7704
Epoch 4 Iteration 900: Loss = 4.6066670417785645, Number of mined triplets = 7385
Epoch 4 Iteration 1000: Loss = 4.591436862945557, Number of mined triplets = 7647
Epoch 4 Iteration 1100: Loss = 4.28

100%|██████████| 3655/3655 [04:19<00:00, 14.08it/s]
100%|██████████| 1569/1569 [01:53<00:00, 13.78it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.1493164851617621
Epoch 5 Iteration 0: Loss = 3.8996593952178955, Number of mined triplets = 7348
Epoch 5 Iteration 100: Loss = 4.021068096160889, Number of mined triplets = 7262
Epoch 5 Iteration 200: Loss = 4.634955883026123, Number of mined triplets = 7663
Epoch 5 Iteration 300: Loss = 3.544334650039673, Number of mined triplets = 5805
