# Deep Metric Learning with FAU's Papyrus Collection

## Customize Notebook and Install Dependencies

In [None]:
!pip install faiss-gpu
!pip install pytorch-metric-learning
!pip install efficientnet_pytorch    <



In [None]:
!pip install PyPDF2
!pip install FPDF

# For Timo's code
from tqdm import tqdm_notebook as tqdm
from skimage import io, transform
from PyPDF2 import PdfFileMerger
from shutil import copyfile
from os.path import isfile, join
import matplotlib.pyplot as plt
from google.colab import drive

from os import listdir
from fpdf import FPDF
import pandas as pd
import numpy as np
import PIL
import time
import toml
import cv2
import os

# For custom dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models

# From Kevin Musgraves GitHub
from efficientnet_pytorch import EfficientNet
from pytorch_metric_learning import distances, losses, miners, reducers, testers, samplers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch

# For Logging
from scipy.interpolate import make_interp_spline

# Notebook Seetings
tqdm().pandas()
import matplotlib
from matplotlib import rc
rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']
!apt install texlive-fonts-recommended texlive-fonts-extra cm-super dvipng



!rm "log.txt"
import logging
#logging.basicConfig(filename="test.log", level=logging.INFO )
logger = logging.getLogger('log')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler('log.txt')
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)
logger.propagate = False



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


0it [00:00, ?it/s]

Reading package lists... Done
Building dependency tree       
Reading state information... Done
cm-super is already the newest version (0.3.4-11).
dvipng is already the newest version (1.15-1).
texlive-fonts-extra is already the newest version (2017.20180305-2).
texlive-fonts-recommended is already the newest version (2017.20180305-1).
0 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.


In [None]:
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
if False:
  info_path = './gdrive/MyDrive/mt/data/06_test/processed_info.csv'
  info_frame = pd.read_csv(info_path, index_col=0, dtype={'fnames_raw':str,'fnames':str,'fragmentID':int,'papyID':int,'posinfo':str, 'pixelCentimeter':float, 'Simpleposinfo':str,'papyPosID':str}, header=0)
  info_frame.papyPosID = info_frame.papyPosID.str.replace('v','0')
  info_frame.papyPosID = info_frame.papyPosID.str.replace('r','1')
  info_frame.papyPosID = info_frame.papyPosID.astype(int)
  info_frame.to_csv(info_path)

## Prepare Dataset

In [None]:
def get_info(path):
    info_path = join(path, 'info.csv')
    if isfile(info_path): 
        info_frame = pd.read_csv(info_path, index_col=0, dtype={'fnames_raw':str,
                                                                'fnames':str,
                                                                'fragmentID':int,
                                                                'papyID':int,
                                                                'posinfo':str,
                                                                'pixelCentimeter':float,
                                                                'Simpleposinfo':str,
                                                                'papyPosID':int}, header=0)
    else:
        fnames = [f for f in listdir(path) if isfile(join(path, f))]
        fnames = [ x for x in fnames if ".jpg" in x ]
        fnames = [f.split('.',1)[0] for f in fnames]
        info_frame = pd.DataFrame(fnames, columns=['fnames'])
        info_frame['papyID'] = info_frame.fnames.apply(lambda x: x.split('_',1)[0])
        info_frame['posinfo'] = info_frame.fnames.apply(lambda x: ''.join(filter(str.isalpha, x)))
        info_path = join(path, 'info.csv')
        info_frame['pixelCentimer'] = info_frame.fnames.progress_apply(get_estimated_resulution)
        split_info_frame = pd.DataFrame(info_frame['pixelCentimer'].tolist(), columns=['pixelCM_Y','pixelCM_X'])
        info_frame = pd.concat([info_frame, split_info_frame], axis=1)
        info_frame.drop('pixelCentimer', axis=1, inplace=True)       
        info_frame['pixelCM'] = info_frame[['pixelCM_Y','pixelCM_X']].max(axis=1)
        info_frame.drop(columns=['pixelCM_X','pixelCM_Y'], inplace=True)
        info_frame.to_csv(info_path)        
        time.sleep(10)
    return info_frame

In [None]:
def retrive_size_by_fname(fname):
  path = './gdrive/MyDrive/mt/data/'
  info_frame = get_info(path=path)
  return float(info_frame.loc[info_frame['fnames'] == fname]['pixelCM'])

