In [1]:
import os
import numpy as np
from tqdm import tqdm

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchio as tio
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchio")

from utils.utils import train_3d, validate_3d
from utils.vis import plot_mri
from utils.dataset import BrainMRIDataset
from utils.loss import DiceCrossEntropyLoss, DiceFocalLoss
from models import UNet3D, AttentionUNet
from monai.networks.nets import UNet, SegResNet, UNETR

## Constants

In [2]:
ROOT_DIR = './Data'
BATCH_SIZE = 1
EPOCHS = 300
NUM_CLASSES = 4
NUM_WORKERS=16
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Transforms

In [3]:
# TorchIO transformations for augmentation
train_transform = tio.Compose([
    tio.RandomAffine(scales=(0.9, 1.1)),
    # tio.RandomElasticDeformation(num_control_points=(7, 7, 7), max_displacement=(4, 4, 4)),
    # tio.RandomFlip(axes=(0, 1, 2)),
    # tio.RandomBiasField(coefficients=(0.1, 0.3)),
    tio.RescaleIntensity((0, 1))  # Normalize intensity to [0, 1]
])

val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1))  # Only normalize intensity for validation
])

In [4]:
# Create datasets
train_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'train'), transform=train_transform)
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)

# Create DataLoaders
train_loader = tio.SubjectsLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = tio.SubjectsLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

## Models

In [5]:
# model = SegResNet(
#     spatial_dims=3, 
#     init_filters=16, 
#     in_channels=1, 
#     out_channels=4, 
#     dropout_prob=None, 
#     act=('RELU', {'inplace': True}), 
#     norm=('GROUP', {'num_groups': 8}), 
#     norm_name='', 
#     num_groups=8, 
#     use_conv_final=True, 
#     blocks_down=(1, 2, 2, 4), 
#     blocks_up=(1, 1, 1)
# )

# model = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=4,  
#     channels=(16, 32, 64, 128, 256),
#     strides=(2, 2, 2, 2),
#     num_res_units=4,
# )
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=4,  
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=4,
    norm="instance",  # Use Instance Normalization
    dropout=0.2       # Add dropout
)
# model = UNETR(in_channels=1, out_channels=4, img_size=(256,128,258), feature_size=32, norm_name='batch')
model = model.to(DEVICE)

## Loss & Optimizer

In [6]:
# class_weights = train_dataset.calculate_class_weights_log(num_classes=4).to(DEVICE)
# criterion = DiceCrossEntropyLoss(dice_weight=1.0, ce_weight=0.0, is_3d=True, class_weights=class_weights)
criterion = DiceFocalLoss(alpha=[0.05, 0.5, 0.3, 0.3], gamma=2, is_3d=True, ignore_background=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-7)

## Training

In [None]:
best_avg_dice = 0
for epoch in range(EPOCHS):
    train_avg_loss, train_avg_dice, train_csf_dice, train_gm_dice, train_wm_dice = train_3d(model, train_loader, criterion, optimizer, DEVICE, epoch, EPOCHS, NUM_CLASSES)
    val_avg_loss, val_avg_dice, val_csf_dice, val_gm_dice, val_wm_dice = validate_3d(model, val_loader, criterion, DEVICE, epoch, EPOCHS, NUM_CLASSES)
    scheduler.step()
    if val_avg_dice > best_avg_dice:
        best_avg_dice = val_avg_dice
        torch.save(model.state_dict(), 'best_model_3d.pth')
        print(f'Best model saved with dice score: {best_avg_dice}\n')