# Plotting SwinTransformer Learning Curves

The present notebook as is plots the learning curves from training and validating a SwinTransformer to classify low vs. high IPTG concentration from images of *Proteus mirabilis* pLac-*cheW* colonies grown at 37C<sup>1,2</sup>. The SwinTransformer was implemented with various optimization methods, hence the multiple learning curves:
- “PM” = *P. mirabilis*-pretrained (i.e. the base of the SwinTransformer (eleven layers, not including the classification head) was initialized with our previously obtained *P. mirabilis*-specific weights) 
- “IM” = ImageNet-pretrained (i.e. the base was initialized with out-of-the-box ImageNet weights)<sup>2</sup> 
- “FFT” = fully fine-tuned (i.e. no layers were fixed during training)
- “PFT” = partially fine-tuned (i.e. jthe first three layers were fixed during training)
- “Aug” = on-the-fly augmentations were randomly implemented during training

[1] Liu, Z., Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo. *Swin transformer: Hierarchical vision transformer using shifted windows*. in *Proceedings of the IEEE/CVF International Conference on Computer Vision*. 2021.

[2] Shkarupa, A. *tfswin*. 2022; Available from: https://github.com/shkarupa-alex/tfswin.

# Imports

In [None]:
from google.colab import drive
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle 
import matplotlib.image as mpimg
import seaborn as sns
sns.set_style('white')
from matplotlib import colors as mcolors

In [None]:
# mount my Google Drive where datasets, models, & results are stored
drive.mount('/content/gdrive')

# Plot learning process

In [None]:
# Set the path to umbrella directory 
drive_classification_dir = '/content/gdrive/MyDrive/Classification_mirabilis/'

# Set the path to the folder containing final model implementation results 
run_name = 'SwinT_adaptLR_chew_TempRobust_ExpDecay_13DC_0pt96DR_pat10' 
run_dir = drive_classification_dir + run_name

# And the results subfolders
saved_models_dir = os.path.join(run_dir,'saved_models')
histories_dir = os.path.join(run_dir,'histories')
CMs_dir = os.path.join(run_dir,'confusion_matrices')


In [None]:
# Set seaborne theme
sns.set_theme()

In [None]:
# Function for plotting multiple learning curves
def plot_multiple(history, model_name, linecolor):
  
  # The history object contains results on the training and val sets for each epoch
  loss = history['loss']
  val_loss = history['val_loss']
  acc = history['categorical_accuracy']
  val_acc = history['val_categorical_accuracy']
  AUC =history['AUC']
  val_AUC = history['val_AUC']

  # Get the number of epochs
  epochs_completed = len(loss)
  epochs_range = range(1,epochs_completed+1)

  # loss
  plt.subplot(3, 1, 1)
  plt.plot(epochs_range, loss, color=linecolor, linestyle='solid', alpha=1, label=f'Train {model_name}')
  plt.plot(epochs_range, val_loss, color=linecolor, linestyle=':', alpha=1, label=f'Val {model_name}')
  
  # accuracy
  plt.subplot(3, 1, 2)
  plt.plot(epochs_range, acc, color=linecolor, linestyle='solid', alpha = 1, label=f'Train {model_name}')
  plt.plot(epochs_range, val_acc, color=linecolor, linestyle=':', alpha = 1, label=f'Val {model_name}')

  # AUC
  plt.subplot(3, 1, 3)
  plt.plot(epochs_range, AUC, color=linecolor, linestyle='solid', alpha = 1, label=f'Train {model_name}')
  plt.plot(epochs_range, val_AUC, color=linecolor, linestyle=':', alpha = 1, label=f'Val {model_name}')


In [None]:
# Function for formatting the plots
def format_histories_plot(max_epochs):
  plt.subplot(3, 1, 1)
  plt.title('Categorical Cross-Entropy Loss')
  plt.xlabel('Epoch',fontsize=10)
  plt.ylabel('Loss',fontsize=10)
  plt.xticks(ticks=np.arange(1,max_epochs+1,4))
  plt.yticks(ticks=np.arange(0.0,1.1,0.2))
  plt.ylim(0.0,1.1)
  plt.legend(bbox_to_anchor=(1.3, 1.01))
  
  plt.subplot(3, 1, 2)
  plt.title('Categorical Accuracy')
  plt.xlabel('Epoch',fontsize=10)
  plt.ylabel('Accuracy',fontsize=10)
  plt.xticks(ticks=np.arange(1,max_epochs+1,4))
  plt.yticks(ticks=np.arange(0.6,1.05,0.1))

  plt.subplot(3, 1, 3)
  plt.title('AUC')
  plt.xlabel('Epoch',fontsize=10)
  plt.ylabel('AUC',fontsize=10)
  plt.xticks(ticks=np.arange(1,max_epochs+1,4))
  plt.yticks(ticks=np.arange(0.6,1.05,0.1))

  fig.suptitle('Training & validating a SwinTransformer to classify IPTG' + os.linesep + 'from cheW images acquired at 37C' ,y=0.93,fontweight='bold') 

In [None]:
# Get the list of model histories
histories_list = os.listdir(histories_dir)
print(histories_list)

In [None]:
# Set names for the legend
legend_names = ['PM FFT', 'PM FFT Aug', 'PM PFT', 'PM PFT Aug', 'IN FFT', 'IN FFT Aug', 'IN PFT', 'IN PFT Aug']

In [None]:
# Plot & save it

all_epochs_completed = list()
fig, ax = plt.subplots(figsize = [10, 20])
idx = 0
palette = ['#4363d8','#ff7f0e','#2ca02c','#d62728','#9467bd','#e377c2','#bcbd22','#17becf']

for h in histories_list:
  
  model_name = legend_names[idx]
  history_path = os.path.join(histories_dir,h)
  file_pi = open(history_path, 'rb')
  history = pickle.load(file_pi)
  file_pi.close()

  eps = len(history['loss'])
  all_epochs_completed.append(eps)

  linecolor = palette[idx]
  plot_multiple(history, model_name, linecolor)
  idx += 1

max_epochs = np.amax(all_epochs_completed)
format_histories_plot(max_epochs)
plt.show()
extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plt.rcParams['svg.fonttype'] = 'none' 
curves_path = run_dir + '/learning_curves.svg'
fig.savefig(curves_path, format='svg', bbox_inches=extent.expanded(2.0, 2.0))