In [1]:
%matplotlib inline

import torch
import numpy as np
import pandas as pd
from copy import deepcopy
from tabulate import tabulate
from torch import optim
import torchvision.utils
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import CrossEntropyLoss
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,random_split

import config
from utils import imshow
from models import SiameseNetwork, CNN
from training import trainSiamese,inferenceSiamese, trainCNN
from datasets import SiameseNetworkDataset, CNNDataset,generate_csv_compare
from loss_functions import ContrastiveLoss

# generate_csv(config.training_dir)

import os
if not os.path.exists('state_dict'):
    os.makedirs('state_dict')

## Generate Dateset with a few images per class

In [2]:
generate_csv_compare(config.training_dir,config.compare_siamese_csv,config.compare_cnn_csv,config.compare_test_csv,
                     num_per_class = 4)

Data directory:  ../../datasets/AT&T Database of Faces/faces/training


## Training Siamese Network

In [None]:
# Split the dataset into train and validation sets
siamese_dataset = SiameseNetworkDataset(config.compare_siamese_csv,
                                        transform=transforms.Compose([
                                            transforms.Resize((config.img_height,config.img_width)),
                                            transforms.ToTensor(),
                                            transforms.Normalize(0,1)]),
                                        should_invert=False)

num_train = round(0.9*siamese_dataset.__len__())
num_validate = siamese_dataset.__len__()-num_train



siamese_train, siamese_valid = random_split(siamese_dataset, [num_train,num_validate])

siamese_train_dataloader = DataLoader(siamese_train,
                        shuffle=True,
                        num_workers=8,
                        batch_size=config.train_batch_size)

siamese_valid_dataloader = DataLoader(siamese_valid,
                        shuffle=True,
                        num_workers=8,
                        batch_size=1)

# Training
netS = SiameseNetwork().cuda()
criterionS = ContrastiveLoss()
optimizer = optim.Adam(netS.parameters(),lr = config.learning_rate )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,config.step_size, config.gamma)

netS, train_loss_historyS, valid_loss_historyS,dict_nameS = trainSiamese(netS,criterionS,optimizer,scheduler,siamese_train_dataloader,
             siamese_valid_dataloader,config.train_number_epochs,do_show=True)

Epoch  0  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  2.08it/s]


Epoch  0  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 57.31it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:3.69	 min:0.78	 mean:2.02	 median 1.69
-1 features max:3.77	 min:1.97	 mean:3.04	 median 3.23
Feature = 0	Threshold = 2.286705	Polorization = 1
Epoch-0	 Train loss: 1.8813e+00	 Valid loss: 2.6000e+00	 Valid error: 0.1905
Epoch  1  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.69it/s]


Epoch  1  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.94it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:2.02	 min:0.54	 mean:1.14	 median 1.10
-1 features max:3.19	 min:1.36	 mean:2.27	 median 2.39
Feature = 0	Threshold = 1.295573	Polorization = 1
Epoch-1	 Train loss: 3.0861e+00	 Valid loss: 7.8085e-01	 Valid error: 0.1429
new model saved
Epoch  2  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.76it/s]


Epoch  2  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.98it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:1.50	 min:0.50	 mean:0.91	 median 0.85
-1 features max:2.79	 min:1.32	 mean:2.23	 median 2.36
Feature = 0	Threshold = 1.247275	Polorization = 1
Epoch-2	 Train loss: 2.0570e+00	 Valid loss: 4.7689e-01	 Valid error: 0.0476
new model saved
Epoch  3  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.62it/s]


Epoch  3  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.60it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:1.31	 min:0.42	 mean:0.81	 median 0.83
-1 features max:2.95	 min:1.24	 mean:2.13	 median 2.26
Feature = 0	Threshold = 1.159618	Polorization = 1
Epoch-3	 Train loss: 1.0245e+00	 Valid loss: 3.8605e-01	 Valid error: 0.0476
new model saved
Epoch  4  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.67it/s]


Epoch  4  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 61.66it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:1.06	 min:0.31	 mean:0.74	 median 0.76
-1 features max:3.04	 min:1.22	 mean:2.11	 median 2.13
Feature = 0	Threshold = 1.141061	Polorization = 1
Epoch-4	 Train loss: 7.0866e-01	 Valid loss: 3.1503e-01	 Valid error: 0.0000
new model saved
Epoch  5  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.54it/s]


