# Pre-training of SwinTransformer

Initial training, validation, and testing of a Transformer model for the task of classifying the strain present in images from our all-strain *Proteus mirabilis* dataset. The optimal model obtained was further fine-tuned in a separate notebook for the new task of classifying the relative IPTG concentraion (low vs. high) of *P. mirabilis* pLac-*cheW* images. The performance of the fine-tuned model was later evaluated based on its ability to accurately classify test images of pLac-*cheW* colonies grown at temperatures previously unseen during model training and evaluation.     

We implement the SwinTransformerTiny224, the shallowest of the SwinTransformer classification models, as originally presented by Liu et al. 2021 and made readily available in Keras (TensorFlow) by Shkarupa<sup>1,2</sup>.

[1] Liu, Z., Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo. *Swin transformer: Hierarchical vision transformer using shifted windows*. in *Proceedings of the IEEE/CVF International Conference on Computer Vision*. 2021.

[2] Shkarupa, A. *tfswin*. 2022; Available from: https://github.com/shkarupa-alex/tfswin.

# Imports

In [None]:
! pip install tfswin

In [None]:
from google.colab import drive
import os
import numpy as np
import shutil
import pandas as pd
import random
import matplotlib.pyplot as plt
import math
import inspect
from collections import Counter
import pickle 
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import preprocessing, layers, models, callbacks
from tensorflow.keras.preprocessing import image 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tfswin import SwinTransformerTiny224, preprocess_input  

In [None]:
# mount my Google Drive where datasets are stored
drive.mount('/content/gdrive')

# Dataset

In [None]:
# Set the path to umbrella directory
drive_classification_dir = '/content/gdrive/MyDrive/Classification_mirabilis/'

# Set the path to all image datasets 
img_datasets_dir = drive_classification_dir + 'img_datasets/'

# Set the path to the specific dataset to use
this_dataset_name = 'img_dataset_curated_split'
this_dataset_dir = img_datasets_dir + this_dataset_name + '/'

In [None]:
# Name the model run & create folders for storing results
run_name = 'SwinTransformerTiny224_curateddataset_pretr_finetu_aug'
run_dir = drive_classification_dir + run_name

saved_models_dir = os.path.join(run_dir,'saved_models')
histories_dir = os.path.join(run_dir,'histories')
CMs_dir = os.path.join(run_dir,'confusion_matrices')

all_run_dirs = [run_dir, saved_models_dir, histories_dir, CMs_dir]
for run_sub_dir in all_run_dirs:
  if not(os.path.isdir(run_sub_dir)):
    os.mkdir(run_sub_dir)

In [None]:
# Print dataset info
subsets = os.listdir(this_dataset_dir)
for sub in subsets:
  print(f"-----{sub}-----")
  sub_path = this_dataset_dir + sub
  class_list = os.listdir(sub_path)
  for cls in class_list:
    cls_path = sub_path + '/' + cls
    img_list = os.listdir(cls_path)
    num_imgs = len(img_list)
    print(f"{cls}: {num_imgs} images total")

In [None]:
# First figure out the img size we should use, 
# based on the default value used in this specific SwinTransformer
size_arg = str(inspect.signature(SwinTransformerTiny224).parameters['pretrain_size'])
IMG_SIZE = int(size_arg.split('=')[1])
print(IMG_SIZE)

In [None]:
# ImageDataGenerator 
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator

TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
TEST_BATCH_SIZE = 1

train_dir = this_dataset_dir + 'train'
val_dir = this_dataset_dir + 'val'
test_dir = this_dataset_dir + 'test'

datagen = ImageDataGenerator()

train_generator = datagen.flow_from_directory(train_dir,
                                              class_mode = 'sparse', # either 'sparse' or 'categorical' is fine
                                              target_size=(IMG_SIZE, IMG_SIZE),
                                              interpolation = 'bicubic', # SwinTransformerTiny224 is sensitive to interpolation method
                                              batch_size=TRAIN_BATCH_SIZE,
                                              shuffle=True,
                                              seed=123,)

