# Fine-tuning of SwinTransformer

Subsequent training, validation, and testing of a  Transformer model for the new task of classifying  the relative IPTG concentraion (low vs. high) of *Proteus mirabilis* pLac-*cheW* images. Initiliazing the model with ImageNet weights, as previously done, or our recently acquired all-strain *P. mirabilis* weights, was explored. The model versions were trained on images of pLac-*cheW* grown at 37C and evaluated on unseen test images of the strain grown at 37C, 36C, and 34C. 

As before, we implement the SwinTransformerTiny224, the shallowest of the SwinTransformer classification models, as originally presented by Liu et al. 2021 and made readily available in Keras (TensorFlow) by Shkarupa<sup>1,2</sup>.

[1] Liu, Z., Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo. Swin transformer: Hierarchical vision transformer using shifted windows. in Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.

[2] Shkarupa, A. tfswin. 2022; Available from: https://github.com/shkarupa-alex/tfswin.

# Imports

In [None]:
! pip install tfswin

In [None]:
from google.colab import drive
import os
import numpy as np
import shutil
import pandas as pd
import random
import matplotlib.pyplot as plt
import math
import inspect
from collections import Counter
import pickle 
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from tensorflow.keras import preprocessing, layers, models, callbacks
from tensorflow.keras.preprocessing import image 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tfswin import SwinTransformerTiny224, preprocess_input  

In [None]:
# mount my Google Drive where datasets and models are stored
drive.mount('/content/gdrive')

# Dataset

In [None]:
# Set the path to umbrella directory
drive_classification_path = '/content/gdrive/MyDrive/Classification_mirabilis/'

# Set the path to all image datasets 
img_datasets_dir = os.path.join(drive_classification_path,'img_datasets')

# Set the path to the specific dataset to use
temperature_folder = 'temperature'
temperature_path = os.path.join(img_datasets_dir,temperature_folder)

# Choose the strain to analyze
strain = 'cheW_37' # one of: 'chew' 'lrp' 'flgm' 'flia' 'umod' 'wt' 'gfp'
strain_path = os.path.join(temperature_path,strain)
strain_split = strain + '_split' # train-val-test split
strain_split_path = os.path.join(temperature_path,strain_split)

In [None]:
# Get counts (before train-val-test split was done) – for creating visualization plot below
class_num = 0
cls_img_counts = list()
classes = os.listdir(strain_path)
for c in classes:
  class_num += 1
  class_subfolder = os.path.join(strain_path,c)
  class_imgs = os.listdir(class_subfolder)
  num_imgs = len(class_imgs)
  cls_img_counts.append(num_imgs)

  print(c)
  print(num_imgs)

print(cls_img_counts)

In [None]:
# First figure out the img size we should use, 
# based on the default value used in this specific SwinTransformer
size_arg = str(inspect.signature(SwinTransformerTiny224).parameters['pretrain_size'])
IMG_SIZE = int(size_arg.split('=')[1])
print(IMG_SIZE)

In [None]:
# ImageDataGenerator 
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator

TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
TEST_BATCH_SIZE = 1

train_dir = os.path.join(strain_split_path,'train')
val_dir = os.path.join(strain_split_path,'val')
test_dir_37 = os.path.join(strain_split_path,'test')
test_dir_36 = os.path.join(temperature_path,'cheW_36')
test_dir_34 = os.path.join(temperature_path,'cheW_34')

datagen = ImageDataGenerator()

train_generator = datagen.flow_from_directory(train_dir,
                                              class_mode = 'categorical', # either 'sparse' or 'categorical' is fine
                                              target_size=(IMG_SIZE, IMG_SIZE),
                                              interpolation = 'bicubic', # SwinTransformerTiny224 is sensitive to interpolation method
                                              batch_size=TRAIN_BATCH_SIZE,
                                              shuffle=True,
                                              seed=123,)