Epoch  5  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.38it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.86	 min:0.22	 mean:0.62	 median 0.66
-1 features max:3.07	 min:1.24	 mean:2.09	 median 2.02
Feature = 0	Threshold = 1.051778	Polorization = 1
Epoch-5	 Train loss: 5.0828e-01	 Valid loss: 2.2378e-01	 Valid error: 0.0000
new model saved
Epoch  6  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.75it/s]


Epoch  6  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.58it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.82	 min:0.20	 mean:0.53	 median 0.57
-1 features max:2.94	 min:1.25	 mean:2.03	 median 1.98
Feature = 0	Threshold = 1.036302	Polorization = 1
Epoch-6	 Train loss: 3.7266e-01	 Valid loss: 1.6859e-01	 Valid error: 0.0000
new model saved
Epoch  7  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.79it/s]


Epoch  7  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 56.70it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.90	 min:0.18	 mean:0.48	 median 0.54
-1 features max:2.72	 min:1.25	 mean:1.94	 median 1.93
Feature = 0	Threshold = 1.072485	Polorization = 1
Epoch-7	 Train loss: 2.5396e-01	 Valid loss: 1.4027e-01	 Valid error: 0.0000
new model saved
Epoch  8  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.60it/s]


Epoch  8  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.31it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.83	 min:0.19	 mean:0.44	 median 0.48
-1 features max:2.73	 min:1.25	 mean:1.86	 median 1.84
Feature = 0	Threshold = 1.040514	Polorization = 1
Epoch-8	 Train loss: 1.8681e-01	 Valid loss: 1.2006e-01	 Valid error: 0.0000
new model saved
Epoch  9  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.53it/s]


Epoch  9  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 63.07it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.70	 min:0.17	 mean:0.41	 median 0.43
-1 features max:2.74	 min:1.15	 mean:1.78	 median 1.74
Feature = 0	Threshold = 0.925272	Polorization = 1
Epoch-9	 Train loss: 1.6611e-01	 Valid loss: 1.0276e-01	 Valid error: 0.0000
new model saved
Epoch  10  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.61it/s]


Epoch  10  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.45it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.61	 min:0.15	 mean:0.39	 median 0.43
-1 features max:2.81	 min:1.05	 mean:1.73	 median 1.66
Feature = 0	Threshold = 0.831309	Polorization = 1
Epoch-10	 Train loss: 1.0522e-01	 Valid loss: 9.0969e-02	 Valid error: 0.0000
new model saved
Epoch  11  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.70it/s]


Epoch  11  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.80it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.54	 min:0.15	 mean:0.36	 median 0.37
-1 features max:2.91	 min:1.02	 mean:1.71	 median 1.59
Feature = 0	Threshold = 0.775260	Polorization = 1
Epoch-11	 Train loss: 1.0049e-01	 Valid loss: 7.8440e-02	 Valid error: 0.0000
new model saved
Epoch  12  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.72it/s]


Epoch  12  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 53.28it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.49	 min:0.14	 mean:0.33	 median 0.34
-1 features max:2.94	 min:1.01	 mean:1.69	 median 1.54
Feature = 0	Threshold = 0.753760	Polorization = 1
Epoch-12	 Train loss: 8.0217e-02	 Valid loss: 6.4323e-02	 Valid error: 0.0000
new model saved
Epoch  13  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch  13  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 58.10it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.45	 min:0.14	 mean:0.32	 median 0.33
-1 features max:2.91	 min:1.01	 mean:1.67	 median 1.50
Feature = 0	Threshold = 0.731186	Polorization = 1
Epoch-13	 Train loss: 7.5649e-02	 Valid loss: 5.8767e-02	 Valid error: 0.0000
new model saved
Epoch  14  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch  14  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 61.08it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.41	 min:0.13	 mean:0.30	 median 0.30
-1 features max:2.86	 min:0.97	 mean:1.65	 median 1.47
Feature = 0	Threshold = 0.689867	Polorization = 1
Epoch-14	 Train loss: 6.2531e-02	 Valid loss: 5.2869e-02	 Valid error: 0.0000
new model saved
Epoch  15  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.67it/s]


