<a href="https://colab.research.google.com/github/hedrergudene/HViT_classification/blob/main/MedMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0 - Requirements

In [1]:
!git clone https://benayas1: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@github.com/hedrergudene/HViT_classification.git
!(cd /content/HViT_classification/ && python setup.py bdist_wheel && pip install dist/hvit-0.0.1-py3-none-any.whl) >> /dev/null
!pip install -U tensorflow-addons >> /dev/null
!pip install wandb >> /dev/null

fatal: destination path 'HViT_classification' already exists and is not an empty directory.


In [1]:
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from typing import List, Dict
import wandb
# Import model
from hvit.tf.ViT_model import HViT, ViT
#from hvit.tf.train_medmnist import run_WB_experiment
from hvit.tf.info import INFO
from hvit.tf.evaluator import Evaluator
import hvit.tf.dataset_without_pytorch as mdn
import cv2
import numpy as np

import zipfile
from tqdm import tqdm
import os
import re

# Login into W&B
WB_ENTITY = 'ual'
WB_PROJECT = 'hvit_classifier'
WB_KEY = 'ab1f4c380e0a008223b6434a42907bacfd7b4e26'
#WB_KEY = '1bb44e6be47564584868ec55bac8cf468cf0e47f'  # antonio's

tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# 1 - Training loop function

In [13]:

def load_data(dataclass, split, task, size, n_classes, n_channels):
    dataset = dataclass(split=split, download=True)
    x = dataset.imgs
    if size is not None:
        x = np.stack([cv2.resize(img, (size,size), interpolation = cv2.INTER_AREA) for img in x])
    if n_channels == 1:
        #x = np.expand_dims(x, 3)
        x = np.stack([x,x,x], axis=-1)
    y = dataset.labels
    if task == 'multi-class':
        y = tf.keras.utils.to_categorical(y, n_classes)
    if task == 'binary-class':
        y = np.squeeze(y, axis=1)
    return x, y

def run_WB_experiment(WB_KEY:str,
                      WB_ENTITY:str,
                      WB_PROJECT:str,
                      WB_GROUP:str,
                      model:tf.keras.Model,
                      data_flag:str,
                      ImageDataGenerator_config:Dict,
                      flow_config:Dict,
                      epochs:int=10,
                      learning_rate:float=0.00005,
                      weight_decay:float=0.0001,
                      label_smoothing:float=.1,
                      es_patience:int=10,
                      verbose:int=1,
                      resize:int = None,
                      ):
    # Check for GPU:
    assert len(tf.config.list_physical_devices('GPU'))>0, f"No GPU available. Check system settings."

    # Download dataset
    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    DataClass = getattr(mdn, info['python_class'])
    print(f'Dataset {data_flag} Task {task} n_channels {n_channels} n_classes {n_classes}')

    # load train Data
    x_train, y_train = load_data(DataClass, 'train', task, resize, n_classes, n_channels)
    x_train = x_train[:1000]
    y_train = y_train[:1000]
    print(f'X train {x_train.shape} | Y train {y_train.shape}')

    # load val Data
    x_val, y_val = load_data(DataClass, 'val', task, resize, n_classes, n_channels)
    print(f'X val {x_val.shape} | Y val {y_val.shape}')

    # load test Data
    x_test, y_test = load_data(DataClass, 'test', task, resize, n_classes, n_channels)
    print(f'X test {x_test.shape} | Y test {y_test.shape}')

    # Log in WB
    wandb.login(key=WB_KEY)

    # Generators
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['train'])
    val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['val'])
    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['test'])
    train_generator = train_datagen.flow(x=x_train, 
                                         y=y_train,
                                         batch_size=flow_config['train']['batch_size'],
                                         shuffle=flow_config['train']['shuffle'],
                                         seed=flow_config['train']['seed'],
                                         )
    val_generator = val_datagen.flow(x=x_val,
                                     y=y_val,
                                     batch_size=flow_config['val']['batch_size'],
                                     shuffle=flow_config['val']['shuffle'],
                                     seed=flow_config['val']['seed'],
                                     )
    test_generator = test_datagen.flow(x=x_test,
                                       y=y_test,
                                       batch_size=flow_config['test']['batch_size'],
                                       shuffle=flow_config['test']['shuffle'],
                                       seed=flow_config['test']['seed'],
                                       )
    # Train & validation steps
    train_steps_per_epoch = len(train_generator)
    val_steps_per_epoch = len(val_generator)
    test_steps_per_epoch = len(test_generator)

    # Save initial weights
    #model.load_weights(os.path.join(os.getcwd(), 'model_weights.h5'))

    # Credentials
    wandb.init(project='_'.join([WB_PROJECT, data_flag]), entity=WB_ENTITY, group = WB_GROUP)
    
    # Model compile
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    if task == 'multi-class':
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=True, num_labels=n_classes, from_logits=True, name="AUC"),
                   tfa.metrics.F1Score(num_classes=n_classes, average='macro', name = 'f1_score')
                   ]
    if task == 'binary-class':
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.BinaryAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=False, from_logits=True, name="AUC")]

    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
    )

    # Callbacks
    reduceLR = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=learning_rate//10, verbose=1)
    patience = tf.keras.callbacks.EarlyStopping(patience=es_patience)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(os.getcwd(), 'model_best_weights.h5'), save_best_only = True, save_weights_only = True)
    wandb_callback = wandb.keras.WandbCallback(save_weights_only=True)

    # Model fit
    history = model.fit(
        train_generator,
        steps_per_epoch= train_steps_per_epoch,
        epochs = epochs,
        validation_data=val_generator,
        validation_steps = val_steps_per_epoch,
        callbacks=[reduceLR, patience, checkpoint, wandb_callback],
        verbose = verbose,
    )

    # Evaluation
    model.load_weights(os.path.join(os.getcwd(), 'model_best_weights.h5'))
    results = model.evaluate(test_generator, steps = test_steps_per_epoch, verbose = 0)
    print("Test metrics:",{k:v for k,v in zip(model.metrics_names, results)})
    wandb.log({("test_"+k):v for k,v in zip(model.metrics_names, results)})
    wandb.log({"n_parameters":np.round(model.count_params()/1000000, 1)})

    #y_pred = model.predict(test_generator, verbose = 0)
    #evaluator = Evaluator(data_flag, 'test')
    #results = evaluator.evaluate(y_pred)

    #print(f"Test metrics: AUC {results.AUC}, ACC {results.ACC}")
    #wandb.log({"test_ACC":results.ACC, "test_AUC":results.AUC})

    # Clear memory
    tf.keras.backend.clear_session()
    wandb.finish()

