## Create a Model Module for Training

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as tt
import matplotlib.pyplot as plt
from monai import transforms as mT ## Breaks with numpy > 2.0
from monai.utils import set_determinism
import timm

In [2]:
import os
from dotenv import load_dotenv
from pathlib import PosixPath, Path
import json
import numpy as np
import yaml
from typing import List, Dict, Tuple, Optional, Union, Any
from tqdm.notebook import tqdm

In [3]:
load_dotenv("../envs/mednist.env")
root_dir = Path(os.environ.get("DATASET_DIR"))
data_dir = Path(os.environ.get("DATA_DIR"))
set_determinism(seed=42)

In [4]:
with open(data_dir / 'hyperparam_mps_timm.yml', 'r') as outfile:
    hparams_dict = yaml.safe_load(outfile)

In [5]:
hparams_dict

{'device': 'cpu',
 'epochs': 4,
 'finetune_frac': 0.1,
 'ftune_batchsize': 16,
 'in_channels': 1,
 'loss': 'CrossEntropyLoss',
 'lr': 1e-05,
 'num_workers': 2,
 'optimizer': 'AdamW',
 'out_channels': 7,
 'spatial_dims': 2,
 'test_frac': 0.1,
 'train_batchsize': 16,
 'val_interval': 1,
 'torch_device': 'mps',
 'model_name': 'resnet34',
 'pretrained': True}

In [6]:
with open(str(data_dir / "random_split.json"), "r") as fp:
    data_split = json.load(fp)

In [7]:
def replace_header(path: str, pattern: str, replace_str: str) -> str:
    return path.replace(pattern, replace_str,)

## Preprocessing
for split_type in ["train", "ftune", "test"]:
    data_split[split_type]['image'] = [
        replace_header(
            path=img_path,
            pattern="<DATASET_DIR>",
            replace_str=str(root_dir)
            ) for img_path in data_split[split_type]['image']]

In [8]:
## Define all relevant transforms!
train_transforms = mT.Compose([
    mT.LoadImage(image_only=True,),
    mT.EnsureChannelFirst(), ## Add a channel to the batch dimension
    mT.ScaleIntensity(),
    mT.RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
    mT.RandFlip(spatial_axis=0, prob=0.5),
    mT.RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    mT.ToTensor(),
    ])

ftune_transforms = mT.Compose([
    mT.LoadImage(image_only=True),
    mT.EnsureChannelFirst(), ## Add a channel to the batch dimension
    mT.ScaleIntensity(),
])

pred_transform = mT.Compose([
    mT.Activations(softmax=True)])

label_transform = mT.Compose([mT.AsDiscrete(to_onehot=hparams_dict['out_channels'])])

In [9]:
type(train_transforms)

monai.transforms.compose.Compose

In [10]:
hparams_dict

{'device': 'cpu',
 'epochs': 4,
 'finetune_frac': 0.1,
 'ftune_batchsize': 16,
 'in_channels': 1,
 'loss': 'CrossEntropyLoss',
 'lr': 1e-05,
 'num_workers': 2,
 'optimizer': 'AdamW',
 'out_channels': 7,
 'spatial_dims': 2,
 'test_frac': 0.1,
 'train_batchsize': 16,
 'val_interval': 1,
 'torch_device': 'mps',
 'model_name': 'resnet34',
 'pretrained': True}

In [11]:
data_split['train'].keys()

dict_keys(['image', 'label'])

In [12]:
## Dataset!
class MedNIST_Dataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            data_dict: Dict, 
            transforms: mT.Compose, 
            image_key: str = "image",
            label_key: str = "label",
            ) -> None:
        self.data = data_dict
        self.transform = transforms
        self.image_key = image_key
        self.label_key = label_key

    def __len__(self):
        return len(self.data[self.image_key])
    
    def __getitem__(self, index):
        return {
            "x": self.transform(self.data[self.image_key][index]),
            "y": int(self.data[self.label_key][index]),}

