<a href="https://colab.research.google.com/github/hedrergudene/HViT_classification/blob/main/MedMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0 - Requirements

In [1]:
from google.colab import drive
drive.mount('/content/drive')
!mkdir macula && unzip /content/drive/MyDrive/archive.zip -d /content/macula >> /dev/null

Mounted at /content/drive


In [2]:
!git clone https://benayas1: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@github.com/hedrergudene/HViT_classification.git
!(cd /content/HViT_classification/ && python setup.py bdist_wheel && pip install dist/hvit-0.0.1-py3-none-any.whl) >> /dev/null
!pip install -U tensorflow-addons >> /dev/null
!pip install wandb >> /dev/null

!git clone https://github.com/MonashAI/HVT

Cloning into 'HViT_classification'...
remote: Enumerating objects: 867, done.[K
remote: Counting objects: 100% (867/867), done.[K
remote: Compressing objects: 100% (675/675), done.[K
remote: Total 867 (delta 450), reused 403 (delta 140), pack-reused 0[K
Receiving objects: 100% (867/867), 1.01 MiB | 6.34 MiB/s, done.
Resolving deltas: 100% (450/450), done.
Cloning into 'HVT'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 35 (delta 6), reused 28 (delta 5), pack-reused 0[K
Unpacking objects: 100% (35/35), done.


In [1]:
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from typing import List, Dict
import wandb
# Import model
from hvit.tf.ViT_model import HViT, ViT
#from hvit.tf.train_medmnist import run_WB_experiment
from hvit.tf.info import INFO
from hvit.tf.evaluator import Evaluator
import hvit.tf.dataset_without_pytorch as mdn
import cv2
import numpy as np

import zipfile
from tqdm import tqdm
import os
import re

# Login into W&B
WB_ENTITY = 'ual'
WB_PROJECT = 'hvit_benchmark'
WB_KEY = 'ab1f4c380e0a008223b6434a42907bacfd7b4e26'
#WB_KEY = '1bb44e6be47564584868ec55bac8cf468cf0e47f'  # antonio's

tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# 1 - Training loop function

In [2]:

def load_data(dataclass, split, task, size, n_classes, n_channels):
    dataset = dataclass(split=split, download=True)
    x = dataset.imgs
    if size is not None:
        x = np.stack([cv2.resize(img, (size,size), interpolation = cv2.INTER_AREA) for img in x])
    if n_channels == 1:
        #x = np.expand_dims(x, 3)
        x = np.stack([x,x,x], axis=-1)
    y = dataset.labels
    if task == 'multi-class':
        y = tf.keras.utils.to_categorical(y, n_classes)
    if task == 'binary-class':
        y = np.squeeze(y, axis=1)
    return x, y

def run_WB_experiment(WB_KEY:str,
                      WB_ENTITY:str,
                      WB_PROJECT:str,
                      WB_GROUP:str,
                      model:tf.keras.Model,
                      data_flag:str,
                      ImageDataGenerator_config:Dict,
                      flow_config:Dict,
                      epochs:int=10,
                      learning_rate:float=0.00005,
                      weight_decay:float=0.0001,
                      label_smoothing:float=.1,
                      es_patience:int=10,
                      verbose:int=1,
                      resize:int = None,
                      ):
    # Check for GPU:
    assert len(tf.config.list_physical_devices('GPU'))>0, f"No GPU available. Check system settings."

    monitor = 'val_AUC'
    mode = 'max'

    # Generators
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['train'])
    val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['val'])
    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**ImageDataGenerator_config['test'])

    if data_flag == 'macula':
        task = 'multi-class'
        n_classes = 4
        monitor = 'val_loss'
        mode = 'min'
        train_generator = train_datagen.flow_from_directory('/content/macula/OCT2017 /train',
                                                            target_size=(resize, resize),
                                                            color_mode='rgb',
                                                            class_mode='categorical',
                                                            batch_size=flow_config['train']['batch_size'],
                                                            shuffle=flow_config['train']['shuffle'],
                                                            seed=flow_config['train']['seed'],
                                                            )
        val_generator = val_datagen.flow_from_directory('/content/macula/OCT2017 /val',
                                                        target_size=(resize, resize),
                                                        color_mode='rgb',
                                                        class_mode='categorical',
                                                        batch_size=flow_config['val']['batch_size'],
                                                        shuffle=flow_config['val']['shuffle'],
                                                        seed=flow_config['val']['seed'],
                                                        )
        test_generator = test_datagen.flow_from_directory('/content/macula/OCT2017 /test',
                                                          target_size=(resize, resize),
                                                          color_mode='rgb',
                                                          class_mode='categorical',
                                                          batch_size=flow_config['test']['batch_size'],
                                                          shuffle=flow_config['test']['shuffle'],
                                                          seed=flow_config['test']['seed'],
                                                          )
    else:
        if data_flag == 'cifar100':
            (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
            x_val, y_val = x_test, y_test

        else:
            # Download dataset
            info = INFO[data_flag]
            task = info['task']
            n_channels = info['n_channels']
            n_classes = len(info['label'])
            n_classes = 1 if n_classes == 2 else n_classes

            DataClass = getattr(mdn, info['python_class'])
            print(f'Dataset {data_flag} Task {task} n_channels {n_channels} n_classes {n_classes}')

            # load train Data
            x_train, y_train = load_data(DataClass, 'train', task, resize, n_classes, n_channels)

            # load val Data
            x_val, y_val = load_data(DataClass, 'val', task, resize, n_classes, n_channels)

            # load test Data
            x_test, y_test = load_data(DataClass, 'test', task, resize, n_classes, n_channels)

            print(f'X train {x_train.shape} | Y train {y_train.shape}')
            print(f'X val {x_val.shape} | Y val {y_val.shape}')
            print(f'X test {x_test.shape} | Y test {y_test.shape}')
          
            train_generator = train_datagen.flow(x=x_train, 
                                                y=y_train,
                                                batch_size=flow_config['train']['batch_size'],
                                                shuffle=flow_config['train']['shuffle'],
                                                seed=flow_config['train']['seed'],
                                                )
            val_generator = val_datagen.flow(x=x_val,
                                            y=y_val,
                                            batch_size=flow_config['val']['batch_size'],
                                            shuffle=flow_config['val']['shuffle'],
                                            seed=flow_config['val']['seed'],
                                            )
            test_generator = test_datagen.flow(x=x_test,
                                              y=y_test,
                                              batch_size=flow_config['test']['batch_size'],
                                              shuffle=flow_config['test']['shuffle'],
                                              seed=flow_config['test']['seed'],
                                              )
    # Log in WB
    wandb.login(key=WB_KEY)

    # Train & validation steps
    train_steps_per_epoch = len(train_generator)
    val_steps_per_epoch = len(val_generator)
    test_steps_per_epoch = len(test_generator)

    # Save initial weights
    #model.load_weights(os.path.join(os.getcwd(), 'model_weights.h5'))

    # Credentials
    wandb.init(project='_'.join([WB_PROJECT, data_flag]), entity=WB_ENTITY, group = WB_GROUP)
    
    # Model compile
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    if task == 'multi-class':
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=True, num_labels=n_classes, from_logits=True, name="AUC"),
                   tfa.metrics.F1Score(num_classes=n_classes, average='macro', name = 'f1_score')
                   ]
    if task == 'binary-class':
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing = label_smoothing)
        metrics = [tf.keras.metrics.BinaryAccuracy(name="accuracy"),
                   tf.keras.metrics.AUC(multi_label=False, from_logits=True, name="AUC")]

    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
    )

    # Callbacks
    reduceLR = tf.keras.callbacks.ReduceLROnPlateau(monitor=monitor, mode=mode, factor=0.2, patience=int(es_patience/2), min_lr=learning_rate//100, verbose=1)
    patience = tf.keras.callbacks.EarlyStopping(monitor=monitor, mode=mode, patience=es_patience)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(os.getcwd(), 'model_best_weights.h5'), monitor=monitor, mode=mode, save_best_only = True, save_weights_only = True)
    wandb_callback = wandb.keras.WandbCallback(save_weights_only=True)

    # Model fit
    history = model.fit(
        train_generator,
        steps_per_epoch= train_steps_per_epoch,
        epochs = epochs,
        validation_data=val_generator,
        validation_steps = val_steps_per_epoch,
        callbacks=[reduceLR, patience, checkpoint, wandb_callback],
        verbose = verbose,
    )

    # Evaluation
    model.load_weights(os.path.join(os.getcwd(), 'model_best_weights.h5'))
    results = model.evaluate(test_generator, steps = test_steps_per_epoch, verbose = 0)
    print("Test metrics:",{k:v for k,v in zip(model.metrics_names, results)})
    wandb.log({("test_"+k):v for k,v in zip(model.metrics_names, results)})
    wandb.log({"n_parameters":np.round(model.count_params()/1000000, 1)})

    #y_pred = model.predict(test_generator, verbose = 0)
    #evaluator = Evaluator(data_flag, 'test')
    #results = evaluator.evaluate(y_pred)

    #print(f"Test metrics: AUC {results.AUC}, ACC {results.ACC}")
    #wandb.log({"test_ACC":results.ACC, "test_AUC":results.AUC})

    # Clear memory
    tf.keras.backend.clear_session()
    wandb.finish()

