In [1]:
import os
import numpy as np
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchio as tio

from utils.utils import train, validate
from utils.vis import plot_mri
from utils.metric import dice_score_3d
from utils.dataset import BrainMRIDataset
from utils.loss import DiceCrossEntropyLoss
from models.Unet import UNet3D

## Constants

In [2]:
ROOT_DIR = './Data/'
BATCH_SIZE = 4
EPOCHS = 300
NUM_CLASSES = 4
NUM_WORKERS=8
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Transforms

In [3]:
# TorchIO transformations for augmentation
train_transform = tio.Compose([
    tio.RandomAffine(scales=(0.9, 1.1), degrees=(10, 10, 10), translation=(5, 5, 5)),
    tio.RandomElasticDeformation(num_control_points=(7, 7, 7), max_displacement=(4, 4, 4)),
    tio.RandomFlip(axes=(0, 1, 2)),
    tio.RandomBiasField(coefficients=(0.1, 0.3)),
    tio.RescaleIntensity((0, 1))  # Normalize intensity to [0, 1]
])

val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1))  # Only normalize intensity for validation
])

In [4]:
train_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'train'), transform=None)
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=None)

In [5]:
# Create datasets
train_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'train'), transform=train_transform)
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)

# Create DataLoaders
train_loader = tio.SubjectsLoader(train_dataset, batch_size=2, shuffle=True, num_workers=NUM_WORKERS)
val_loader = tio.SubjectsLoader(val_dataset, batch_size=2, shuffle=False, num_workers=NUM_WORKERS)

## Models

In [6]:
model = UNet3D(in_channels=1, out_channels=NUM_CLASSES)
model = model.to(DEVICE)

## Loss & Optimizer

In [7]:
criterion = DiceCrossEntropyLoss(is_3d=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5)

## Training

In [None]:
# Example training loop
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    epoch_dice = 0
    progress_bar = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch + 1}/{EPOCHS}")
    for batch in progress_bar:
        images, masks = batch["image"]["data"].to(DEVICE), batch["mask"]["data"].long().to(DEVICE)  # Adjust keys if necessary

        # Forward pass
        outputs = model(images)

        # Compute loss
        loss = criterion(outputs, masks)
        epoch_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Dice score
        dice = dice_score_3d(outputs, masks, NUM_CLASSES)
        
        #update the progress bar
        progress_bar.set_postfix({"Loss": loss.item() / len(batch), "Avg Dice": np.mean(list(dice.values())), "WM Dice": dice[1], "GM Dice": dice[2], "CSF Dice": dice[3]})

    print(f"Epoch {epoch + 1}, Loss: {epoch_loss/len(train_loader):.4f}")
    print(f"Epoch {epoch + 1}, Dice: {np.mean(list(dice.values())):.4f}", f"WM Dice: {dice[1]:.4f}", f"GM Dice: {dice[2]:.4f}", f"CSF Dice: {dice[3]:.4f}")

Epoch 1/300: 100%|██████████| 5/5 [00:52<00:00, 10.45s/it, Loss=0.165, Avg Dice=1.06e-6, WM Dice=1.08e-10, GM Dice=1.35e-6, CSF Dice=2.9e-6]    


Epoch 1, Loss: 0.6077
Epoch 1, Dice: 0.0000 WM Dice: 0.0000 GM Dice: 0.0000 CSF Dice: 0.0000


Epoch 2/300: 100%|██████████| 5/5 [00:50<00:00, 10.17s/it, Loss=0.161, Avg Dice=2.71e-11, WM Dice=1.03e-10, GM Dice=1.54e-12, CSF Dice=3.22e-12]


Epoch 2, Loss: 0.3289
Epoch 2, Dice: 0.0000 WM Dice: 0.0000 GM Dice: 0.0000 CSF Dice: 0.0000


Epoch 3/300: 100%|██████████| 5/5 [00:51<00:00, 10.35s/it, Loss=0.162, Avg Dice=1.19e-11, WM Dice=4.39e-11, GM Dice=1.38e-12, CSF Dice=2.33e-12]


Epoch 3, Loss: 0.3147
Epoch 3, Dice: 0.0000 WM Dice: 0.0000 GM Dice: 0.0000 CSF Dice: 0.0000


