In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from PIL import Image
# from torchinfo import summary
import torch
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple

from PIL import Image
import torch
import torch.nn as nn
# import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from datasets import load_dataset
import pandas as pd

import random

# import wandb
# wandb.login("1fa58b4e42c64c2531b3abeb43c04f5991be307e")

torch.cuda.empty_cache()


In [None]:
# Install MedViT
!git clone https://github.com/Omid-Nejati/MedViT.git


In [None]:
# Install requirements
!pip install -r MedViT/requirements.txt

In [9]:
# ls ../mnt/local/data/kalexu97/DR_grading/DR_grading

In [2]:
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/DR_grading.csv')
root_dir = '../mnt/local/data/kalexu97/DR_grading/DR_grading'
labelsTable['image_path'] = labelsTable['id_code'].apply(lambda x: os.path.join(root_dir, x))
labelsTable['label'] = labelsTable['diagnosis']

test_dataset = labelsTable.drop(columns=['id_code', 'diagnosis'], axis=1)

# test_dataset = pd.read_csv('test_dataset.csv')

In [26]:
from datasets import Dataset
from transformers import AutoImageProcessor

model_name_or_path = "./saved_models/MedViT512_tr35_stage4(2)_Spot2HTrvlAug_fastvitprepr_lr2e5" 

image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
image_processor.size['height'] = 512
image_processor.size['width'] = 512

import torch.nn.functional as F
import torchvision
import torch

class CenterCrop(torch.nn.Module):
    def __init__(self, size=None, ratio="1:1"):
        super().__init__()
        self.size = size
        self.ratio = ratio
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            PIL Image or Tensor: Cropped image.
        """
        if self.size is None:
            if isinstance(img, torch.Tensor):
                h, w = img.shape[-2:]
            else:
                w, h = img.size
            ratio = self.ratio.split(":")
            ratio = float(ratio[0]) / float(ratio[1])
            # Size must match the ratio while cropping to the edge of the image
            # print(ratio, w, h)
            ratioed_w = int(h * ratio)
            ratioed_h = int(w / ratio)
            if w>=h:
                if ratioed_h <= h:
                    size = (ratioed_h, w)
                else:
                    size = (h, ratioed_w)
            else:
                if ratioed_w <= w:
                    size = (h, ratioed_w)
                else:
                    size = (ratioed_h, w)
        else:
            size = self.size
        return torchvision.transforms.functional.center_crop(img, size)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"

_transforms_test = T.Compose([
    CenterCrop(),
])

def load_image(path_image, label, mode):
    # load image
    image = Image.open(path_image)
    # print(image.size)

    if mode == 'train':
        image = _transforms_train(image)

        return image
    else:
        image = _transforms_test(image)
        # print(image.size)
        return image

def func_transform_test(examples):
    
    # loaded_images = [load_image(path, lb, 'test').convert("RGB") for path, lb in zip(examples['image_path'], examples['label'])]
    inputs = image_processor([load_image(path, lb, 'test').convert("RGB")
                                for path, lb in zip(examples['image_path'], examples['label'])], return_tensors='pt')
    inputs['label'] = examples['label']
    # print(inputs)
    return inputs

test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)
prepared_ds_test = test_ds.with_transform(func_transform_test)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)

In [14]:
prepared_ds_test 

Dataset({
    features: ['image_path', 'label'],
    num_rows: 12522
})

In [5]:
print("rows in test_dataset: ", len(prepared_ds_test))

labels = [0, 1, 2, 3, 4]
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

print("ID2label: ", id2label)

rows in test_dataset:  12522
ID2label:  {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}


In [6]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [7]:
from sklearn.metrics import cohen_kappa_score, confusion_matrix
from sklearn.metrics import f1_score #, kappa
# from sklearn import metrics

import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions_proba, labels = eval_pred

    predictions = np.argmax(predictions_proba, axis=1)
    result_accuracy = accuracy.compute(predictions=predictions, references=labels)

    result = {
             'accuracy': np.mean([result_accuracy['accuracy']]),
             'kappa': np.mean([cohen_kappa_score(labels, predictions, weights = "quadratic")]),
             'f1': np.mean([f1_score(labels, predictions, average='weighted')]),
             }

    cm = confusion_matrix(labels, predictions)
    print(cm)

    return result


In [8]:
# Initialise a MedViT class
from transformers import PreTrainedModel
from MedViT.MedViT import MedViT, MedViT_large

# Define configuration
from transformers import PretrainedConfig
from typing import List


class MedViTConfig(PretrainedConfig):
    model_type = "medvit"

    def __init__(
        self,
        stem_chs: List[int] = [64, 32, 64],
        depths: List[int] = [3, 4, 30, 3],
        path_dropout: float = 0.2,
        attn_drop: int = 0,
        drop: int = 0,
        num_classes: int = 5,
        strides: List[int] = [1, 2, 2, 2],
        sr_ratios: List[int] = [8, 4, 2, 1],
        head_dim: int = 32,
        mix_block_ratio: float = 0.75,
        use_checkpoint: bool = False,
        pretrained: bool = False,
        pretrained_cfg: str = None,
        **kwargs
    ):
        self.stem_chs = stem_chs
        self.depths = depths
        self.path_dropout = path_dropout
        self.attn_drop = attn_drop
        self.drop = drop
        self.num_classes = num_classes
        self.strides = strides
        self.sr_ratios = sr_ratios
        self.head_dim = head_dim
        self.mix_block_ratio = mix_block_ratio
        self.use_checkpoint = use_checkpoint
        self.pretrained = pretrained,
        self.pretrained_cfg = pretrained_cfg
        super().__init__(**kwargs)

class MedViTClassification(PreTrainedModel):
    config_class = MedViTConfig

    def __init__(self, config, pretrained=False):
        super().__init__(config)

        if pretrained is False:
          print('Initialized with random weights:')
          self.model = MedViT(
          stem_chs = config.stem_chs,
          depths = config.depths,
          path_dropout = config.path_dropout,
          attn_drop = config.attn_drop,
          drop = config.drop,
          num_classes = config.num_classes,
          strides = config.strides,
          sr_ratios = config.sr_ratios,
          head_dim = config.head_dim,
          mix_block_ratio = config.mix_block_ratio,
          use_checkpoint = config.use_checkpoint)
        else:
          print('Initialized with pretrained weights:')
          self.model = MedViT_large(use_checkpoint = config.use_checkpoint)
          # self.state_dict = torch.load(config.pretrained_cfg, weights_only=True) #, weights_only=True
          self.model.load_state_dict(torch.load(config.pretrained_cfg, weights_only=True)['model'])
          # self.model.load_state_dict(torch.load('MedViT_large_im1k.pth'))
          self.model.proj_head = nn.Linear(1024, 5)        

    def forward(self, pixel_values, labels=None):
        logits = self.model(pixel_values)
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

model = MedViTClassification.from_pretrained("./saved_models/MedViT512_tr35_stage5(2)_Spot2HTrvlAug_fastvitprepr_lr1e5")

Initialized with random weights:
initialize_weights...


In [27]:
from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments(
    output_dir="./MedViT-base_test",
    evaluation_strategy="steps",
    logging_steps=50,

    save_steps=50,
    eval_steps=50,
    save_total_limit=3,
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    lr_scheduler_type = 'constant_with_warmup', # 'constant', #'cosine', #'constant_with_warmup',
    
    learning_rate=1e-5,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    warmup_ratio=0.02,
    
    metric_for_best_model="kappa", 
    greater_is_better = True,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_test,
    eval_dataset=prepared_ds_test
)

In [24]:
sample_ids = np.random.choice(len(prepared_ds_test), size=100, replace=False)
inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)
val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

In [28]:
metrics = trainer.evaluate(prepared_ds_test)
trainer.log_metrics("eval", metrics)

[[5553  449  182   11   71]
 [ 205  291  102    8   24]
 [ 577  450 2199  635  616]
 [   1    5   47  122   61]
 [  58    7   96   61  691]]
***** eval metrics *****
  eval_accuracy           =     0.7072
  eval_f1                 =      0.723
  eval_kappa              =     0.7685
  eval_loss               =     0.9081
  eval_runtime            = 0:08:29.45
  eval_samples_per_second =     24.579
  eval_steps_per_second   =      6.146


***** eval metrics *****
  eval_accuracy           =      0.175
  eval_f1                 =     0.2025
  eval_kappa              =     0.0473
  eval_loss               =     3.1161
  eval_runtime            = 0:08:32.99
  eval_samples_per_second =      24.41
  eval_steps_per_second   =      6.103


In [36]:
metrics = trainer.evaluate(prepared_ds_test)
trainer.save_metrics("eval", metrics)
trainer.log_metrics("eval", metrics)

[[4679  348  135   10   31]
 [ 188  260   44    1    6]
 [ 167  188  531   90   41]
 [   6    5   46   73   25]
 [   4    1   26   11  110]]
***** eval metrics *****
  eval_accuracy           =     0.8046
  eval_f1                 =     0.8126
  eval_kappa              =     0.7629
  eval_loss               =     0.6082
  eval_runtime            = 0:04:46.89
  eval_samples_per_second =      24.49
  eval_steps_per_second   =      6.124