Epoch  15  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.06it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.40	 min:0.12	 mean:0.29	 median 0.28
-1 features max:2.87	 min:0.93	 mean:1.63	 median 1.43
Feature = 0	Threshold = 0.669250	Polorization = 1
Epoch-15	 Train loss: 4.7165e-02	 Valid loss: 4.9372e-02	 Valid error: 0.0000
new model saved
Epoch  16  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.73it/s]


Epoch  16  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.47it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.41	 min:0.12	 mean:0.28	 median 0.26
-1 features max:2.88	 min:0.90	 mean:1.62	 median 1.40
Feature = 0	Threshold = 0.655669	Polorization = 1
Epoch-16	 Train loss: 3.2432e-02	 Valid loss: 4.6295e-02	 Valid error: 0.0000
new model saved
Epoch  17  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.63it/s]


Epoch  17  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 61.48it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.40	 min:0.11	 mean:0.27	 median 0.25
-1 features max:2.88	 min:0.87	 mean:1.61	 median 1.38
Feature = 0	Threshold = 0.638445	Polorization = 1
Epoch-17	 Train loss: 4.5348e-02	 Valid loss: 4.3451e-02	 Valid error: 0.0000
new model saved
Epoch  18  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.59it/s]


Epoch  18  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.32it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.38	 min:0.11	 mean:0.26	 median 0.24
-1 features max:2.86	 min:0.86	 mean:1.60	 median 1.37
Feature = 0	Threshold = 0.619557	Polorization = 1
Epoch-18	 Train loss: 2.6716e-02	 Valid loss: 4.0101e-02	 Valid error: 0.0000
new model saved
Epoch  19  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch  19  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.19it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.37	 min:0.11	 mean:0.25	 median 0.24
-1 features max:2.84	 min:0.85	 mean:1.60	 median 1.38
Feature = 0	Threshold = 0.610050	Polorization = 1
Epoch-19	 Train loss: 2.4919e-02	 Valid loss: 3.7365e-02	 Valid error: 0.0000
new model saved
Epoch  20  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.53it/s]


Epoch  20  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.09it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.36	 min:0.10	 mean:0.25	 median 0.24
-1 features max:2.82	 min:0.86	 mean:1.59	 median 1.39
Feature = 0	Threshold = 0.607251	Polorization = 1
Epoch-20	 Train loss: 2.0880e-02	 Valid loss: 3.5875e-02	 Valid error: 0.0000
new model saved
Epoch  21  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.66it/s]


Epoch  21  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 55.54it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.36	 min:0.09	 mean:0.24	 median 0.25
-1 features max:2.81	 min:0.86	 mean:1.58	 median 1.39
Feature = 0	Threshold = 0.605594	Polorization = 1
Epoch-21	 Train loss: 1.7758e-02	 Valid loss: 3.4515e-02	 Valid error: 0.0000
new model saved
Epoch  22  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.46it/s]


Epoch  22  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 60.73it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.36	 min:0.09	 mean:0.24	 median 0.24
-1 features max:2.78	 min:0.86	 mean:1.57	 median 1.39
Feature = 0	Threshold = 0.608893	Polorization = 1
Epoch-22	 Train loss: 1.1797e-02	 Valid loss: 3.2775e-02	 Valid error: 0.0000
new model saved
Epoch  23  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.69it/s]


Epoch  23  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 62.95it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.35	 min:0.09	 mean:0.23	 median 0.23
-1 features max:2.77	 min:0.86	 mean:1.56	 median 1.38
Feature = 0	Threshold = 0.606748	Polorization = 1
Epoch-23	 Train loss: 1.1174e-02	 Valid loss: 3.0639e-02	 Valid error: 0.0000
new model saved
Epoch  24  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.81it/s]


Epoch  24  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 45.16it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.35	 min:0.09	 mean:0.22	 median 0.21
-1 features max:2.76	 min:0.85	 mean:1.55	 median 1.37
Feature = 0	Threshold = 0.599052	Polorization = 1
Epoch-24	 Train loss: 1.1397e-02	 Valid loss: 2.8461e-02	 Valid error: 0.0000
new model saved
Epoch  25  training


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.66it/s]


