# Siamese CNN for Triplet Loss and Contrastive Loss: Training Notebook

In [1]:
# INTRODUCTION
#  A Siamese CNN is a convolutional neural network which produces embeddings of several images, which enables to compare the embeddings afterwards and determine how similar they are. 
# We have implemented this in two ways: With the triplet loss and the contrastive loss. Both of these loss functions can be used to train the model within this notebook.
# The two loss functions require a different dataloader and predict the treshold for similarity in different ways. Since BOTH LOSS FUNCTIONS PERFORMED SIMILARLY WELL, 
# we decided to include them both to show how differently a Siamese CNN can be implemented. 

# TRIPLET LOSS: EXPLANATION
# Our implementation of TripletLoss is based on the paper "FaceNet: A Unified Embedding for Face Recognition and Clustering" (Schroff et al., 2015, http://www.arxiv.org/pdf/1503.03832). 
# TripletLoss is a loss function which encourages similar embeddings to have small distances, while it encourages different embeddings to have large distances. During training it therefore
# needs a triplet of embeddings: An anchor, a positive and a negative. 
# To improve training performance, only triplets which are difficult to the model are chosen. This is done by letting the model compute large batches of images and produce embeddings,
# and only then forming difficult triplets. These can then be passed to the loss function and used to train the model effectively. 

# CONTRASTIVE LOSS: EXPLANATION
# The Constrative Loss function also produces embeddings and tries to maximise the distance for image pairs of different persons, while minimising the distance for image pairs of the same person. 
# To feed two images to the model, a custom Dataset is implemented which randomly chooses them from the whole data.

# TLDR: We have implemented two loss functions, and one can choose which one to use within the "main" cell below.

## Step 1: Preparation of Google Collab Environment

In [None]:
# IMPORT DATA
# For the Triplet Loss a data split of Training(90%), Validation(5%) and Testing (5%) was used. 
# For the Contrastive Loss a data split of Training(80%), Validation(10%) and Testing (20%) was used.
# Both models tested and validated their performance on an unseen fraction of the provided synthetic data. 
# Because the ImageFolder Dataset provides the most utility, the data was split beforehand in the filesystem. 
# When trying to use randomSplit or train_test_split, there were difficulties with further handling the data.

In [None]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [1]:
# Create a folder structure for the training, validation and testing data
!mkdir /content/datasets
!mkdir /content/datasets/train_zip
!mkdir /content/datasets/val_zip
!mkdir /content/datasets/test_zip
!mkdir /content/datasets/train_img
!mkdir /content/datasets/val_img
!mkdir /content/datasets/test_img

In [4]:
# Copy the corresponding training, validation and testing data into the previously created folders, which lay in the runningtime environment of google colab
# Duration: approx 2 minutes
!cp /content/drive/MyDrive/Colab_Notebooks/da2_train_aug.zip /content/datasets/train_zip
!cp /content/drive/MyDrive/Colab_Notebooks/da2_val.zip /content/datasets/val_zip
!cp /content/drive/MyDrive/Colab_Notebooks/da2_test.zip /content/datasets/test_zip

In [None]:
# Unzip all files - Duration: Approx. 3 minutes
!unzip /content/datasets/train_zip/da2_train.zip -d /content/datasets/train_img
!unzip /content/datasets/val_zip/da2_val.zip -d /content/datasets/val_img
!unzip /content/datasets/test_zip/da2_test.zip -d /content/datasets/test_img

## Step 2: Define the actual program elements

### Step 2.1: Specify used imports

In [None]:
# INSTALL
# If not already installed. This library is used to implement triplet loss and online triplet mining.
!pip install pytorch_metric_learning

In [5]:
# FUNCTIONS
import matplotlib.pyplot as plt
import sys
import numpy as np
import pandas as pd
import random
from PIL import Image
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.transforms import RandomApply
from torch.utils.data import DataLoader, Dataset, Subset
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
from pytorch_metric_learning import distances, losses, miners, reducers, samplers
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay
import torchvision
import itertools
from torch.optim.lr_scheduler import StepLR

### Step 2.2: Define the DataLoader for the Contrastive Loss function