# 2 - Global Configuration

In [9]:
# Config
# 'pneumoniamnist','breastmnist'
datasets = ['octmnist','tissuemnist','pathmnist','dermamnist','bloodmnist', 'organamnist', 'organcmnist', 'organsmnist']
#datasets = ['bloodmnist', 'organamnist', 'organcmnist', 'organsmnist']
#datasets = ['macula']

batch_size = 64
epochs = 100
es_patience = 7
seed = 2785
verbose=1
learning_rate = 0.0001
weight_decay = 0.0001
label_smoothing = .1
img_size = 32

ImageDataGenerator_config = {
    'train':{
        "rescale":1./255,
        "shear_range":.1,
        "rotation_range":.2,
        "zoom_range":.1,
        "horizontal_flip" : True,
        },
    'val':{
        "rescale":1./255,
        },
    'test':{
        "rescale":1./255,
        }
}
flow_config = {
    'train':{
        "batch_size":batch_size,
        "shuffle":True,
        "seed":seed,
        },
    'val':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        },
    'test':{
        "batch_size":batch_size,
        "shuffle":False,
        "seed":seed,
        }
}

# 3 - Experiments

## HViT


In [10]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()
WB_GROUP = 'HViT'

for data_flag in datasets:

    hvit_params = { 'img_size':img_size,
                    'patch_size':[2,4,8],
                    'num_channels': 3,
                    'num_heads': 8,
                    'transformer_layers':[4,4,4],
                    'hidden_unit_factor':2,
                    'mlp_head_units': [256, 64],
                    'num_classes':n_classes,
                    'drop_attn':0.2,
                    'drop_proj':0.2,
                    'drop_linear':0.4,
                    'projection_dim' : 48,
                    'resampling_type':"conv",
                    'original_attn':True,
                    }

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4
            hvit_params = { 'img_size':128,
                    'patch_size':[8,16,32],
                    'num_channels': 3,
                    'num_heads': 8,
                    'transformer_layers':[4,4,4],
                    'hidden_unit_factor':2,
                    'mlp_head_units': [256, 64],
                    'num_classes':n_classes,
                    'drop_attn':0.2,
                    'drop_proj':0.2,
                    'drop_linear':0.4,
                    'projection_dim' : 768,
                    'resampling_type':"conv",
                    'original_attn':True,
                    }

    # Start running
    with tf.device('/device:GPU:0'):
      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      outputs = HViT(**hvit_params)(inputs)
      model = tf.keras.Model(inputs, outputs)
      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset organamnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organamnist.npz
