In [1]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torchvision   
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

from sklearn.metrics import roc_auc_score

In [2]:
cuda = torch.cuda.is_available()

In [3]:
cuda

True

In [4]:
transform = T.Compose([T.Resize(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

In [5]:
train_dataset = torchvision.datasets.ImageFolder(root='classification_data/train_data/', 
                                                       transform=transform)
val_dataset = torchvision.datasets.ImageFolder(root='classification_data/val_data/', 
                                                       transform=transform)
test_dataset = torchvision.datasets.ImageFolder(root='classification_data/test_data/', 
                                                       transform=transform)

In [6]:
#Hyperparameters
batch_size = 128
num_workers = 8 if cuda else 0
pin_memory = True if cuda else False
numEpochs = 30
#num_feats = 3
lr = 0.001 #0.001
patience = 5
factor = 0.316
#learningRate = 1e-2
weight_decay = 5e-5

num_classes = len(train_dataset.classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch.min(val_dataset[0][0])

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [28]:
class Conv_Net(nn.Module):
    def __init__(self, num_classes, feat_dim):
        super(Conv_Net, self).__init__()
        self.net = torchvision.models.resnet18(num_classes = num_classes)
        self.linear_closs = nn.Linear(self.net.fc.in_features, feat_dim, bias=False)
        self.relu_closs = nn.ReLU(inplace=True)

    def forward(self, x):
        
        newmodel = torch.nn.Sequential(*(list(self.net.children())[:-1]))
        feature_output = newmodel(x)
        label_output = self.net.fc(feature_output.view(-1,self.net.fc.in_features))
        #print(feature_output.size())
        #print(label_output.size())
        closs_output = self.linear_closs(feature_output.view(-1,self.net.fc.in_features))
        closs_output = self.relu_closs(closs_output)
        return closs_output, label_output, feature_output

In [29]:
model = Conv_Net(num_classes = 4000, feat_dim=10)

In [None]:
print(model)

In [20]:
def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight.data)

In [21]:
class CenterLoss(nn.Module):
    """
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes, feat_dim, device=torch.device('cpu')):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.device = device
        
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(self.device))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = []
        for i in range(batch_size):
            value = distmat[i][mask[i]]
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
            dist.append(value)
        dist = torch.cat(dist)
        loss = dist.mean()

        return loss

In [22]:
def train_closs(model, data_loader, test_loader, ver_loader, task='Classification'):
    model.train()
    if task == 'Classification':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_label, patience=patience, factor=factor, verbose=True)
    else:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_label, patience=patience, factor=factor, verbose=True, mode='max')
    for epoch in tqdm(range(numEpochs)):
        avg_loss = 0.0
        for batch_num, (feats, labels) in enumerate(data_loader):
            feats, labels = feats.to(device), labels.to(device)
            
            optimizer_label.zero_grad()
            optimizer_closs.zero_grad()
            
            feature, outputs, _ = model(feats)

            l_loss = criterion_label(outputs, labels.long())
            c_loss = criterion_closs(feature, labels.long())
            loss = l_loss + closs_weight * c_loss
            
            loss.backward()
            
            optimizer_label.step()
            # by doing so, weight_cent would not impact on the learning of centers
            for param in criterion_closs.parameters():
                param.grad.data *= (1. / closs_weight)
            optimizer_closs.step()
            
            avg_loss += loss.item()

            if batch_num % 50 == 49:
                print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch+1, batch_num+1, avg_loss/50))
                avg_loss = 0.0    
            
            torch.cuda.empty_cache()
            del feats
            del labels
            del loss
        
        val_loss, val_acc = test_classify_closs(model, test_loader)
        print('Val Loss: {:.4f}\tVal Accuracy: {:.4f}'.format(val_loss, val_acc))
        if task == 'Classification':
            scheduler.step(val_loss)
            #train_loss, train_acc = test_classify_closs(model, data_loader)
            #print('Train Loss: {:.4f}\tTrain Accuracy: {:.4f}\tVal Loss: {:.4f}\tVal Accuracy: {:.4f}'.
             #     format(train_loss, train_acc, val_loss, val_acc))
        else:
            roc_score = test_verify_closs(model, ver_loader)
            print('Roc score: {:.4f}'.
                  format(roc_score))
            scheduler.step(roc_score)


def test_classify_closs(model, test_loader):
    model.eval()
    test_loss = []
    accuracy = 0
    total = 0

    for batch_num, (feats, labels) in enumerate(test_loader):
        feats, labels = feats.to(device), labels.to(device)
        feature, outputs, _ = model(feats)
        
        _, pred_labels = torch.max(F.softmax(outputs, dim=1), 1)
        pred_labels = pred_labels.view(-1)
        
        l_loss = criterion_label(outputs, labels.long())
        c_loss = criterion_closs(feature, labels.long())
        loss = l_loss + closs_weight * c_loss
        
        accuracy += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)
        test_loss.extend([loss.item()]*feats.size()[0])
        del feats
        del labels

    model.train()
    return np.mean(test_loss), accuracy/total

def test_verify_closs(model, test_loader):
    with torch.no_grad():
        
        model.eval()
        similarities = []
        true_labels = []
        for batch_num, (feats1, feats2, labels) in enumerate(test_loader):
            feats1, feats2 = feats1.to(device), feats2.to(device)
            #feats1 = feats1.to(device)
            _, _, output1 = model(feats1)
            _, _, output2 = model(feats2)
        
            cos = nn.CosineSimilarity()
        
            sim = cos(output1, output2)
            similarities.extend(sim)
            true_labels.extend(labels)
        
            del feats1
            del feats2
            del labels
    
        true_labels = np.array(true_labels)
        similarities = np.array(similarities)
        model.train()
        return roc_auc_score(true_labels, similarities)

In [23]:
class VerificationDataset(Dataset):
    def __init__(self, filename):
        self.file1_list = []
        self.file2_list = []
        self.target_list = []
        infile = open(filename , "r" )
        
        for line in infile :
            imfile1, imfile2, match = line.split()
            self.file1_list.append(imfile1)
            self.file2_list.append(imfile2)
            self.target_list.append(int(match))
        
        infile.close()
        

    def __len__(self):
        return len(self.file1_list)

    def __getitem__(self, index):
        img1 = Image.open(self.file1_list[index])
        img1 = transform(img1)
        img2 = Image.open(self.file2_list[index])
        img2 = transform(img2)
        label = self.target_list[index]
        return img1, img2, label

In [24]:
verification_set = VerificationDataset("verification_pairs_val.txt")

In [25]:
ver_dataloader = DataLoader(verification_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=False)

In [30]:
closs_weight = 0.01
lr_cent = 0.5


criterion_label = nn.CrossEntropyLoss()
criterion_closs = CenterLoss(num_classes, 10, device)
optimizer_label = torch.optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
optimizer_closs = torch.optim.SGD(criterion_closs.parameters(), lr=lr_cent)

model.apply(init_weights)




Conv_Net(
  (net): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [None]:
model.to(device)
train_closs(model, train_dataloader, val_dataloader, ver_dataloader, task='Verification')


  0%|          | 0/30 [00:00<?, ?it/s][A

Epoch: 1	Batch: 50	Avg-Loss: 8.6159
Epoch: 1	Batch: 100	Avg-Loss: 8.4080
Epoch: 1	Batch: 150	Avg-Loss: 8.3769
Epoch: 1	Batch: 200	Avg-Loss: 8.3701
Epoch: 1	Batch: 250	Avg-Loss: 8.3614
Epoch: 1	Batch: 300	Avg-Loss: 8.3511
Epoch: 1	Batch: 350	Avg-Loss: 8.3398
Epoch: 1	Batch: 400	Avg-Loss: 8.3219
Epoch: 1	Batch: 450	Avg-Loss: 8.2837
Epoch: 1	Batch: 500	Avg-Loss: 8.2644
Epoch: 1	Batch: 550	Avg-Loss: 8.2010
Epoch: 1	Batch: 600	Avg-Loss: 8.1372
Epoch: 1	Batch: 650	Avg-Loss: 8.0878
Epoch: 1	Batch: 700	Avg-Loss: 8.0252
Epoch: 1	Batch: 750	Avg-Loss: 7.9273
Epoch: 1	Batch: 800	Avg-Loss: 7.9187
Epoch: 1	Batch: 850	Avg-Loss: 7.8474
Epoch: 1	Batch: 900	Avg-Loss: 7.7885
Epoch: 1	Batch: 950	Avg-Loss: 7.7262
Epoch: 1	Batch: 1000	Avg-Loss: 7.6783
Epoch: 1	Batch: 1050	Avg-Loss: 7.6083
Epoch: 1	Batch: 1100	Avg-Loss: 7.5411
Epoch: 1	Batch: 1150	Avg-Loss: 7.4907
Epoch: 1	Batch: 1200	Avg-Loss: 7.4391
Epoch: 1	Batch: 1250	Avg-Loss: 7.3184
Epoch: 1	Batch: 1300	Avg-Loss: 7.2802
Epoch: 1	Batch: 1350	Avg-Loss: 7


  3%|▎         | 1/30 [24:38<11:54:31, 1478.33s/it][A

Roc score: 0.8282
Epoch: 2	Batch: 50	Avg-Loss: 4.9791
Epoch: 2	Batch: 100	Avg-Loss: 4.9232
Epoch: 2	Batch: 150	Avg-Loss: 4.8520
Epoch: 2	Batch: 200	Avg-Loss: 4.8119
Epoch: 2	Batch: 250	Avg-Loss: 4.7712
Epoch: 2	Batch: 300	Avg-Loss: 4.7707
Epoch: 2	Batch: 350	Avg-Loss: 4.7272
Epoch: 2	Batch: 400	Avg-Loss: 4.6687
Epoch: 2	Batch: 450	Avg-Loss: 4.6240
Epoch: 2	Batch: 500	Avg-Loss: 4.5601
Epoch: 2	Batch: 550	Avg-Loss: 4.5297
Epoch: 2	Batch: 600	Avg-Loss: 4.4709
Epoch: 2	Batch: 650	Avg-Loss: 4.4287
Epoch: 2	Batch: 700	Avg-Loss: 4.4567
Epoch: 2	Batch: 750	Avg-Loss: 4.4300
Epoch: 2	Batch: 800	Avg-Loss: 4.3033
Epoch: 2	Batch: 850	Avg-Loss: 4.3198
Epoch: 2	Batch: 900	Avg-Loss: 4.2687
Epoch: 2	Batch: 950	Avg-Loss: 4.2978
Epoch: 2	Batch: 1000	Avg-Loss: 4.2617
Epoch: 2	Batch: 1050	Avg-Loss: 4.2246
Epoch: 2	Batch: 1100	Avg-Loss: 4.1045
Epoch: 2	Batch: 1150	Avg-Loss: 4.1156
Epoch: 2	Batch: 1200	Avg-Loss: 4.0127
Epoch: 2	Batch: 1250	Avg-Loss: 4.1000
Epoch: 2	Batch: 1300	Avg-Loss: 4.0172
Epoch: 2	Batch


  7%|▋         | 2/30 [49:30<11:31:48, 1482.46s/it][A

Roc score: 0.8788
Epoch: 3	Batch: 50	Avg-Loss: 2.8320
Epoch: 3	Batch: 100	Avg-Loss: 2.8743
Epoch: 3	Batch: 150	Avg-Loss: 2.8770
Epoch: 3	Batch: 200	Avg-Loss: 2.8154
Epoch: 3	Batch: 250	Avg-Loss: 2.8632
Epoch: 3	Batch: 300	Avg-Loss: 2.8329
Epoch: 3	Batch: 350	Avg-Loss: 2.8400
Epoch: 3	Batch: 400	Avg-Loss: 2.8050
Epoch: 3	Batch: 450	Avg-Loss: 2.8619
Epoch: 3	Batch: 500	Avg-Loss: 2.8048
Epoch: 3	Batch: 550	Avg-Loss: 2.8431
Epoch: 3	Batch: 600	Avg-Loss: 2.8059
Epoch: 3	Batch: 650	Avg-Loss: 2.7652
Epoch: 3	Batch: 700	Avg-Loss: 2.7740
Epoch: 3	Batch: 750	Avg-Loss: 2.7441
Epoch: 3	Batch: 800	Avg-Loss: 2.7927
Epoch: 3	Batch: 850	Avg-Loss: 2.7468
Epoch: 3	Batch: 900	Avg-Loss: 2.7430
Epoch: 3	Batch: 950	Avg-Loss: 2.8177
Epoch: 3	Batch: 1000	Avg-Loss: 2.7002
Epoch: 3	Batch: 1050	Avg-Loss: 2.6788
Epoch: 3	Batch: 1100	Avg-Loss: 2.6647
Epoch: 3	Batch: 1150	Avg-Loss: 2.6805
Epoch: 3	Batch: 1200	Avg-Loss: 2.6128
Epoch: 3	Batch: 1250	Avg-Loss: 2.6459
Epoch: 3	Batch: 1300	Avg-Loss: 2.6967
Epoch: 3	Batch


 10%|█         | 3/30 [1:14:15<11:07:23, 1483.10s/it][A

Roc score: 0.8780
Epoch: 4	Batch: 50	Avg-Loss: 1.9411
Epoch: 4	Batch: 100	Avg-Loss: 1.8659
Epoch: 4	Batch: 150	Avg-Loss: 1.8959
Epoch: 4	Batch: 200	Avg-Loss: 1.8780
Epoch: 4	Batch: 250	Avg-Loss: 1.8969
Epoch: 4	Batch: 300	Avg-Loss: 1.8806
Epoch: 4	Batch: 350	Avg-Loss: 1.9456
Epoch: 4	Batch: 400	Avg-Loss: 1.9321
Epoch: 4	Batch: 450	Avg-Loss: 1.9241
Epoch: 4	Batch: 500	Avg-Loss: 1.9264
Epoch: 4	Batch: 550	Avg-Loss: 1.9420
Epoch: 4	Batch: 600	Avg-Loss: 1.9190
Epoch: 4	Batch: 650	Avg-Loss: 1.9882
Epoch: 4	Batch: 700	Avg-Loss: 1.9533
Epoch: 4	Batch: 750	Avg-Loss: 1.9459
Epoch: 4	Batch: 800	Avg-Loss: 1.9377
Epoch: 4	Batch: 850	Avg-Loss: 1.8659
Epoch: 4	Batch: 900	Avg-Loss: 1.9419
Epoch: 4	Batch: 950	Avg-Loss: 1.9105
Epoch: 4	Batch: 1000	Avg-Loss: 1.9404
Epoch: 4	Batch: 1050	Avg-Loss: 1.9752
Epoch: 4	Batch: 1100	Avg-Loss: 1.9624
Epoch: 4	Batch: 1150	Avg-Loss: 1.9045
Epoch: 4	Batch: 1200	Avg-Loss: 1.8953
Epoch: 4	Batch: 1250	Avg-Loss: 1.9141
Epoch: 4	Batch: 1300	Avg-Loss: 1.9057
Epoch: 4	Batch


 13%|█▎        | 4/30 [1:38:24<10:38:21, 1473.12s/it][A

Roc score: 0.8888
Epoch: 5	Batch: 50	Avg-Loss: 1.3972
Epoch: 5	Batch: 100	Avg-Loss: 1.3890
Epoch: 5	Batch: 150	Avg-Loss: 1.4242
Epoch: 5	Batch: 200	Avg-Loss: 1.4132
Epoch: 5	Batch: 250	Avg-Loss: 1.4489
Epoch: 5	Batch: 300	Avg-Loss: 1.4621
Epoch: 5	Batch: 350	Avg-Loss: 1.4146
Epoch: 5	Batch: 400	Avg-Loss: 1.4910
Epoch: 5	Batch: 450	Avg-Loss: 1.4484
Epoch: 5	Batch: 500	Avg-Loss: 1.4487
Epoch: 5	Batch: 550	Avg-Loss: 1.4548
Epoch: 5	Batch: 600	Avg-Loss: 1.5113
Epoch: 5	Batch: 650	Avg-Loss: 1.5332
Epoch: 5	Batch: 700	Avg-Loss: 1.5171
Epoch: 5	Batch: 750	Avg-Loss: 1.5293
Epoch: 5	Batch: 800	Avg-Loss: 1.5089
Epoch: 5	Batch: 850	Avg-Loss: 1.4911
Epoch: 5	Batch: 900	Avg-Loss: 1.5139
Epoch: 5	Batch: 950	Avg-Loss: 1.5437
Epoch: 5	Batch: 1000	Avg-Loss: 1.5431
Epoch: 5	Batch: 1050	Avg-Loss: 1.5443
Epoch: 5	Batch: 1100	Avg-Loss: 1.5109
Epoch: 5	Batch: 1150	Avg-Loss: 1.5496
Epoch: 5	Batch: 1200	Avg-Loss: 1.5401
Epoch: 5	Batch: 1250	Avg-Loss: 1.5483
Epoch: 5	Batch: 1300	Avg-Loss: 1.5741
Epoch: 5	Batch


 17%|█▋        | 5/30 [2:02:27<10:10:02, 1464.11s/it][A

Roc score: 0.9097
Epoch: 6	Batch: 50	Avg-Loss: 1.1298
Epoch: 6	Batch: 100	Avg-Loss: 1.0911
Epoch: 6	Batch: 150	Avg-Loss: 1.1054
Epoch: 6	Batch: 200	Avg-Loss: 1.1280
Epoch: 6	Batch: 250	Avg-Loss: 1.1348
Epoch: 6	Batch: 300	Avg-Loss: 1.1976
Epoch: 6	Batch: 350	Avg-Loss: 1.1395
Epoch: 6	Batch: 400	Avg-Loss: 1.1905
Epoch: 6	Batch: 450	Avg-Loss: 1.1838
Epoch: 6	Batch: 500	Avg-Loss: 1.1618
Epoch: 6	Batch: 550	Avg-Loss: 1.1717
Epoch: 6	Batch: 600	Avg-Loss: 1.2440
Epoch: 6	Batch: 650	Avg-Loss: 1.2697
Epoch: 6	Batch: 700	Avg-Loss: 1.2289
Epoch: 6	Batch: 750	Avg-Loss: 1.2650
Epoch: 6	Batch: 800	Avg-Loss: 1.2302
Epoch: 6	Batch: 850	Avg-Loss: 1.3027
Epoch: 6	Batch: 900	Avg-Loss: 1.3061
Epoch: 6	Batch: 950	Avg-Loss: 1.2760
Epoch: 6	Batch: 1000	Avg-Loss: 1.2735
Epoch: 6	Batch: 1050	Avg-Loss: 1.2893
Epoch: 6	Batch: 1100	Avg-Loss: 1.2579
Epoch: 6	Batch: 1150	Avg-Loss: 1.3384
Epoch: 6	Batch: 1200	Avg-Loss: 1.2894
Epoch: 6	Batch: 1250	Avg-Loss: 1.2627
Epoch: 6	Batch: 1300	Avg-Loss: 1.3103
Epoch: 6	Batch


 20%|██        | 6/30 [2:26:39<9:44:10, 1460.44s/it] [A

Roc score: 0.9106
Epoch: 7	Batch: 50	Avg-Loss: 0.9580
Epoch: 7	Batch: 100	Avg-Loss: 0.9305
Epoch: 7	Batch: 150	Avg-Loss: 0.8673
Epoch: 7	Batch: 200	Avg-Loss: 0.9307
Epoch: 7	Batch: 250	Avg-Loss: 0.9695
Epoch: 7	Batch: 300	Avg-Loss: 0.9600
Epoch: 7	Batch: 350	Avg-Loss: 0.9474
Epoch: 7	Batch: 400	Avg-Loss: 1.0063
Epoch: 7	Batch: 450	Avg-Loss: 1.0086
Epoch: 7	Batch: 500	Avg-Loss: 1.0041
Epoch: 7	Batch: 550	Avg-Loss: 0.9968
Epoch: 7	Batch: 600	Avg-Loss: 1.0497
Epoch: 7	Batch: 650	Avg-Loss: 1.0513
Epoch: 7	Batch: 700	Avg-Loss: 1.0735
Epoch: 7	Batch: 750	Avg-Loss: 1.0591
Epoch: 7	Batch: 800	Avg-Loss: 1.0695
Epoch: 7	Batch: 850	Avg-Loss: 1.0584
Epoch: 7	Batch: 900	Avg-Loss: 1.1164
Epoch: 7	Batch: 950	Avg-Loss: 1.0878
Epoch: 7	Batch: 1000	Avg-Loss: 1.0848
Epoch: 7	Batch: 1050	Avg-Loss: 1.0897
Epoch: 7	Batch: 1100	Avg-Loss: 1.1152
Epoch: 7	Batch: 1150	Avg-Loss: 1.1483
Epoch: 7	Batch: 1200	Avg-Loss: 1.1537
Epoch: 7	Batch: 1250	Avg-Loss: 1.1128
Epoch: 7	Batch: 1300	Avg-Loss: 1.1484
Epoch: 7	Batch


 23%|██▎       | 7/30 [2:50:51<9:18:48, 1457.78s/it][A

Roc score: 0.9122
Epoch: 8	Batch: 50	Avg-Loss: 0.8104
Epoch: 8	Batch: 100	Avg-Loss: 0.7550
Epoch: 8	Batch: 150	Avg-Loss: 0.7933
Epoch: 8	Batch: 200	Avg-Loss: 0.7898
Epoch: 8	Batch: 250	Avg-Loss: 0.7995
Epoch: 8	Batch: 300	Avg-Loss: 0.8289
Epoch: 8	Batch: 350	Avg-Loss: 0.8415
Epoch: 8	Batch: 400	Avg-Loss: 0.8662
Epoch: 8	Batch: 450	Avg-Loss: 0.8494
Epoch: 8	Batch: 500	Avg-Loss: 0.8602
Epoch: 8	Batch: 550	Avg-Loss: 0.8905
Epoch: 8	Batch: 600	Avg-Loss: 0.8963
Epoch: 8	Batch: 650	Avg-Loss: 0.8592
Epoch: 8	Batch: 700	Avg-Loss: 0.8989
Epoch: 8	Batch: 750	Avg-Loss: 0.9103
Epoch: 8	Batch: 800	Avg-Loss: 0.9588
Epoch: 8	Batch: 850	Avg-Loss: 0.9329
Epoch: 8	Batch: 900	Avg-Loss: 0.9378
Epoch: 8	Batch: 950	Avg-Loss: 0.9537
Epoch: 8	Batch: 1000	Avg-Loss: 0.9594
Epoch: 8	Batch: 1050	Avg-Loss: 0.9899
Epoch: 8	Batch: 1100	Avg-Loss: 1.0006
Epoch: 8	Batch: 1150	Avg-Loss: 0.9782
Epoch: 8	Batch: 1200	Avg-Loss: 0.9950
Epoch: 8	Batch: 1250	Avg-Loss: 1.0059
Epoch: 8	Batch: 1300	Avg-Loss: 0.9700
Epoch: 8	Batch


 27%|██▋       | 8/30 [3:15:00<8:53:35, 1455.25s/it][A

Roc score: 0.9038
Epoch: 9	Batch: 50	Avg-Loss: 0.6979
Epoch: 9	Batch: 100	Avg-Loss: 0.6628
Epoch: 9	Batch: 150	Avg-Loss: 0.6751
Epoch: 9	Batch: 200	Avg-Loss: 0.7163
Epoch: 9	Batch: 250	Avg-Loss: 0.7034
Epoch: 9	Batch: 300	Avg-Loss: 0.7056
Epoch: 9	Batch: 350	Avg-Loss: 0.7382
Epoch: 9	Batch: 400	Avg-Loss: 0.7270
Epoch: 9	Batch: 450	Avg-Loss: 0.7736
Epoch: 9	Batch: 500	Avg-Loss: 0.7784
Epoch: 9	Batch: 550	Avg-Loss: 0.7973
Epoch: 9	Batch: 600	Avg-Loss: 0.7701
Epoch: 9	Batch: 650	Avg-Loss: 0.8051
Epoch: 9	Batch: 700	Avg-Loss: 0.7762
Epoch: 9	Batch: 750	Avg-Loss: 0.7898
Epoch: 9	Batch: 800	Avg-Loss: 0.8467
Epoch: 9	Batch: 850	Avg-Loss: 0.8807
Epoch: 9	Batch: 900	Avg-Loss: 0.8500
Epoch: 9	Batch: 950	Avg-Loss: 0.8616
Epoch: 9	Batch: 1000	Avg-Loss: 0.8850
Epoch: 9	Batch: 1050	Avg-Loss: 0.8885
Epoch: 9	Batch: 1100	Avg-Loss: 0.9061
Epoch: 9	Batch: 1150	Avg-Loss: 0.9373
Epoch: 9	Batch: 1200	Avg-Loss: 0.8947
Epoch: 9	Batch: 1250	Avg-Loss: 0.9104
Epoch: 9	Batch: 1300	Avg-Loss: 0.9429
Epoch: 9	Batch


 30%|███       | 9/30 [3:39:04<8:28:10, 1451.93s/it][A

Roc score: 0.9203
Epoch: 10	Batch: 50	Avg-Loss: 0.6712
Epoch: 10	Batch: 100	Avg-Loss: 0.5937
Epoch: 10	Batch: 150	Avg-Loss: 0.5493
Epoch: 10	Batch: 200	Avg-Loss: 0.6143
Epoch: 10	Batch: 250	Avg-Loss: 0.6273
Epoch: 10	Batch: 300	Avg-Loss: 0.6441
Epoch: 10	Batch: 350	Avg-Loss: 0.6437
Epoch: 10	Batch: 400	Avg-Loss: 0.6707
Epoch: 10	Batch: 450	Avg-Loss: 0.6504
Epoch: 10	Batch: 500	Avg-Loss: 0.6629
Epoch: 10	Batch: 550	Avg-Loss: 0.7089
Epoch: 10	Batch: 600	Avg-Loss: 0.7287
Epoch: 10	Batch: 650	Avg-Loss: 0.7607
Epoch: 10	Batch: 700	Avg-Loss: 0.7471
Epoch: 10	Batch: 750	Avg-Loss: 0.7760
Epoch: 10	Batch: 800	Avg-Loss: 0.7790
Epoch: 10	Batch: 850	Avg-Loss: 0.7753
Epoch: 10	Batch: 900	Avg-Loss: 0.7517
Epoch: 10	Batch: 950	Avg-Loss: 0.7519
Epoch: 10	Batch: 1000	Avg-Loss: 0.7907
Epoch: 10	Batch: 1300	Avg-Loss: 0.8875
Epoch: 10	Batch: 1350	Avg-Loss: 0.8440
Epoch: 10	Batch: 1400	Avg-Loss: 0.8642
Epoch: 10	Batch: 1450	Avg-Loss: 0.8539
Epoch: 10	Batch: 1500	Avg-Loss: 0.8565
Epoch: 10	Batch: 1550	Avg-L


 33%|███▎      | 10/30 [4:03:01<8:02:28, 1447.44s/it][A

Roc score: 0.9066
Epoch: 11	Batch: 50	Avg-Loss: 0.5899
Epoch: 11	Batch: 100	Avg-Loss: 0.5344
Epoch: 11	Batch: 150	Avg-Loss: 0.5369
Epoch: 11	Batch: 200	Avg-Loss: 0.5599
Epoch: 11	Batch: 250	Avg-Loss: 0.5596
Epoch: 11	Batch: 300	Avg-Loss: 0.5641
Epoch: 11	Batch: 350	Avg-Loss: 0.5695
Epoch: 11	Batch: 400	Avg-Loss: 0.5913
Epoch: 11	Batch: 450	Avg-Loss: 0.6100
Epoch: 11	Batch: 500	Avg-Loss: 0.6136
Epoch: 11	Batch: 550	Avg-Loss: 0.6256
Epoch: 11	Batch: 600	Avg-Loss: 0.6491
Epoch: 11	Batch: 650	Avg-Loss: 0.6658
Epoch: 11	Batch: 700	Avg-Loss: 0.7008
Epoch: 11	Batch: 750	Avg-Loss: 0.6901
Epoch: 11	Batch: 800	Avg-Loss: 0.6968
Epoch: 11	Batch: 850	Avg-Loss: 0.7438
Epoch: 11	Batch: 900	Avg-Loss: 0.7215
Epoch: 11	Batch: 950	Avg-Loss: 0.7467
Epoch: 11	Batch: 1000	Avg-Loss: 0.7441
Epoch: 11	Batch: 1050	Avg-Loss: 0.7551
Epoch: 11	Batch: 1100	Avg-Loss: 0.7260
Epoch: 11	Batch: 1150	Avg-Loss: 0.7454
Epoch: 11	Batch: 1200	Avg-Loss: 0.7946
Epoch: 11	Batch: 1250	Avg-Loss: 0.7563
Epoch: 11	Batch: 1300	Avg-L


 37%|███▋      | 11/30 [4:27:03<7:37:48, 1445.72s/it][A

Roc score: 0.9183
Epoch: 12	Batch: 50	Avg-Loss: 0.5438
Epoch: 12	Batch: 100	Avg-Loss: 0.4784
Epoch: 12	Batch: 150	Avg-Loss: 0.5064
Epoch: 12	Batch: 200	Avg-Loss: 0.4948
Epoch: 12	Batch: 250	Avg-Loss: 0.5121
Epoch: 12	Batch: 300	Avg-Loss: 0.5235
Epoch: 12	Batch: 350	Avg-Loss: 0.5332
Epoch: 12	Batch: 400	Avg-Loss: 0.5511
Epoch: 12	Batch: 450	Avg-Loss: 0.5357
Epoch: 12	Batch: 500	Avg-Loss: 0.5676
Epoch: 12	Batch: 550	Avg-Loss: 0.5975
Epoch: 12	Batch: 600	Avg-Loss: 0.5891
Epoch: 12	Batch: 650	Avg-Loss: 0.5642
Epoch: 12	Batch: 700	Avg-Loss: 0.6341
Epoch: 12	Batch: 750	Avg-Loss: 0.5955
Epoch: 12	Batch: 800	Avg-Loss: 0.6549
Epoch: 12	Batch: 850	Avg-Loss: 0.6276
Epoch: 12	Batch: 900	Avg-Loss: 0.6502
Epoch: 12	Batch: 950	Avg-Loss: 0.7094
Epoch: 12	Batch: 1000	Avg-Loss: 0.6628
Epoch: 12	Batch: 1050	Avg-Loss: 0.6906
Epoch: 12	Batch: 1100	Avg-Loss: 0.6994
Epoch: 12	Batch: 1150	Avg-Loss: 0.7274
Epoch: 12	Batch: 1200	Avg-Loss: 0.7444
Epoch: 12	Batch: 1250	Avg-Loss: 0.7411
Epoch: 12	Batch: 1300	Avg-L


 40%|████      | 12/30 [4:51:02<7:13:04, 1443.58s/it][A

Roc score: 0.9112
Epoch: 13	Batch: 50	Avg-Loss: 0.5101
Epoch: 13	Batch: 100	Avg-Loss: 0.4512
Epoch: 13	Batch: 150	Avg-Loss: 0.4624
Epoch: 13	Batch: 200	Avg-Loss: 0.4693
Epoch: 13	Batch: 250	Avg-Loss: 0.4774
Epoch: 13	Batch: 300	Avg-Loss: 0.4727
Epoch: 13	Batch: 350	Avg-Loss: 0.5036
Epoch: 13	Batch: 400	Avg-Loss: 0.5179
Epoch: 13	Batch: 450	Avg-Loss: 0.5033
Epoch: 13	Batch: 500	Avg-Loss: 0.5405
Epoch: 13	Batch: 550	Avg-Loss: 0.5603
Epoch: 13	Batch: 600	Avg-Loss: 0.5895
Epoch: 13	Batch: 650	Avg-Loss: 0.5647
Epoch: 13	Batch: 700	Avg-Loss: 0.6122
Epoch: 13	Batch: 750	Avg-Loss: 0.6126
Epoch: 13	Batch: 800	Avg-Loss: 0.6028
Epoch: 13	Batch: 850	Avg-Loss: 0.6088
Epoch: 13	Batch: 900	Avg-Loss: 0.6224
Epoch: 13	Batch: 950	Avg-Loss: 0.6166
Epoch: 13	Batch: 1000	Avg-Loss: 0.6308
Epoch: 13	Batch: 1050	Avg-Loss: 0.6648
Epoch: 13	Batch: 1100	Avg-Loss: 0.6535
Epoch: 13	Batch: 1150	Avg-Loss: 0.6493
Epoch: 13	Batch: 1200	Avg-Loss: 0.6679
Epoch: 13	Batch: 1250	Avg-Loss: 0.6695
Epoch: 13	Batch: 1300	Avg-L

In [None]:
test_classify(model, val_dataloader)

In [None]:
print(model)

In [None]:
torch.save(model.state_dict(), './18epochswd5e-5-adam-lranneal-resize.pt')

In [None]:
model.load_state_dict(torch.load('./26epochswd5e-5-adam-lranneal.pt'))

In [None]:
test_verify(model, ver_dataloader)

In [None]:
class VerificationTestDataset(Dataset):
    def __init__(self, filename):
        self.file1_list = []
        self.file2_list = []
        
        infile = open(filename , "r" )
        
        for line in infile :
            imfile1, imfile2 = line.split()
            self.file1_list.append(imfile1)
            self.file2_list.append(imfile2)
            
        
        infile.close()
        

    def __len__(self):
        return len(self.file1_list)

    def __getitem__(self, index):
        img1 = Image.open(self.file1_list[index])
        img1 = transform(img1)
        img2 = Image.open(self.file2_list[index])
        img2 = transform(img2)
        
        return img1, img2

In [None]:
ver_test_set = VerificationTestDataset("verification_pairs_test.txt")

In [None]:
vertest_dataloader = DataLoader(ver_test_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=False)

In [None]:
def test_final(model, test_loader):
    with torch.no_grad():
        newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
        newmodel.to(device)
        newmodel.eval()
        similarities = []
        
        for batch_num, (feats1, feats2) in enumerate(test_loader):
            feats1, feats2 = feats1.to(device), feats2.to(device)
            #feats1 = feats1.to(device)
            output1 = newmodel(feats1)
            output2 = newmodel(feats2)
        
            cos = nn.CosineSimilarity()
        
            sim = cos(output1, output2)
            sim = sim.cpu()
            similarities.extend(sim.numpy())
            
        
            del feats1
            del feats2
            
    
        
        similarities = np.array(similarities)
        
        return similarities

In [None]:
test_sims = test_final(model, vertest_dataloader)

In [None]:
test_sims = np.reshape(test_sims, (51835,))

In [None]:
test_sims

In [None]:
import pandas as pd
data = {'Id':["{} {}".format(a, b) for a, b in zip(ver_test_set.file1_list, ver_test_set.file2_list)],
       'Category': test_sims}



In [None]:
df = pd.DataFrame(data)

In [None]:
df

In [None]:
df.to_csv("submission.csv", header=True, index=False)