In [6]:
# DATALOADER CONTRASTIVE LOSS
# This is the dataloader used for the contrastive loss implementation. Two images are returned,
# and with a 50/50 probability they belong to the same person or not. 

class SiameseNetworkDataset(Dataset):
    def __init__(self, image_folder_dataset, transform=None):
        self.image_folder_dataset = image_folder_dataset
        self.transform = transform

    def __getitem__(self, index):

        # Select a random image from the given image folder
        img0_tuple = random.choice(self.image_folder_dataset.imgs)

        # Select if you choose two images of the same class (similar person) or different class (different person)
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            while True:
                # Keep looping till the same class image is found
                img1_tuple = random.choice(self.image_folder_dataset.imgs)
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            while True:
                # Keep looping till a different class image is found
                img1_tuple = random.choice(self.image_folder_dataset.imgs)
                if img0_tuple[1] != img1_tuple[1]:
                    break

        # Open the image
        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])

        # Extract the image path
        img0_path = img0_tuple[0]
        img1_path = img1_tuple[0]

        # Label: 0 for same class, 1 for different class - Necessary for the contrastive loss function, we want to represent a distance,
        # that represents if two images are similar (small distance) or different (large distance)
        label = torch.tensor([int(img0_tuple[1] != img1_tuple[1])], dtype=torch.float32)

        # Transform images to a tensor
        img0 = self.transform(img0)
        img1 = self.transform(img1)

        # Return tuple, which contains the two choosen images with their corresponding paths and label
        return img0, img1, img0_path, img1_path, label

    # Check how many images the given image folder contains
    def __len__(self):
        return len(self.image_folder_dataset.imgs)

### Step 2.3 Define the CNN architecture

In [8]:
# MODEL ARCHITECTURE
# Used architecture with ~109.000 Parameters - It contains three convolutional layers, with a corresponding pooling layer,
# a global average pooling and one fully connected layer
class FaceRecognitionModel(nn.Module):
    def __init__(self, embedding_size=128):
        super(FaceRecognitionModel, self).__init__()

        # Start with 3 input channels (RGB image), 32 output channels to extract initial features
        # Kernel size 3 and padding/stride of 1 to retain the image size without elevating the parameter count (as would happen with a kernel size of 5)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

        # Pooling to reduce spatial dimensions and speed up computation
        # Also enhances translational invariance by reducing sensitivity to exact image positions
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        # Increasing depth with the second convolutional layer
        # Still retaining the same size with a kernel size of 3 and padding/stride of 1, followed by pooling
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

        # Third convolutional layer for further depth and feature extraction, followed by pooling again
        # Not increasing the feature maps beyond 128 to keep the parameter count low
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Global Average Pooling reduces the spatial dimensions of the feature maps to a single value
        # This significantly reduces the number of parameters and computation time
        # It also helps in reducing overfitting and sensitivity to exact image positions
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layer to convert pooled features into a compact embedding vector
        # Embedding represents high-dimensional data in a lower-dimensional space (128 dimensions in this case)
        self.fc1 = nn.Linear(128, embedding_size)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)   # Flatten
        x = self.fc1(x)             # Output: embedding vector
        return x

### Step 2.4: Define contrastive loss function

In [9]:
# CONTRASTIVE LOSS FUNCTION
# Implementation of the contrastive loss function with default margin = 1.0
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

### Step 2.5: Define functions for the training and validation procedure