Using downloaded and verified file: /root/.medmnist/organamnist.npz
Using downloaded and verified file: /root/.medmnist/organamnist.npz
X train (34581, 32, 32, 3) | Y train (34581, 11)
X val (6491, 32, 32, 3) | Y val (6491, 11)
X test (17778, 32, 32, 3) | Y test (17778, 11)




VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Epoch 1/100
 37/541 [=>............................] - ETA: 1:17 - loss: 2.3209 - accuracy: 0.1964 - AUC: 0.5748 - f1_score: 0.1451

KeyboardInterrupt: ignored

## ViT

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'ViT-small'

for data_flag in datasets:

    vit_params = {'img_size':img_size,
                  'patch_size':4,
                  'num_channels': 3,
                  'num_heads': 8,
                  'transformer_layers':16,
                  'hidden_unit_factor':4,
                  'mlp_head_units': [256, 64],
                  'num_classes':n_classes,
                  'drop_attn':0.2,
                  'drop_proj':0.2,
                  'drop_linear':0.4,
                  'projection_dim' : 3*16
                  }

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4
            vit_params = {'img_size':img_size,  # 128
                          'patch_size':16,
                          'num_channels': 3,
                          'num_heads': 8,
                          'transformer_layers':12,
                          'hidden_unit_factor':2,
                          'mlp_head_units': [256, 64],
                          'num_classes':n_classes,
                          'drop_attn':0.2,
                          'drop_proj':0.2,
                          'drop_linear':0.4,
                          'projection_dim' : 768,
                          'resampling_type':"conv",
                          'original_attn':True,
                          }

    # Start running
    with tf.device('/device:GPU:0'):
        # Instance model
        inputs = tf.keras.layers.Input((img_size, img_size, 3))
        outputs = ViT(**vit_params)(inputs)
        model = tf.keras.Model(inputs, outputs)
        # Run experiment
        run_WB_experiment(WB_KEY,
                          WB_ENTITY,
                          WB_PROJECT,
                          WB_GROUP,
                          model,
                          data_flag,
                          ImageDataGenerator_config,
                          flow_config,
                          epochs=epochs,
                          learning_rate=learning_rate,
                          weight_decay=weight_decay,
                          label_smoothing = label_smoothing,
                          verbose=verbose,
                          resize=img_size,
                          es_patience=es_patience,
                          )

