In [None]:
import os
import copy
import numpy as np
import pandas as pd
import cv2 as cv
import matplotlib.pyplot as plt
import seaborn as sn
from tqdm import tqdm
from pathlib import Path
import time

import torch
import torchvision
import torchvision.models as models 
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torch import optim
from PIL import Image
from random import choice, shuffle, sample
from torch.autograd import Variable
import random
from PIL import Image
import PIL.ImageOps    
from sklearn.metrics import average_precision_score

import warnings
warnings.filterwarnings("ignore")
pd.options.display.max_columns = None

!pip install livelossplot --quiet
from livelossplot import PlotLosses


# Labels and dataframes

In [None]:
def images_paths(path):
    '''Retrieves the paths of all available images from path folder'''
    imgs_paths = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if(file.endswith(".jpg")):
                imgs_paths.append(os.path.join(root,file))
    return imgs_paths

def path_to_label(path, normal = False):
    '''Encodes each category label from the path of each image'''
    if normal:
        path = path.split('/')
    else:
        path = path.split('\\')
    label = path[2]
    for w in path[3:-1]:
        if not w.endswith('.jpg'):
            label = label + '-' + w        
    return label

## Train, support and query datasets

In [None]:
ROOT = ''

# All brands training, evaluation on Givenchy
df_all_brands = pd.read_csv('all_brands_but_givenchy.csv', index_col=False, usecols=['path', 'label'])
support_givenchy = pd.read_csv('support_givenchy.csv', index_col=False, usecols=['path', 'label'])
query_givenchy = pd.read_csv('query_givenchy.csv', index_col=False, usecols=['path', 'label'])

# Evaluation on Versace
def local_path(x):
    m = len('/content/drive/MyDrive/')
    new_path = 'Navee Dataset/drive_data/' + x[m:]
    return new_path

# Versace train
train_df = pd.read_csv(ROOT+'train_100_categories.csv', index_col=False, usecols=['path', 'label'])
train_df.path = train_df.path.apply(lambda x:local_path(x))
train_df.label = train_df.path.apply(lambda x:path_to_label(x))

query_versace = pd.read_csv(ROOT+'query_50_categories.csv', index_col=False, usecols=['path', 'label'])
support_versace = pd.read_csv(ROOT+'support_50_categories.csv', index_col=False, usecols=['path', 'label'])
query_versace.path = query_versace.path.apply(lambda x:local_path(x))
support_versace.path = support_versace.path.apply(lambda x:local_path(x))
query_versace.label = query_versace.path.apply(lambda x:path_to_label(x,True))
support_versace.label = support_versace.path.apply(lambda x:path_to_label(x,True))

#DO NOT TRAIN ON VERSACE QUERY/SUPPORT IMAGES
df_all_brands = df_all_brands.loc[~df_all_brands.label.isin(support_versace.label.values)]


# Siamese Network


In [None]:
class SiameseResnet(nn.Module):
    ''' Siamese network to learn images representation'''
    def __init__(self):
        super(SiameseResnet, self).__init__()
        
        # Loading ResNet
        model = models.resnet18(pretrained=True)
            
        # Removing last fully-connected layer
        model.fc = nn.Sequential()
        self.extractor = model
        
        # Fully-connected layers
        self.fc1 = nn.Sequential(nn.Linear(512,512),nn.LeakyReLU())
        self.fc2 = nn.Linear(512,256)

    def forward_one(self, x):
        x = self.extractor(x)
        x = self.fc1(x)
        x = F.sigmoid(self.fc2(x))
        return x

    def forward(self, x1, x2):
        output_1 = self.forward_one(x1)
        output_2 = self.forward_one(x2)
        return output_1, output_2

# Dataset 

