# Setup

In [None]:
!pip install medmnist
!pip install git+https://github.com/qubvel/classification_models.git

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd drive/MyDrive/MLHM

In [None]:
import os
import medmnist
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision.transforms import transforms
import torch.optim as optim
from torchsummary import summary
from tqdm import tqdm
from utils import cal_metrics
from medvit import MedViT
import warnings
warnings.filterwarnings('ignore')

In [None]:
def train(model, train_loader, optimizer, n_epochs, device, model_path=None):
    for epoch in range(n_epochs):
        model.train()
        y_score = torch.tensor([])
        y_true = torch.tensor([])
        epoch_loss = 0

        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            targets = targets.to(torch.float32)
            loss = nn.BCEWithLogitsLoss()(outputs, targets)

            loss.backward()
            optimizer.step()

            epoch_loss += loss
            y_true = torch.cat((y_true, targets.cpu()), 0)
            outputs = outputs.sigmoid()
            y_score = torch.cat((y_score, outputs.cpu()), 0)

        y_true = y_true.detach().numpy()
        y_score = y_score.detach().numpy()
        epoch_loss /= len(train_loader)
        train_metrics = cal_metrics(y_true, y_score)

        print(f'Epoch {epoch}; loss: {epoch_loss}; train_acc: {train_metrics[0]}; train_f1: {train_metrics[1]}; train_auc: {train_metrics[2]}')

    if model_path != None:
        torch.save(model.state_dict(), f=model_path)
        print('Model saved to:', model_path)


def evaluate(model, test_loader, device):
    with torch.no_grad():
        y_score = torch.tensor([])
        y_true = torch.tensor([])
        for inputs, targets in tqdm(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            y_true = torch.cat((y_true, targets.cpu()), 0)
            outputs = outputs.softmax(dim=-1)
            y_score = torch.cat((y_score, outputs.cpu()), 0)

    test_metrics = cal_metrics(y_true, y_score)

    return test_metrics

In [None]:
data_flag = 'chestmnist'
dataset_folder = f'./{data_flag}'
model_folder = './saved_models'
download = False
info = medmnist.INFO[data_flag]
task, n_samples, n_channels, n_classes, label_dict = info['task'], info['n_samples'], info['n_channels'], len(info['label']), info['label']

print('Task:', task)
print('Number of samples:', n_samples)
print('Number of channels:', n_channels)
print('Number of classes:', n_classes)
print('Label Dict:', label_dict)

Task: multi-label, binary-class
Number of samples: {'train': 78468, 'val': 11219, 'test': 22433}
Number of channels: 1
Number of classes: 14
Label Dict: {'0': 'atelectasis', '1': 'cardiomegaly', '2': 'effusion', '3': 'infiltration', '4': 'mass', '5': 'nodule', '6': 'pneumonia', '7': 'pneumothorax', '8': 'consolidation', '9': 'edema', '10': 'emphysema', '11': 'fibrosis', '12': 'pleural', '13': 'hernia'}


In [None]:
BATCH_SIZE = 128
N_EPOCHS = 20
LR = 0.001
WEIGHT_DECAY = 0.001
DEVICE = 'cuda:0'
DataClass = getattr(medmnist, info['python_class'])

# preprocessing
train_transform = transforms.Compose([
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

test_transform = transforms.Compose([
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# 64x64

In [None]:
image_size_64 = 64
train_dataset_64 = DataClass(root=dataset_folder, size=image_size_64, split='train', transform=train_transform, download=False)
test_dataset_64 = DataClass(root=dataset_folder, size=image_size_64, split='test', transform=test_transform, download=False)

train_loader_64 = data.DataLoader(dataset=train_dataset_64, batch_size=BATCH_SIZE, shuffle=True)
test_loader_64 = data.DataLoader(dataset=test_dataset_64, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
medvit_64 = MedViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.1, num_classes = n_classes).to(DEVICE)
summary(model=medvit_64, input_size=(3, image_size_64, image_size_64))

In [None]:
medvit_64_path = f'{model_folder}/medvit_64.pth'
torch.save(medvit_64.state_dict(), f=medvit_64_path)

In [None]:
medvit_64_path = f'{model_folder}/medvit_64.pth'
optimizer = optim.Adam(medvit_64.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
train(medvit_64, train_loader_64, optimizer, n_epochs=N_EPOCHS, device=DEVICE, model_path=medvit_64_path)

In [None]:
medvit_64.load_state_dict(torch.load(medvit_64_path, weights_only=True))
medvit_64.eval()

test_metrics = evaluate(medvit_64, test_loader_64, device=DEVICE)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

100%|██████████| 176/176 [00:17<00:00,  9.93it/s]


test_acc: 0.9474307620788253; test_f1: 0.0; test_auc: 0.5066876700316763


# 128x128

In [None]:
image_size_128 = 128
train_dataset_128 = DataClass(root=dataset_folder, size=image_size_128, split='train', transform=train_transform, download=False)
test_dataset_128 = DataClass(root=dataset_folder, size=image_size_128, split='test', transform=test_transform, download=False)

train_loader_128 = data.DataLoader(dataset=train_dataset_128, batch_size=BATCH_SIZE, shuffle=True)
test_loader_128 = data.DataLoader(dataset=test_dataset_128, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
medvit_128 = MedViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.1, num_classes=n_classes).to(DEVICE)
summary(model=medvit_128, input_size=(3, image_size_128, image_size_128))

In [None]:
medvit_128_path = f'{model_folder}/medvit_128.pth'
torch.save(medvit_128.state_dict(), f=medvit_128_path)

In [None]:
medvit_128_path = f'{model_folder}/medvit_128.pth'
optimizer = optim.Adam(medvit_128.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
train(medvit_128, train_loader_128, optimizer, n_epochs=N_EPOCHS, device=DEVICE, model_path=medvit_128_path)

In [None]:
medvit_128.load_state_dict(torch.load(medvit_128_path, weights_only=True))
medvit_128.eval()

test_metrics = evaluate(medvit_128, test_loader_128, device=DEVICE)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')

# 224x224

In [None]:
image_size_224 = 224
train_dataset_224 = DataClass(root=dataset_folder, size=image_size_224, split='train', transform=train_transform, download=False)
test_dataset_224 = DataClass(root=dataset_folder, size=image_size_224, split='test', transform=test_transform, download=False)

train_loader_224 = data.DataLoader(dataset=train_dataset_224, batch_size=BATCH_SIZE, shuffle=True)
test_loader_224 = data.DataLoader(dataset=test_dataset_224, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
medvit_224 = MedViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.1, num_classes=n_classes).to(DEVICE)
summary(model=medvit_224, input_size=(3, image_size_224, image_size_224))

In [None]:
medvit_224_path = f'{model_folder}/medvit_224.pth'
torch.save(medvit_224.state_dict(), f=medvit_224_path)

In [None]:
medvit_224_path = f'{model_folder}/medvit_224.pth'
optimizer = optim.Adam(medvit_224.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
train(medvit_224, train_loader_224, optimizer, n_epochs=N_EPOCHS, device=DEVICE, model_path=medvit_224_path)

In [None]:
medvit_224.load_state_dict(torch.load(medvit_224_path, weights_only=True))
medvit_224.eval()

test_metrics = evaluate(medvit_224, test_loader_224, device=DEVICE)
print(f'test_acc: {test_metrics[0]}; test_f1: {test_metrics[1]}; test_auc: {test_metrics[2]}')