In [10]:
# CONTRASTIVE LOSS TRAINING
# Method to conduct a training with the given hyperparameters and the contrastive loss
def train_contrastive(model, epoch_number, train_dataloader, val_dataloader, optimizer, criterion_name, device, treshold, path_to_store_model_weights, path_to_store_model):
    print("Start Training")
    loss_history = []
    batch_iteration_counter = 0
    validation_history = []
    error_history = []


    for epoch in range(epoch_number):
        # Training phase
        model.train()
        total_train_loss = 0
        correct_train = 0
        total_train = 0

        for batch_idx, (img0, img1, img0_path, img1_path, labels) in enumerate(train_dataloader):
            img0, img1, labels = img0.to(device), img1.to(device), labels.to(device)

            optimizer.zero_grad()

            output1 = model(img0)
            output2 = model(img1)

            criterion = ContrastiveLoss(margin=1.0)
            loss = criterion(output1, output2, labels)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            total_train += labels.size(0)
            loss_history.append(total_train_loss / (batch_idx + 1))
            batch_iteration_counter += 1

            # Log and save the loss every 50 batches
            if (batch_idx + 1) % 50 == 0:
                current_loss = total_train_loss / (batch_idx + 1)
                print(f"Batch {batch_idx+1}/{len(train_dataloader)}, Current Loss: {current_loss:.4f}")

        train_loss = total_train_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{epoch_number}, Train Loss: {train_loss:.4f}")

        try:
            torch.save(model, f'{path_to_store_model_weights}/siamese_model_epoch_{epoch+1}.pth')
            torch.save(model.state_dict(), f'{path_to_store_model}/siamese_model_weights_epoch{epoch+1}.pth')
        except Exception as e:
          print(f"An error occured while storing the model or the model weights: {e}")

        validation_process, error_process = validate_contrastive(model, val_dataloader, device, treshold)
        validation_history.append(validation_process)
        error_history.append(error_process)
    return loss_history, batch_iteration_counter, validation_history, validation_history


In [11]:
# CONTRASTIVE LOSS VALIDATION
# Method to validate the model during the training process
def validate_contrastive(model, val_dataloader, device, threshold):
    model.eval()  # Set the model to evaluation mode
    correct = 0.0
    total = 0.0
    accuracy_process = []
    error_process = []

    with torch.no_grad():
        for batch_idx, (img1, img2, img0_path, img1_path, labels) in enumerate(val_dataloader):
          labels = labels.view(-1, 1).float()
          img1 = img1.to(device)
          img2 = img2.to(device)
          labels = labels.to(device)
          img1 = Variable(img1)
          img2 = Variable(img2)
          labels = Variable(labels)

          output1 = model(img1)
          output2 = model(img2)

          # Calculate Similarity with euclidean distance
          euclidean_distance = F.pairwise_distance(output1, output2)
          predicted = torch.tensor([0 if sd < threshold else 1 for sd in euclidean_distance]).to(device)

          # # Also possible: Calculate similarity with Cosine Similarity
          #cosine_similarity = F.cosine_similarity(output1, output2)
          #predicted = torch.tensor([0 if tresh > threshold else 1 for tresh in cosine_similarity]).to(device)

          predicted_array = predicted.cpu().numpy()
          labels_array = labels.cpu().numpy().flatten().astype(int)

          zero_matches = np.sum((labels_array == 0) & (predicted_array == 0))
          one_matches = np.sum((labels_array == 1) & (predicted_array == 1))

          total += labels.size(0)
          correct += zero_matches + one_matches
    accuracy = 100 * (correct / total)
    error = 1 - (correct / total)

    accuracy_process.append(accuracy)
    error_process.append(error_process)
    print('Validation Accuracy of the network: {:.2f}%'.format(accuracy))

    return accuracy_process, error_process


