# Siamese Neural Network (SNN)

## Importando bibliotecas

In [None]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
import pandas as pd
from PIL import Image
import torch
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from math import sqrt

## Configuration class

In [None]:
class Config:
    
    def __init__(self, training_dir, testing_dir, dir_training, dir_testing, batch_s, n_epochs):
        self.training_dir = training_dir
        self.testing_dir = testing_dir
        self.dset_training = pd.read_csv(dir_training, index_col = 0)
        self.dset_testing = pd.read_csv(dir_testing, index_col = 0)
        self.train_batch_size = batch_s
        self.train_number_epochs = n_epochs

conf = Config("./data/memes/training/", 
              "./data/memes/testing/",
              "./data/spb_training.csv", 
              "./data/spb_testing.csv",
              16, 
              150)

## Dataset class

In [None]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self, imageFolderDataset, dset_csv, transform = None):
        self.imageFolderDataset = imageFolderDataset
        self.dset_csv = dset_csv
        self.transform = transform
        self.favs_list = sorted(self.dset_csv['favorites'].to_list())
        self.max_fav = self.favs_list[-1]
        self.min_fav = self.favs_list[0]
        
    def __getitem__(self, index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)   
        img0_id = self.obtain_id(img0_tuple[0])
        fav0 = self.dset_csv.loc[img0_id]['favorites']
        c = self.check_fav(fav0)
        while fav0 == self.max_fav or fav0 == self.min_fav:
            img0_tuple = random.choice(self.imageFolderDataset.imgs)   
            img0_id = self.obtain_id(img0_tuple[0])
            fav0 = self.dset_csv.loc[img0_id]['favorites']
            c = self.check_fav(fav0)
        while True:
            img1_tuple = random.choice(self.imageFolderDataset.imgs)
            while img1_tuple == img0_tuple:
                img1_tuple = random.choice(self.imageFolderDataset.imgs)
            img1_id = self.obtain_id(img1_tuple[0])
            fav1 = self.dset_csv.loc[img1_id]['favorites']
            if self.classify(fav0, fav1, c):
                  break
        while True:
            img2_tuple = random.choice(self.imageFolderDataset.imgs)
            while img2_tuple == img0_tuple or img2_tuple == img1_tuple:
                img2_tuple = random.choice(self.imageFolderDataset.imgs)
            img2_id = self.obtain_id(img2_tuple[0])
            fav2 = self.dset_csv.loc[img2_id]['favorites']
            if not self.classify(fav0, fav2, c):
                  break
        img0, img1, img2 = Image.open(img0_tuple[0]), Image.open(img1_tuple[0]), Image.open(img2_tuple[0])
        img0, img1, img2 = img0.convert("RGB"), img1.convert("RGB"), img2.convert("RGB")
        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img0, img1, img2, fav0, fav1, fav2
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

    def obtain_id(self, img_route):
        return int(img_route.split("/")[-1].split(".")[0])
    
    def check_fav(self, fav):
        fav_index = self.favs_list.index(fav)
        if fav == self.min_fav or fav == self.max_fav:
            return False
        if abs(self.favs_list[fav_index - 1] - fav) < int(fav/10) or abs(self.favs_list[fav_index + 1] - fav) < int(fav/10):
            if self.min_fav < fav - int(fav/10) or self.max_fav > fav + int(fav/10):
                return True
        return False
    
    def classify(self, fav0, fav1, condition):
        dif = fav0 - fav1
        if condition:
            if abs(dif) < int(fav0/10):
                return True
            else:
                return False
        else: 
            if dif < - int(fav0/10):
                return True
            if dif > int(fav0/10):
                return False

In [None]:
folder_dataset = dset.ImageFolder(conf.training_dir)
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                 std = [0.229, 0.224, 0.225])
siamese_dataset = SiameseNetworkDataset(imageFolderDataset = folder_dataset,
                                        dset_csv = conf.dset_training,
                                        transform = transforms.Compose([transforms.Resize((224, 224)),
                                                    transforms.ToTensor(), 
                                                    normalize]))

In [None]:
train_dataloader = DataLoader(siamese_dataset,
                              shuffle = True,
                              num_workers = 2,
                              batch_size = conf.train_batch_size)

## Siamese Neural Network definition

In [None]:
class SiameseNetwork(nn.Module):
    
    def __init__(self, use_pretrained, feature_extracting, num_classes):
        super(SiameseNetwork, self).__init__()
        self.model_ft = models.vgg16(pretrained = use_pretrained)
        self.set_parameter_requires_grad(self.model_ft, feature_extracting)
        num_ftrs = self.model_ft.classifier[6].in_features
        self.model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)

    def set_parameter_requires_grad(self, model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False

    def forward_once(self, x):
        output = self.model_ft(x)
        return output

    def forward(self, input1, input2, input3):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        output3 = self.forward_once(input3)
        return output1, output2, output3

## Training (cosine similarity)

In [None]:
net = SiameseNetwork(True, True, 512).cuda()
criterion = nn.TripletMarginWithDistanceLoss(distance_function = lambda x, y: 1.0 - F.cosine_similarity(x, y), margin = sqrt(3)/2).cuda()
optimizer = optim.Adam(net.parameters(), lr = 1e-4)
net.train()

In [None]:
for name, param in net.named_parameters():
    if param.requires_grad == True:
        print(name)

In [None]:
counter = []
loss_history = [] 
iteration_number = 0

In [None]:
for epoch in range(0, conf.train_number_epochs):
    for i, data in enumerate(train_dataloader, 0):
        img0, img1, img2, _, _, _ = data
        img0, img1, img2 = img0.cuda(), img1.cuda() , img2.cuda()
        optimizer.zero_grad()
        output1, output2, output3 = net(img0, img1, img2)
        loss = criterion(output1, output2, output3)
        loss.backward()
        optimizer.step()
    print("Epoch number {}\n Current loss {}\n".format(epoch, loss.item()))
    iteration_number += 1
    counter.append(iteration_number)
    loss_history.append(loss.item())

In [None]:
fig = plt.figure(figsize = (20, 10))
plt.plot(counter, loss_history, 'r')
plt.title('VGG16')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.rc('font', size = 15)
plt.grid()
plt.show()

## Saving model

In [None]:
torch.save(net.state_dict(), "./models/vgg16_simcos.zip")
print('Saved.')

## Training (euclidean distance)

In [None]:
net = SiameseNetwork(True, True, 512).cuda()
criterion = nn.TripletMarginLoss(margin = 1.0).cuda()
optimizer = optim.Adam(net.parameters(), lr = 1e-4)
net.train()

In [None]:
for name, param in net.named_parameters():
    if param.requires_grad == True:
        print(name)

In [None]:
counter = []
loss_history = [] 
iteration_number = 0

In [None]:
for epoch in range(0, conf.train_number_epochs):
    for i, data in enumerate(train_dataloader, 0):
        img0, img1, img2, _, _, _ = data
        img0, img1, img2 = img0.cuda(), img1.cuda() , img2.cuda()
        optimizer.zero_grad()
        output1, output2, output3 = net(img0, img1, img2)
        loss = criterion(output1, output2, output3)
        loss.backward()
        optimizer.step()
    print("Epoch number {}\n Current loss {}\n".format(epoch, loss.item()))
    iteration_number += 1
    counter.append(iteration_number)
    loss_history.append(loss.item())

In [None]:
fig = plt.figure(figsize = (20, 10))
plt.plot(counter, loss_history, 'r')
plt.title('VGG16')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.rc('font', size = 15)
plt.grid()
plt.show()

## Saving model

In [None]:
torch.save(net.state_dict(), "./models/vgg16_euc.zip")
print('Saved.')