In [None]:
def create_processed_info(path, debug=False):
  if debug:
    info_path = join(path, 'debug_processed_info.csv')
  else:
    info_path = join(path, 'processed_info.csv')
  if isfile(info_path):
    processed_frame = pd.read_csv(info_path, index_col=0, dtype={'fnames':str,'papyID':int,'posinfo':str, 'pixelCentimer':float}, header=0)    
  else:    
    fnames = [f for f in listdir(path) if isfile(join(path, f))]
    fnames = [ x for x in fnames if ".png" in x ]
    fnames = [f.split('.',1)[0] for f in fnames]
    fnames_frame = pd.DataFrame(fnames,columns=['fnames'])
    fragmentID = pd.DataFrame([f.split('_',1)[0] for f in fnames], columns=['fragmentID'])
    fnames_raw = [f.split('_',1)[1] for f in fnames]
    processed_frame = pd.DataFrame(fnames_raw, columns=['fnames_raw'])
    
    processed_frame = pd.concat([processed_frame, fnames_frame], axis=1)

    processed_frame = pd.concat([processed_frame, fragmentID], axis=1)
    processed_frame['papyID'] = processed_frame.fnames_raw.apply(lambda x: x.split('_',1)[0])
    processed_frame['posinfo'] = processed_frame.fnames_raw.apply(lambda x: ''.join(filter(str.isalpha, x)))
    processed_frame['pixelCentimer'] = processed_frame.fnames_raw.progress_apply(retrive_size_by_fname)
    processed_frame.to_csv(info_path)
     
  return processed_frame

In [None]:
class FAUPapyrusCollectionDataset(Dataset):
    """FAUPapyrusCollection dataset."""
    def __init__(self, root_dir, processed_frame, transform=None):

        self.root_dir = root_dir
        self.processed_frame = processed_frame
        self.transform = transform

    def __len__(self):
        return len(self.processed_frame)       

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.processed_frame.iloc[idx, 1])
        
        img_name = img_name + '.png'
        
        #image = io.imread(img_name , plugin='matploPILtlib')        
        image = PIL.Image.open(img_name)
        if self.transform:
            image = self.transform(image)         

        papyID = self.processed_frame.iloc[idx,3]


        return image, papyID
        #sample

## Helper Functions

In [None]:
### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

In [None]:
def create_output_dir(name, experiment_name, x=1):
  
  while True:
        dir_name = (name + (str(x) + '_iteration_' if x is not 0 else '') + 'of_experiment_' + experiment_name).strip()
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)            

            return dir_name
        else:
            x = x + 1

In [None]:
def replace_helper(some_list_1, some_list_2):
  new_list_1 = []
  new_list_2 = []

  for string_a, string_b in zip(some_list_1,some_list_2):     
    new_list_1.append(str(string_a).replace("_", " "))
    new_list_2.append(str(string_b).replace("_", " "))

  return new_list_1, new_list_2

## Model Definition

In [None]:
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

## Define Training

In [None]:
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()

        embeddings = model(data)

        indices_tuple = mining_func(embeddings, labels)

        loss = loss_func(embeddings, labels, indices_tuple)        
        
        loss.backward()        
        optimizer.step()

        # Console-Logging
        if batch_idx % 20 == 0:
          logger.info(f' Training:')
          logger.info(f'  Mined Tiplets {mining_func.num_triplets}')
          logger.info(f'  Loss {loss}')
            
    return loss, mining_func.num_triplets

## Define Validation

In [None]:
def validation(model, loss_func, mining_func, device, eval_loader, optimizer, train_set, eval_set, accuracy_calculator,epoch):
    model.eval()
    
    for batch_idx, (data, labels) in enumerate(eval_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)              

        # Console-Logging
        if batch_idx % 20 == 0:
          logger.info(f' Validation:')
          logger.info(f'  Mined Tiplets: {mining_func.num_triplets}')
          logger.info(f'  Loss{loss}')
                  
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    eval_embeddings, eval_labels = get_all_embeddings(eval_set, model)

    train_labels = train_labels.squeeze(1)
    eval_labels = eval_labels.squeeze(1)

    accuracies = accuracy_calculator.get_accuracy(
        eval_embeddings, train_embeddings, eval_labels, train_labels, False
    )  
    logger.info(f'  AMI {accuracies["AMI"]}')
    logger.info(f'  NMI {accuracies["NMI"]}')
    logger.info(f'  MAP {accuracies["mean_average_precision"]}')
    logger.info(f'  P@1 {accuracies["precision_at_1"]}')
    

    return loss, mining_func.num_triplets, accuracies["AMI"], accuracies["NMI"], accuracies["mean_average_precision"], accuracies["precision_at_1"]