In [None]:
class SiameseDataset(Dataset):
    ''' Implements dataset creation for siamese network.'''

    def __init__(self,df_paths_labels,length,chosen_labels=None,transform=None, p=0.5):
        self.df = df_paths_labels
        self.len_df = len(df_paths_labels)    
        if chosen_labels is not None:
            self.chosen_labels = chosen_labels
        else:
            self.chosen_labels = df_paths_labels.label.unique()
        self.length = length
        self.transform = transform
        self.fraction_same = p # Proportion of positive pairs fed during training
        
    def __getitem__(self,index):
        '''Selects first label at random. Second image depends on positive pairs proportion wanted'''
        path_1 = random.choice(self.df.loc[self.df.label.isin(self.chosen_labels)].path.values)
        label_1 = path_to_label(path_1)
        # Dataset with a fraction p of positively labeled pairs
        same_label = random.random()
        same_label = int(same_label < self.fraction_same)

        if same_label:
            # Picks image from the same label as the first one
            path_2 = random.choice(self.df.loc[(self.df.label == label_1) & ~(self.df.path == path_1)].path.values)
            y = torch.from_numpy(np.array([0],dtype=np.float32))
        else:
            # Picks image from a different label
            path_2 = random.choice(self.df.loc[(self.df.label != label_1)].path.values)
            y = torch.from_numpy(np.array([1],dtype=np.float32))

        img_1 = Image.open(path_1).convert("RGB")
        img_2 = Image.open(path_2).convert("RGB")

        if self.transform is not None:
            img_1 = self.transform(img_1)
            img_2 = self.transform(img_2)
        
        return img_1, img_2 , y
    
    def __len__(self):
        return self.length

# Loss

In [None]:
class ContrastiveLoss(nn.Module):
    ''' Implements contrastive loss to train Siamese Network'''

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output_1, output_2, label):
        # Distance between embedded outputs
        L2_distance = F.pairwise_distance(output_1, output_2, keepdim = True)
        
        # Loss calculation
        losses = (1-label) * torch.pow(L2_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - L2_distance, min=0.0), 2)
        contrastive_loss = torch.mean(losses)

        return contrastive_loss

def contrastive_batch_loss(output_1, output_2, label, margin=2.0 ):
    ''' Computes the loss for each pair of images in the batch, not returning mean but a tensor with losses to sort them afterwards'''
    L2_distance = F.pairwise_distance(output_1, output_2, keepdim = True)
    contrastive_loss = (1-label) * torch.pow(L2_distance, 2) + (label) * torch.pow(torch.clamp(margin - L2_distance, min=0.0), 2)
    return contrastive_loss

# Seeds

