<a href="https://colab.research.google.com/github/hedrergudene/HViT_classification/blob/main/MedMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0 - Requirements

In [1]:
!git clone https://benayas1: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@github.com/hedrergudene/HViT_classification.git
!(cd /content/HViT_classification/ && python setup.py bdist_wheel && pip install dist/hvit-0.0.1-py3-none-any.whl) >> /dev/null
!pip install -U tensorflow-addons >> /dev/null
!pip install wandb >> /dev/null

fatal: destination path 'HViT_classification' already exists and is not an empty directory.


In [12]:
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from typing import List, Dict
import wandb
# Import model
from hvit.tf.ViT_model import HViT, ViT
#from hvit.tf.train_medmnist import run_WB_experiment
from hvit.tf.info import INFO
from hvit.tf.evaluator import Evaluator
import hvit.tf.dataset_without_pytorch as mdn
import cv2
import numpy as np

import zipfile
from tqdm import tqdm
import os
import re

# Login into W&B
WB_ENTITY = 'ual'
WB_PROJECT = 'hvit_classifier'
WB_KEY = 'ab1f4c380e0a008223b6434a42907bacfd7b4e26'
#WB_KEY = '1bb44e6be47564584868ec55bac8cf468cf0e47f'  # antonio's

tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# 1 - Training loop function

In [27]:

def load_data(dataclass, split, task, size, n_classes, n_channels):
    dataset = dataclass(split=split, download=True)
    x = dataset.imgs
    if size is not None:
        x = np.stack([cv2.resize(img, (size,size), interpolation = cv2.INTER_AREA) for img in x])
    if n_channels == 1:
        x = np.expand_dims(x, 3)
    y = dataset.labels
    if task == 'multi-class':
        y = tf.keras.utils.to_categorical(y, n_classes)
    if task == 'binary-class':
        y = np.squeeze(y, axis=1)
    return x, y

def run_WB_experiment(WB_KEY:str,
                      WB_ENTITY:str,
                      WB_PROJECT:str,
                      WB_GROUP:str,
                      model:tf.keras.Model,
                      data_flag:str,
                      ImageDataGenerator_config:Dict,
                      flow_config:Dict,
                      epochs:int=10,
                      learning_rate:float=0.00005,
                      weight_decay:float=0.0001,
                      label_smoothing:float=.1,
                      es_patience:int=10,
                      verbose:int=1,
                      resize:int = None,
                      ):
    # Check for GPU:
    assert len(tf.config.list_physical_devices('GPU'))>0, f"No GPU available. Check system settings."

    # Download dataset
    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])

    DataClass = getattr(mdn, info['python_class'])
    print(f'Dataset {data_flag} Task {task} n_channels {n_channels} n_classes {n_classes}')

    # load train Data
    x_train, y_train = load_data(DataClass, 'train', task, resize, n_classes, n_channels)
    x_train = x_train[:1000]
    y_train = y_train[:1000]
    print(f'X train {x_train.shape} | Y train {y_train.shape}')

    # load val Data
    x_val, y_val = load_data(DataClass, 'val', task, resize, n_classes, n_channels)
    print(f'X val {x_val.shape} | Y val {y_val.shape}')

    # load test Data
    x_test, y_test = load_data(DataClass, 'test', task, resize, n_classes, n_channels)
    print(f'X test {x_test.shape} | Y test {y_test.shape}')

    # Log in WB
    wandb.login(key=WB_KEY)

    # Generators
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['train'])
    val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['val'])
    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['test'])
    train_generator = train_datagen.flow(x=x_train, 
                                         y=y_train,
                                         batch_size=flow_config['train']['batch_size'],
                                         shuffle=flow_config['train']['shuffle'],
                                         seed=flow_config['train']['seed'],
                                         )
    val_generator = val_datagen.flow(x=x_val,
                                     y=y_val,
                                     batch_size=flow_config['val']['batch_size'],
                                     shuffle=flow_config['val']['shuffle'],
                                     seed=flow_config['val']['seed'],
                                     )
    test_generator = test_datagen.flow(x=x_test,
                                       y=y_test,
                                       batch_size=flow_config['test']['batch_size'],
                                       shuffle=flow_config['test']['shuffle'],
                                       seed=flow_config['test']['seed'],
                                       )
    # Train & validation steps
    train_steps_per_epoch = len(train_generator)
    val_steps_per_epoch = len(val_generator)
    test_steps_per_epoch = len(test_generator)

    # Save initial weights
    #model.load_weights(os.path.join(os.getcwd(), 'model_weights.h5'))

    # Credentials
    wandb.init(project='_'.join([WB_PROJECT, data_flag]), entity=WB_ENTITY, group = WB_GROUP)
    
    # Model compile
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    if task == 'multi-class':
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=True, num_labels=n_classes, from_logits=True, name="AUC"),
                   tfa.metrics.F1Score(num_classes=n_classes, average='macro', name = 'f1_score')
                   ]
    if task == 'binary-class':
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.BinaryAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=False, from_logits=True, name="AUC")]

    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
    )

    # Callbacks
    reduceLR = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=learning_rate//10, verbose=1)
    patience = tf.keras.callbacks.EarlyStopping(patience=es_patience)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(os.getcwd(), 'model_best_weights.h5'), save_best_only = True, save_weights_only = True)
    wandb_callback = wandb.keras.WandbCallback(save_weights_only=True)

    # Model fit
    history = model.fit(
        train_generator,
        steps_per_epoch= train_steps_per_epoch,
        epochs = epochs,
        validation_data=val_generator,
        validation_steps = val_steps_per_epoch,
        callbacks=[reduceLR, patience, checkpoint, wandb_callback],
        verbose = verbose,
    )

    # Evaluation
    model.load_weights(os.path.join(os.getcwd(), 'model_best_weights.h5'))
    results = model.evaluate(test_generator, steps = test_steps_per_epoch, verbose = 0)
    print("Test metrics:",{k:v for k,v in zip(model.metrics_names, results)})
    wandb.log({("test_"+k):v for k,v in zip(model.metrics_names, results)})
    wandb.log({"n_parameters":np.round(model.count_params()/1000000, 1)})

    #y_pred = model.predict(test_generator, verbose = 0)
    #evaluator = Evaluator(data_flag, 'test')
    #results = evaluator.evaluate(y_pred)

    #print(f"Test metrics: AUC {results.AUC}, ACC {results.ACC}")
    #wandb.log({"test_ACC":results.ACC, "test_AUC":results.AUC})

    # Clear memory
    tf.keras.backend.clear_session()
    wandb.finish()