In [12]:
# TRIPPLET LOSS TRAINING
def train_tripplet(model, train_loader, val_loader, optimizer, mining_func, loss_fn, epoch_number, threshold, device, path_to_store_model_weights, path_to_store_model):
    for epoch in range(epoch_number):
      # TRAINING
      print("Training")
      model.train()
      for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()

        # send all images through the model to produce embeddings
        embeddings = model(data)

        # call mining function to find difficult triplets
        triplets = mining_func(embeddings, labels)

        # call loss function, pass in difficult triplets
        loss = loss_fn(embeddings, labels, triplets)

        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Iteration {batch_idx}: Loss = {loss}, Number of mined triplets = {mining_func.num_triplets}")

      try:
        torch.save(model, f'{path_to_store_model_weights}/siamese_model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), f'{path_to_store_model}/siamese_model_weights_epoch{epoch+1}.pth')
      except Exception as e:
          print(f"An error occured while storing the model or the model weights: {e}")

      validate_tripplet(model, val_loader, device, threshold, loss_fn)

In [13]:
# TRIPPLET LOSS VALIDATION
def validate_tripplet(model, val_loader, device, threshold, loss_fn):
    # VALIDATION
    print("Validation")
    model.eval()

    val_loss = 0
    val_acc = 0
    num_batches = 0
    val_loss_sum = 0
    val_acc_sum = 0
    num_pairs = 0

    with torch.no_grad():
            for data, labels in val_loader:
                num_batches += 1

                # Reminder: 40 images from two identities are loaded in one batch.
                data, labels = data.to(device), labels.to(device)
                embeddings = model(data)

                # VALIDATION LOSS
                # Calculate the loss based on all embeddings in the batch (no mining is used, therefore the loss function
                # forms tripplets itself, see documentation of Pytorch metric Learning)
                batch_loss = loss_fn(embeddings, labels)

                # VALIDATION PREDICTIONS AND VALIDATION ACCURACY
                # Create a tensor which indexes all possible pairs of images in the batch
                idx = list(range(len(labels)))
                pairs = list(itertools.combinations(idx, 2))
                pairs_tensor = torch.tensor(pairs)

                batch_predictions_made = 0
                batch_correct = 0
                for pair in pairs_tensor:
                    num_pairs += 1

                    label1 = labels[pair[0]].unsqueeze(0)
                    label2 = labels[pair[1]].unsqueeze(0)
                    label = torch.tensor(0 if label1 == label2 else 1).to(device)

                    embedding1 = embeddings[pair[0]].unsqueeze(0)
                    embedding2 = embeddings[pair[1]].unsqueeze(0)

                    # CHOICE OF THRESHOLD - EUCLIDEAN OR COSINE DISTANCE?
                    # In the FaceNet Paper, a squared euclidean distance treshold was used to determine wether
                    # images are similar or not.

                    # Instead you can also use the cosine_similarity, as it can be converted more intuitively
                    # into the range of 0 to 1 because higher values represent higher similarity, which is
                    # also requested by the case study.

                    # SQUARED EUCLIDEAN DISTANCE TRESHOLD (not used)
                    #output1N = F.normalize(embedding1, p=2, dim=1)
                    #output2N = F.normalize(embedding2, p=2, dim=1)
                    #squared_euclidean_distance = F.pairwise_distance(output1N, output2N)**2
                    #prediction = torch.tensor([0 if tresh < 1.1 else 1 for tresh in squared_euclidean_distance]).to(device)

                    # COSINE SIMILARITY TRESHOLD
                    cosine_similarity = F.cosine_similarity(embedding1, embedding2)
                    prediction = torch.tensor([0 if tresh > threshold else 1 for tresh in cosine_similarity]).to(device)


                    batch_predictions_made += 1
                    batch_correct += (prediction == label).sum().item()
                    batch_accuracy = 100 * (batch_correct / batch_predictions_made)


                val_loss_sum += batch_loss
                val_loss = val_loss_sum / num_batches

                val_acc_sum += batch_accuracy
                val_acc = val_acc_sum / num_batches

                print("Total Loss.: ",val_loss.item())
                print("Total Acc.: ",val_acc)
                print("Image Pairs Computed: ", num_pairs)
                print("Batches Remaining: ", len(val_loader)-num_batches+1)

### Step 2.6 Define functions for making the actual predictions

In [14]:
# CONTRASTIVE LOSS PREDICTION
# Method to test the trained model - in contrast to the validation loop this method returns an evaluation.csv file which containts the image paths, the true label, the predicted label and the distance between two embeddings.
def predict_contrastive(model, val_dataloader, device, threshold):
    model.eval()
    correct = 0.0
    total = 0.0
    true_labels = []
    predicted_labels = []
    distance = []
    img1_paths = []
    img2_paths = []

    with torch.no_grad():
        for batch_idx, (img1, img2, img1_path, img2_path, labels) in enumerate(val_dataloader):
          labels = labels.view(-1, 1).float()
          img1 = img1.to(device)
          img2 = img2.to(device)
          labels = labels.to(device)
          img1 = Variable(img1)
          img2 = Variable(img2)
          labels = Variable(labels)

          output1 = model(img1)
          output2 = model(img2)
          euclidean_distance = F.pairwise_distance(output1, output2)
          predicted = torch.tensor([0 if sd < threshold else 1 for sd in euclidean_distance]).to(device)

          predicted_array = predicted.cpu().numpy()
          labels_array = labels.cpu().numpy().flatten().astype(int)
          euclidean_distance = euclidean_distance.cpu().numpy().astype(float)

          zero_matches = np.sum((labels_array == 0) & (predicted_array == 0))
          one_matches = np.sum((labels_array == 1) & (predicted_array == 1))

          true_labels.extend(labels_array)
          predicted_labels.extend(predicted_array)
          distance.extend(euclidean_distance)
          img1_paths.extend(img1_path)
          img2_paths.extend(img2_path)

          total += labels.size(0)
          correct += zero_matches + one_matches
    accuracy = 100 * (correct / total)
    print('Accuracy of the network on the test set: {:.2f}%'.format(accuracy))

    evaluation_df = pd.DataFrame({
    'Image1 Path': img1_paths,
    'Image2 Path': img2_paths,
    'True Label': true_labels,
    'Predicted Label': predicted_labels,
    'Distance': distance })

    return evaluation_df

In [15]:
# TRIPPLET LOSS PREDICTION
def prediction_tripplet(model, test_loader, device, threshold, loss_fn):
    true_labels = []
    predicted_labels = []
    distance = []

    test_loss = 0
    test_acc = 0

    num_batches = 0
    test_loss_sum = 0
    test_acc_sum = 0

    num_pairs = 0
    model.eval()
    with torch.no_grad():
        for data, labels in test_loader:
            num_batches += 1
            data, labels = data.to(device), labels.to(device)
            embeddings = model(data)

            #calculate the loss based on all embeddings in the batch
            batch_loss = loss_fn(embeddings, labels)
            #print("Batch Loss: ",batch_loss.item())

            #create a tensor which indexes all possible pairs of images in the batch
            idx = list(range(len(labels)))
            pairs = list(itertools.combinations(idx, 2))
            pairs_tensor = torch.tensor(pairs)

            #for each pair, make the prediction and calculate accuracy
            batch_predictions_made = 0
            batch_correct = 0
            for pair in pairs_tensor:
                num_pairs += 1

                label1 = labels[pair[0]].unsqueeze(0)
                label2 = labels[pair[1]].unsqueeze(0)
                label = torch.tensor(0 if label1 == label2 else 1).to(device)

                embedding1 = embeddings[pair[0]].unsqueeze(0)
                embedding2 = embeddings[pair[1]].unsqueeze(0)

                #output1N = F.normalize(embedding1, p=2, dim=1)
                #output2N = F.normalize(embedding2, p=2, dim=1)
                #squared_euclidean_distance = F.pairwise_distance(output1N, output2N)**2
                #prediction = torch.tensor([0 if thresh < 1.1 else 1 for tresh in squared_euclidean_distance]).to(device)

                cosine_similarity = F.cosine_similarity(embedding1, embedding2)
                prediction = torch.tensor([0 if tresh > threshold else 1 for tresh in cosine_similarity]).to(device)

                true_labels.append(label.item())
                predicted_labels.append(prediction.item())
                distance.append(cosine_similarity.item())

                #print("distance:",cosine_similarity)
                #print("prediction", prediction.item(), "label", label.item())
                #imshow(data[pair[0]], data[pair[1]])

                batch_predictions_made += 1
                batch_correct += (prediction == label).sum().item()
                batch_accuracy = 100 * (batch_correct / batch_predictions_made)

            #print("Batch Acc.: ",batch_accuracy)

            test_loss_sum += batch_loss
            test_loss = test_loss_sum / num_batches

            test_acc_sum += batch_accuracy
            test_acc = test_acc_sum / num_batches


            print("Testing Loss.: ",test_loss.item())
            print("Testing Accuracy.: ",test_acc)
            print("Image Pairs Computed: ", num_pairs)
            print("Batches Remaining: ", len(test_loader)-num_batches+1)

    evaluation_df = pd.DataFrame({
        'True Label': true_labels,
        'Predicted Label': predicted_labels,
        'Distance': distance })
    return evaluation_df

## Step 3: Define the program procedure for the corresponding loss function implementations

In [16]:
# CONTRASTIVE LOSS PROCEDURE
def run_contrastive(path_to_train_images, path_to_validation_images, path_to_test_images, siamese_model,
                         epoch_number, batch_size_training, batch_size_validation, learning_rate, threshold,
                         optimizer, device, path_to_store_model_weights, path_to_store_model):

    # Step 1: Load the training, validation and test dataset with ImageFolder
    train_dataset = datasets.ImageFolder(path_to_train_images)
    val_dataset = datasets.ImageFolder(path_to_validation_images)
    test_dataset = datasets.ImageFolder(path_to_test_images)

    # Step 2: Transform to tensors
    transformation = transforms.ToTensor()

    # Step 3: Initialize the Network Data Loader for training and validation
    siamese_train_dataset = SiameseNetworkDataset(train_dataset, transform=transformation)
    siamese_val_dataset = SiameseNetworkDataset(val_dataset, transform=transformation)
    siamese_test_dataset = SiameseNetworkDataset(test_dataset, transform=transformation)

    # Step 4: Print out the number of trainable params
    total_params = sum(p.numel() for p in siamese_model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters {total_params}")

    # Step 5: Load training-, validation- and testdataset with DataLoader
    train_dataloader = DataLoader(siamese_train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size_training)

    val_dataloader = DataLoader(siamese_val_dataset,
                                 shuffle=False,
                                 batch_size=batch_size_validation)

    test_dataloader = DataLoader(siamese_test_dataset,
                                 shuffle=False,
                                 batch_size=batch_size_validation)
    # Step 6: Start Training Loop
    loss_process, batch_iteration_counter, accuracy_process, error_process = train_contrastive(siamese_model, epoch_number, train_dataloader, val_dataloader, optimizer, "Contrastive", device, threshold, path_to_store_model_weights, path_to_store_model)

    # Step 7: Make actual predictions
    evaluation_df = predict_contrastive(siamese_model, test_dataloader, device, threshold)

    # Step 8: Store evaluation_df in a csv file
    evaluation_df.to_csv("contrastive_evaluation_df", sep=",", index=False)

In [17]:
# TRIPPLET LOSS PROCEDURE
def run_tripplet(path_to_training_images, path_to_validation_images, path_to_test_images, siamese_model,
                         epoch_number, batch_size_training, batch_size_validation, learning_rate, threshold,
                         optimizer, device, path_to_store_model_weights, path_to_store_model):

    # Step 1: Transform image to tensor
    transformation = transforms.ToTensor()

    # Step 2: Load the training-, validation- and test dataset with ImageFolder and extract the corresponding labels
    training_dataset = datasets.ImageFolder(path_to_training_images, transformation)
    training_labels = [item[1] for item in training_dataset.imgs]

    val_dataset = datasets.ImageFolder(path_to_validation_images, transformation)
    val_labels = [item[1] for item in val_dataset.imgs]

    test_dataset = datasets.ImageFolder(path_to_test_images, transformation)
    test_labels = [item[1] for item in test_dataset.imgs]

    # Step 3: Normalize distance and define, mining_function, thresholdReducer and loss_function
    distance = distances.LpDistance(normalize_embeddings=True, p=2, power=1)
    mining_func = miners.TripletMarginMiner(margin=0.2, type_of_triplets="semihard", distance=distance)
    reducer = reducers.ThresholdReducer(low=0)
    loss_fn = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer, swap=False, triplets_per_anchor="all", smooth_loss=False)

    # Step 4: Load training-, validation- and testdataset with DataLoader and sampler
    train_loader = DataLoader(training_dataset, batch_size=batch_size_training, sampler = samplers.MPerClassSampler(labels=training_labels, m=20, length_before_new_iter=180000))
    val_loader = DataLoader(val_dataset, batch_size=batch_size_validation, sampler = samplers.MPerClassSampler(labels=val_labels, m=20, length_before_new_iter=100000))
    test_loader = DataLoader(test_dataset, batch_size=batch_size_validation, sampler = samplers.MPerClassSampler(labels=test_labels, m=20, length_before_new_iter=100000))

    # Step 5: Start training loop
    train_tripplet(siamese_model, train_loader, val_loader, optimizer, mining_func, loss_fn, epoch_number, threshold, device, path_to_store_model_weights, path_to_store_model)

    # Step 6: Make actual predictions
    evaluation_df = prediction_tripplet(siamese_model, test_loader, device, threshold, loss_fn)

    # Step 7: Store evaluation_df in a csv file
    evaluation_df.to_csv("tripplet_evaluation_df", sep=",", index=False)

## Step 4: Define main function --> Run the main function to run the program!

In [21]:
### THE ACTUAL PROGRAM STARTS FROM HERE! ###

'''
   HYPERPARAMETERS USED FOR CONTRASTIVE_LOSS:
    1. epoch_number = 10
    2. batch_size_training = 32
    3. batch_size_validation = 40
    4. learning_rate = 0.001
    5. treshold = 0.64

   HYPERPARAMETERS USED FOR Triplet Loss:
    1. epoch_number = 10
    2. batch_size_training = 1000 --> A large batch size of 1000 was chosen to ensure enough "semihard" triplets can be found within each batch during the online triplet mining.
    3. batch_size_validation = 40 --> The batch size 40 was chosen to ensure that only two identities are loaded into one batch during validation. All possible combinations 
                                        between images are computed during validation, and random combinations of images are provided by the sampler.
    4. learning_rate = 0.005
    5. treshold = 0.0625
    5. TRIPLET LOSS: SAMPLER      --> The MPerClassSampler retrieves all 20 images of one identity into a batch. This is to ensure that enough positives are available to train the model.
                                        Setting length_before_new_iter to the number of images in the Dataset (180000) determines the lenght of the Dataloader (180000/1000) = 180. Later, this is the 
                                        number of iterations per epoch. 

    
    COMMENT ON TRESHOLDS: Optimal thresholds were chosen by running the evaluation_df trough a ROC curve, which is done in another notebook. 
    COMMENT ON SIMILARITY PROBABILITY: To get a similarity probability of two embeddings, the Cosine Similarity or Euclidean distance needs to be converted. This is also done in the other notebook.
                                        
                                        '''

def main():

    # Step 1: Define paths to the training-, validation- and test set and where the weights and models should be stored
    path_to_training_images = 'dataset/unzipped/generated_images_10Kids_cropped/training'
    path_to_validation_images = 'dataset/unzipped/generated_images_10Kids_cropped/validation'
    path_to_test_images = 'dataset/unzipped/generated_images_10Kids_cropped/testing'
    path_to_store_weights = 'dataset'
    path_to_store_model = '/dataset'

    # Step 2: Define device for GPU usage if possible
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Step 3: Create initial CNN model instance
    siamese_model = FaceRecognitionModel(embedding_size=128).to(device)

    # Step 4: Set Hyperparameters
    epoch_number = 10
    batch_size_training = 1000
    batch_size_validation = 40
    learning_rate = 0.005
    loss_function = "Tripplet" # Choose one of "Contrastive" or "Tripplet"
    threshold = 0.64
    optimizer = optim.Adam(siamese_model.parameters(), lr=learning_rate)

    # Step 5: Start the program with the corresponding loss function implementation
    if(loss_function == "Contrastive"):
        run_contrastive(path_to_training_images, path_to_validation_images, path_to_test_images, siamese_model,
                         epoch_number, batch_size_training, batch_size_validation, learning_rate, threshold,
                         optimizer, device, path_to_store_weights, path_to_store_model)
    elif(loss_function == "Tripplet"):
        run_tripplet(path_to_training_images, path_to_validation_images, path_to_test_images, siamese_model,
                         epoch_number, batch_size_training, batch_size_validation, learning_rate, threshold,
                         optimizer, device, path_to_store_weights, path_to_store_model)
    else:
        print("The defined loss function is not implemented - Please use Contrastive or Tripplet Loss")
        sys.exit(-1)

if __name__ == '__main__':
  main()

KeyboardInterrupt: 

##  ------------------------------------------------------
## Archive

In [None]:
### ONLY FOR ILLUSTRATION PURPOSES! ###
# Architecture with ~185.000 Parameters - This class is not used anymore, but for illustration purposes still in this notebook
class SiameseNeuralNetwork(nn.Module):
    def __init__(self):
        super(SiameseNeuralNetwork, self).__init__()

        # Convolutional layers with BatchNormalization
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),

            nn.Conv2d(64, 128, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        # Fully connected layers with BatchNormalization
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            #nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x