In [36]:
onebatch = {
    "image": [data_split['train']['image'][0]],
    "label": [data_split['train']['label'][0]],
}

train_ds = MedNIST_Dataset(
    data_dict = onebatch,
    transforms=train_transforms,)

ftune_ds = MedNIST_Dataset(
    data_dict = onebatch,
    transforms=ftune_transforms,)

## Dataloaders!
train_dl = torch.utils.data.DataLoader(
    train_ds, 
    batch_size=1,
    num_workers=0)

ftune_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=1,
    num_workers=0)

## Model Def

In [39]:
# train_batch = next(iter(train_dl))

In [40]:
# from torchmetrics import F1Score, Accuracy
# rocauc = ROCAUCMetric()
# ## Set average to None to get classwise.
# acc = Accuracy(task="multiclass", num_classes=hparams_dict['out_channels'])
# f1 = F1Score(task="multiclass", num_classes=hparams_dict['out_channels']) 

# with torch.no_grad():
#     outs = net(train_batch['x'].to(torch_device))
#     pred = torch.stack([pred_transform(out.cpu()) for out in outs])
#     gt = torch.stack([label_transform(i.cpu()) for i in train_batch['y']]) 
#     y_pred = pred ## Append or cat with multi-batch
#     y_gt = gt

# # acc = torch.eq(torch.stack(y_pred).argmax(dim=1), train_batch['y']).astype(int).mean() # Channel dimension is 1
# out_acc = acc(y_pred.argmax(dim=1), y_gt.argmax(dim=1))
# out_f1 = f1(y_pred, y_gt)

# metric = rocauc(y_pred, y_gt)
# metric = rocauc.aggregate()

In [41]:
from torchmetrics import Accuracy, F1Score, AUROC
from monai.metrics import ROCAUCMetric

metric_suite = {
    "rocauc": ROCAUCMetric(),
    "acc": Accuracy(task="multiclass", num_classes=hparams_dict['out_channels']),
    "f1": F1Score(task="multiclass", num_classes=hparams_dict['out_channels'])}
metric_suite

{'rocauc': <monai.metrics.rocauc.ROCAUCMetric at 0x3543aa6c0>,
 'acc': MulticlassAccuracy(),
 'f1': MulticlassF1Score()}

In [42]:
hparams_dict['out_channels']

7

In [48]:
log_tracker = []
def train_epoch(
        net: Any, 
        train_dl: torch.utils.data.DataLoader,
        torch_device: str,
        log_tracker: Dict,
        optimizer: Any,
        criterion: Any
        ):
    net.train()
    epoch_loss = 0
    step = 0
    
    for batch in  tqdm(train_dl):
        imgs = batch['x'].to(torch_device)
        labels = batch['y'].to(torch_device)
        optimizer.zero_grad(set_to_none=True)

        outputs = net(imgs)
        batch_size = len(imgs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        step += batch_size
    
    epoch_loss /= step 
    log_tracker['train_loss'].append(epoch_loss)

    return log_tracker

In [49]:
def val_epoch(
        net: Any, 
        val_dl: torch.utils.data.DataLoader,
        torch_device: str,
        criterion: Any,
        log_tracker: Dict,
        split_type: str = "ftune",
        metric_suite: Dict = metric_suite,
        ):

    net.eval()
    epoch_loss = 0
    step = 0
    y_pred, y_gt = [], []
    with torch.no_grad():
        for batch in tqdm(val_dl):
            imgs = batch['x'].to(torch_device)
            labels = batch['y'].to(torch_device)
            outputs = net(imgs)
            batch_size = len(imgs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item()
            step += batch_size

            pred = [pred_transform(out.cpu()) for out in outputs]
            gt = [label_transform(i.cpu()) for i in labels]
            y_pred.append(pred) ## Append or cat with multi-batch
            y_gt.append(gt)
            metric_suite['rocauc'](pred, gt)

    # acc = torch.eq(
    #     torch.stack(y_pred).argmax(dim=1), # Channel dimension is 1
    #     y_gt).astype(int).mean() 
    predY = torch.stack([pred for batch_pred in y_pred for pred in batch_pred])
    gtY = torch.stack([pred for batch_pred in y_gt for pred in batch_pred])
    # for pred, gt in zip(predY, gtY):
    #     print(pred, gt)
    metrics = metric_suite['rocauc'].aggregate(), metric_suite['acc'](predY.argmax(dim=1), gtY.argmax(dim=1)), metric_suite['f1'](predY.argmax(dim=1), gtY.argmax(dim=1))
    # print(metrics)
    epoch_loss /= step 
    log_tracker[f'{split_type}_loss'].append(epoch_loss)
    log_tracker[f'{split_type}_acc'].append(metrics[1])
    log_tracker[f'{split_type}_f1'].append(metrics[2])
    log_tracker[f'{split_type}_rocauc'].append(metrics[0])
    metric_suite['rocauc'].reset()
    return log_tracker, predY, gtY

In [50]:
log_tracker = {}
split_type = "ftune"
for key in ["loss", "acc", "rocauc", "f1"]:
        log_tracker[f'{split_type}_{key}'] = []
split_type = "train"
for key in ["loss", "acc", "rocauc", "f1"]:
        log_tracker[f'{split_type}_{key}'] = []

# log_tracker = val_epoch(
#         net=net, 
#         val_dl=ftune_dl,
#         torch_device=torch_device,
#         log_tracker=log_tracker,
#         split_type="ftune",)

In [51]:
from monai.networks import nets as monai_nets

torch_device = torch.device(hparams_dict['torch_device'])
# net = timm.create_model(
#     'resnet34', 
#     pretrained=hparams_dict['pretrained'], 
#     in_chans=hparams_dict['in_channels'],
#     num_classes=hparams_dict['out_channels'],
#     ).to(torch_device)
net = monai_nets.DenseNet121(
    spatial_dims=hparams_dict['spatial_dims'],
    in_channels=hparams_dict['in_channels'],
    out_channels=hparams_dict['out_channels']).to(torch_device)
if hparams_dict['device'] == "cuda":
    net = torch.compile(net)

## Training related:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3)

In [52]:
## Training:

for epoch in tqdm(range(3)):
    ## train:
    log_tracker = train_epoch(
        net=net, 
        train_dl=train_dl,
        torch_device=torch_device,
        log_tracker=log_tracker,
        optimizer=optimizer,
        criterion=criterion,
        )
    ## validate/finetune
    log_tracker, predY, gtY = val_epoch(
        net=net, 
        val_dl=ftune_dl,
        torch_device=torch_device,
        log_tracker=log_tracker,
        criterion=criterion,
        split_type="ftune",)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/47163 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [35]:
log_tracker

{'ftune_loss': [1.8930829763412476,
  1.8440446853637695,
  1.7182203531265259,
  1.5484941005706787,
  1.3476866483688354,
  1.1381984949111938,
  0.9368500709533691,
  0.7463205456733704,
  0.5812855362892151,
  0.4587387144565582,
  0.35765063762664795,
  0.2618445158004761,
  0.21153660118579865,
  0.14850567281246185,
  0.12456497550010681,
  0.09833212941884995,
  0.07602529227733612,
  0.06082341820001602,
  0.05379243940114975,
  0.04336537420749664,
  0.04009545221924782,
  0.02711883932352066,
  0.025193143635988235,
  0.024645835161209106,
  0.026506297290325165,
  0.02483121119439602,
  0.02100251242518425,
  0.021244702860713005,
  0.015392908826470375,
  0.014979499392211437],
 'ftune_acc': [metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),
  metatensor(1.),

In [30]:
predY.shape

torch.Size([1, 7])

In [29]:
predY.argmax(), gtY.argmax()

(metatensor(1), metatensor(1))