# GLOBAL VARIABLES 

In [111]:
import os

FILE = "Aberdeen"#.zip"
EXTENSION = ".jpg"
PATH = os.path.join(os.getcwd(),"datasets", FILE)
BATCH_SIZE = 16
DO_LEARN = False
LEARNING_RATE = 0.001
NUM_EPOCH = 10 
WEIGHT_DECAY = 0.0001
SAVE_MODEL = True 

# MANAGEMENT OF THE DATA 

In [70]:
import numpy as np
import random
from random import shuffle
from string import digits

import warnings
warnings.filterwarnings('ignore')

import glob
import random 

import os.path as osp
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data


"""
STILL TODO: 
    - Better management of the dataset (during the composition of triplets)
    - Build a Training and a Testing sets
"""
class Face_DS(torch.utils.data.Dataset):
    
    def __init__(self, root=PATH, train=False, transform=None):
        
        print("A Face Dataset is building ... ")
        if transform is None: 
            self.transform = transforms.ToTensor()
        else: 
            self.transform=transform
            
        filenames = glob.glob(osp.join(PATH, '*' + EXTENSION))
        faces_dic = {}

        #################################
        # Order the picture per label 
        #################################

        for fn in filenames:
            # Extract name of the person from the name of the file
            filename = fn.split(PATH)[1]
            label = filename.translate(str.maketrans('', '', digits)).split(EXTENSION)[0][1:]

            formated_image = self.transform(Image.open(fn).convert("RGB"))

            try: 
                faces_dic[label].append(formated_image)
            except KeyError: 
                faces_dic[label] = [formated_image]
                    
        all_labels = list(faces_dic.keys())
        nb_labels = len(all_labels)
        
        #############################################
        # Build triplet supporting the dataset 
        #############################################
        self.train_data = []
        self.train_labels = []
        
        for label, pictures_list in faces_dic.items():
            pictures_indexes_pos = list(range(len(pictures_list)))
            shuffle(pictures_indexes_pos)
            labels_indexes_neg = [x for x in range(0, nb_labels) if x != all_labels.index(label)]

            for i, picture_ref in enumerate(pictures_list):
                picture_positive = pictures_list[pictures_indexes_pos.pop()]
                # Pick a random different person 
                label_neg = all_labels[random.choice(labels_indexes_neg)]
                picture_negative = random.choice(faces_dic[label_neg]) 
                self.train_data.append([picture_ref, picture_positive, picture_negative]) # torch.stack is not applied because we want a list of tensors 
                self.train_labels.append([1,0])
                
        #self.train_data = torch.stack(self.train_data)
        self.train_labels = torch.tensor(self.train_labels)
    
    # You must override __getitem__ and __len__
    def __getitem__(self, index, visualization=False):
        """ ---------------------------------------------------------------------------------------------
            An item is made up of 3 images (P, P, N) and 2 target (1, 0) specifying that the 2 first 
            images are the same and the first and the third are different. The images are represented  
            through TENSORS. 
            
            If visualize = True: the image is printed 
            BUG: if visualization = True: RunTimeError due to the returned content (??)
        ----------------------------------------------------------------------------------------------- """ 
        return self.train_data[index], self.train_labels[index]

      
    def __len__(self):
        """
        Total number of samples in the dataset
        """
        return len(self.train_data)

In [3]:
########################
#         TEST         #
########################

#ds = Face_DS()
#ds.__getitem__(1, visualization=False)



# DEFINITION OF THE NEURAL NETWORK 

In [11]:
from torch import nn
from torch import optim
import torch.nn.functional as F


class Net(nn.Module):
   def __init__(self):
      super().__init__()
      
      self.conv1 = nn.Conv2d(3, 64, 7)
      self.pool1 = nn.MaxPool2d(2)
      self.conv2 = nn.Conv2d(64, 128, 5)
      self.conv3 = nn.Conv2d(128, 256, 5)
      self.linear1 = nn.Linear(2304, 512)
      
      self.linear2 = nn.Linear(512, 2)
      
   def forward(self, data):
      res = []
      for i in range(2): # Siamese nets; sharing weights
         x = data[i]
         x = self.conv1(x)
         x = F.relu(x)
         x = self.pool1(x)
         x = self.conv2(x)
         x = F.relu(x)
         x = self.conv3(x)
         x = F.relu(x)
         
         x = x.view(x.shape[0], -1)
         x = self.linear1(x)
         res.append(F.relu(x))
         
      res = torch.abs(res[1] - res[0])
      res = self.linear2(res)
      return res