# 2 - Global Configuration

In [14]:
# Config
# ['pneumoniamnist','breastmnist'
datasets = ['octmnist','tissuemnist','pathmnist','dermamnist','bloodmnist', 'organamnist', 'organcmnist', 'organsmnist']
batch_size = 32
epochs = 1
es_patience = 10
seed = 123
verbose=1
learning_rate = 0.00005
weight_decay = 0.0001
label_smoothing = .1
img_size = 32

ImageDataGenerator_config = {
    'train':{
        "rescale":1./255,
        "shear_range":.1,
        "rotation_range":.2,
        "zoom_range":.1,
        "horizontal_flip" : True,
        },
    'val':{
        "rescale":1./255,
        },
    'test':{
        "rescale":1./255,
        }
}
flow_config = {
    'train':{
        "batch_size":batch_size,
        "shuffle":True,
        "seed":seed,
        },
    'val':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        },
    'test':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        }
}

# 2 - Experiments

## HViT


In [None]:
WB_GROUP = 'HViT'

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    hvit_params = { 'img_size':img_size,
                    'patch_size':[2,4,8],
                    'num_channels': 3,
                    'num_heads': 8,
                    'transformer_layers':[4,4,4],
                    'hidden_unit_factor':2,
                    'mlp_head_units': [256, 64],
                    'num_classes':n_classes,
                    'drop_attn':0.2,
                    'drop_proj':0.2,
                    'drop_linear':0.4,
                    'projection_dim' : 48,
                    'resampling_type':"conv",
                    'original_attn':True,
                    }

    # Start running
    with tf.device('/device:GPU:0'):
      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      outputs = HViT(**hvit_params)(inputs)
      model = tf.keras.Model(inputs, outputs)
      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset octmnist Task multi-class n_channels 1 n_classes 4
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz




X val (10832, 32, 32, 3) | Y val (10832, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X test (1000, 32, 32, 3) | Y test (1000, 4)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Test metrics: {'loss': 1.5628355741500854, 'accuracy': 0.24899999797344208, 'AUC': 0.52265465259552, 'f1_score': 0.09967974573373795}


VBox(children=(Label(value=' 4.54MB of 4.54MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.47946
accuracy,0.398
best_epoch,0.0
best_val_loss,1.20413
epoch,0.0
f1_score,0.25213
loss,1.33541
lr,5e-05
n_parameters,1.1
test_AUC,0.52265


Dataset tissuemnist Task multi-class n_channels 1 n_classes 8
Using downloaded and verified file: /root/.medmnist/tissuemnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 8)
Using downloaded and verified file: /root/.medmnist/tissuemnist.npz
X val (23640, 32, 32, 3) | Y val (23640, 8)
Using downloaded and verified file: /root/.medmnist/tissuemnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (47280, 32, 32, 3) | Y test (47280, 8)


Test metrics: {'loss': 1.8674685955047607, 'accuracy': 0.3207275867462158, 'AUC': 0.5080062747001648, 'f1_score': 0.06073259189724922}


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.48979
accuracy,0.248
best_epoch,0.0
best_val_loss,1.86771
epoch,0.0
f1_score,0.11445
loss,2.03915
lr,5e-05
n_parameters,1.1
test_AUC,0.50801


Dataset pathmnist Task multi-class n_channels 3 n_classes 9
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X val (10004, 32, 32, 3) | Y val (10004, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (7180, 32, 32, 3) | Y test (7180, 9)


Test metrics: {'loss': 2.1997392177581787, 'accuracy': 0.18189415335655212, 'AUC': 0.5587500929832458, 'f1_score': 0.03665540739893913}


VBox(children=(Label(value=' 4.54MB of 4.54MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.51952
accuracy,0.126
best_epoch,0.0
best_val_loss,2.19516
epoch,0.0
f1_score,0.10898
loss,2.25684
lr,5e-05
n_parameters,1.1
test_AUC,0.55875


Dataset dermamnist Task multi-class n_channels 3 n_classes 7
Using downloaded and verified file: /root/.medmnist/dermamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X train (1000, 32, 32, 3) | Y train (1000, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X val (1003, 32, 32, 3) | Y val (1003, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X test (2005, 32, 32, 3) | Y test (2005, 7)


Test metrics: {'loss': 1.3240504264831543, 'accuracy': 0.6688279509544373, 'AUC': 0.4881301522254944, 'f1_score': 0.11450772732496262}


VBox(children=(Label(value=' 4.54MB of 4.54MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.49552
accuracy,0.432
best_epoch,0.0
best_val_loss,1.32348
epoch,0.0
f1_score,0.15868
loss,1.67726
lr,5e-05
n_parameters,1.1
test_AUC,0.48813


Dataset bloodmnist Task multi-class n_channels 3 n_classes 8
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 8)
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz
X val (1712, 32, 32, 3) | Y val (1712, 8)
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (3421, 32, 32, 3) | Y test (3421, 8)


Test metrics: {'loss': 2.0318729877471924, 'accuracy': 0.18094123899936676, 'AUC': 0.5551242828369141, 'f1_score': 0.05428792163729668}


VBox(children=(Label(value=' 4.54MB of 4.54MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.48386
accuracy,0.131
best_epoch,0.0
best_val_loss,2.03271
epoch,0.0
f1_score,0.10714
loss,2.16837
lr,5e-05
n_parameters,1.1
test_AUC,0.55512


Dataset organamnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organamnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 11)
Using downloaded and verified file: /root/.medmnist/organamnist.npz
X val (6491, 32, 32, 3) | Y val (6491, 11)
Using downloaded and verified file: /root/.medmnist/organamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (17778, 32, 32, 3) | Y test (17778, 11)


Test metrics: {'loss': 2.2576239109039307, 'accuracy': 0.26932162046432495, 'AUC': 0.6804309487342834, 'f1_score': 0.08941315859556198}


VBox(children=(Label(value=' 4.55MB of 4.55MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.51658
accuracy,0.127
best_epoch,0.0
best_val_loss,2.21389
epoch,0.0
f1_score,0.09551
loss,2.41291
lr,5e-05
n_parameters,1.1
test_AUC,0.68043


Dataset organcmnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organcmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 11)
Using downloaded and verified file: /root/.medmnist/organcmnist.npz
X val (2392, 32, 32, 3) | Y val (2392, 11)
Using downloaded and verified file: /root/.medmnist/organcmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (8268, 32, 32, 3) | Y test (8268, 11)


Test metrics: {'loss': 2.297333002090454, 'accuracy': 0.22206096351146698, 'AUC': 0.6356149911880493, 'f1_score': 0.033044759184122086}


VBox(children=(Label(value=' 4.55MB of 4.55MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.5131
accuracy,0.177
best_epoch,0.0
best_val_loss,2.36582
epoch,0.0
f1_score,0.08926
loss,2.36764
lr,5e-05
n_parameters,1.1
test_AUC,0.63561


Dataset organsmnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organsmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 11)
Using downloaded and verified file: /root/.medmnist/organsmnist.npz
X val (2452, 32, 32, 3) | Y val (2452, 11)
Using downloaded and verified file: /root/.medmnist/organsmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (8829, 32, 32, 3) | Y test (8829, 11)




## ViT

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'ViT'

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': 3,
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):
      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      outputs = ViT(**vit_params)(inputs)
      model = tf.keras.Model(inputs, outputs)
      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## EfficientNet

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'EfficientNetB0'
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.EfficientNetB0(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## ResNet 150v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "ResNet 152 v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.resnet_v2.ResNet152V2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## Inception ResNet v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "Inception ResNet v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.InceptionResNetV2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## Mobile Net v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "MobileNet v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )