In [1]:
# coding: utf-8
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import pickle

import multiprocessing
num_cpus = multiprocessing.cpu_count()

# Setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(num_cpus, device)

8 cuda:0


In [2]:
# parameters
dim_embedding = 1000
k = 30

In [3]:
def default_image_loader(path):
    return Image.open(path).convert('RGB')

class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self, base_path, triplets_file_name, transform=None,
                 loader=default_image_loader):
        self.base_path = base_path  
        self.filenamelist = []
        triplets = []
        for line in open(triplets_file_name):
            triplets.append((line.split()[0], line.split()[1], line.split()[2])) # Q, P, N
        self.triplets = triplets
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        path1, path2, path3 = self.triplets[index]
        img1 = self.loader(os.path.join(self.base_path,path1))
        img2 = self.loader(os.path.join(self.base_path,path2))
        img3 = self.loader(os.path.join(self.base_path,path3))
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)

        return img1, img2, img3

    def __len__(self):
        return len(self.triplets)
    
class EmbeddingImageLoader(torch.utils.data.Dataset):
    def __init__(self, base_path, image_file_name, transform=None,
                 loader=default_image_loader):
        self.base_path = base_path  
        self.filenamelist = []
        files = []
        for line in open(image_file_name):
            files.append((line.rstrip('\n')))
        self.files = files
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        path = self.files[index]
        img = self.loader(os.path.join(self.base_path,path))
        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.files)

In [4]:
# transformations
pretrain_image_size = 224