val_generator = datagen.flow_from_directory(val_dir,
                                            class_mode = 'categorical', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=VAL_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

test_generator_37 = datagen.flow_from_directory(test_dir_37,
                                            class_mode = 'categorical', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=TEST_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

test_generator_36 = datagen.flow_from_directory(test_dir_36,
                                            class_mode = 'categorical', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=TEST_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

test_generator_34 = datagen.flow_from_directory(test_dir_34,
                                            class_mode = 'categorical', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=TEST_BATCH_SIZE,
                                            shuffle=False,
                                            seed=123,)

In [None]:
# Verify classes
class_names = train_generator.class_indices
print(class_names)
num_classes = len(class_names)

In [None]:
# Confirm shapes of batches imgs & labels
for image_batch, labels_batch in train_generator:
  print(image_batch.shape)
  print(labels_batch)
  print(labels_batch.shape) # (batch_size, num_classes) if categorical, (batch_size) if sparse
  break

In [None]:
# Determine class counts in each data subset
train_counter = Counter(train_generator.classes)    
print(train_counter)     
val_counter = Counter(val_generator.classes)    
print(val_counter)    
test_counter_37 = Counter(test_generator_37.classes)    
print(test_counter_37)    
test_counter_36 = Counter(test_generator_36.classes)    
print(test_counter_36) 
test_counter_34 = Counter(test_generator_34.classes)    
print(test_counter_34) 

In [None]:
# B/c the classes are imbalanced, compute class weights for training 
max_val = float(max(train_counter.values()))    
print(max_val)                      
class_weights = {class_id : max_val/num_images for class_id, num_images in train_counter.items()}  
print(class_weights)

# Build Model(s) & load in weights

In [None]:
# Set some hyperparamters
LR = 1e-5 # initial learning rate
EPSILON = 1e-8

# Number of steps after which to reduce learning rate
# (one step refers to the execution of one batch of data)
DECAY_STEPS = 13
ALPHA = 0.0 # Minimum learning rate value as a fraction of initial_learning_rate
DECAY_RATE = 0.96
LR_schedule = tf.keras.optimizers.schedules.ExponentialDecay(LR,
                                                             decay_steps=DECAY_STEPS,
                                                             decay_rate=DECAY_RATE)

In [None]:
# Potential additional on-the-fly augmentations
augs_on_the_fly = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=(IMG_SIZE,IMG_SIZE,3)),
    tf.keras.layers.experimental.preprocessing.RandomZoom((-0.1, -0.02)), # neg for zoom in by random amnt in range [+2%, +10%]
    tf.keras.layers.experimental.preprocessing.RandomContrast(0.4),
    tf.keras.layers.experimental.preprocessing.RandomRotation((-0.01, 0.01),fill_mode='reflect'),
    tf.keras.layers.experimental.preprocessing.RandomTranslation(height_factor=(-0, 0),width_factor=(-0.1, 0.1),fill_mode='reflect')],
    name='on_fly_augs')

In [None]:
# Function for building model
# finetune='partial','full','none' 

def build_swintransformer(num_classes, pretrained=False, finetune='full', data_augmentation=False):
  
  # first reset all layers
  inputs = None
  preproces = None
  swin_model = None
  classification_head = None

  # define input shape
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3),name='input_new')
  
  # optional on-the-fly training augs
  if data_augmentation:
    x = augs_on_the_fly(inputs)
  else:
    x = inputs

  # preprocess layer
  preproces = tf.keras.Sequential([tf.keras.layers.Lambda(preprocess_input)],
                                  name='preprocess_new')

  # swin
  swin_model = SwinTransformerTiny224(weights=('imagenet' if pretrained else None),
                                      include_top=False)
  if finetune == 'full':
    swin_model.trainable = True
  elif finetune == 'none': 
    swin_model.trainable = False
  elif finetune == 'partial':     
    # Let's take a look to see how many layers are in the base model
    print("Number of layers in the base swin model: ", len(swin_model.layers))
    # Fine-tune from this layer onwards
    fine_tune_at = len(swin_model.layers) // 3
    # Freeze all the layers before the `fine_tune_at` layer
    for layer in swin_model.layers[:fine_tune_at]:
      layer.trainable = False

  # classification head
  classification_head = tf.keras.Sequential([
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(num_classes, activation='softmax'),
        ],name='class_head_new')
  
  # put it all together
  x = preproces(x)
  x = swin_model(x, training=pretrained)
  outputs = classification_head(x)
  model = tf.keras.Model(inputs, outputs)

  # compile model
  model.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=LR_schedule, epsilon=EPSILON),
                loss=tf.keras.losses.CategoricalCrossentropy(),
                metrics=[tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
                         tf.keras.metrics.AUC(name='AUC',multi_label=False)
                         ],
               )

  return model 