# DEFINITION OF THE TRAINING FUNCTION 

In [98]:
def train(model, device, train_loader, epoch, optimizer):

    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        for i in range(len(data)):
            data[i] = data[i].to(device)

        optimizer.zero_grad()
        output_positive = model(data[:2])
        output_negative = model(data[0:3:2])

        target = target.type(torch.LongTensor).to(device)
        target_positive = torch.squeeze(target[:,0])
        target_negative = torch.squeeze(target[:,1])

        loss_positive = F.cross_entropy(output_positive, target_positive)
        loss_negative = F.cross_entropy(output_negative, target_negative)

        loss = loss_positive + loss_negative
        loss.backward()

        optimizer.step()

        if batch_idx % 10 == 0: # Print the state of the training each 10 batches (i.e each 10*size_batch considered examples)
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx*BATCH_SIZE, len(train_loader.dataset), 100. * batch_idx*BATCH_SIZE / len(train_loader.dataset),
            loss.item()))

# DEFINITION OF THE TESTING FUNCTION 

In [6]:
def test(model, device, test_loader):
   model.eval()
   
   with torch.no_grad():
      accurate_labels = 0
      all_labels = 0
      loss = 0
      for batch_idx, (data, target) in enumerate(test_loader):
         for i in range(len(data)):
            data[i] = data[i].to(device)
            
         output_positive = model(data[:2])
         output_negative = model(data[0:3:2])
            
         target = target.type(torch.LongTensor).to(device)
         target_positive = torch.squeeze(target[:,0])
         target_negative = torch.squeeze(target[:,1])
            
         loss_positive = F.cross_entropy(output_positive, target_positive)
         loss_negative = F.cross_entropy(output_negative, target_negative)
            
         loss = loss + loss_positive + loss_negative
            
         accurate_labels_positive = torch.sum(torch.argmax(output_positive, dim=1) == target_positive).cpu()
         accurate_labels_negative = torch.sum(torch.argmax(output_negative, dim=1) == target_negative).cpu()
            
         accurate_labels = accurate_labels + accurate_labels_positive + accurate_labels_negative
         all_labels = all_labels + len(target_positive) + len(target_negative)
      
      accuracy = 100. * accurate_labels / all_labels
      print('Test accuracy: {}/{} ({:.3f}%)\tLoss: {:.6f}'.format(accurate_labels, all_labels, accuracy, loss))
   

# DEFINITION OF THE ONE-SHOT FUNCTION 

In [119]:
def oneshot(model, device, data):
   model.eval()

   with torch.no_grad():
      for i in range(len(data)):
            data[i] = data[i].to(device)
      
      output = model(data)
      return torch.squeeze(torch.argmax(output, dim=1)).cpu().item()

# MAIN FUNCTION 

In [122]:
#########################################
#       FUNCTION main                   #
#########################################

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Specifies where the torch.tensor is allocated
    trans = transforms.Compose([transforms.CenterCrop(28), transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) # If applied, a dimensional error is raised 
    #transforms.CenterCrop(28), 
    name_model = "siamese_face"
    extension_model = ".pt"

    model = Net().to(device)
    
    if DO_LEARN:
   
        ##################
        #  training mode
        ##################
        train_loader = torch.utils.data.DataLoader(Face_DS(train=True, transform=trans), batch_size=BATCH_SIZE, shuffle=True)
        test_loader = torch.utils.data.DataLoader(Face_DS(train=False, transform=trans), batch_size=BATCH_SIZE, shuffle=False)

        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

        for epoch in range(NUM_EPOCH):
            train(model, device, train_loader, epoch, optimizer)
            test(model, device, test_loader)

        if SAVE_MODEL:
            torch.save(model, (name_model + '{:03}' + extension_model).format(epoch))
            print("Model is saved!")
      
    else: # prediction
        dataset = Face_DS(train=True, transform=trans)
        prediction_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) # batch_size = Nb of pairs you want to test 
      
        load_model_path = os.getcwd() + "/" + name_model + extension_model #name_model + "000" + extension_model
        model = torch.load(load_model_path)
        
        #####################################################################
        # Data: list containing the tensor representations of the 2 images
        #####################################################################
        data = []
        data.extend(next(iter(prediction_loader))[0][:3:2])
        #print("The data given to the onshot function is: " + str(data))
        same = oneshot(model, device, data)
        if same > 0:
            print('These two images represent the same person')
        else:
            print("These two images don't represent the same person")

In [123]:
if __name__ == '__main__':
   main()

These two images represent the same person
