# Vision Transformer on MedMNIST 2D Multi-Class

[MedMNIST v2](https://medmnist.com/) is a collection of biomedical images.
There are 8 datasets with 2D images for a multi-class classification task.
The authors of the dataset collection report baseline performances with ResNets,
and with the AutoML solutions auto-sklearn, AutoKeras, and Google AutoML Vision.

Using a pre-trained [Vision Transformer model](https://huggingface.co/google/vit-base-patch16-224), and fine-tuning it for each task,
we are able to outperform almost all of those baselines.
This notebook contains the full code to run the Vision Transformer experiment.

Sources:
* Dataset: https://medmnist.com/
* Model: https://huggingface.co/google/vit-base-patch16-224
* Fine tuning: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb

### Imports

In [1]:
import numpy as np
import torch

from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, EarlyStoppingCallback

from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

from sklearn.metrics import roc_auc_score
import evaluate

### Loading data and transforms

In [2]:
modelname = "google/vit-base-patch16-224"

processor = ViTImageProcessor.from_pretrained(modelname)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [3]:
# list of datasets to use
dataset_2d_list = ['pathmnist', 'dermamnist', 'octmnist', 'bloodmnist',
                   'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']

In [4]:
def load_dataset_2d(dataset_name):
    '''
    Loads dataset with the given name and returns train, val, and test datasets,
    and dictionaries to convert between labels and ids.
    '''

    dataset = load_dataset("albertvillanova/medmnist-v2", dataset_name)

    train_ds = dataset['train']
    val_ds = dataset['validation']
    test_ds = dataset['test']

    id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
    label2id = {label:id for id,label in id2label.items()}
    id2label

    # Set the transforms
    train_ds.set_transform(train_transforms)
    val_ds.set_transform(val_transforms)
    test_ds.set_transform(val_transforms)

    return train_ds, val_ds, test_ds, id2label, label2id

### Set up training and evaluation

In [5]:
args = TrainingArguments(
    f"medmnist-vit-1",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate = 2e-5,
    per_device_train_batch_size = 60,
    per_device_eval_batch_size = 40,
    num_train_epochs = 50,
    weight_decay = 0.01,
    load_best_model_at_end = True,
    metric_for_best_model = "roc_auc",
    remove_unused_columns = False,
)

In [6]:
def compute_metrics(eval_preds):
    '''
    Calculate accuracy and roc_auc.
    '''

    # accuracy
    metric_accuracy = evaluate.load("accuracy")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    accuracy = metric_accuracy.compute(predictions=predictions, references=labels)['accuracy']

    # calculate roc_auc the same way it is calculated for the MedMNIST benchmark:
    # https://github.com/MedMNIST/MedMNIST/blob/main/medmnist/evaluator.py
    auc = 0
    for i in range(logits.shape[1]):
        y_true_binary = (labels == i).astype(float)
        y_score_binary = logits[:, i]
        auc += roc_auc_score(y_true_binary, y_score_binary)
    roc_auc = auc / logits.shape[1]
    
    return {'accuracy': accuracy, 'roc_auc': roc_auc}

In [7]:
def train_and_evaluate(train_ds, val_ds, test_ds, id2label, label2id):
    '''
    Train the model and evaluate on the test set.
    '''

    model = ViTForImageClassification.from_pretrained(modelname, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)

    trainer = Trainer(
        model,
        args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        tokenizer=processor,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=5)],
    )

    trainer.train()

    # evaluate
    outputs = trainer.predict(test_ds)
    print(outputs.metrics)

### Train and evaluate each dataset

In [8]:
for dataset_name in dataset_2d_list:
    print(dataset_name)
    train_ds, val_ds, test_ds, id2label, label2id = load_dataset_2d(dataset_name)
    train_and_evaluate(train_ds, val_ds, test_ds, id2label, label2id)

pathmnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([9]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([9, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,0.1624,0.070312,0.976309,0.999329
2,0.1045,0.04409,0.986505,0.999695
3,0.0816,0.025287,0.991204,0.999861
4,0.0736,0.019002,0.994002,0.999924
5,0.0587,0.017753,0.994302,0.999915
6,0.0534,0.014458,0.995602,0.999942
7,0.0496,0.011532,0.996401,0.999974
8,0.0481,0.006128,0.998301,0.999989
9,0.0399,0.012785,0.996202,0.999962
10,0.0395,0.007764,0.997401,0.999991


{'test_loss': 0.42436450719833374, 'test_accuracy': 0.9143454038997214, 'test_roc_auc': 0.9874672401641509, 'test_runtime': 15.3634, 'test_samples_per_second': 467.343, 'test_steps_per_second': 11.716}
dermamnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,No log,0.682369,0.762712,0.889226
2,No log,0.62923,0.765703,0.915448
3,No log,0.585388,0.792622,0.930032
4,No log,0.561394,0.79661,0.934195
5,0.678200,0.549953,0.789631,0.939986
6,0.678200,0.536674,0.800598,0.942484
7,0.678200,0.525553,0.80658,0.940341
8,0.678200,0.535341,0.79661,0.944387
9,0.487300,0.588256,0.780658,0.943628
10,0.487300,0.547596,0.793619,0.944605


{'test_loss': 0.5721451640129089, 'test_accuracy': 0.799501246882793, 'test_roc_auc': 0.9420302156915736, 'test_runtime': 5.3648, 'test_samples_per_second': 373.734, 'test_steps_per_second': 9.506}
octmnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,0.4464,0.274107,0.907958,0.955157
2,0.4085,0.25922,0.914328,0.959604
3,0.381,0.227861,0.922637,0.965638
4,0.3767,0.221341,0.924206,0.963943
5,0.3501,0.216218,0.924945,0.970996
6,0.339,0.21576,0.926699,0.969919
7,0.3279,0.208742,0.927068,0.972758
8,0.3231,0.195524,0.93353,0.974007
9,0.3129,0.198849,0.931407,0.976098
10,0.304,0.189976,0.935561,0.973433


{'test_loss': 1.023380994796753, 'test_accuracy': 0.815, 'test_roc_auc': 0.9650813333333333, 'test_runtime': 3.2888, 'test_samples_per_second': 304.058, 'test_steps_per_second': 7.601}
bloodmnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,No log,0.17979,0.949182,0.994779
2,No log,0.103714,0.969626,0.997357
3,0.410300,0.099566,0.970794,0.998366
4,0.410300,0.080269,0.969626,0.99895
5,0.205600,0.088416,0.974299,0.998621
6,0.205600,0.07884,0.973715,0.998809
7,0.205600,0.07644,0.973131,0.999143
8,0.172800,0.08058,0.973715,0.998811
9,0.172800,0.068659,0.977804,0.998958
10,0.152400,0.076751,0.973715,0.999176


{'test_loss': 0.07312551885843277, 'test_accuracy': 0.9780765857936276, 'test_roc_auc': 0.9988406302414363, 'test_runtime': 8.1439, 'test_samples_per_second': 420.067, 'test_steps_per_second': 10.56}
tissuemnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,1.1328,1.000813,0.623942,0.894228
2,1.0708,0.920807,0.663748,0.910742
3,1.0669,0.848068,0.691117,0.911742
4,1.0182,0.864072,0.687352,0.911341
5,1.0065,0.830042,0.702411,0.916719
6,0.9951,0.804011,0.710448,0.921395
7,0.9746,0.792444,0.715905,0.922619
8,0.96,0.802216,0.709602,0.922756
9,0.9576,0.790549,0.718147,0.925099
10,0.9353,0.781085,0.718359,0.926326


{'test_loss': 0.7928290367126465, 'test_accuracy': 0.7190143824027073, 'test_roc_auc': 0.9310531575254984, 'test_runtime': 91.4392, 'test_samples_per_second': 517.065, 'test_steps_per_second': 12.927}
organamnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([11]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([11, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,0.6818,0.074366,0.981821,0.99979
2,0.3669,0.04495,0.986289,0.999941
3,0.3014,0.065729,0.980434,0.999958
4,0.2709,0.052615,0.984132,0.999933
5,0.2494,0.035375,0.989986,0.999972
6,0.2305,0.035063,0.989678,0.999961
7,0.2105,0.046509,0.985827,0.999983
8,0.1953,0.070419,0.979202,0.999966
9,0.1908,0.065081,0.980743,0.999958
10,0.1864,0.013113,0.995532,0.99999


{'test_loss': 0.18295544385910034, 'test_accuracy': 0.9517943525705929, 'test_roc_auc': 0.997973759389236, 'test_runtime': 35.5003, 'test_samples_per_second': 500.784, 'test_steps_per_second': 12.535}
organcmnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([11]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([11, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,No log,0.243618,0.939799,0.996051
2,No log,0.144212,0.960284,0.998926
3,0.722500,0.103046,0.97408,0.999595
4,0.722500,0.066718,0.983696,0.999655
5,0.388000,0.098959,0.973662,0.999471
6,0.388000,0.071873,0.982023,0.99954
7,0.324400,0.085149,0.976171,0.999523
8,0.324400,0.072209,0.982441,0.99954
9,0.324400,0.073911,0.978679,0.999723
10,0.279800,0.062673,0.984114,0.999622


{'test_loss': 0.2132023572921753, 'test_accuracy': 0.9327527818093856, 'test_roc_auc': 0.9930285838724114, 'test_runtime': 17.1878, 'test_samples_per_second': 481.039, 'test_steps_per_second': 12.043}
organsmnist


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([11]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([11, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,No log,0.463623,0.838091,0.98534
2,No log,0.344777,0.868679,0.989933
3,0.914100,0.296895,0.878874,0.991234
4,0.914100,0.257948,0.886215,0.991494
5,0.574900,0.267272,0.886623,0.991809
6,0.574900,0.248179,0.890701,0.992434
7,0.497000,0.206471,0.909462,0.993265
8,0.497000,0.212156,0.909054,0.992954
9,0.434700,0.206258,0.911909,0.993092
10,0.434700,0.207539,0.910685,0.993504


{'test_loss': 0.5182576775550842, 'test_accuracy': 0.8277268093781855, 'test_roc_auc': 0.9812998686754533, 'test_runtime': 18.6595, 'test_samples_per_second': 473.163, 'test_steps_per_second': 11.844}
