## Prepare Notebook

### Deleate Outdated Log Files

In [None]:
!rm "log.txt"

rm: cannot remove 'log.txt': No such file or directory


### Install Dependencies

In [None]:
!pip install pytorch-metric-learning
!pip install faiss-gpu
!pip install PyPDF2
!pip install FPDF
!pip install efficientnet_pytorch
!pip install umap-learn
!pip install gpustat
#!apt install texlive-fonts-recommended texlive-fonts-extra cm-super dvipng

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-1.1.0-py3-none-any.whl (106 kB)
[?25l[K     |███                             | 10 kB 28.1 MB/s eta 0:00:01[K     |██████▏                         | 20 kB 35.3 MB/s eta 0:00:01[K     |█████████▏                      | 30 kB 21.1 MB/s eta 0:00:01[K     |████████████▎                   | 40 kB 17.2 MB/s eta 0:00:01[K     |███████████████▍                | 51 kB 7.8 MB/s eta 0:00:01[K     |██████████████████▍             | 61 kB 8.3 MB/s eta 0:00:01[K     |█████████████████████▌          | 71 kB 8.0 MB/s eta 0:00:01[K     |████████████████████████▋       | 81 kB 8.9 MB/s eta 0:00:01[K     |███████████████████████████▋    | 92 kB 9.5 MB/s eta 0:00:01[K     |██████████████████████████████▊ | 102 kB 7.5 MB/s eta 0:00:01[K     |████████████████████████████████| 106 kB 7.5 MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-1.1.0
Collecting fais

### Import Dependencies

In [None]:
!nvidia-smi

Wed Jan 19 20:58:08 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from efficientnet_pytorch import EfficientNet
from PyPDF2 import PdfFileMerger
from shutil import copyfile
from fpdf import FPDF
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import pandas as pd
import numpy as np
import PIL
import os
import toml
from os.path import isfile, join
from google.colab import drive
import matplotlib
from matplotlib import rc
from sklearn.feature_extraction.image import extract_patches_2d
import umap
from skimage import io
from numpy.core.fromnumeric import mean

rc('text', usetex=False)
matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']

### Mount Google Drive

Structure

---



---


* conf
* data
  * 04_train
  * 05_val
  * 06_test
* out
  * experimentName_Iteration

In [None]:
drive.mount('/content/gdrive')

Mounted at /content/gdrive


### Instansiate Logger



```
# This is formatted as code
```



In [None]:
#logging.basicConfig(filename="test.log", level=logging.INFO )
logger = logging.getLogger('log')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler('log.txt')
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)
logger.propagate = False

## PyTorch Dataset Class for FAU Papyri Data

In [None]:
class FAUPapyrusCollectionDataset(torch.utils.data.Dataset):
    """FAUPapyrusCollection dataset."""
    def __init__(self, root_dir, processed_frame, transform=None):

        self.root_dir = root_dir
        self.processed_frame = processed_frame
        self.transform = transform
        self.targets = processed_frame["papyID"].unique()

    def __len__(self):
        return len(self.processed_frame)       

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.processed_frame.iloc[idx, 1])
        
        img_name = img_name + '.png'
        
        image = io.imread(img_name)        
        #image = PIL.Image.open(img_name)
        
        if self.transform:
            image = self.transform(image)         

        papyID = self.processed_frame.iloc[idx,3]


        return image, papyID

## PyTorch Network Architectures

### Simple CNN

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

## PyTorch NN Functions

### Training

In [None]:
def train(model, loss_func, mining_func, device, train_loader, optimizer, train_set, epoch, accuracy_calculator, scheduler, accumulation_steps):
    model.train()
    model.zero_grad()  
    epoch_loss = 0.0
    running_loss = 0.0
    accumulation_steps = 2
    for batch_idx, (input_imgs, labels) in enumerate(train_loader):
      labels = labels.to(device)
      input_imgs = input_imgs.to(device)
      bs, ncrops, c, h, w = input_imgs.size()
      #optimizer.zero_grad()
      
      embeddings = model(input_imgs.view(-1, c, h, w)) 
      embeddings_avg = embeddings.view(bs, ncrops, -1).mean(1) 

      #Use this if you have to check embedding size
      #embedding_size = embeddings_avg.shape
      #print(embedding_size)
      
      indices_tuple = mining_func(embeddings_avg, labels)
      loss = loss_func(embeddings_avg, labels, indices_tuple)
      loss = loss / accumulation_steps              
      loss.backward()

      if (batch_idx+1) % accumulation_steps == 0:  
        optimizer.step()
        optimizer.zero_grad()
      
      #optimizer.step()
      epoch_loss += embeddings_avg.shape[0] * loss.item()
      
    scheduler.step()
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    train_labels = train_labels.squeeze(1)

    accuracies = accuracy_calculator.get_accuracy(
        train_embeddings,
        train_embeddings,
        train_labels,
        train_labels,
        False)

    #mean_loss = torch.mean(torch.stack(batch_loss_values))  
    logger.info(f"Epoch {epoch} averg loss from {batch_idx} batches: {epoch_loss}")
    map = accuracies["mean_average_precision"]
    logger.info(f"Eoch {epoch} maP: {map}")
    return epoch_loss, accuracies["mean_average_precision"]

### Validation

In [None]:
def val(train_set, test_set, model, accuracy_calculator):
  train_embeddings, train_labels = get_all_embeddings(train_set, model)
  
  test_embeddings, test_labels = get_all_embeddings(test_set, model)

  train_labels = train_labels.squeeze(1)
  test_labels = test_labels.squeeze(1)

  print("Computing accuracy")
  
  accuracies = accuracy_calculator.get_accuracy(
      test_embeddings, train_embeddings, test_labels, train_labels, False
  )

  idx = torch.randperm(test_labels.nelement())
  test_labels = test_labels.view(-1)[idx].view(test_labels.size())


  random_accuracies = accuracy_calculator.get_accuracy(
      test_embeddings, train_embeddings, test_labels, train_labels, False
  )

  
  map = accuracies["mean_average_precision"]
  random_map = random_accuracies["mean_average_precision"]
  logger.info(f"Val mAP = {map}")
  logger.info(f"Val random mAP) = {random_map}")
  
  return accuracies["mean_average_precision"], random_accuracies["mean_average_precision"]
  

## Python-Helper-Functions

### Deep Metric Learning

In [None]:
from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester
from pytorch_metric_learning.utils import common_functions as c_f

class CustomTester(GlobalEmbeddingSpaceTester):
    def get_embeddings_for_eval(self, trunk_model, embedder_model, input_imgs):
        input_imgs = c_f.to_device(
            input_imgs, device=self.data_device, dtype=self.dtype
        )
        print('yes')
        # from https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.FiveCrop
        bs, ncrops, c, h, w = input_imgs.size()
        result = embedder_model(trunk_model(input_imgs.view(-1, c, h, w))) # fuse batch size and ncrops
        result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
        return result_avg 

def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):
    logging.info(
        "UMAP plot for the {} split and label set {}".format(split_name, keyname)
    )
    label_set = np.unique(labels)
    num_classes = len(label_set)
    fig = plt.figure(figsize=(20, 15))
    plt.gca().set_prop_cycle(
        cycler(
            "color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]
        )
    )
    for i in range(num_classes):
        idx = labels == label_set[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=1)
    plt.show()

def get_all_embeddings(dataset, model, collate_fn=None, eval=True):
    tester = CustomTester(visualizer=umap.UMAP(),visualizer_hook=visualizer_hook,)
    return tester.get_all_embeddings(dataset, model, collate_fn=None)

### Visualization

#### Gradients

In [None]:
def gradient_visualization(parameters, output_path):
    """
    Returns the parameter gradients over the epoch.
    :param parameters: parameters of the network
    :type parameters: iterator
    :param results_folder: path to results folder
    :type results_folder: str
    """
    tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": False,
    "font.family": "serif",
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 10,
    "font.size": 10,
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.loc":'lower left'
}
    plt.rcParams.update(tex_fonts)
    ave_grads = []
    layers = []
    for n, p in parameters:        
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient Visualization")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path + "/gradients.pdf")
    plt.close()

#### Accuracy

In [None]:
def plot_acc(map_vals, random_map_vals, train_map, epochs, output_path):  
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  #linestyle='dotted'

  plt.rcParams.update(tex_fonts)  
  epochs = np.arange(1, epochs + 1)
  fig, ax = plt.subplots(1, 1,figsize=set_size(width))
  ax.plot(epochs, random_map_vals, 'r', label='random mAP')
  ax.plot(epochs, train_map, 'g', label='train mAP')
  ax.plot(epochs, map_vals, 'b', label='val mAP')
  ax.set_title('Validation Accuracy')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Accuracy')
  ax.legend()
  fig.savefig(output_path + '/acc.pdf', format='pdf', bbox_inches='tight')
  plt.close()

#### Loss

In [None]:
def plot_loss(train_loss_values, epochs, output_path):
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  plt.rcParams.update(tex_fonts)
  epochs = np.arange(1, epochs + 1)
  train_loss_values = np.array(train_loss_values)
  plt.style.use('seaborn')
  fig, ax = plt.subplots(1, 1,figsize=set_size(width))  
  ax.plot(epochs, train_loss_values, 'b', label='Training Loss', linestyle='dotted')  
  ax.set_title('Training')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Loss')
  ax.legend()
  
  fig.savefig(output_path + '/loss.pdf', format='pdf', bbox_inches='tight')
  plt.close()

#### Hyperparameters

In [None]:
def plot_table(setting, param, dml_param, output_path):  
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  plt.rcParams.update(tex_fonts)

  ########## Plot Settings ##################
  setting_name_list = list(setting.keys())
  setting_value_list = list(setting.values())
  setting_name_list, setting_value_list = replace_helper(setting_name_list, setting_value_list)
  vals = np.array([setting_name_list, setting_value_list], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=vals, colLabels=['Setting', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Experiment Settings')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/settings.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot Params ##################
  param_name_list = param.keys()
  param_value_list = param.values()
  param_name_list, param_value_list = replace_helper(param_name_list, param_value_list)
  param_vals = np.array([list(param_name_list), list(param_value_list)], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=param_vals, colLabels=['Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Hyperparaeters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot DML Params ##################
  dml_param_name_list = dml_param.keys()
  dml_param_value_list = dml_param.values()
  dml_param_name_list, dml_param_value_list = replace_helper(dml_param_name_list, dml_param_value_list)
  dml_param_vals = np.array([list(dml_param_name_list), list(dml_param_value_list)], dtype=str).T  
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=dml_param_vals, colLabels=['DML Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('DML Hyperparameters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/dml_params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

### Dataloading

In [None]:
def create_processed_info(path, debug=False):
  if debug:
    info_path = join(path, 'debug_processed_info.csv')
  else:
    info_path = join(path, 'processed_info.csv')
  if isfile(info_path):
    processed_frame = pd.read_csv(info_path, index_col=0, dtype={'fnames':str,'papyID':int,'posinfo':str, 'pixelCentimer':float}, header=0)    
  else:    
    fnames = [f for f in listdir(path) if isfile(join(path, f))]
    fnames = [ x for x in fnames if ".png" in x ]
    fnames = [f.split('.',1)[0] for f in fnames]
    fnames_frame = pd.DataFrame(fnames,columns=['fnames'])
    fragmentID = pd.DataFrame([f.split('_',1)[0] for f in fnames], columns=['fragmentID'])
    fnames_raw = [f.split('_',1)[1] for f in fnames]
    processed_frame = pd.DataFrame(fnames_raw, columns=['fnames_raw'])
    
    processed_frame = pd.concat([processed_frame, fnames_frame], axis=1)

    processed_frame = pd.concat([processed_frame, fragmentID], axis=1)
    processed_frame['papyID'] = processed_frame.fnames_raw.apply(lambda x: x.split('_',1)[0])
    processed_frame['posinfo'] = processed_frame.fnames_raw.apply(lambda x: ''.join(filter(str.isalpha, x)))
    processed_frame['pixelCentimer'] = processed_frame.fnames_raw.progress_apply(retrive_size_by_fname)
    processed_frame.to_csv(info_path)
     
  return processed_frame

### Logging

#### Thesis Settings

In [None]:
def set_size(width, fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    fig_width_pt = width_pt * fraction
    inches_per_pt = 1 / 72.27
    golden_ratio = (5**.5 - 1) / 2
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

In [None]:
def replace_helper(some_list_1, some_list_2):
  new_list_1 = []
  new_list_2 = []

  for string_a, string_b in zip(some_list_1,some_list_2):     
    new_list_1.append(str(string_a).replace("_", " "))
    new_list_2.append(str(string_b).replace("_", " "))

  return new_list_1, new_list_2

#### Dir-Management

In [None]:
def create_output_dir(name, experiment_name, x=1):
  
  while True:
        dir_name = (name + (str(x) + '_iteration_' if x is not 0 else '') + '_' + experiment_name).strip()
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)            

            return dir_name
        else:
            x = x + 1

#### Report-PDF

In [None]:
def create_logging(setting, param, dml_param, train_loss_values, map_vals, random_map_vals, train_map, epochs, output_dir, model):
  plot_table(setting, param, dml_param, output_dir)
  
  gradient_visualization(model.named_parameters(), output_dir)
  plot_loss(train_loss_values, epochs, output_dir)
  plot_acc(map_vals, random_map_vals, train_map, epochs, output_dir)
  

  pdfs = ['/loss.pdf', '/acc.pdf', '/params.pdf','/dml_params.pdf', '/settings.pdf', '/gradients.pdf']
  bookmarks = ['Loss', 'Accuracy', 'Hyperparameters','DML Hyperparameters', 'Seetings','Gradients']

  merger = PdfFileMerger()

  for i, pdf in enumerate(pdfs):
      merger.append(output_dir + pdf, bookmark=bookmarks[i])
  
  pdf = FPDF()   
  pdf.add_page() 
  pdf.set_font("Helvetica", size = 6)
  # open the text file in read mode
  f = open("log.txt", "r")
  
  # insert the texts in pdf
  for x in f:
    pdf.cell(200, 6, txt = x, ln = 1, align = 'l')

    # save the pdf with name .pdf
  pdf.output("log.pdf")   
  merger.append("log.pdf", bookmark='Log')
  merger.write(output_dir + "/report.pdf")
  merger.close()
  
  copyfile('log.txt', output_dir + '/log.txt')

## Initialize

### Settings

In [None]:
device = torch.device("cuda")
model = Net().to(device)
config = toml.load('./gdrive/MyDrive/mt/conf/conf.toml')
setting = config.get('settings')
param = config.get('params')
dml_param = config.get('dml_params')

### Logging

#### Create Dir

In [None]:
output_dir = create_output_dir(setting['output'], setting['experiment_name'])

### Hyperparameters

In [None]:
batch_size_train = param['batch_size_train']
batch_size_val = param['batch_size_val']
lr = param['lr']
num_epochs = param['num_epochs']

#### Optimizer

In [None]:
if param['optimizer'] == 'Adam':
  optimizer = optim.Adam(model.parameters(), lr=lr)

elif param['optimizer'] == 'SGD': 
  optimizer =optim.SGD(model.parameters(), lr=lr)

elif param['optimizer'] == 'AdamW': 
  optimizer =optim.SGD(model.parameters(), lr=lr)

else:
  logger.error(' Optimizer is not supported or you have not specified one.')
  raise ValueError() 

#### Model Architecture

In [None]:
if param['archi'] == 'SimpleCNN':
  model = Net().to(device)

elif param['archi'] == 'efficientnetB0':
  model = EfficientNet.from_name('efficientnet-b0').to(device)

elif param['archi'] == 'efficientnetB7':
  model = EfficientNet.from_name('efficientnet-b7').to(device)
  model._fc  = torch.nn.Identity()

elif param['archi'] == 'densenet201':
  model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet201', pretrained=False).to(device)
  model.classifier  = torch.nn.Identity()
  
elif param['archi'] == 'ResNet':
  model = models.resnet18(pretrained=True).to(device)

#### Scheduler

In [None]:
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[12], gamma=0.1)

### PyTorch-Metric-Learning Hyperparameters

#### Distance

In [None]:
if  dml_param['distance'] == 'CosineSimilarity':   
  distance = distances.CosineSimilarity()

elif  dml_param['distance'] == 'LpDistance':   
  distance = distances.LpDistance(normalize_embeddings=True, p=2, power=1)
  
else:
  logger.error(' Distance is not supported or you have not specified one.') 
  raise ValueError()

#### Reducer

In [None]:
if  dml_param['reducer'] == 'ThresholdReducer':   
  reducer = reducers.ThresholdReducer(low=dml_param['ThresholdReducer_low'])

elif  dml_param['reducer'] == 'AvgNonZeroReducer':
  reducer = reducers.AvgNonZeroReducer()
  
else:
  logger.error(f'Reducer is not supported or you have not specified one.')
  raise ValueError() 

#### Los Function

In [None]:
if dml_param['loss_function'] == 'TripletMarginLoss': 
  loss_func = losses.TripletMarginLoss(margin=dml_param['TripletMarginLoss_margin'], distance=distance, reducer=reducer)

elif dml_param['loss_function'] == 'ContrastiveLoss':
  loss_func = losses.ContrastiveLoss(pos_margin=1, neg_margin=0)

elif dml_param['loss_function'] == 'CircleLoss':
  loss_func = losses.CircleLoss(m=dml_param['m'], gamma=dml_param['gamma'], distance=distance, reducer=reducer)
  
else:
  logger.error('DML Loss is not supported or you have not specified one.')
  raise ValueError() 

#### Mining Function

In [None]:
if dml_param['miner'] == 'TripletMarginMiner':
  mining_func = miners.TripletMarginMiner(
      margin=dml_param['TripletMarginMiner_margin'],
      distance=distance,
      type_of_triplets=dml_param['type_of_triplets']
      )

else:    
  logger.error('DML Miner is not supported or you have not specified one.')
  raise ValueError() 

#### Accuracy Calculator

In [None]:
accuracy_calculator = AccuracyCalculator(include=(dml_param['metric_1'],                                                  
                                                  dml_param['metric_2']),                                                   
                                         k=dml_param['precision_at_1_k'])

### Transformations

#### Custom Transformation

In [None]:
class NCrop(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size, n):        
      self.output_size = output_size
      self.n = n

    def __call__(self, sample):
      out = extract_patches_2d(sample, self.output_size, max_patches=self.n)
      out = out.transpose((0,3,2,1))
      out = torch.tensor(out)
      return out 

In [None]:
if param['transform'] == "TenCrop":
  train_transform = transforms.Compose([
                                  transforms.TenCrop((param['crop_1'],param['crop_2'])),
                                  transforms.Lambda(lambda crops: torch.stack([transforms.PILToTensor()(crop) for crop in crops])),                                
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )


  val_transform = transforms.Compose([                                                                                                           
                                  transforms.TenCrop((param['crop_1'],param['crop_2'])),
                                  transforms.Lambda(lambda crops: torch.stack([transforms.PILToTensor()(crop) for crop in crops])),
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )
elif param['transform'] == "Ncrop":
  train_transform = transforms.Compose([
                                  NCrop((param['crop_1'],param['crop_2']),n=param['max_patches_train']),                                  
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )

  val_transform = transforms.Compose([                                                                                                           
                                  NCrop((param['crop_1'],param['crop_2']),n=param['max_patches_train']),                                  
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )

In [None]:
if True:

  train_transform = transforms.Compose([
                                  NCrop((param['crop_1'],param['crop_2']),n=param['max_patches_train']),                                  
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )

  val_transform = transforms.Compose([                                                                                                           
                                  NCrop((param['crop_1'],param['crop_2']),n=param['max_patches_train']),                                  
                                  transforms.ConvertImageDtype(torch.float),
                                  transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
  )


### Data

#### Helpers

In [None]:
processed_frame_train = create_processed_info(setting['path_train'])
processed_frame_val = create_processed_info(setting['path_val'])

#### Dataset

In [None]:
train_dataset = FAUPapyrusCollectionDataset(setting['path_train'], processed_frame_train, train_transform)
val_dataset = FAUPapyrusCollectionDataset(setting['path_val'], processed_frame_val, val_transform)

#### Data Loader

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=param["shuffle"], drop_last=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_val, drop_last=param["shuffle"], num_workers=4)

#### Result Lists

In [None]:
loss_vals = []
val_loss_vals = []
map_vals = []
random_map_vals = []
train_map_vals = []

## Training

### Log Hyperparameters

In [None]:
logger.info(f'Debug:                {setting["debug"]}')
logger.info(f'Loos Function:        {dml_param["loss_function"]}')
logger.info(f'Margin Miner Margin:  {dml_param["TripletMarginMiner_margin"]}')
logger.info(f'Triplet Margin Loss:  {dml_param["TripletMarginLoss_margin"]}')
logger.info(f'Type of Tribles:      {dml_param["type_of_triplets"]}')
logger.info(f'Miner:                {dml_param["miner"]}')
logger.info(f'Reducer:              {dml_param["reducer"]}')
logger.info(f'Archi:                {param["archi"]}')
logger.info(f'Epochs:               {param["num_epochs"]}')
logger.info(f'Batch Size Train:     {param["batch_size_train"]}')
logger.info(f'Batch Size Val:       {param["batch_size_val"]}')
logger.info(f'Optimizer:            {param["optimizer"]}')
logger.info(f'Learning Rate:        {param["lr"]}')
logger.info(f'Shuffle:              {param["shuffle"]}')

Debug:                False
Loos Function:        TripletMarginLoss
Margin Miner Margin:  0.2
Triplet Margin Loss:  0.2
Type of Tribles:      semihard
Miner:                TripletMarginMiner
Reducer:              AvgNonZeroReducer
Archi:                efficientnetB7
Epochs:               20
Batch Size Train:     64
Batch Size Val:       1
Optimizer:            SGD
Learning Rate:        0.0001
Shuffle:              True


### Train

In [None]:
if setting["training"]:
  old_map = 0

  for epoch in range(1, num_epochs + 1):
      ############### Training ###############
      train_loss, train_map = train(
          model,
          loss_func,mining_func,
          device,
          train_loader,
          optimizer,
          train_dataset,
          epoch,  
          accuracy_calculator,
          scheduler,
          accumulation_steps=param["accumulation_steps"]         
          )
      

      ############### Validation ###############
      map, random_map = val(val_dataset, val_dataset, model, accuracy_calculator)
    

      ############### Fill Lists ###############
      loss_vals.append(train_loss)
      map_vals.append(map)
      random_map_vals.append(random_map)
      train_map_vals.append(train_map)
      ############### Checkpoint ###############
      
      if map >= old_map: 
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                    }, output_dir + "/model.pt")
      
      old_map = map
      ############### Logging ###############
      create_logging(setting, param, dml_param, loss_vals, map_vals, random_map_vals, train_map_vals, epoch, output_dir, model)

RuntimeError: ignored

## Inference

### Imports

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision import datasets, transforms
import cv2
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder
import json
import skimage

### Helpers

In [None]:
def print_decision(is_match):
    if is_match:
        print("Same class")
    else:
        print("Different class")


mean = [0.6143, 0.6884, 0.7665]
std = [0.229, 0.224, 0.225]

inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
)
import numpy as np
import cv2
import json
from matplotlib import pyplot as plt


def imshow(img, figsize=(21, 9), boarder=None, get_img = False):
    img = inv_normalize(img)
    BLUE = [255,0,0]
    npimg = img.numpy()
    transposed = np.transpose(npimg, (1, 2, 0))
    #boarderized = draw_border(transposed, bt=5, with_plot=False, gray_scale=False, color_name="red")
    x = int(transposed.shape[1] * 0.025)
    y = int(transposed.shape[2] * 0.025)
    if x > y:
      y=x
    else:
      y=x

    if boarder == 'green':
      boarderized = cv2.copyMakeBorder(transposed,x,x,y,y,cv2.BORDER_CONSTANT,value=[0,255,0])
    elif boarder == 'red':
      boarderized = cv2.copyMakeBorder(transposed,x,x,y,y,cv2.BORDER_CONSTANT,value=[255,0,0])
    else:
      boarderized = transposed
    if get_img:
      return boarderized
    else:
      plt.figure(figsize=figsize)
      plt.imshow((boarderized * 255).astype(np.uint8))
      plt.show()

### Transform

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
)

### Dataset

In [None]:
class FAUPapyrusCollectionInferenceDataset(torch.utils.data.Dataset):
    """FAUPapyrusCollection dataset."""
    def __init__(self, root_dir, processed_frame, transform=None):

        self.root_dir = root_dir
        self.processed_frame = processed_frame
        self.transform = transform
        self.targets = processed_frame["papyID"].unique()

    def __len__(self):
        return len(self.processed_frame)       

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.processed_frame.iloc[idx, 1])
        
        img_name = img_name + '.png'
        
        #image = io.imread(img_name , plugin='matploPILtlib')        
        image = PIL.Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        #if False:
        max_img_size = 2048

        if (image.shape[1] > max_img_size) or (image.shape[2] > max_img_size):
          image = transforms.CenterCrop(max_img_size)(image)          

        papyID = self.processed_frame.iloc[idx,3]



        return image, papyID

In [None]:
# No longer needed deleated soon
class MyInferenceModel(InferenceModel):
  
    def get_embeddings_from_tensor_or_dataset(self, inputs, batch_size):
        inputs = self.process_if_list(inputs)
        embeddings = []
        if isinstance(inputs, (torch.Tensor, list)):
            for i in range(0, len(inputs), batch_size):
                embeddings.append(self.get_embeddings(inputs[i : i + batch_size]))
        elif isinstance(inputs, torch.utils.data.Dataset):
            dataloader = torch.utils.data.DataLoader(inputs, batch_size=batch_size)
            for inp, _ in dataloader:
                embeddings.append(self.get_embeddings(inp))
        else:
            raise TypeError(f"Indexing {type(inputs)} is not supported.")
        return torch.cat(embeddings)

In [None]:
dataset = FAUPapyrusCollectionInferenceDataset(setting['path_val'], processed_frame_val, transform)

### Apply DML-Helper Functions

In [None]:
def get_labels_to_indices(dataset):
  labels_to_indices = {}
  for i, sample in enumerate(dataset):
    img, label = sample
    if label in labels_to_indices.keys():
      labels_to_indices[label].append(i)
    else:
      labels_to_indices[label] = [i]
  return labels_to_indices

In [None]:
labels_to_indices = get_labels_to_indices(dataset)

### Load Checkpoint

In [None]:
model =  model = EfficientNet.from_name('efficientnet-b7').to(device)
model._fc  = torch.nn.Identity()
checkpoint = torch.load(output_dir + "/model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.to(device)

### Prepare DML Methods

In [None]:
match_finder = MatchFinder(distance=CosineSimilarity(), threshold=0.2)
inference_model = InferenceModel(model, match_finder=match_finder)

### PapyIDs to Index

### Prepare KNN for Inference on Embeddings

In [None]:
# create faiss index
inference_model.train_knn(dataset, batch_size=1)

### Infercening

In [None]:
k = 100

lowest_acc = 1
highest_acc = 0

temp_counter = 0
max_counter = 2


for papyID in labels_to_indices.keys():
  if temp_counter >=max_counter:
    break
  
  for fragment in labels_to_indices[papyID]:
    if temp_counter >=max_counter:
      break
    temp_counter = temp_counter + 1
    img, org_label = dataset[fragment]
    img = img.unsqueeze(0)
    #print(f"query image: {org_label}")
    #imshow(torchvision.utils.make_grid(img))
    distances, indices = inference_model.get_nearest_neighbors(img, k=k)
    #print(len(distances[0]))
    
    nearest_imgs = [dataset[i][0] for i in indices.cpu()[0]]
    #print(f"Nearest Images:\n")
    

    neighbours = []
    labels = []
    for i in indices.cpu()[0]:
      neighbour, label = dataset[i]
      
      #print(f"Label: {label}")
      neighbours.append(neighbour)
      labels.append(label)
    
    occurrences = labels.count(org_label)
    acc = occurrences / 100


    if acc < lowest_acc:
      lowest_acc = acc
      print(f'Found new lowest example with acc {acc}')
      input_img_of_lowest_acc = img
      input_label_of_lowest_acc = org_label
      input_index_of_lowest_acc = fragment
      detected_neighbours_of_lowest_acc = neighbours
      detected_labels_of_lowest_acc = labels
      detected_distances_of_lowest_acc = distances

    if acc > highest_acc:
      highest_acc = acc
      print(f'Found new highest example with acc {acc}')
      input_img_of_highest_acc = img
      input_label_of_highest_acc = org_label
      input_index_of_highest_acc = fragment
      detected_neighbours_of_highest_acc = neighbours
      detected_labels_of_highest_acc = labels
      detected_distances_of_highest_acc = distances
    

In [None]:
  def get_inference_plot(neighbours, labels, distances, org_label, img ,k, lowest):
    if lowest:
      print(f"query image for lowest acc: {org_label}")
    else:      
      print(f"query image for highest acc: {org_label}")

    imshow(torchvision.utils.make_grid(img))

    Nr = k
    Nc = 10
    my_dpi = 96
    fig, axs = plt.subplots(Nr, Nc)
    fig.set_figheight(320)
    fig.set_figwidth(30)
    fig.suptitle(f'Neighbour Crops of {org_label}')

    for i, neighbour in enumerate(neighbours):      
      #print(neighbour.shape)
      neighbour_crops = extract_patches_2d(image=neighbour.T.numpy(), patch_size=(32,32), max_patches= 10)
      neighbour_crops = neighbour_crops.transpose((0,3,2,1))
      neighbour_crops = torch.tensor(neighbour_crops)
      for j in range(Nc):                        
        if j == 0:
          distance = (distances[i].cpu().numpy().round(2))

          row_label = f"label: {labels[i]} \n distance: {distance}"
          axs[i,j].set_ylabel(row_label)

        neighbour_crop = neighbour_crops[j]
        img = inv_normalize(neighbour_crop)
        npimg = img.numpy()
        transposed = np.transpose(npimg, (1, 2, 0))
        
        # find right size for the frame
        x = int(transposed.shape[1] * 0.05)
        

        boarder = 'green'

        if org_label == labels[i]:
          boarderized = cv2.copyMakeBorder(transposed,x,x,x,x,cv2.BORDER_CONSTANT,value=[0,1,0])
        elif org_label != labels[i]:
          boarderized = cv2.copyMakeBorder(transposed,x,x,x,x,cv2.BORDER_CONSTANT,value=[1,0,0])
        else:
          boarderized = transposed

        axs[i,j].imshow(boarderized, aspect='auto')
        
    plt.tight_layout()
    if lowest:
      plt.savefig(output_dir + "/results_for_lowest_acc.pdf",bbox_inches='tight',dpi=100)
    else:
      plt.savefig(output_dir + "/results_for_highest_acc.pdf",bbox_inches='tight',dpi=100)
    plt.show()

  #get_inference_plot(neighbours, labels, distances[0], org_label, img, k=100)

In [None]:
get_inference_plot(detected_neighbours_of_highest_acc, detected_labels_of_highest_acc, detected_distances_of_highest_acc[0], input_label_of_highest_acc, input_img_of_highest_acc, k=100, lowest=False)
get_inference_plot(detected_neighbours_of_lowest_acc, detected_labels_of_lowest_acc, detected_distances_of_lowest_acc[0], input_label_of_lowest_acc, input_img_of_lowest_acc, k=100, lowest=True)      