Epoch  25  validating


100%|███████████████████████████████████████████| 21/21 [00:00<00:00, 59.15it/s]


+1/-1 ratrio:0.52/0.48
+1 features max:0.34	 min:0.08	 mean:0.21	 median 0.21
-1 features max:2.77	 min:0.84	 mean:1.54	 median 1.36
Feature = 0	Threshold = 0.591710	Polorization = 1
Epoch-25	 Train loss: 1.2210e-02	 Valid loss: 2.6768e-02	 Valid error: 0.0000
new model saved
Epoch  26  training


  0%|                                                     | 0/2 [00:00<?, ?it/s]

## Train CNN with Cross Entropy Loss

In [None]:
cnn_dataset = CNNDataset(config.compare_cnn_csv,
                                        transform=transforms.Compose([
                                            transforms.Resize((config.img_height,config.img_width)),
                                            transforms.ToTensor(),
                                            transforms.Normalize(0,1)]),
                                        should_invert=False)

cnn_train_dataloader = DataLoader(cnn_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=config.train_batch_size)

netC = CNN().cuda()
criterionC = CrossEntropyLoss()
optimizer = optim.Adam(netC.parameters(),lr = config.learning_rate )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,config.step_size, config.gamma)

In [None]:
netC, train_loss_historyC, dict_nameC = trainCNN(netC,criterionC,optimizer,scheduler,cnn_train_dataloader,
             config.train_number_epochs,do_show=True)

In [None]:
test_dataset = CNNDataset(config.compare_test_csv,
                                        transform=transforms.Compose([
                                            transforms.Resize((config.img_height,config.img_width)),
                                            transforms.ToTensor(),
                                            transforms.Normalize(0,1)]),
                                        should_invert=False)

test_dataloader = DataLoader(test_dataset,
                        shuffle=False,
                        num_workers=1,
                        batch_size=1)
true_labels = []
for _,label in iter(test_dataloader):
    true_labels.append(int(label))

netS = SiameseNetwork().cuda()
netS.load_state_dict(torch.load(os.path.join("state_dict",dict_nameS)))
netS.eval()

siamese_precalculate_dataloader = DataLoader(cnn_dataset,
                        shuffle=False,
                        num_workers=1,
                        batch_size=1)

feature_vectors = []
temp = []
current_label = 0 #int(next(iter(siamese_precalculate_dataloader))[1])
for data,label in iter(siamese_precalculate_dataloader):
    v = (netS.forward_once(data.cuda()).detach().cpu().numpy())[0]
    if label==current_label:
        temp.append(v)
    else:
        current_label = label
        feature_vectors.append(deepcopy(temp))
        temp = []
        temp.append(v)
feature_vectors.append(deepcopy(temp))

inferenceS = []
for data, label in iter(test_dataloader):
    current_feature = netS.forward_once(data.cuda())
    dissims = []
    temp = 0
    for vectors in feature_vectors:
        for feature in vectors:
            temp += F.pairwise_distance(current_feature,torch.tensor(feature).cuda()).detach().cpu().numpy()
        dissims.append(temp / len(vectors))
        temp = 0
    inferenceS.append(np.argmin(dissims))
    
netC = SiameseNetwork().cuda()
netC.load_state_dict(torch.load(os.path.join("state_dict",dict_nameC)))
netC.eval()

inferenceC = []
for data, label in iter(test_dataloader):
    current_feature = netC.forward_once(data.cuda())
    inferenceC.append(np.argmax(current_feature.detach().cpu().numpy()))

accS = sum([true_labels[i]==inferenceS[i] for i in range(len(true_labels))])/len(true_labels)
accC = sum([true_labels[i]==inferenceC[i] for i in range(len(true_labels))])/len(true_labels)

print(tabulate(np.array([true_labels,inferenceS,inferenceC]).T, headers=["true labels", "Siamese", "CNN"]))

print("Siamese accuracy: %.4f"%(accS))
print("Convnet accuracy: %.4f"%(accC))