In [1]:
# Install MedViT
!git clone https://github.com/Omid-Nejati/MedViT.git

Cloning into 'MedViT'...
remote: Enumerating objects: 146, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 146 (delta 70), reused 125 (delta 57), pack-reused 1[K
Receiving objects: 100% (146/146), 804.62 KiB | 3.87 MiB/s, done.
Resolving deltas: 100% (70/70), done.


In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1" # is needed for hachiko

from PIL import Image
import torch
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple

from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from datasets import load_dataset
import pandas as pd

import random
import json

import sys
sys.path.append('..')
from validation_utils import custom_dataset, plot_confusion_matrix, get_trainer, get_compute_metrics

torch.cuda.empty_cache()

## MedViT

In [2]:
pretrained_model_name = 'MedViT512_tr35_stage6(3)_CCropSpot2HTrivAug_fastvitprepr_lr1e5'
dataset_name = 'DDR'
prepared_ds_test = custom_dataset(pretrained_model_name)
compute_metrics = get_compute_metrics(pretrained_model_name, dataset_name)

In [3]:
# Initialise a MedViT class
from transformers import PreTrainedModel
from MedViT.MedViT import MedViT, MedViT_large

# Define configuration
from transformers import PretrainedConfig
from typing import List


class MedViTConfig(PretrainedConfig):
    model_type = "medvit"

    def __init__(
        self,
        stem_chs: List[int] = [64, 32, 64],
        depths: List[int] = [3, 4, 30, 3],
        path_dropout: float = 0.2,
        attn_drop: int = 0,
        drop: int = 0,
        num_classes: int = 5,
        strides: List[int] = [1, 2, 2, 2],
        sr_ratios: List[int] = [8, 4, 2, 1],
        head_dim: int = 32,
        mix_block_ratio: float = 0.75,
        use_checkpoint: bool = False,
        pretrained: bool = False,
        pretrained_cfg: str = None,
        **kwargs
    ):
        self.stem_chs = stem_chs
        self.depths = depths
        self.path_dropout = path_dropout
        self.attn_drop = attn_drop
        self.drop = drop
        self.num_classes = num_classes
        self.strides = strides
        self.sr_ratios = sr_ratios
        self.head_dim = head_dim
        self.mix_block_ratio = mix_block_ratio
        self.use_checkpoint = use_checkpoint
        self.pretrained = pretrained,
        self.pretrained_cfg = pretrained_cfg
        super().__init__(**kwargs)

class MedViTClassification(PreTrainedModel):
    config_class = MedViTConfig

    def __init__(self, config, pretrained=False):
        super().__init__(config)

        if pretrained is False:
          print('Initialized with random weights:')
          self.model = MedViT(
          stem_chs = config.stem_chs,
          depths = config.depths,
          path_dropout = config.path_dropout,
          attn_drop = config.attn_drop,
          drop = config.drop,
          num_classes = config.num_classes,
          strides = config.strides,
          sr_ratios = config.sr_ratios,
          head_dim = config.head_dim,
          mix_block_ratio = config.mix_block_ratio,
          use_checkpoint = config.use_checkpoint)
        else:
          print('Initialized with pretrained weights:')
          self.model = MedViT_large(use_checkpoint = config.use_checkpoint)
          self.model.proj_head = nn.Linear(1024, 5)        

    def forward(self, pixel_values, labels=None):
        logits = self.model(pixel_values)
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

model = MedViTClassification.from_pretrained(f'../saved_models/{pretrained_model_name}')
trainer = get_trainer(model, prepared_ds_test, compute_metrics)

Initialized with random weights:
initialize_weights...


In [4]:
sample_ids = np.random.choice(len(prepared_ds_test), size=100, replace=False)
inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)
val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

In [5]:
metrics = trainer.evaluate(val_ds)
trainer.log_metrics("eval", metrics)
json.dump(metrics, open( f'../results/metrics_{pretrained_model_name}_{dataset_name}.json', 'w' ) )

  0%|          | 0/25 [00:00<?, ?it/s]

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

## MedViT with ECA attention

In [None]:
pretrained_model_name = 'MedViT512_tr35_stage6(3)_CCropSpot2HTrivAug_fastvitprepr_lr1e5'
dataset_name = 'DDR'
prepared_ds_test = custom_dataset(pretrained_model_name)
compute_metrics = get_compute_metrics(pretrained_model_name, dataset_name)

In [None]:
# Initialise a MedViT class
import math
from transformers import PreTrainedModel
from MedViT.MedViT import MedViT, MedViT_large

class ECALayer(nn.Module):
    def __init__(self, channels, gamma=2, b=1):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        t = int(abs((math.log(channels, 2) + b) / gamma))
        k = t if t % 2 else t + 1
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

        # Initializing the weights with a uniform distribution 
        nn.init.uniform_(self.conv.weight) 

    def forward(self, x):
        y = self.avgpool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class MedViTClassification(PreTrainedModel):
    config_class = MedViTConfig

    def __init__(self, config, pretrained=False):
        super().__init__(config)

        if pretrained is False:
          print('Initialized with random weights:')
          self.model = MedViT(
          stem_chs = config.stem_chs,
          depths = config.depths,
          path_dropout = config.path_dropout,
          attn_drop = config.attn_drop,
          drop = config.drop,
          num_classes = config.num_classes,
          strides = config.strides,
          sr_ratios = config.sr_ratios,
          head_dim = config.head_dim,
          mix_block_ratio = config.mix_block_ratio,
          use_checkpoint = config.use_checkpoint)

        else:
          print('Initialized with pretrained weights:')
          self.model = MedViT_large(use_checkpoint = config.use_checkpoint)
          self.model.proj_head = nn.Linear(1024, 5)

          nn.init.uniform_(self.model.proj_head.weight)
        
        self.apply_attention()
        self.check_freezing()
        
    def check_freezing(self):

        # freeze some layers
        for name, child in self.model.named_children():
            for param in child.parameters():
                if param.requires_grad == True:
                    print(name)

    def apply_attention(self):
        self.model.features[6].e_mhsa = ECALayer(192)
        self.model.features[11].e_mhsa = ECALayer(384)
        self.model.features[16].e_mhsa = ECALayer(384)
        self.model.features[21].e_mhsa = ECALayer(384)
        
        self.model.features[26].e_mhsa = ECALayer(384)
        self.model.features[31].e_mhsa = ECALayer(384)
        self.model.features[36].e_mhsa = ECALayer(384)
        self.model.features[39].e_mhsa = ECALayer(768)

    def freeze_layers(self):

        # freeze some layers
        for name, child in self.model.named_children():
            for param in child.parameters():                    
                param.requires_grad = False

        

    def forward(self, pixel_values, labels=None):
        logits = self.model(pixel_values)
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}