## Prepare Notebook

### Deleate Outdated Log Files

In [1]:
!rm "log.txt"

rm: cannot remove 'log.txt': No such file or directory


### Install Dependencies

In [2]:
!pip install pytorch-metric-learning
!pip install faiss-gpu
!pip install PyPDF2
!pip install FPDF
!pip install efficientnet_pytorch
!pip install optuna
#!apt install texlive-fonts-recommended texlive-fonts-extra cm-super dvipng

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-1.1.0-py3-none-any.whl (106 kB)
[?25l[K     |███                             | 10 kB 34.0 MB/s eta 0:00:01[K     |██████▏                         | 20 kB 25.9 MB/s eta 0:00:01[K     |█████████▏                      | 30 kB 12.7 MB/s eta 0:00:01[K     |████████████▎                   | 40 kB 10.2 MB/s eta 0:00:01[K     |███████████████▍                | 51 kB 7.7 MB/s eta 0:00:01[K     |██████████████████▍             | 61 kB 7.7 MB/s eta 0:00:01[K     |█████████████████████▌          | 71 kB 7.9 MB/s eta 0:00:01[K     |████████████████████████▋       | 81 kB 8.8 MB/s eta 0:00:01[K     |███████████████████████████▋    | 92 kB 8.2 MB/s eta 0:00:01[K     |██████████████████████████████▊ | 102 kB 7.6 MB/s eta 0:00:01[K     |████████████████████████████████| 106 kB 7.6 MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-1.1.0
Collecting fais

### Import Dependencies

In [3]:
from efficientnet_pytorch import EfficientNet
from PyPDF2 import PdfFileMerger
from shutil import copyfile
from fpdf import FPDF
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import pandas as pd
import numpy as np
import PIL
import os
import toml
from os.path import isfile, join
from google.colab import drive
import matplotlib
from matplotlib import rc
import optuna
rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']

### Mount Google Drive

Structure

---



---


* conf
* data
  * 04_train
  * 05_val
  * 06_test
* out
  * experimentName_Iteration

In [4]:
drive.mount('/content/gdrive')

Mounted at /content/gdrive


### Instansiate Logger

In [5]:
#logging.basicConfig(filename="test.log", level=logging.INFO )
logger = logging.getLogger('log')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler('log.txt')
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)
logger.propagate = False

## PyTorch Dataset Class for FAU Papyri Data

In [6]:
class FAUPapyrusCollectionDataset(torch.utils.data.Dataset):
    """FAUPapyrusCollection dataset."""
    def __init__(self, root_dir, processed_frame, transform=None):

        self.root_dir = root_dir
        self.processed_frame = processed_frame
        self.transform = transform

    def __len__(self):
        return len(self.processed_frame)       

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.processed_frame.iloc[idx, 1])
        
        img_name = img_name + '.png'
        
        #image = io.imread(img_name , plugin='matploPILtlib')        
        image = PIL.Image.open(img_name)
        if self.transform:
            image = self.transform(image)         

        papyID = self.processed_frame.iloc[idx,3]


        return image, papyID

## PyTorch Network Architectures

### Simple CNN

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

## PyTorch NN Functions

### Training

In [8]:
from numpy.core.fromnumeric import mean
"""
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
      data, labels = data.to(device), labels.to(device)
      optimizer.zero_grad()
      embeddings = model(data)
      indices_tuple = mining_func(embeddings, labels)
      loss = loss_func(embeddings, labels, indices_tuple)
      loss.backward()
      optimizer.step()        
    print(f"Epoch {epoch} After Iteration {batch_idx}: Loss = {loss}")
    return loss
"""
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    batch_loss_values = []
    for batch_idx, (input_imgs, labels) in enumerate(train_loader):
      labels = labels.to(device)
      input_imgs = input_imgs.to(device)
      bs, ncrops, c, h, w = input_imgs.size()
      optimizer.zero_grad()
      
      embeddings = model(input_imgs.view(-1, c, h, w)) 
      embeddings_avg = embeddings.view(bs, ncrops, -1).mean(1) 

      
      indices_tuple = mining_func(embeddings_avg, labels)
      loss = loss_func(embeddings_avg, labels, indices_tuple)
      loss.backward()
      optimizer.step()
      batch_loss_values.append(loss)

      mean_loss = torch.mean(torch.stack(batch_loss_values))  

    print(f"Epoch {epoch} averg loss from {batch_idx} batches: {mean_loss}")
    return mean_loss

### Validation

In [9]:
def val(train_set, test_set, model, accuracy_calculator):
  train_embeddings, train_labels = get_all_embeddings(train_set, model)
  
  test_embeddings, test_labels = get_all_embeddings(test_set, model)

  train_labels = train_labels.squeeze(1)
  test_labels = test_labels.squeeze(1)

  print("Computing accuracy")
  
  accuracies = accuracy_calculator.get_accuracy(
      test_embeddings, train_embeddings, test_labels, train_labels, False
  )

  idx = torch.randperm(test_labels.nelement())
  test_labels = test_labels.view(-1)[idx].view(test_labels.size())


  random_accuracies = accuracy_calculator.get_accuracy(
      test_embeddings, train_embeddings, test_labels, train_labels, False
  )

  #print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
  print("Test set accuracy (mAP) = {}".format(accuracies["mean_average_precision"]))
  return accuracies["mean_average_precision"], random_accuracies["mean_average_precision"]
  

## Python-Helper-Functions

### Deep Metric Learning

In [10]:
from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester
from pytorch_metric_learning.utils import common_functions as c_f

class CustomTester(GlobalEmbeddingSpaceTester):
    def get_embeddings_for_eval(self, trunk_model, embedder_model, input_imgs):
        input_imgs = c_f.to_device(
            input_imgs, device=self.data_device, dtype=self.dtype
        )
        print('yes')
        # from https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.FiveCrop
        bs, ncrops, c, h, w = input_imgs.size()
        result = embedder_model(trunk_model(input_imgs.view(-1, c, h, w))) # fuse batch size and ncrops
        result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
        return result_avg 


def get_all_embeddings(dataset, model, collate_fn=None, eval=True):
    tester = CustomTester()
    return tester.get_all_embeddings(dataset, model, collate_fn=None)

### Visualization

#### Gradients

In [11]:
def gradient_visualization(parameters, output_path):
    """
    Returns the parameter gradients over the epoch.
    :param parameters: parameters of the network
    :type parameters: iterator
    :param results_folder: path to results folder
    :type results_folder: str
    """
    tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": False,
    "font.family": "serif",
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 10,
    "font.size": 10,
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.loc":'lower left'
}
    plt.rcParams.update(tex_fonts)
    ave_grads = []
    layers = []
    for n, p in parameters:
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient Visualization")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path + "/gradients.pdf")
    plt.close()

#### Accuracy

In [12]:
def plot_acc(map_vals, random_map_vals, epochs, output_path):  
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  #linestyle='dotted'

  plt.rcParams.update(tex_fonts)  
  epochs = np.arange(1, epochs + 1)
  fig, ax = plt.subplots(1, 1,figsize=set_size(width))
  ax.plot(epochs, random_map_vals, 'r', label='random mAP')
  ax.plot(epochs, map_vals, 'b', label='mAP')
  ax.set_title('Validation Accuracy')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Accuracy')
  ax.legend()
  fig.savefig(output_path + '/acc.pdf', format='pdf', bbox_inches='tight')
  plt.close()

#### Loss

In [13]:
def plot_loss(train_loss_values, epochs, output_path):
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  plt.rcParams.update(tex_fonts)
  epochs = np.arange(1, epochs + 1)
  train_loss_values = np.array(train_loss_values)
  plt.style.use('seaborn')
  fig, ax = plt.subplots(1, 1,figsize=set_size(width))  
  ax.plot(epochs, train_loss_values, 'b', label='Training Loss', linestyle='dotted')  
  ax.set_title('Training')
  ax.set_xlabel('Epochs')
  ax.set_ylabel('Loss')
  ax.legend()
  
  fig.savefig(output_path + '/loss.pdf', format='pdf', bbox_inches='tight')
  plt.close()

#### Hyperparameters

In [14]:
def plot_table(setting, param, dml_param, output_path):  
  width = 460
  plt.style.use('seaborn-bright')
  tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": False,
        "font.family": "serif",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 10,
        "font.size": 10,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8
    }
  plt.rcParams.update(tex_fonts)

  ########## Plot Settings ##################
  setting_name_list = list(setting.keys())
  setting_value_list = list(setting.values())
  setting_name_list, setting_value_list = replace_helper(setting_name_list, setting_value_list)
  vals = np.array([setting_name_list, setting_value_list], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=vals, colLabels=['Setting', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Experiment Settings')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/settings.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot Params ##################
  param_name_list = param.keys()
  param_value_list = param.values()
  param_name_list, param_value_list = replace_helper(param_name_list, param_value_list)
  param_vals = np.array([list(param_name_list), list(param_value_list)], dtype=str).T
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=param_vals, colLabels=['Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('Hyperparaeters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

  ########## Plot DML Params ##################
  dml_param_name_list = dml_param.keys()
  dml_param_value_list = dml_param.values()
  dml_param_name_list, dml_param_value_list = replace_helper(dml_param_name_list, dml_param_value_list)
  dml_param_vals = np.array([list(dml_param_name_list), list(dml_param_value_list)], dtype=str).T  
  fig, ax = plt.subplots(1, 1, figsize=set_size(width))
  ax.table(cellText=dml_param_vals, colLabels=['DML Hyperparameter', 'Value'], loc='center', zorder=3, rowLoc='left', cellLoc='left')
  ax.set_title('DML Hyperparameters')
  ax.set_xticks([])
  ax.set_yticks([])
  fig.savefig(output_path + '/dml_params.pdf', format='pdf', bbox_inches='tight')
  plt.close()

### Dataloading

In [15]:
def create_processed_info(path, debug=False):
  if debug:
    info_path = join(path, 'debug_processed_info.csv')
  else:
    info_path = join(path, 'processed_info.csv')
  if isfile(info_path):
    processed_frame = pd.read_csv(info_path, index_col=0, dtype={'fnames':str,'papyID':int,'posinfo':str, 'pixelCentimer':float}, header=0)    
  else:    
    fnames = [f for f in listdir(path) if isfile(join(path, f))]
    fnames = [ x for x in fnames if ".png" in x ]
    fnames = [f.split('.',1)[0] for f in fnames]
    fnames_frame = pd.DataFrame(fnames,columns=['fnames'])
    fragmentID = pd.DataFrame([f.split('_',1)[0] for f in fnames], columns=['fragmentID'])
    fnames_raw = [f.split('_',1)[1] for f in fnames]
    processed_frame = pd.DataFrame(fnames_raw, columns=['fnames_raw'])
    
    processed_frame = pd.concat([processed_frame, fnames_frame], axis=1)

    processed_frame = pd.concat([processed_frame, fragmentID], axis=1)
    processed_frame['papyID'] = processed_frame.fnames_raw.apply(lambda x: x.split('_',1)[0])
    processed_frame['posinfo'] = processed_frame.fnames_raw.apply(lambda x: ''.join(filter(str.isalpha, x)))
    processed_frame['pixelCentimer'] = processed_frame.fnames_raw.progress_apply(retrive_size_by_fname)
    processed_frame.to_csv(info_path)
     
  return processed_frame

### Logging

#### Thesis Settings

In [16]:
def set_size(width, fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    fig_width_pt = width_pt * fraction
    inches_per_pt = 1 / 72.27
    golden_ratio = (5**.5 - 1) / 2
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

In [17]:
def replace_helper(some_list_1, some_list_2):
  new_list_1 = []
  new_list_2 = []

  for string_a, string_b in zip(some_list_1,some_list_2):     
    new_list_1.append(str(string_a).replace("_", " "))
    new_list_2.append(str(string_b).replace("_", " "))

  return new_list_1, new_list_2

#### Dir-Management

In [18]:
def create_output_dir(name, experiment_name, x=1):
  
  while True:
        dir_name = (name + (str(x) + '_iteration_' if x is not 0 else '') + 'of_experiment_' + experiment_name).strip()
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)            

            return dir_name
        else:
            x = x + 1

#### Report-PDF

In [19]:
def create_logging(setting, param, dml_param, train_loss_values, map_vals, random_map_vals, epochs, output_dir, model):
  plot_table(setting, param, dml_param, output_dir)
  
  gradient_visualization(model.named_parameters(), output_dir)
  plot_loss(train_loss_values, epochs, output_dir)
  plot_acc(map_vals, random_map_vals, epochs, output_dir)
  

  pdfs = ['/loss.pdf', '/acc.pdf', '/params.pdf','/dml_params.pdf', '/settings.pdf', '/gradients.pdf']
  bookmarks = ['Loss', 'Accuracy', 'Hyperparameters','DML Hyperparameters', 'Seetings','Gradients']

  merger = PdfFileMerger()

  for i, pdf in enumerate(pdfs):
      merger.append(output_dir + pdf, bookmark=bookmarks[i])
  
  pdf = FPDF()   
  pdf.add_page() 
  pdf.set_font("Helvetica", size = 6)
  # open the text file in read mode
  f = open("log.txt", "r")
  
  # insert the texts in pdf
  for x in f:
    pdf.cell(200, 6, txt = x, ln = 1, align = 'l')

    # save the pdf with name .pdf
  pdf.output("log.pdf")   
  merger.append("log.pdf", bookmark='Log')
  merger.write(output_dir + "/report.pdf")
  merger.close()
  
  copyfile('log.txt', output_dir + '/log.txt')

## Initialize

### Settings

In [20]:
device = torch.device("cuda")
model = Net().to(device)
config = toml.load('./gdrive/MyDrive/mt/conf/conf.toml')
setting = config.get('settings')
param = config.get('params')
dml_param = config.get('dml_params')

### Logging

#### Create Dir

In [21]:
output_dir = create_output_dir(setting['output'], setting['experiment_name'])

### Hyperparameters

In [22]:
batch_size_train = param['batch_size_train']
batch_size_val = param['batch_size_val']
lr = param['lr']
num_epochs = param['num_epochs']

#### Optimizer

In [23]:
if param['optimizer'] == 'Adam':
  optimizer = optim.Adam(model.parameters(), lr=lr)

elif param['optimizer'] == 'SGD': 
  optimizer =optim.SGD(model.parameters(), lr=lr)

else:
  logger.error(' Optimizer is not supported or you have not specified one.')
  raise ValueError() 

#### Model Architecture

In [24]:
if param['archi'] == 'SimpleCNN':
  model = Net().to(device)

elif param['archi'] == 'efficientnetB0':
  model = EfficientNet.from_name('efficientnet-b0').to(device)

elif param['archi'] == 'efficientnetB7':
  model = EfficientNet.from_name('efficientnet-b7').to(device)
  model._fc  = torch.nn.Identity()
  
elif param['archi'] == 'ResNet':
  model = models.resnet18(pretrained=True).to(device)

### PyTorch-Metric-Learning Hyperparameters

#### Distance

In [25]:
if  dml_param['distance'] == 'CosineSimilarity':   
  distance = distances.CosineSimilarity()

elif  dml_param['distance'] == 'LpDistance':   
  distance = distances.LpDistance(normalize_embeddings=True, p=2, power=1)
  
else:
  logger.error(' Distance is not supported or you have not specified one.') 
  raise ValueError()

#### Reducer

In [26]:
if  dml_param['reducer'] == 'ThresholdReducer':   
  reducer = reducers.ThresholdReducer(low=dml_param['ThresholdReducer_low'])

elif  dml_param['reducer'] == 'AvgNonZeroReducer':
  reducer = reducers.AvgNonZeroReducer()
  
else:
  logger.error(f'Reducer is not supported or you have not specified one.')
  raise ValueError() 

#### Los Function

In [27]:
if dml_param['loss_function'] == 'TripletMarginLoss': 
  loss_func = losses.TripletMarginLoss(margin=dml_param['TripletMarginLoss_margin'], distance=distance, reducer=reducer)

elif dml_param['loss_function'] == 'ContrastiveLoss':
  loss_func = losses.ContrastiveLoss(pos_margin=1, neg_margin=0)
  
else:
  logger.error(' DML Loss is not supported or you have not specified one.')
  raise ValueError() 

#### Mining Function

In [28]:
if dml_param['miner'] == 'TripletMarginMiner':
  mining_func = miners.TripletMarginMiner(
      margin=dml_param['TripletMarginMiner_margin'],
      distance=distance,
      type_of_triplets=dml_param['type_of_triplets']
      )

else:    
  logger.error('DML Miner is not supported or you have not specified one.')
  raise ValueError() 

#### Accuracy Calculator

In [29]:
accuracy_calculator = AccuracyCalculator(include=(dml_param['metric_1'],                                                  
                                                  dml_param['metric_2']),                                                   
                                         k=dml_param['precision_at_1_k'])

### Transformations

In [30]:
train_transform = transforms.Compose([
                                transforms.TenCrop((param['val_center_crop_1'],param['val_center_crop_2'])),
                                transforms.Lambda(lambda crops: torch.stack([transforms.PILToTensor()(crop) for crop in crops])),                                
                                transforms.ConvertImageDtype(torch.float),
                                transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
)


val_transform = transforms.Compose([                                                                                                           
                                transforms.TenCrop((param['val_center_crop_1'],param['val_center_crop_2'])),
                                transforms.Lambda(lambda crops: torch.stack([transforms.PILToTensor()(crop) for crop in crops])),
                                transforms.ConvertImageDtype(torch.float),
                                transforms.Normalize((param['normalize_1'],param['normalize_2'],param['normalize_3']), (param['normalize_4'],param['normalize_5'],param['normalize_6']))]
)

### Data

#### Helpers

In [31]:
processed_frame_train = create_processed_info(setting['path_train'])
processed_frame_val = create_processed_info(setting['path_val'])

#### Dataset

In [32]:
train_dataset = FAUPapyrusCollectionDataset(setting['path_train'], processed_frame_train, train_transform)
val_dataset = FAUPapyrusCollectionDataset(setting['path_val'], processed_frame_val, val_transform)

#### Data Loader

In [33]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_val, drop_last=True)

#### Result Lists

In [34]:
loss_vals = []
val_loss_vals = []
map_vals = []
random_map_vals = []

## Training

### Log Hyperparameters

In [35]:
logger.info(f'Debug:                {setting["debug"]}')
logger.info(f'Loos Function:        {dml_param["loss_function"]}')
logger.info(f'Margin Miner Margin:  {dml_param["TripletMarginMiner_margin"]}')
logger.info(f'Triplet Margin Loss:  {dml_param["TripletMarginLoss_margin"]}')
logger.info(f'Type of Tribles:      {dml_param["type_of_triplets"]}')
logger.info(f'Miner:                {dml_param["miner"]}')
logger.info(f'Reducer:              {dml_param["reducer"]}')
logger.info(f'Archi:                {param["archi"]}')
logger.info(f'Epochs:               {param["num_epochs"]}')
logger.info(f'Batch Size Train:     {param["batch_size_train"]}')
logger.info(f'Batch Size Val:       {param["batch_size_val"]}')
logger.info(f'Optimizer:            {param["optimizer"]}')
logger.info(f'Learning Rate:        {param["lr"]}')
logger.info(f'Shuffle:              {param["shuffle"]}')

Debug:                False
Loos Function:        TripletMarginLoss
Margin Miner Margin:  0.2
Triplet Margin Loss:  0.2
Type of Tribles:      semihard
Miner:                TripletMarginMiner
Reducer:              ThresholdReducer
Archi:                efficientnetB7
Epochs:               15
Batch Size Train:     64
Batch Size Val:       1
Optimizer:            Adam
Learning Rate:        0.01
Shuffle:              True


### Train

In [36]:
def objective(trial):
  optimizer_name = trial.suggest_categorical("optimizer", ["AdamW", "Adam", "SGD"])
  lr = trial.suggest_uniform("lr", 1e-5, 1e-2)
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
  margins = trial.suggest_uniform("TripletMarginMinerMargin", 0.1, 0.3)
  loss_func = losses.TripletMarginLoss(margin=margins, distance=distance, reducer=reducer)
  mining_func = miners.TripletMarginMiner(
    margin=margins,
    distance=distance,
    type_of_triplets=dml_param['type_of_triplets']
    )


  for epoch in range(1, num_epochs + 1):
      ############### Training ###############
      train_loss = train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
      

      ############### Validation ###############
      map, random_map = val(val_dataset, val_dataset, model, accuracy_calculator)
      

      ############### Fill Lists ###############
      trial.report(map, epoch)

      if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
      
      

In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

pruned_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[32m[I 2022-01-10 20:53:29,981][0m A new study created in memory with name: no-name-e4b45d8b-757a-4b52-b40d-c0ec69421a65[0m