## Define Testing

In [None]:
### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator, epoch):

    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)

    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )  

    # Console-Logging
    logger.info(f' Test:')
    logger.info(f'  AMI {accuracies["AMI"]}')
    logger.info(f'  NMI {accuracies["NMI"]}')
    logger.info(f'  MAP {accuracies["mean_average_precision"]}')
    logger.info(f'  P@1 {accuracies["precision_at_1"]}')

    return accuracies["AMI"], accuracies["NMI"], accuracies["mean_average_precision"], accuracies["precision_at_1"]

## Visualizatin and Logging

In [None]:
def set_size(width, fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    fig_width_pt = width_pt * fraction
    inches_per_pt = 1 / 72.27
    golden_ratio = (5**.5 - 1) / 2
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

In [None]:
def plot_loss(train_loss_values, val_loss_values, epochs, output_path):  
  
  epochs = np.arange(1, epochs + 1)
  train_loss_values = np.array(train_loss_values)
  val_loss_values = np.array(val_loss_values)
  plt.style.use('seaborn')
  width = 460
  
  if True:
    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.loc":'lower left'
    }

    plt.rcParams.update(tex_fonts)
  
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))

  # plot original lines
  ax.plot(epochs, train_loss_values, 'b', label='Training Loss', linestyle='dotted')
  ax.plot(epochs, val_loss_values, 'g', label='Validation Loss', linestyle='dotted')

  # plot smoothed lines
  epochs_smooth = np.linspace(epochs.min(), epochs.max(), 300)
  a_BSpline_train = make_interp_spline(epochs, train_loss_values)
  a_BSpline_val = make_interp_spline(epochs, val_loss_values)

  train_loss_smooth =  a_BSpline_train(epochs_smooth)
  val_loss_smooth = a_BSpline_val(epochs_smooth)
  ax.plot(epochs_smooth, train_loss_smooth, 'b', label='Training Smoothed Loss')
  ax.plot(epochs_smooth, val_loss_smooth, 'g', label='Validation Smoothed Loss')

  
  ax.set_title('Training and Validation Loss')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Loss')
  ax.legend()
  # Save and remove excess whitespace
  fig.savefig(output_path + '/loss.pdf', format='pdf', bbox_inches='tight')
  plt.close()

In [None]:
def plot_acc(val_AMI_values, val_NMI_values,val_mean_average_precision_values, val_precision_at_1_values, epochs, output_path):  
  epochs = np.arange(1, epochs + 1)
  plt.style.use('seaborn')
  width = 460
  tex_fonts = {
      # Use LaTeX to write all text
      "text.usetex": True,
      "font.family": "serif",
      # Use 10pt font in plots, to match 10pt font in document
      "axes.labelsize": 10,
      "font.size": 10,
      # Make the legend/label fonts a little smaller
      "legend.fontsize": 8,
      "xtick.labelsize": 8,
      "ytick.labelsize": 8,
      "legend.loc":'lower left'
  }
  plt.rcParams.update(tex_fonts)
  
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.plot(epochs, val_AMI_values, 'b', label=' Val AMI', linestyle='dotted', linewidth=.3)
  ax.plot(epochs, val_NMI_values, 'g', label='Val NMI', linestyle='dotted', linewidth=.3)
  ax.plot(epochs, val_mean_average_precision_values, 'r', label='Val MAP', linestyle='dotted', linewidth=.3)
  ax.plot(epochs, val_precision_at_1_values, 'm', label='Val P@1', linestyle='dotted', linewidth=.3)

  epochs_smooth = np.linspace(epochs.min(), epochs.max(), 300)
  a_BSpline_AMI = make_interp_spline(epochs, val_AMI_values)
  a_BSpline_NMI = make_interp_spline(epochs, val_NMI_values)
  a_BSpline_mean_average_precision = make_interp_spline(epochs, val_mean_average_precision_values)
  a_BSpline_precision_at_1 = make_interp_spline(epochs, val_precision_at_1_values)
  val_AMI_values_smooth =  a_BSpline_AMI(epochs_smooth)
  val_NMI_values_smooth = a_BSpline_NMI(epochs_smooth)
  train_loss_smooth_smooth =  a_BSpline_mean_average_precision(epochs_smooth)
  val_loss_smooth_smooth = a_BSpline_precision_at_1(epochs_smooth)

  ax.plot(epochs_smooth, val_AMI_values_smooth, 'b', label='Val AMI Smoothed', linewidth=.6)
  ax.plot(epochs_smooth, val_NMI_values_smooth, 'g', label='Val NMI Smoothed', linewidth=.6)
  ax.plot(epochs_smooth, train_loss_smooth_smooth, 'r', label='Val MAP Smoothed', linewidth=.6)
  ax.plot(epochs_smooth, val_loss_smooth_smooth, 'm', label='Val P@1 Smoothed ', linewidth=.6)  
  ax.set_title('Validation Accuracy')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Accuracy')
  ax.legend()
  fig.savefig(output_path + '/acc.pdf', format='pdf', bbox_inches='tight')
  plt.close()