transform_train = transforms.Compose(
    [torchvision.transforms.Resize(pretrain_image_size),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_test = transforms.Compose(
    [torchvision.transforms.Resize(pretrain_image_size),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [5]:
#base_path = "/projects/training/bauh/tiny-imagenet-200/"
base_path = ""
train_sampler_file = "E:/study/CS598/tiny-imagenet-200/train_sampler.txt"
trainset = TripletImageLoader(base_path, train_sampler_file, transform = transform_train)
trainloader = DataLoader(trainset, batch_size=10, shuffle=False, num_workers=0)

#test_sampler_file = "test_sampler.txt"
#testset = TripletImageLoader(base_path, test_sampler_file, transform = transform_test)
#testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=num_cpus)

In [6]:
db_file = 'db.txt'
val_file = 'val.txt'

db_classes = np.array(pickle.load(open('db_classes.pkl','rb')))
val_classes = np.array(pickle.load(open('val_classes.pkl','rb')))

dataset = EmbeddingImageLoader(base_path, db_file, transform = transform_test)
valset = EmbeddingImageLoader(base_path, val_file, transform = transform_test)

testloader_db = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=0)
testloader_val = DataLoader(valset, batch_size=100, shuffle=False, num_workers=0)

In [7]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, Q, P, N):
        Q_embedding = self.embedding_net(Q)
        P_embedding = self.embedding_net(P)
        N_embedding = self.embedding_net(N)
        return Q_embedding, P_embedding, N_embedding

In [8]:
resnet = models.resnet101(pretrained=True)

# change output layer
#num_ftrs = net.fc.in_features
#net.fc = nn.Linear(num_ftrs, dim_embedding)


# for p in net.parameters():
#     p.requires_grad=False
# for p in net.fc.parameters():
#     p.requires_grad=True

In [9]:
net = TripletNet(resnet)

net.to(device)

criterion = nn.TripletMarginLoss(margin = 1)
optimizer = optim.SGD(net.parameters(), lr=0.001)

In [12]:
# Train
val_percision_over_time = []

net.train()
for epoch in range(1):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        net.train()
        # get the inputs
        Qin, Pin, Nin = data
        Qin, Pin, Nin = Qin.to(device), Pin.to(device), Nin.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        eQ, eP, eN = net(Qin, Pin, Nin)
        
        loss_tripletNet = criterion(eQ, eP, eN)
        loss_embeddingNet = eQ.norm(2) + eP.norm(2) + eN.norm(2)
        loss = loss_tripletNet + 0.001 * loss_embeddingNet
        
        #  backward + optimize
        loss.backward()

        for group in optimizer.param_groups:
            for p in group['params']:
                state = optimizer.state[p]
                if('step' in state and state['step']>=1024):
                    state['step'] = 1000

        optimizer.step()
        
        running_loss += loss.item()
        if (i+1) % 10 == 0:
            print('average 10 epoch loss at epoch number',(i+1),':\t',running_loss/10)
            running_loss = 0.0
    
        # validate
        if (i+1) % (200*500//10//10) == 0: # 1/10 of an epoch
            print("testing")
            net.eval()
            correct = 0
            total = 0
            embeddings_db = []
            embeddings_val = []
            with torch.no_grad():
                progress_counter = 0
                for data in testloader_db:
                    Qin = data
                    Qin = Qin.to(device)

                    eQ = resnet(Qin)
                    embeddings_db.append(eQ)

                    progress_counter += 100
                    if progress_counter % 1000 == 0:
                        print('db embedding progress:',progress_counter)

                progress_counter = 0
                for data in testloader_val:
                    Qin = data
                    Qin = Qin.to(device)

                    eQ = resnet(Qin)
                    embeddings_val.append(eQ)

                    progress_counter += 100
                    if progress_counter % 1000 == 0:
                        print('val embedding progress:',progress_counter)
                    
            db_embedding = torch.cat(embeddings_db)
            val_embedding = torch.cat(embeddings_val)
            
            val_top_k_percision = []
            
            for iv in range(val_embedding.size(0)):
                distances = torch.norm((db_embedding-val_embedding[iv]),p=2,dim=1)
                top_k_idx = torch.topk(distances,k=k,dim=0,largest=False,sorted=False)[1]

                top_k_classes = db_classes[top_k_idx]
                val_class = val_classes[iv]
                top_k_percision = (top_k_classes == val_class).sum()/k
            
                val_top_k_percision.append(top_k_percision)
            
            print("testing done")
            val_percision_over_time.append(val_top_k_percision)

print('Finished Training')

average 10 epoch loss at epoch number 10 :	 0.8108937978744507
average 10 epoch loss at epoch number 20 :	 0.8987368971109391
average 10 epoch loss at epoch number 30 :	 0.7618053764104843
average 10 epoch loss at epoch number 40 :	 0.7220729053020477
average 10 epoch loss at epoch number 50 :	 1.352463138103485
average 10 epoch loss at epoch number 60 :	 1.0354071706533432
average 10 epoch loss at epoch number 70 :	 0.8446289330720902
average 10 epoch loss at epoch number 80 :	 1.0003650814294816
average 10 epoch loss at epoch number 90 :	 0.8954549759626389
average 10 epoch loss at epoch number 100 :	 1.1094906479120255
average 10 epoch loss at epoch number 110 :	 1.0676131695508957
average 10 epoch loss at epoch number 120 :	 1.0442687630653382
average 10 epoch loss at epoch number 130 :	 0.6357000559568405
average 10 epoch loss at epoch number 140 :	 0.820886293053627
average 10 epoch loss at epoch number 150 :	 1.2252923548221588
average 10 epoch loss at epoch number 160 :	 0.7473

db embedding progress: 64000
db embedding progress: 65000
db embedding progress: 66000
db embedding progress: 67000
db embedding progress: 68000
db embedding progress: 69000
db embedding progress: 70000
db embedding progress: 71000
db embedding progress: 72000
db embedding progress: 73000
db embedding progress: 74000
db embedding progress: 75000
db embedding progress: 76000
db embedding progress: 77000
db embedding progress: 78000
db embedding progress: 79000
db embedding progress: 80000
db embedding progress: 81000
db embedding progress: 82000
db embedding progress: 83000
db embedding progress: 84000
db embedding progress: 85000
db embedding progress: 86000
db embedding progress: 87000
db embedding progress: 88000
db embedding progress: 89000
db embedding progress: 90000
db embedding progress: 91000
db embedding progress: 92000
db embedding progress: 93000
db embedding progress: 94000
db embedding progress: 95000
db embedding progress: 96000
db embedding progress: 97000
db embedding p

NameError: name 'db_embeddings' is not defined

In [14]:
         
db_embedding = torch.cat(embeddings_db)
val_embedding = torch.cat(embeddings_val)
            
val_top_k_percision = []
            
for iv in range(val_embedding.size(0)):
    distances = torch.norm((db_embedding-val_embedding[iv]),p=2,dim=1)
    top_k_idx = torch.topk(distances,k=k,dim=0,largest=False,sorted=False)[1]

    top_k_classes = db_classes[top_k_idx]
    val_class = val_classes[iv]
    top_k_percision = (top_k_classes == val_class).sum()/k
            
    val_top_k_percision.append(top_k_percision)
            
print("testing done")
val_percision_over_time.append(val_top_k_percision)

testing done


In [15]:
val_percision_over_time

[[0.8,
  0.0,
  0.06666666666666667,
  0.4,
  0.03333333333333333,
  0.8333333333333334,
  0.36666666666666664,
  0.6,
  1.0,
  0.3333333333333333,
  0.1,
  0.0,
  0.3333333333333333,
  0.1,
  0.0,
  0.2,
  0.06666666666666667,
  0.7666666666666667,
  0.6,
  0.06666666666666667,
  0.23333333333333334,
  0.9666666666666667,
  0.0,
  0.06666666666666667,
  0.7,
  0.8333333333333334,
  0.6333333333333333,
  0.1,
  0.2,
  0.36666666666666664,
  0.06666666666666667,
  0.1,
  0.13333333333333333,
  0.5333333333333333,
  0.03333333333333333,
  0.5666666666666667,
  0.4,
  0.03333333333333333,
  0.0,
  0.6666666666666666,
  0.1,
  0.4666666666666667,
  1.0,
  0.06666666666666667,
  0.5,
  0.06666666666666667,
  0.06666666666666667,
  0.9666666666666667,
  0.5666666666666667,
  0.1,
  0.06666666666666667,
  0.3,
  0.6666666666666666,
  0.23333333333333334,
  0.3333333333333333,
  0.23333333333333334,
  0.5333333333333333,
  0.0,
  0.3,
  0.0,
  0.26666666666666666,
  0.5333333333333333,
  0.933