# 2 - Global Configuration

In [24]:
# Config
# ['pathmnist','dermamnist'
datasets = ['pathmnist','dermamnist','octmnist','pneumoniamnist','breastmnist','bloodmnist','tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']
batch_size = 32
epochs = 1
es_patience = 10
seed = 123
verbose=1
learning_rate = 0.00005
weight_decay = 0.0001
label_smoothing = .1
img_size = 32

ImageDataGenerator_config = {
    'train':{
        "rescale":1./255,
        "shear_range":.1,
        "rotation_range":.2,
        "zoom_range":.1,
        "horizontal_flip" : True,
        },
    'val':{
        "rescale":1./255,
        },
    'test':{
        "rescale":1./255,
        }
}
flow_config = {
    'train':{
        "batch_size":batch_size,
        "shuffle":True,
        "seed":seed,
        },
    'val':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        },
    'test':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        }
}

# 2 - Experiments

## HViT


In [25]:
WB_GROUP = 'HViT'

for data_flag in datasets:

    info = INFO[data_flag]

    hvit_params = { 'img_size':img_size,
                    'patch_size':[2,4,8],
                    'num_channels': info['n_channels'],
                    'num_heads': 8,
                    'transformer_layers':[4,4,4],
                    'hidden_unit_factor':2,
                    'mlp_head_units': [256, 64],
                    'num_classes':len(info['label']),
                    'drop_attn':0.2,
                    'drop_proj':0.2,
                    'drop_linear':0.4,
                    'projection_dim' : info['n_channels']*36,
                    'resampling_type':"conv",
                    'original_attn':True,
                    }

    # Start running
    with tf.device('/device:GPU:0'):
      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      outputs = HViT(**hvit_params)(inputs)
      model = tf.keras.Model(inputs, outputs)
      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

KeyboardInterrupt: ignored

## ViT

In [21]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'ViT'