Dataset octmnist Task multi-class n_channels 1 n_classes 4
Downloading https://zenodo.org/record/5208230/files/octmnist.npz?download=1 to /root/.medmnist/octmnist.npz


  0%|          | 0/54938180 [00:00<?, ?it/s]

Using downloaded and verified file: /root/.medmnist/octmnist.npz
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X train (97477, 32, 32, 3) | Y train (97477, 4)
X val (10832, 32, 32, 3) | Y val (10832, 4)
X test (1000, 32, 32, 3) | Y test (1000, 4)


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100


## EfficientNetB0

In [4]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'EfficientNetB0'
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.EfficientNetB0(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Found 83484 images belonging to 4 classes.
Found 32 images belonging to 4 classes.
Found 968 images belonging to 4 classes.


[34m[1mwandb[0m: Currently logged in as: [33mbenayas[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 00007: ReduceLROnPlateau reducing learning rate to 1.9999999494757503e-05.
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 00011: ReduceLROnPlateau reducing learning rate to 3.999999898951501e-06.
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 00014: ReduceLROnPlateau reducing learning rate to 7.999999979801942e-07.
Epoch 15/100
Test metrics: {'loss': 0.46040573716163635, 'accuracy': 0.9462810158729553, 'AUC': 0.9965016841888428, 'f1_score': 0.945526659488678}


VBox(children=(Label(value=' 17.03MB of 17.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▃▅▆▇██████▇▆▃▁▁
accuracy,▂▅▆▇▇████▇▇▆▂▁▁
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
f1_score,▂▄▆▇▇████▇▇▅▂▁▁
loss,▇▅▃▂▂▁▁▁▁▂▂▄▇██
lr,███████▂▂▂▂▁▁▁▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.5004
accuracy,0.44565
best_epoch,7.0
best_val_loss,0.40664
epoch,14.0
f1_score,0.15414
loss,1.29013
lr,0.0
n_parameters,4.4
test_AUC,0.9965


## EfficientNetB4

In [5]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = 'EfficientNetB4'
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.EfficientNetB4(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Found 83484 images belonging to 4 classes.
Found 32 images belonging to 4 classes.


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Found 968 images belonging to 4 classes.


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 00012: ReduceLROnPlateau reducing learning rate to 1.9999999494757503e-05.
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 00015: ReduceLROnPlateau reducing learning rate to 3.999999898951501e-06.
Epoch 16/100
Test metrics: {'loss': 0.5227959752082825, 'accuracy': 0.9070248007774353, 'AUC': 0.9936429858207703, 'f1_score': 0.9047904014587402}


VBox(children=(Label(value=' 69.74MB of 69.74MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▃▅▇▇████████▇▆▄
accuracy,▁▃▅▆▇████████▇▆▄
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
f1_score,▁▂▄▆▇████████▇▅▃
loss,█▆▄▃▂▁▁▁▁▁▁▁▁▂▃▆
lr,████████████▂▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.70975
accuracy,0.60402
best_epoch,8.0
best_val_loss,0.44693
epoch,15.0
f1_score,0.35654
loss,1.09524
lr,0.0
n_parameters,18.1
test_AUC,0.99364


## ResNet 150v2

In [6]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "ResNet 152 v2"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.resnet_v2.ResNet152V2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Found 83484 images belonging to 4 classes.
Found 32 images belonging to 4 classes.


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Found 968 images belonging to 4 classes.


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 00014: ReduceLROnPlateau reducing learning rate to 1.9999999494757503e-05.
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 00017: ReduceLROnPlateau reducing learning rate to 3.999999898951501e-06.
Epoch 18/100
Test metrics: {'loss': 0.39550378918647766, 'accuracy': 0.9814049601554871, 'AUC': 0.9996620416641235, 'f1_score': 0.9815011024475098}


VBox(children=(Label(value=' 225.23MB of 225.23MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…

0,1
AUC,▁▆▇▇▇▇▇▇██████████
accuracy,▁▅▆▆▇▇▇▇▇▇▇▇▇▇████
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
f1_score,▁▅▆▆▇▇▇▇▇▇▇▇▇▇████
loss,█▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁
lr,██████████████▂▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.99354
accuracy,0.96732
best_epoch,16.0
best_val_loss,0.35762
epoch,17.0
f1_score,0.9519
loss,0.42233
lr,0.0
n_parameters,58.9
test_AUC,0.99966


## Conv Mixer

In [None]:
def activation_block(x, dropout=.2):
    x = tf.keras.layers.Activation("gelu")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(dropout)(x)
    return x


def conv_stem(x, filters: int, patch_size: int, dropout: float):
    x = tf.keras.layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
    return activation_block(x, dropout)


def conv_mixer_block(x, filters: int, kernel_size: int, dropout: float):
    # Depthwise convolution.
    x0 = x
    x = tf.keras.layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
    x = tf.keras.layers.Add()([activation_block(x, dropout), x0])  # Residual.

    # Pointwise convolution.
    x = tf.keras.layers.Conv2D(filters, kernel_size=1)(x)
    x = activation_block(x, dropout)

    return x


def get_conv_mixer_256_8(
    image_size=32, filters=256, depth=12, kernel_size=5, patch_size=4, mlp_head_units:List[int]=[256,64], drop_enc:float=.2, drop_linear:float=.2, num_classes=10,
):
    """ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
    The hyperparameter values are taken from the paper.
    """
    inputs = tf.keras.Input((image_size, image_size, 3))
    x = tf.keras.layers.Rescaling(scale=1.0 / 255)(inputs)

    # Extract patch embeddings.
    x = conv_stem(x, filters, patch_size, drop_enc)
    # ConvMixer blocks.
    for _ in range(depth):
        x = conv_mixer_block(x, filters, kernel_size, drop_enc)

    # Classification block.
    x = tf.keras.layers.GlobalAvgPool2D()(x)
    for i in mlp_head_units:
        x = tf.keras.layers.Dense(i)(x)
        x = tf.keras.layers.Dropout(drop_linear)(x)
    logits = tf.keras.layers.Dense(num_classes)(x)
    return tf.keras.Model(inputs, logits)

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "ConvMixer"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      model = get_conv_mixer_256_8(patch_size = 4, num_classes = n_classes)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

## Inception ResNet v2

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "Inception ResNet v2"
mlp_head_units = [256,64]
drop_linear = .2

img_size=128
for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      x = tf.keras.applications.inception_resnet_v2.preprocess_input(inputs)
      base_model = tf.keras.applications.InceptionResNetV2(weights=None, include_top=False)(x)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset octmnist Task multi-class n_channels 1 n_classes 4
Using downloaded and verified file: /root/.medmnist/octmnist.npz
Using downloaded and verified file: /root/.medmnist/octmnist.npz
Using downloaded and verified file: /root/.medmnist/octmnist.npz
X train (97477, 128, 128, 3) | Y train (97477, 4)
X val (10832, 128, 128, 3) | Y val (10832, 4)
X test (1000, 128, 128, 3) | Y test (1000, 4)


[34m[1mwandb[0m: Currently logged in as: [33mbenayas[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## HVT (PyTorch)

In [None]:
import torch
torch.cuda.empty_cache()
tf.keras.backend.clear_session()

from models import Attention, get_attention_flops

WB_GROUP = "HVT"
mlp_head_units = [256,64]
drop_linear = .2

for data_flag in datasets:

    if data_flag in INFO:
        info = INFO[data_flag]
        n_classes = len(info['label'])
        n_classes = 1 if n_classes == 2 else n_classes
    else:
        if data_flag == 'cifar100':
            n_classes = 100
        else:
            n_classes = 4

    # Start running
    with tf.device('/device:GPU:0'):

      # Instance model
      inputs = tf.keras.layers.Input((img_size, img_size, 3))
      base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights=None, include_top=False)(inputs)
      x = tf.keras.layers.GlobalAveragePooling2D()(base_model)
      for i in mlp_head_units:
          x = tf.keras.layers.Dense(i)(x)
          x = tf.keras.layers.Dropout(drop_linear)(x)
      logits = tf.keras.layers.Dense(n_classes)(x)
      model = tf.keras.Model(inputs, logits)

      # Run experiment
      run_WB_experiment(WB_KEY,
                        WB_ENTITY,
                        WB_PROJECT,
                        WB_GROUP,
                        model,
                        data_flag,
                        ImageDataGenerator_config,
                        flow_config,
                        epochs=epochs,
                        learning_rate=learning_rate,
                        weight_decay=weight_decay,
                        label_smoothing = label_smoothing,
                        verbose=verbose,
                        resize=img_size,
                        es_patience=es_patience,
                        )

Dataset pathmnist Task multi-class n_channels 3 n_classes 9
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X train (89996, 32, 32, 3) | Y train (89996, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
X val (10004, 32, 32, 3) | Y val (10004, 9)
Using downloaded and verified file: /root/.medmnist/pathmnist.npz




X test (7180, 32, 32, 3) | Y test (7180, 9)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100

Epoch 00015: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 16/100
Epoch 17/100

Epoch 00017: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 18/100
Test metrics: {'loss': 1.4452080726623535, 'accuracy': 0.6295264363288879, 'AUC': 0.9270712733268738, 'f1_score': 0.5709664821624756}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▅▆▇▇▇████████████
accuracy,▁▄▅▆▇▇▇▇▇▇▇██████▇
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
f1_score,▁▄▅▇▇▇▇▇▇▇▇██████▇
loss,█▅▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▂
lr,███▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.98663
accuracy,0.8969
best_epoch,12.0
best_val_loss,1.3953
epoch,17.0
f1_score,0.89577
loss,0.76787
lr,1e-05
n_parameters,2.6
test_AUC,0.92707


Dataset dermamnist Task multi-class n_channels 3 n_classes 7
Using downloaded and verified file: /root/.medmnist/dermamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X train (7007, 32, 32, 3) | Y train (7007, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X val (1003, 32, 32, 3) | Y val (1003, 7)
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
X test (2005, 32, 32, 3) | Y test (2005, 7)


Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100

Epoch 00005: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 6/100
Test metrics: {'loss': 1.4138860702514648, 'accuracy': 0.6688279509544373, 'AUC': 0.5, 'f1_score': 0.11450772732496262}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▅▅▇██
accuracy,▁▆▆▇▇█
epoch,▁▂▄▅▇█
f1_score,▁▂▅▆▇█
loss,█▄▃▂▁▁
lr,███▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.73472
accuracy,0.6906
best_epoch,5.0
best_val_loss,1.41402
epoch,5.0
f1_score,0.22349
loss,1.14615
lr,4e-05
n_parameters,2.6
test_AUC,0.5


Dataset bloodmnist Task multi-class n_channels 3 n_classes 8
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz
X train (11959, 32, 32, 3) | Y train (11959, 8)
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz
X val (1712, 32, 32, 3) | Y val (1712, 8)
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (3421, 32, 32, 3) | Y test (3421, 8)


Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100

Epoch 00005: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 6/100
Test metrics: {'loss': 2.040238618850708, 'accuracy': 0.1692487597465515, 'AUC': 0.5, 'f1_score': 0.03618749976158142}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▅▆▇██
accuracy,▁▄▅▇██
epoch,▁▂▄▅▇█
f1_score,▁▄▅▇██
loss,█▅▄▂▁▁
lr,███▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.95021
accuracy,0.78309
best_epoch,4.0
best_val_loss,2.04031
epoch,5.0
f1_score,0.74107
loss,0.94409
lr,4e-05
n_parameters,2.6
test_AUC,0.5


Dataset organamnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organamnist.npz
X train (34581, 32, 32, 3) | Y train (34581, 11)
Using downloaded and verified file: /root/.medmnist/organamnist.npz
X val (6491, 32, 32, 3) | Y val (6491, 11)
Using downloaded and verified file: /root/.medmnist/organamnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (17778, 32, 32, 3) | Y test (17778, 11)


Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100

Epoch 00005: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 6/100
Test metrics: {'loss': 2.359232187271118, 'accuracy': 0.18477894365787506, 'AUC': 0.5, 'f1_score': 0.028356490656733513}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▆▇███
accuracy,▁▅▆▇██
epoch,▁▂▄▅▇█
f1_score,▁▅▆▇██
loss,█▄▃▂▁▁
lr,███▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.99411
accuracy,0.94295
best_epoch,5.0
best_val_loss,2.36678
epoch,5.0
f1_score,0.94284
loss,0.66353
lr,4e-05
n_parameters,2.6
test_AUC,0.5


Dataset organcmnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organcmnist.npz
X train (13000, 32, 32, 3) | Y train (13000, 11)
Using downloaded and verified file: /root/.medmnist/organcmnist.npz
X val (2392, 32, 32, 3) | Y val (2392, 11)
Using downloaded and verified file: /root/.medmnist/organcmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (8268, 32, 32, 3) | Y test (8268, 11)


Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100

Epoch 00005: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 6/100
Test metrics: {'loss': 2.368472099304199, 'accuracy': 0.22206096351146698, 'AUC': 0.5, 'f1_score': 0.03303822502493858}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▅▆▇██
accuracy,▁▄▅▇▇█
epoch,▁▂▄▅▇█
f1_score,▁▄▆▆▇█
loss,█▄▄▂▂▁
lr,███▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.94557
accuracy,0.74823
best_epoch,0.0
best_val_loss,2.39317
epoch,5.0
f1_score,0.70828
loss,1.09509
lr,4e-05
n_parameters,2.6
test_AUC,0.5


Dataset organsmnist Task multi-class n_channels 1 n_classes 11
Using downloaded and verified file: /root/.medmnist/organsmnist.npz
X train (13940, 32, 32, 3) | Y train (13940, 11)
Using downloaded and verified file: /root/.medmnist/organsmnist.npz
X val (2452, 32, 32, 3) | Y val (2452, 11)
Using downloaded and verified file: /root/.medmnist/organsmnist.npz


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


X test (8829, 32, 32, 3) | Y test (8829, 11)


Epoch 1/100
Epoch 2/100
Epoch 3/100

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 4/100
Epoch 5/100

Epoch 00005: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 6/100
Test metrics: {'loss': 2.314279079437256, 'accuracy': 0.2353607416152954, 'AUC': 0.5, 'f1_score': 0.03463997319340706}


VBox(children=(Label(value=' 10.15MB of 10.15MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
AUC,▁▅▆▇██
accuracy,▁▄▅▆▇█
epoch,▁▂▄▅▇█
f1_score,▁▄▅▆▇█
loss,█▄▄▂▁▁
lr,███▂▂▁
n_parameters,▁
test_AUC,▁
test_accuracy,▁
test_f1_score,▁

0,1
AUC,0.92171
accuracy,0.59498
best_epoch,1.0
best_val_loss,2.36765
epoch,5.0
f1_score,0.52424
loss,1.3319
lr,4e-05
n_parameters,2.6
test_AUC,0.5


In [None]:
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path

from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma

from datasets import build_dataset
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
from samplers import RASampler
from models import Attention, get_attention_flops
import utils
from params import args
from logger import logger


torch.cuda.empty_cache()
tf.keras.backend.clear_session()

WB_GROUP = "HVT"
mlp_head_units = [256,64]
drop_linear = .2

model_params = {
  "model": "hvt_model",
  "batch_size": 128,
  "exp_name": "hvt-s-1",
  "input_size": 224,
  "patch_size": 16,
  "num_heads": 6,
  "head_dim": 64,
  "num_blocks": 12,
  "num_workers": 10,
  "pool_kernel_size": 3,
  "pool_stride": 2,
  "pool_block_width": 12,
  "weight_decay": 0.025
}

for data_flag in datasets:

    info = INFO[data_flag]
    n_classes = len(info['label'])
    n_classes = 1 if n_classes == 2 else n_classes

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    logger.info(f"Creating model: {args.model}")
    model = create_model(
        model_params,
        pretrained=False,
        num_classes=n_classes,
        drop_rate=drop_linear,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    logger.info(str(model))

    if utils.get_rank() == 0:
        try:
            from ptflops import get_model_complexity_info
            macs, params = get_model_complexity_info(model, (3, args.input_size, args.input_size), as_strings=True,
                                                     print_per_layer_stat=False, verbose=False, custom_modules_hooks={Attention:get_attention_flops})
            # flops = macs
            logger.info('{:<30}  {:<8}'.format('MACs: ', macs))
            logger.info('{:<30}  {:<8}'.format('Number of parameters: ', params))
        except:
            pass

    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info('number of params: ' + str(n_parameters))

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
 
    teacher_model = None


    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, mixup_fn,
            set_training_mode=args.finetune == ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'last_checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'scaler': loss_scaler.state_dict(),
                    'args': args,
                }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        if max_accuracy < test_stats["acc1"]:
            utils.save_on_master({
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'scaler': loss_scaler.state_dict(),
                'args': args,
            },  os.path.join(args.output_dir, 'best_checkpoint.pth'))

        max_accuracy = max(max_accuracy, test_stats["acc1"])
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))