In [None]:
# First instantiate the same model that we found to be best on the all-strain dataset
# (it had 6 classes for 6 strains).
# We need to do this to grab its weights to load into a new model.
former_model = build_swintransformer(num_classes=6, pretrained=True, finetune='full', data_augmentation=True)

In [None]:
# Weights from the last model
former_run_name = 'SwinTransformerTiny224_curateddataset_pretr_finetu_aug'
former_run_dir = os.path.join(drive_classification_path,former_run_name)
former_best_model_name = 'pretr_aug'
former_best_model_path = os.path.join(former_run_dir,'saved_models',former_best_model_name)
former_model.load_weights(former_best_model_path)

In [None]:
former_model.summary()

In [None]:
former_swin_weights = former_model.layers[3].get_weights()
print(f"Shape of SwinTransformer's weights: {former_swin_weights[0].shape}") 
print(f"Shape of SwinTransformer's biases: {former_swin_weights[1].shape}") 

In [None]:
# Create folder for this new run
run_name = 'SwinT_adaptLR_chew_TempRobust_ExpDecay_13DC_0pt96DR_pat10'
run_dir = os.path.join(drive_classification_path,run_name)

saved_models_dir = os.path.join(run_dir,'saved_models')
histories_dir = os.path.join(run_dir,'histories')
CMs_dir = os.path.join(run_dir,'confusion_matrices')

all_run_dirs = [run_dir, saved_models_dir, histories_dir, CMs_dir]

for run_sub_dir in all_run_dirs:
  if not(os.path.isdir(run_sub_dir)):
    os.mkdir(run_sub_dir)

In [None]:
# Now instantiate new models to train on single strain, binned by iptg
# (the only difference from above is # classes)
# for some of the models, we'll then load in the P. mirabilis (PM) weights of the above swin layer,
# rather than ImageNet (IN) weights

try_models = {
    'PM FFT': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='full', data_augmentation=False), 
    'PM FFT Aug': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='full', data_augmentation=True),
    'PM PFT': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='partial', data_augmentation=False),
    'PM PFT Aug': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='partial', data_augmentation=True),
    'IN FFT': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='full', data_augmentation=False),
    'IN FFT Aug': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='full', data_augmentation=True),
    'IN PFT': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='partial', data_augmentation=False),
    'IN PFT Aug': build_swintransformer(num_classes=num_classes,pretrained=True, finetune='partial', data_augmentation=True),
}

In [None]:
# show the model summaries
for model_key, model in try_models.items():
  print(f'\n\n Summary of {model_key} model \n\n')
  try_models[model_key].summary()

In [None]:
# load in P. mirabilis weights for specified models
for model_key, model in list(try_models.items()):
  if 'PM' in model_key:
    print(model_key)
    try_models[model_key].layers[-2].set_weights(former_swin_weights)

# Training & validation

In [None]:
# define ES callback
early_stopping = callbacks.EarlyStopping(monitor="val_loss", patience=10)

In [None]:
# re-make train & val generators that have batch size 1 and aren't shuffled for testing
# Note: only test images are used for final test metrics
# (but it can also be helpful to see how the model performs on its train & val images)
reset_train_generator = datagen.flow_from_directory(train_dir,
                                              class_mode = 'categorical', # 'sparse', 'categorical'
                                              target_size=(IMG_SIZE, IMG_SIZE),
                                              interpolation = 'bicubic',
                                              batch_size=1,
                                              shuffle=False,
                                              seed=123,)

reset_val_generator = datagen.flow_from_directory(val_dir,
                                            class_mode = 'categorical', # 'sparse', 'categorical'
                                            target_size=(IMG_SIZE, IMG_SIZE),
                                            interpolation = 'bicubic',
                                            batch_size=1,
                                            shuffle=False,
                                            seed=123,)

In [None]:
# Get filenames and true labels of all images

# train 
reset_train_filenames = reset_train_generator.filenames
reset_train_labels = reset_train_generator.labels

# val 
reset_val_filenames = reset_val_generator.filenames
reset_val_labels = reset_val_generator.labels

# test
test_filenames_37 = test_generator_37.filenames
test_labels_37 = test_generator_37.labels
test_filenames_36 = test_generator_36.filenames
test_labels_36 = test_generator_36.labels
test_filenames_34 = test_generator_34.filenames
test_labels_34 = test_generator_34.labels

In [None]:
# function for generating predictions & CMs
def plot_CM(generator,best_model,model_key,labels,subset):

  generator.reset()
  preds = best_model.predict(generator)
  binary_preds = preds.argmax(axis=1)

  fig, ax = plt.subplots(figsize=(10, 10))
  title = f"Confusion matrix of {model_key} model's" + os.linesep + f"predictions on cheW {subset} images"
  cm = confusion_matrix(labels, binary_preds, normalize=None)
  disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
  disp.plot(include_values=True, cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format=None)
  disp.ax_.set_title(title,fontweight='bold')
  extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
  plt.rcParams['svg.fonttype'] = 'none' 

  CM_path = CMs_dir + f"/{model_key}_CM_{subset}.svg"
  fig.savefig(CM_path, format='svg', bbox_inches=extent.expanded(2.0, 2.0))

In [None]:
EPOCHS = 50
all_histories = {}


for model_key, model in try_models.items():
  # finish defining checkpt callback
  model_path = os.path.join(saved_models_dir,model_key) # native tf format
  save_model = callbacks.ModelCheckpoint(model_path, monitor="val_loss",
                                        verbose=1, save_best_only = True,
                                        mode='min')
  # reset generators
  train_generator.reset()
  val_generator.reset()
  test_generator_37.reset()
  test_generator_36.reset()
  test_generator_34.reset()
  reset_train_generator.reset()
  reset_val_generator.reset()

  # train & validate
  print(f'\n\n Training {model_key} \n\n')
  all_histories[model_key] = model.fit(train_generator,
                                       validation_data=val_generator,
                                       epochs=EPOCHS,
                                       steps_per_epoch=len(train_generator),
                                       validation_steps=len(val_generator),
                                       class_weight=class_weights,
                                       callbacks=[early_stopping, save_model],
                                       )
  # save model history
  model_history = all_histories[model_key].history
  history_path = histories_dir  + '/' + model_key + '_history.pckl'
  file_pi = open(history_path, 'wb')
  pickle.dump(model_history, file_pi)
  file_pi.close()

  # load in best model
  best_model = tf.keras.models.load_model(model_path)

  # create dataframe for storing test metrics
  df_test_metrics = pd.DataFrame(index = ['37C','36C','34C'],
                               columns = ['Loss','Accuracy','AUC'])
  test_metrics_path = os.path.join(run_dir,f"{model_key}_test_metrics.pkl")

  # test on test sets
  test_results_37 = best_model.evaluate(test_generator_37) 
  test_results_37 = dict(zip(best_model.metrics_names,test_results_37))
  test_results_36 = best_model.evaluate(test_generator_36) 
  test_results_36 = dict(zip(best_model.metrics_names,test_results_36))
  test_results_34 = best_model.evaluate(test_generator_34) 
  test_results_34 = dict(zip(best_model.metrics_names,test_results_34))
  
  # save test metrics
  test_loss_37 = test_results_37['loss']
  test_acc_37 = test_results_37['categorical_accuracy']
  test_AUC_37 = test_results_37['AUC']
  test_loss_36 = test_results_36['loss']
  test_acc_36 = test_results_36['categorical_accuracy']
  test_AUC_36 = test_results_36['AUC']
  test_loss_34 = test_results_34['loss']
  test_acc_34 = test_results_34['categorical_accuracy']
  test_AUC_34 = test_results_34['AUC']
  df_test_metrics.loc['37C'] = [test_loss_37, test_acc_37, test_AUC_37]
  df_test_metrics.loc['36C'] = [test_loss_36, test_acc_36, test_AUC_36]
  df_test_metrics.loc['34C'] = [test_loss_34, test_acc_34, test_AUC_34]
  df_test_metrics.to_pickle(test_metrics_path)

  # predict on train, val, & test
  plot_CM(reset_train_generator,best_model,model_key,reset_train_labels,'37C train')
  plot_CM(reset_val_generator,best_model,model_key,reset_val_labels,'37C val')
  plot_CM(test_generator_37,best_model,model_key,test_labels_37,'37C test')
  plot_CM(test_generator_36,best_model,model_key,test_labels_36,'36C test')
  plot_CM(test_generator_34,best_model,model_key,test_labels_34,'34C test')

  # clear session
  tf.keras.backend.clear_session()



---



---