val_generator = datagen.flow_from_directory(val_dir,
                                            class_mode = 'sparse', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=VAL_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

test_generator = datagen.flow_from_directory(test_dir,
                                            class_mode = 'sparse', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=TEST_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

In [None]:
# Verify all possible classes
train_class_names = train_generator.class_indices
print(train_class_names)
val_class_names = val_generator.class_indices
print(val_class_names)
test_class_names = test_generator.class_indices
print(test_class_names)

In [None]:
# All subsets have same class names & order
class_names = train_class_names

In [None]:
# Confirm shapes of batches imgs & labels
for image_batch, labels_batch in train_generator:
  print(image_batch.shape)
  print(labels_batch)
  print(labels_batch.shape) # (batch_size, 6) if categorical, (batch_size) if sparse
  break

In [None]:
# Determine class counts in each data subset
train_counter = Counter(train_generator.classes)    
print(train_counter)     
val_counter = Counter(val_generator.classes)    
print(val_counter)    
test_counter = Counter(test_generator.classes)    
print(test_counter)    

In [None]:
# B/c the classes are imbalanced, compute class weights for training 
max_val = float(max(train_counter.values()))    
print(max_val)                      
class_weights = {class_id : max_val/num_images for class_id, num_images in train_counter.items()}  
print(class_weights)

# Model(s): 
- configuration
- training, validtion, & testing

In [None]:
# Some hyperparamters
LR = 1e-5 # learning rate
EPSILON = 1e-8

In [None]:
# Potential additional on-the-fly augmentations
augs_on_the_fly = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=(IMG_SIZE,IMG_SIZE,3)),
    tf.keras.layers.experimental.preprocessing.RandomZoom((-0.1, -0.02)), # neg for zoom in by random amnt in range [+2%, +10%]
    tf.keras.layers.experimental.preprocessing.RandomContrast(0.4),
    tf.keras.layers.experimental.preprocessing.RandomRotation((-0.01, 0.01),fill_mode='reflect'),
    tf.keras.layers.experimental.preprocessing.RandomTranslation(height_factor=(-0, 0),width_factor=(-0.1, 0.1),fill_mode='reflect')],
    name='on_fly_augs')

In [None]:
# Function for building model
# finetune='partial','full','none' 

def build_swintransformer(pretrained=False, finetune='full', data_augmentation=False):
  
  # first reset all layers
  inputs = None
  preproces = None
  swin_model = None
  classification_head = None

  # define input shape
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

  # optional on-the-fly training augs
  if data_augmentation:
    x = augs_on_the_fly(inputs)
  else:
    x = inputs

  # preprocess layer
  preproces = tf.keras.Sequential([tf.keras.layers.Lambda(preprocess_input)],
                                  name='preprocess')

  # swin
  swin_model = SwinTransformerTiny224(weights=('imagenet' if pretrained else None),
                                      include_top=False)
  if finetune == 'full':
    swin_model.trainable = True
  elif finetune == 'none': 
    swin_model.trainable = False
  elif finetune == 'partial':     
    # Let's take a look to see how many layers are in the base model
    print("Number of layers in the base swin model: ", len(swin_model.layers))
    # Fine-tune from this layer onwards
    fine_tune_at = len(swin_model.layers) // 3
    # Freeze all the layers before the `fine_tune_at` layer
    for layer in swin_model.layers[:fine_tune_at]:
      layer.trainable = False

  # classification head
  classification_head = tf.keras.Sequential([
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(class_names), activation='softmax'),
        ],name='class_head')
  
  # put it all together
  x = preproces(x)
  x = swin_model(x, training=pretrained)
  outputs = classification_head(x)
  model = tf.keras.Model(inputs, outputs)

  # compile model
  model.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=LR, epsilon=EPSILON),
                loss='sparse_categorical_crossentropy', 
                metrics=['sparse_categorical_accuracy', 
                         tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name="sparse_top_3_categorical_accuracy")
                         ],
               )

  return model

In [None]:
# Define the model versions to try
try_models = {
    'baseline': build_swintransformer(pretrained=False, finetune='full', data_augmentation=False),
    'baseline_aug': build_swintransformer(pretrained=False, finetune='full', data_augmentation=True),
    'pretr': build_swintransformer(pretrained=True, finetune='full', data_augmentation=False),
    'pretr_aug': build_swintransformer(pretrained=True, finetune='full', data_augmentation=True),
    'pretr_frozen': build_swintransformer(pretrained=True, finetune='none', data_augmentation=False),
    'pretr_frozen_aug': build_swintransformer(pretrained=True, finetune='none', data_augmentation=True),
    'pretr_finetu': build_swintransformer(pretrained=True, finetune='partial', data_augmentation=False),
    'pretr_finetu_aug': build_swintransformer(pretrained=True, finetune='partial', data_augmentation=True),
}

In [None]:
# Show the model summaries
for model_key, model in try_models.items():
  print(f'\n\n Summary of {model_key} model \n\n')
  try_models[model_key].summary()