In [None]:
# Setting of seeds
def enforce_all_seeds(seed):
    '''Forces seeds for reproducibility and consistancy in our results'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    rgen = np.random.default_rng(seed)
    return rgen

# mAP

In [None]:
def mean_average_precision(net, support_df, query_df, custom_transforms):
    '''Computes mAP for the given network on support and query datasets'''
    
    transformer = custom_transforms['val']

    def preprocess(path):
        '''Returns image in tensor formed ready to be fed to Siamese network'''
        return transformer(Image.open(path).convert("RGB"))

    def forward_pass(path, net):
        '''Performs a fordward pass into the net'''
        img = preprocess(path)
        y = net.forward_one(img.unsqueeze(0).cuda())
        y = y.detach().cpu().numpy()
        return y

    support_df['embedded_images'] = support_df.path.apply(lambda x:forward_pass(x, net))
    query_df['embedded_images'] = query_df.path.apply(lambda x:forward_pass(x, net))

    def calculate_AP(label):
        '''calculates AP for the given label'''
        # Ground truth vector
        y_ground = support_df.label.apply(lambda x: 1 if x==label else 0).values
        
        # Embedded query
        img_embedded = query_df.embedded_images.loc[query_df.label == label].values[0]
                
        def distance_to_query(x):
            return -np.linalg.norm(x-img_embedded)
        
        # Distances vector
        y_distances = support_df.embedded_images.apply(lambda x: distance_to_query(x))
        return average_precision_score(y_ground, y_distances)

    query_df['AP'] = query_df.label.apply(lambda x:calculate_AP(x))
    mAP = np.mean(query_df['AP'].values)
    return mAP

# Training

## Dataloader

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [None]:
rgen = enforce_all_seeds(42)

batch_size = 250

data = SiameseDataset(train_df, 
                      length = 10000,
                      chosen_labels = None,     
                      transform = data_transforms['train'], 
                      p=0.5)

siamese_dataloader = DataLoader(data,
                                shuffle=False,
                                num_workers=0,
                                batch_size=batch_size)

net = SiameseResnet().cuda()

# Training logs
loss_history = []
hard_examples_history = []
mean_loss_history = []
nb_batchs = data.length//batch_size
plt.rcParams['figure.figsize'] = [15, 8]

## Logs names

In [None]:
from datetime import datetime

# Keeping logs
today = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
exp_name = 'resnet_full_retrain_'
weights_file =  exp_name + str(today) +'.pt'
CSV_logs_file = exp_name + str(today) +'.csv'
df_logs = pd.DataFrame({'epoch':[], 'batch':[], 'loss':[], 'hard_loss':[], 'versace_mAP':[], 'givenchy_mAP':[] })

## Training

In [None]:
# Hyperparameters
EPOCHS = 50
top_k = int(batch_size*0.4) # Proportion of elements in each batch used to train the network
margin = 4.0 # Loss margin

criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.00005) # Original = 0.00005
groups = {'contrastive loss': ['mean loss', 'hard_examples loss', 'last batch loss'], 'mAP': ['Versace mAP', 'Givenchy mAP']}
liveloss = PlotLosses(groups=groups)

start = time.time()
for epoch in range(EPOCHS):
    mean_epoch_loss = []
    for i, batch in enumerate(siamese_dataloader):
        net.train()
        img_1, img_2 , label = batch
        img_1, img_2 , label = img_1.cuda(), img_2.cuda() , label.cuda()
        optimizer.zero_grad()

        with torch.no_grad():
            # Computes the loss for each batch element, sorts the batch to get negatives examples
            temp_output_1, temp_output_2 = net(img_1,img_2)
            temp_loss_contrastive = contrastive_batch_loss(temp_output_1,temp_output_2,label)
            mean_epoch_loss.append(torch.mean(temp_loss_contrastive).item())      
            loss, indexes = torch.sort(temp_loss_contrastive, dim=0)

        # Negative mining: retrieves indexes of the given % of worse examples
        indexes = indexes[-top_k:] 
        input_1, input_2 = img_1[indexes].squeeze(), img_2[indexes].squeeze()
        mined_labels = label[indexes].squeeze(dim=2)

        # Computes forward pass on selected pairs.
        output_1, output_2 = net(input_1,input_2)
        loss_contrastive = criterion(output_1,output_2,mined_labels)
        loss_contrastive.backward()
        optimizer.step()

        print(f"\r Epoch number {epoch+1}/{EPOCHS}, batch number {i+1}/{int(data.length/batch_size)}, current hard loss={loss_contrastive.item(): .5}, batch loss={mean_epoch_loss[-1]: .5}", end='')

        # Logs update: CSV files and weights are saved at each epoch
        net.eval()
        hard_examples_history.append(loss_contrastive)
        loss_history.append(torch.mean(temp_loss_contrastive))
        if i + 1 != nb_batchs:
            df_logs = df_logs.append(pd.DataFrame({'epoch':[epoch+1], 'batch':[i+1], 'loss':[mean_epoch_loss[-1]], 'hard_loss':[loss_contrastive.item()], 'versace_mAP':[np.nan], 'givenchy_mAP':[np.nan]}))
        else:
            versace_mAP = mean_average_precision(net, support_versace, query_versace, data_transforms)
            givenchy_mAP = mean_average_precision(net, support_givenchy, query_givenchy, data_transforms)
            df_logs = df_logs.append(pd.DataFrame({'epoch':[epoch+1], 'batch':[i+1], 'loss':[mean_epoch_loss[-1]], 'hard_loss':[loss_contrastive.item()], 'versace_mAP':[versace_mAP], 'givenchy_mAP':[givenchy_mAP]}))

    mean_loss_history.append(np.mean(mean_epoch_loss))
    df_logs.to_csv(ROOT + 'training_logs_Navee/' + CSV_logs_file, index=False)
    
    # Livelossplot logs
    liveloss.update({'mean loss': np.mean(mean_epoch_loss),
                     'hard_examples loss': loss_contrastive.item(),
                     'last batch loss': mean_epoch_loss[-1],
                     'Versace mAP': versace_mAP,
                     'Givenchy mAP': givenchy_mAP})
    liveloss.send()
    torch.save(net,ROOT+'training_logs_Navee/'+weights_file)
end = time.time()
print(f'\n Training time {end-start: .5}s, that is {int((end-start)//3600)}h{int((end-start)%3600/60)}min')