In [None]:
def plot_table(setting, param, dml_param, output_path):  
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  plt.rcParams.update(tex_fonts)

  ########## Plot Settings ##################
  setting_name_list = list(setting.keys())
  setting_value_list = list(setting.values())
  setting_name_list, setting_value_list = replace_helper(setting_name_list, setting_value_list)
  vals = np.array([setting_name_list, setting_value_list], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=vals, colLabels=['Setting', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Experiment Settings')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/settings.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot Params ##################
  param_name_list = param.keys()
  param_value_list = param.values()
  param_name_list, param_value_list = replace_helper(param_name_list, param_value_list)
  param_vals = np.array([list(param_name_list), list(param_value_list)], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=param_vals, colLabels=['Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Hyperparaeters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot DML Params ##################
  dml_param_name_list = dml_param.keys()
  dml_param_value_list = dml_param.values()
  dml_param_name_list, dml_param_value_list = replace_helper(dml_param_name_list, dml_param_value_list)
  dml_param_vals = np.array([list(dml_param_name_list), list(dml_param_value_list)], dtype=str).T  
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=dml_param_vals, colLabels=['DML Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('DML Hyperparameters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/dml_params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

In [None]:
def gradient_visualization(parameters, results_folder: str):
    """
    Returns the parameter gradients over the epoch.
    :param parameters: parameters of the network
    :type parameters: iterator
    :param results_folder: path to results folder
    :type results_folder: str
    """
    tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": False,
    "font.family": "serif",
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 10,
    "font.size": 10,
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.loc":'lower left'
}

    plt.rcParams.update(tex_fonts)

    ave_grads = []
    layers = []


    for n, p in parameters:
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.ylim(ymin=0, ymax=0.0075)
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient Visualization")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(results_folder + "/gradients.pdf")
    plt.close()

In [None]:
def create_logging(setting, param, dml_param, train_loss_values, val_loss_values, val_AMI_values, val_NMI_values,val_mean_average_precision_values, val_precision_at_1_values, epochs, output_dir):
  plot_table(setting, param, dml_param, output_dir)
  
  plot_loss(train_loss_values, val_loss_values, epochs, output_dir)
  plot_acc(val_AMI_values, val_NMI_values,val_mean_average_precision_values, val_precision_at_1_values, epochs, output_dir)
  

  pdfs = ['/loss.pdf', '/acc.pdf', '/params.pdf','/dml_params.pdf', '/settings.pdf','/HistogramFragAfterTrain.pdf','/HistogramFragAfterVal.pdf', '/HistogramFragTest.pdf', '/gradients.pdf']
  bookmarks = ['Loss', 'Accuracy', 'Hyperparameters','DML Hyperparameters', 'Seetings', 'HistogramTrain', 'HistrogramVal', 'HistrogramTest','Gradients']

  merger = PdfFileMerger()

  for i, pdf in enumerate(pdfs):
      merger.append(output_dir + pdf, bookmark=bookmarks[i])
  
  pdf = FPDF()   
  pdf.add_page() 
  pdf.set_font("Helvetica", size = 6)
  # open the text file in read mode
  f = open("log.txt", "r")
  
  # insert the texts in pdf
  for x in f:
    pdf.cell(200, 6, txt = x, ln = 1, align = 'l')

    # save the pdf with name .pdf
  pdf.output("log.pdf")   
  merger.append("log.pdf", bookmark='Log')
  merger.write(output_dir + "/report.pdf")
  merger.close()
  
  copyfile('log.txt', output_dir + '/log.txt')

In [None]:
def remove_empty_bins(counts):
  ticks = range(len(counts))
  new_ticks = []
  new_counts = []
  for tick, count in zip(ticks, counts):
    if count != 0:
      new_ticks.append(str(tick))
      new_counts.append(count)
  return new_ticks, new_counts

In [None]:
def get_cleand_papyri_hist(processed_frame, title, fig_name):
  id_series = processed_frame.groupby('papyPosID')['fnames'].nunique().sort_values(ascending=False)
  counts = np.bincount(id_series)  
  ticks, counts = remove_empty_bins(counts)
  plt.style.use('seaborn')
  width = 460


  tex_fonts = {
      # Use LaTeX to write all text
      "text.usetex": True,
      "font.family": "serif",
      # Use 10pt font in plots, to match 10pt font in document
      "axes.labelsize": 10,
      "font.size": 10,
      # Make the legend/label fonts a little smaller
      "legend.fontsize": 8,
      "xtick.labelsize": 8,
      "ytick.labelsize": 8,
      "legend.loc":'lower left'
  }

  plt.rcParams.update(tex_fonts)

  fig, ax = plt.subplots(1, 1, figsize=set_size(width))

  rects = ax.bar(ticks, counts, width=.5, align='center')
  ax.set(xticks=ticks, xlim=[-1, len(ticks)], title=title, ylabel='Number of papyri',xlabel='Number of fragments per papyri')

  def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
      height = rect.get_height()
      ax.annotate('{}'.format(height),xy=(rect.get_x() + rect.get_width() / 2, height),xytext=(0, 3), textcoords="offset points",ha='center', va='baseline')

  autolabel(rects)

  fig.tight_layout()
  fig.savefig(fig_name, format='pdf', bbox_inches='tight')
  plt.close()
  plt.show()

In [None]:
if False:
  config_path = './gdrive/MyDrive/mt/conf/conf.toml'
  config = toml.load(config_path)
  setting = config.get('settings') 
  logger.info(f' Start Experiment {setting["experiment_name"]}')
  param = config.get('params')
  dml_param = config.get('dml_params')
  output_dir = create_output_dir(setting['output'], setting['experiment_name'])
  processed_frame = create_processed_info(setting['path_train'], setting['debug'])
  len(processed_frame.papyID.unique())

# Start Experiment

In [None]:
def train_and_validate(config_path):
  logger.info(f'Initilization  -------------------------')
  config = toml.load(config_path)
  setting = config.get('settings')
  logger.info(f'Experiment:           {setting["experiment_name"]}')
  param = config.get('params')
  dml_param = config.get('dml_params')
  output_dir = create_output_dir(setting['output'], setting['experiment_name'])
  device = torch.device(setting['env'])

  ###### Log ######
  logger.info(f'Debug:                {setting["debug"]}')
  logger.info(f'Loos Function:        {dml_param["loss_function"]}')
  logger.info(f'Margin Miner Margin:  {dml_param["TripletMarginMiner_margin"]}')
  logger.info(f'Triplet Margin Loss:  {dml_param["TripletMarginLoss_margin"]}')
  logger.info(f'Type of Tribles:      {dml_param["type_of_triplets"]}')
  logger.info(f'Miner:                {dml_param["miner"]}')
  logger.info(f'Reducer:              {dml_param["reducer"]}')
  logger.info(f'Archi:                {param["archi"]}')
  logger.info(f'Epochs:               {param["num_epochs"]}')
  logger.info(f'Batch Size:           {param["batch_size"]}')
  logger.info(f'Optimizer:            {param["optimizer"]}')
  logger.info(f'Learning Rate:        {param["lr"]}')
  logger.info(f'Shuffle:              {param["shuffle"]}')
  logger.info(f'Padding Width:        {param["padding_width"]}')
  logger.info(f'Padding Height:       {param["padding_height"]}')
  logger.info(f'Center Crop Width:    {param["center_crop_width"]}')
  logger.info(f'Center Crop Height:   {param["center_crop_height"]}')


  ###### Transformation ######
  '''transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.RandomCrop((param['padding_height'],param['padding_width']), padding=None, pad_if_needed=True, fill=0, padding_mode='constant'),
                                  transforms.Normalize((param['normalize_0'],param['normalize_1'],param['normalize_2']),(param['normalize_3'],param['normalize_4'], param['normalize_5']))])
   '''
  transform = transforms.Compose(
    [
        transforms.Resize((64,64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[param['normalize_0'], param['normalize_1'], param['normalize_2']], std=[param['normalize_3'], param['normalize_4'], param['normalize_5']]),
    ])
  
  ############### Debug or Train Data ###############
  if setting['debug']:
    processed_frame_train = create_processed_info(setting['path_train'], setting['debug'])
    processed_frame_val = create_processed_info(setting['path_val'], setting['debug'])
    processed_frame_test = create_processed_info(setting['path_test'])
  else:
    processed_frame_train = create_processed_info(setting['path_train'])
    processed_frame_val = create_processed_info(setting['path_val'])
    processed_frame_test = create_processed_info(setting['path_test'])

  ############### Datasets ###############
  dataset1 = FAUPapyrusCollectionDataset(setting['path_train'], processed_frame_train, transform)
  dataset2 = FAUPapyrusCollectionDataset(setting['path_val'], processed_frame_val, transform)
  dataset3 = FAUPapyrusCollectionDataset(setting['path_test'], processed_frame_test, transform)

  ############### Sampler ###############
  train_labels = processed_frame_train.papyID.unique()
  val_labels = processed_frame_train.papyID.unique()
  test_labels = processed_frame_train.papyID.unique()

  train_sampler = samplers.MPerClassSampler(train_labels, 1, batch_size=None, length_before_new_iter=len(dataset1))
  val_sampler = samplers.MPerClassSampler(val_labels, 1, batch_size=None, length_before_new_iter=len(dataset2))
  test_sampler = samplers.MPerClassSampler(test_labels, 1, batch_size=None, length_before_new_iter=len(dataset3))

  ############### Log Distribution ###############
  get_cleand_papyri_hist(processed_frame_train, title = 'Histogram Fragments / Papyri  [Train]', fig_name=output_dir + '/HistogramFragAfterTrain.pdf')
  get_cleand_papyri_hist(processed_frame_val, title = 'Histogram Fragments / Papyri  [Val]', fig_name=output_dir +  '/HistogramFragAfterVal.pdf')
  get_cleand_papyri_hist(processed_frame_test, title = 'Histogram Fragments / Papyri  [Test]', fig_name=output_dir +  '/HistogramFragTest.pdf')

  ############### Dataloader ###############
  test_loader = torch.utils.data.DataLoader(dataset3, batch_size=param['batch_size'], drop_last=True, sampler=train_sampler)
  train_loader = torch.utils.data.DataLoader(dataset1, batch_size=param['batch_size'], shuffle=False, sampler=val_sampler)
  val_loader = torch.utils.data.DataLoader(dataset2, batch_size=param['batch_size'], drop_last=True)
  
  ############### Archi ###############
  if param['archi'] == 'SimpleCNN':
    model = Net().to(device)
  elif param['archi'] == 'efficientnetB0':
    model = EfficientNet.from_name('efficientnet-b0').to(device)
  elif param['archi'] == 'ResNet':
    model = models.resnet18(pretrained=True).to(device)

  ############### Optimizer ###############
  if param['optimizer'] == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=param['lr'], weight_decay=0.00005)
  elif param['optimizer'] == 'SGD': 
    optimizer =optim.SGD(model.parameters(), lr=param['lr'], momentum=param['sgd_momentum'])
  else:
    logger.error(' Optimizer is not supported or you have not specified one.')
    raise ValueError() 

  ###############  Distance ###############
  if  dml_param['distance'] == 'CosineSimilarity':   
    distance = distances.CosineSimilarity()
  elif  dml_param['distance'] == 'LpDistance':   
    distance = distances.LpDistance(normalize_embeddings=True, p=2, power=1)
  else:
    logger.error(' Distance is not supported or you have not specified one.') 
    raise ValueError()

  ###############  Reducer ###############
  if  dml_param['reducer'] == 'ThresholdReducer':   
    reducer = reducers.ThresholdReducer(low=dml_param['ThresholdReducer_low'])
  elif  dml_param['reducer'] == 'AvgNonZeroReducer':
    reducer = reducers.AvgNonZeroReducer()
  else:
    logger.error(f'Reducer is not supported or you have not specified one.')
    raise ValueError() 
  
  ###############  Loss ###############
  if dml_param['loss_function'] == 'TripletMarginLoss': 
    loss_func = losses.TripletMarginLoss(margin=dml_param['TripletMarginLoss_margin'], distance=distance, reducer=reducer)
  elif dml_param['loss_function'] == 'ContrastiveLoss':
    loss_func = losses.ContrastiveLoss(pos_margin=1, neg_margin=0)
  else:
    logger.error(' DML Loss is not supported or you have not specified one.')
    raise ValueError() 

  ############### Mining ###############
  if dml_param['miner'] == 'PairMarginMiner':  
    mining_func = miners.PairMarginMiner(pos_margin=0.2, neg_margin=0.8)
  elif dml_param['miner'] == 'TripletMarginMiner':
    mining_func = miners.TripletMarginMiner(
        margin=dml_param['TripletMarginMiner_margin'], distance=distance, type_of_triplets=dml_param['type_of_triplets'])
  elif dml_param['miner'] == 'UniformHistogramMiner':
    mining_func = miners.UniformHistogramMiner(num_bins=100,pos_per_bin=25,neg_per_bin=33,distance=distance)
  else:    
    logger.error('DML Miner is not supported or you have not specified one.')
    raise ValueError() 

  ############### Accuracy ###############
  accuracy_calculator = AccuracyCalculator(include=(dml_param['metric_1'],
                                                    dml_param['metric_2'],
                                                    dml_param['metric_3'],
                                                    dml_param['metric_4']),                                                   
                                           k=dml_param['precision_at_1_k'])

 


  ############### Trainer ###############
  train_loss_values = []
  val_loss_values = []
  val_num_triplets_values = []
  val_AMI_values = []
  val_NMI_values = []
  val_mean_average_precision_values = []
  val_precision_at_1_values = []

  for epoch in range(1, param['num_epochs'] + 1):
      
      logger.info(f'Epoch {epoch} -------------------------')

      
      train_loss, train_num_triplets = train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
      train_loss_values.append(train_loss)

      ############### Validation ###############
      val_loss, val_num_triplets, val_AMI, val_NMI, val_mean_average_precision, val_precision_at_1 = validation(model, loss_func, mining_func, device, val_loader, optimizer, dataset1, dataset2, accuracy_calculator,epoch)
      val_loss_values.append(val_loss)
      val_AMI_values.append(val_AMI)
      val_NMI_values.append(val_NMI)
      val_mean_average_precision_values.append(val_mean_average_precision)
      val_precision_at_1_values.append(val_precision_at_1)

      ############### Checkpoint ###############            
      torch.save({
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': train_loss,
                  }, output_dir + "/model "+ f'_epoch_{str(epoch)}'+ ".pt")
      
      ############### Logging ###############            
      gradient_visualization(model.named_parameters(), output_dir)
      if epoch >= 4:
        create_logging(setting, param, dml_param, train_loss_values, val_loss_values, val_AMI_values, val_NMI_values, val_mean_average_precision_values, val_precision_at_1_values, epoch, output_dir)

In [None]:
train_and_validate(config_path = './gdrive/MyDrive/mt/conf/conf.toml')

Initilization  -------------------------
Initilization  -------------------------
Experiment:           Debug
Experiment:           Debug
Debug:                False
Debug:                False
Loos Function:        TripletMarginLoss
Loos Function:        TripletMarginLoss
Margin Miner Margin:  0.02
Margin Miner Margin:  0.02
Triplet Margin Loss:  1
Triplet Margin Loss:  1
Type of Tribles:      semihard
Type of Tribles:      semihard
Miner:                TripletMarginMiner
Miner:                TripletMarginMiner
Reducer:              AvgNonZeroReducer
Reducer:              AvgNonZeroReducer
Archi:                ResNet
Archi:                ResNet
Epochs:               4
Epochs:               4
Batch Size:           128
Batch Size:           128
Optimizer:            Adam
Optimizer:            Adam
Learning Rate:        4e-05
Learning Rate:        4e-05
Shuffle:              True
Shuffle:              True
Padding Width:        512
Padding Width:        512
Padding Height:       512