In [None]:
# Define ES callback
early_stopping = callbacks.EarlyStopping(monitor="val_loss", patience=4)

In [None]:
# Create dataframe for storing test metrics
df_test_metrics = pd.DataFrame(columns = ['ModelName','Loss','Accuracy', 'Top3Accuracy'])
test_metrics_path = os.path.join(run_dir,'all_models_test_metrics.pkl')

In [None]:
# Re-make train & val generators that have batch size 1 and aren't shuffled for testing
# Note: only test images are used for final test metrics
# (but it can also be helpful to see how the model performs on its train & val images)
reset_train_generator = datagen.flow_from_directory(train_dir,
                                              class_mode = 'sparse', # 'sparse', 'categorical'
                                              target_size=(IMG_SIZE, IMG_SIZE),
                                              interpolation = 'bicubic',
                                              batch_size=1,
                                              shuffle=False,
                                              seed=123,)

reset_val_generator = datagen.flow_from_directory(val_dir,
                                            class_mode = 'sparse', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=1,
                                            shuffle=False,
                                            seed=123,)

In [None]:
# Get filenames and true labels of all images

# train 
reset_train_filenames = reset_train_generator.filenames
reset_train_labels = reset_train_generator.labels

# val 
reset_val_filenames = reset_val_generator.filenames
reset_val_labels = reset_val_generator.labels

# test
test_filenames = test_generator.filenames
test_labels = test_generator.labels

In [None]:
# Function for generating predictions & CMs
def plot_CM(generator,best_model,model_key,labels,subset):

  generator.reset()
  preds = best_model.predict(generator)
  binary_preds = preds.argmax(axis=1)

  fig, ax = plt.subplots(figsize=(10, 10))
  title = f"Confusion matrix of {model_key} model's" + os.linesep + f"predictions on {subset} images"
  cm = confusion_matrix(labels, binary_preds, normalize='true')
  disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
  disp.plot(include_values=True, cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format=None)
  disp.ax_.set_title(title,fontweight='bold')
  extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())

  CM_path_eps = CMs_dir + f"/{model_key}_CM_{subset}.eps"
  fig.savefig(CM_path_eps, format='eps', bbox_inches=extent.expanded(2.0, 2.0))

In [None]:
EPOCHS = 50
all_histories = {}

model_num = 1

for model_key, model in try_models.items():
  # finish defining checkpt callback
  model_path = os.path.join(saved_models_dir,model_key) # native tf format
  save_model = callbacks.ModelCheckpoint(model_path, monitor="val_loss",
                                        verbose=1, save_best_only = True,
                                        mode='min')
  # reset generators
  train_generator.reset()
  val_generator.reset()
  test_generator.reset()
  reset_train_generator.reset()
  reset_val_generator.reset()

  # train & validate
  print(f'\n\n Training {model_key} \n\n')
  all_histories[model_key] = model.fit(train_generator,
                                       validation_data=val_generator,
                                       epochs=EPOCHS,
                                       steps_per_epoch=len(train_generator),
                                       validation_steps=len(val_generator),
                                       class_weight=class_weights,
                                       callbacks=[early_stopping, save_model],
                                       )
  # save model history
  model_history = all_histories[model_key].history
  history_path = histories_dir  + '/' + model_key + '_history.pckl'
  file_pi = open(history_path, 'wb')
  pickle.dump(model_history, file_pi)
  file_pi.close()

  # load in best model
  best_model = tf.keras.models.load_model(model_path)

  # test on test set
  test_results = best_model.evaluate(test_generator) 
  test_results = dict(zip(best_model.metrics_names,test_results))

  # save test metrics
  test_loss = test_results['loss']
  test_acc = test_results['sparse_categorical_accuracy']
  test_top3acc = test_results['sparse_top_3_categorical_accuracy']
  df_test_metrics.loc[model_num] = [model_key, test_loss, test_acc, test_top3acc]
  df_test_metrics.to_pickle(test_metrics_path)

  # predict on train, val, & test
  plot_CM(reset_train_generator,best_model,model_key,reset_train_labels,'train')
  plot_CM(reset_val_generator,best_model,model_key,reset_val_labels,'val')
  plot_CM(test_generator,best_model,model_key,test_labels,'test')

  # clear session
  tf.keras.backend.clear_session()

  model_num += 1



In [None]:
# to retrieve model history, you would do: 
#file_pi = open(history_path, 'rb')
#history = pickle.load(file_pi)
#file_pi.close()