for data_flag in datasets:

    info = INFO[data_flag]

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':len(info['label']),
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : info['n_channels']*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):
      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      outputs = ViT(**vit_params)(inputs)
      model = tf.keras.Model(inputs, outputs)
      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset pathmnist Task multi-class n_channels 3 n_classes 9
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X val (10004, 32, 32, 3) | Y val (10004, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz




X test (7180, 32, 32, 3) | Y test (7180, 9)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
val_AUC,▁
val_accuracy,▁
val_f1_score,▁
val_loss,▁

0,1
AUC,0.51679
accuracy,0.11
best_epoch,0.0
best_val_loss,2.27069
epoch,0.0
f1_score,0.1028
loss,3.04738
lr,5e-05
val_AUC,0.6304
val_accuracy,0.14054


Test metrics: {'loss': 2.119122266769409, 'accuracy': 0.20083566009998322, 'AUC': 0.7470336556434631, 'f1_score': 0.07146500051021576}


VBox(children=(Label(value=' 91.04MB of 91.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.51812
accuracy,0.135
best_epoch,0.0
best_val_loss,2.24864
epoch,0.0
f1_score,0.12551
loss,3.01552
lr,5e-05
n_parameters,23.8
test_AUC,0.74703


Dataset dermamnist Task multi-class n_channels 3 n_classes 7
Using downloaded and verified file: /root/.medmnist/dermamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X train (1000, 32, 32, 3) | Y train (1000, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X val (1003, 32, 32, 3) | Y val (1003, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X test (2005, 32, 32, 3) | Y test (2005, 7)


Test metrics: {'loss': 1.4649347066879272, 'accuracy': 0.6688279509544373, 'AUC': 0.6064751744270325, 'f1_score': 0.11450772732496262}


VBox(children=(Label(value=' 91.04MB of 91.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.50204
accuracy,0.491
best_epoch,0.0
best_val_loss,1.46767
epoch,0.0
f1_score,0.15705
loss,1.98239
lr,5e-05
n_parameters,23.8
test_AUC,0.60648


Dataset octmnist Task multi-class n_channels 1 n_classes 4
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X train (1000, 32, 32, 1) | Y train (1000, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X val (10832, 32, 32, 1) | Y val (10832, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X test (1000, 32, 32, 1) | Y test (1000, 4)


InvalidArgumentError: ignored

## EfficientNet

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'EfficientNetB0'
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':len(info['label']),
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.EfficientNetB0(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(len(info['label']))(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset pathmnist Task multi-class n_channels 3 n_classes 9
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X train (1000, 32, 32, 3) | Y train (1000, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X val (10004, 32, 32, 3) | Y val (10004, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz




X test (7180, 32, 32, 3) | Y test (7180, 9)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Test metrics: {'loss': 2.209102153778076, 'accuracy': 0.05863509699702263, 'AUC': 0.5, 'f1_score': 0.01230832189321518}


VBox(children=(Label(value=' 17.03MB of 17.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.51577
accuracy,0.122
best_epoch,0.0
best_val_loss,2.19682
epoch,0.0
f1_score,0.11453
loss,2.90662
lr,5e-05
n_parameters,4.4
test_AUC,0.5


Dataset dermamnist Task multi-class n_channels 3 n_classes 7
Using downloaded and verified file: /root/.medmnist/dermamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X train (1000, 32, 32, 3) | Y train (1000, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X val (1003, 32, 32, 3) | Y val (1003, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X test (2005, 32, 32, 3) | Y test (2005, 7)


Test metrics: {'loss': 1.8423728942871094, 'accuracy': 0.6688279509544373, 'AUC': 0.5, 'f1_score': 0.11450772732496262}


VBox(children=(Label(value=' 17.03MB of 17.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.50589
accuracy,0.45
best_epoch,0.0
best_val_loss,1.84235
epoch,0.0
f1_score,0.14104
loss,2.04461
lr,5e-05
n_parameters,4.4
test_AUC,0.5


Dataset octmnist Task multi-class n_channels 1 n_classes 4
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X train (1000, 32, 32, 1) | Y train (1000, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X val (10832, 32, 32, 1) | Y val (10832, 4)
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X test (1000, 32, 32, 1) | Y test (1000, 4)


Test metrics: {'loss': 1.3933874368667603, 'accuracy': 0.25, 'AUC': 0.5, 'f1_score': 0.10000000149011612}


VBox(children=(Label(value=' 17.03MB of 17.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁
accuracy,▁
epoch,▁
f1_score,▁
loss,▁
lr,▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.51927
accuracy,0.373
best_epoch,0.0
best_val_loss,1.32711
epoch,0.0
f1_score,0.25293
loss,1.83472
lr,5e-05
n_parameters,4.4
test_AUC,0.5


## ResNet 150v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "ResNet 152 v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':len(info['label']),
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.resnet_v2.ResNet152V2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(len(info['label']))(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## Inception ResNet v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "Inception ResNet v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':len(info['label']),
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.InceptionResNetV2(weights=None, include_top=False)(aug)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(len(info['label']))(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## Mobile Net v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "MobileNet v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    info = INFO[data_flag]

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': info['n_channels'],
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':len(info['label']),
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*64
                  }


    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights=None, include_top=False)(aug)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(len(info['label']))(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )