## Create a Torch Dataset for MedNISt

In [2]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as tt
import matplotlib.pyplot as plt
from monai import transforms as mT ## Breaks with numpy > 2.0
from monai.utils import set_determinism

In [3]:
import os
from dotenv import load_dotenv
from pathlib import PosixPath, Path
import json
import numpy as np
import yaml
from typing import List, Dict, Tuple, Optional, Union, Any
from tqdm.notebook import tqdm

In [4]:
load_dotenv("../envs/mednist.env")
root_dir = Path(os.environ.get("DATASET_DIR"))
data_dir = Path(os.environ.get("DATA_DIR"))
set_determinism(seed=42)

In [5]:
with open(data_dir / 'hyperparam.yml', 'r') as outfile:
    hparams_dict = yaml.safe_load(outfile)

In [6]:
hparams_dict

{'device': 'cpu',
 'epochs': 4,
 'finetune_frac': 0.1,
 'ftune_batchsize': 64,
 'in_channels': 1,
 'loss': 'CrossEntropyLoss',
 'lr': 1e-05,
 'num_workers': 2,
 'optimizer': 'AdamW',
 'out_channels': 5,
 'spatial_dims': 2,
 'test_frac': 0.1,
 'train_batchsize': 64,
 'val_interval': 1}

In [7]:
with open(str(data_dir / "random_split.json"), "r") as fp:
    data_split = json.load(fp)

In [8]:
def replace_header(path: str, pattern: str, replace_str: str) -> str:
    return path.replace(pattern, replace_str,)

## Preprocessing
for split_type in ["train", "ftune", "test"]:
    data_split[split_type]['image'] = [
        replace_header(
            path=img_path,
            pattern="<DATASET_DIR>",
            replace_str=str(root_dir)
            ) for img_path in data_split[split_type]['image']]

In [9]:
## Define all relevant transforms!
train_transforms = mT.Compose([
    mT.LoadImage(image_only=True),
    mT.EnsureChannelFirst(), ## Add a channel to the batch dimension
    mT.ScaleIntensity(),
    mT.RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
    mT.RandFlip(spatial_axis=0, prob=0.5),
    mT.RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    mT.ToTensor(),
    ])

ftune_transforms = mT.Compose([
    mT.LoadImage(image_only=True),
    mT.EnsureChannelFirst(), ## Add a channel to the batch dimension
    mT.ScaleIntensity(),
])

pred_transform = mT.Compose([
    mT.Activations(softmax=True)])

label_transform = mT.Compose(mT.AsDiscrete(to_onehot=hparams_dict['out_channels']))

In [10]:
type(train_transforms)

monai.transforms.compose.Compose

In [18]:
hparams_dict

{'device': 'cpu',
 'epochs': 4,
 'finetune_frac': 0.1,
 'ftune_batchsize': 64,
 'in_channels': 1,
 'loss': 'CrossEntropyLoss',
 'lr': 1e-05,
 'num_workers': 2,
 'optimizer': 'AdamW',
 'out_channels': 5,
 'spatial_dims': 2,
 'test_frac': 0.1,
 'train_batchsize': 64,
 'val_interval': 1}

In [11]:
data_split['train'].keys()

dict_keys(['image', 'label'])

In [12]:
## Dataset!
class MedNIST_Dataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            data_dict: Dict, 
            transforms: mT.Compose, 
            image_key: str = "image",
            label_key: str = "label",
            ) -> None:
        self.data = data_dict
        self.transform = transforms
        self.image_key = image_key
        self.label_key = label_key

    def __len__(self):
        return len(self.data[self.image_key])
    
    def __getitem__(self, index):
        return {
            "x": self.transform(self.data[self.image_key][index]),
            "y": int(self.data[self.label_key][index]),}

In [13]:
train_ds = MedNIST_Dataset(
    data_dict = data_split['train'],
    transforms=train_transforms,)

ftune_ds = MedNIST_Dataset(
    data_dict = data_split['ftune'],
    transforms=ftune_transforms,)

## Dataloaders!
train_dl = torch.utils.data.DataLoader(
    train_ds, 
    batch_size=hparams_dict['train_batchsize'],
    num_workers=0)

ftune_dl = torch.utils.data.DataLoader(
    ftune_ds, 
    batch_size=hparams_dict['ftune_batchsize'],
    num_workers=0)

In [14]:
len(train_ds)

47163

In [15]:
for idx, batch in enumerate(tqdm(train_dl)):
    continue

  0%|          | 0/737 [00:00<?, ?it/s]

In [16]:
for idx, batch in enumerate(tqdm(ftune_dl)):
    continue

  0%|          | 0/93 [00:00<?, ?it/s]

In [17]:
batch['x'].shape

torch.Size([7, 1, 64, 64])