In [9]:
!pip install torch monai nibabel scikit-learn matplotlib

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl.metadata (31 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.1-cp310-cp310-macosx_10_12_x86_64.whl.metadata (11 kB)
Collecting scipy>=1.6.0 (from scikit-learn)
  Downloading scipy-1.15.2-cp310-cp310-macosx_14_0_x86_64.whl.metadata (61 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.57.0-cp310-cp310-macosx_10_9_x86_64.whl.metadata

Imports

In [None]:
import torch 
import glob 
import os
import nibabel as nib
import numpy as np
from sklearn.model_selection import train_test_split

import monai
from monai.transforms import (
    Compose,
    Rand3DElasticd,
    SpatialPadd,
    RandFlipd,
    RandSpatialCropd
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet

import time
from pytorchtools import EarlyStopping
from monai.losses import SSIMLoss as SSIM

from monai.utils import progress_bar

import matplotlib.pyplot as plt
from monai.inferers import sliding_window_inference
from pathlib import Path

Model variables

In [None]:
patch_size = (16, 32)
batch_size = (32, 64, 128)
lr = (0.0001, 0.001, 0.01, 0.05)
filter_num = (16, 32, 64)
depth = (3, 4)
num_conv = (2, 3)
loss_func = "mae"

In [None]:
output_dir = f"output/patch-{patch_size}_batch-{batch_size}_LR-{lr}_filter-{filter_num}_depth-{depth}_convs-{num_conv}_loss-{loss_func}/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

Resources

In [None]:
# use gpu if available 
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Move preprocessed data to one input folder 
TO BE CHANGED BASED ON PATH TO PREPROCESSING OUTPUT

In [None]:
bids_root = "/inputs/rigid/"
subjects = sorted(glob.glob(os.path.join(bids_root, "sub-*")))

data_dicts = []
for sub in subjects:
    gad_images = glob.glob(os.path.join(sub, "anat", "*gad.nii.gz"))
    nogad_images = glob.glob(os.path.join(sub, "anat", "*nogad.nii.gz"))
    if gad_images and nogad_images:
        data_dicts.append({"image": gad_images[0], "label": nogad_images[0]})

print("Loaded", len(data_dicts), "paired samples.")

In [None]:
fname_tr=data_dicts[0]# training file
radius_actual = [int(patch_size/2-1)]*3 # getting c3d patch radius ie. if 32 ^3 patch size, it is 15
patch_radius= np.array(radius_actual) # Patch dimensions
dims = 1+2*patch_radius # numpyt
dims_tuple = (patch_size,)*3
k = 2  # Number of channels
bps = (4 * k * np.prod(dims)) # Bytes per sample
np_tr = os.path.getsize(fname_tr) // bps  # Number of samples
arr_shape_tr= (int(np_tr),dims[0],dims[1],dims[2], k)
arr_train = np.memmap(fname_tr,'float32','r+',shape=arr_shape_tr)

fname_va=data_dicts[1] # validation file   
np_va = os.path.getsize(fname_va) // bps      # Number of samples
arr_shape_va= (int(np_va),dims[0],dims[1],dims[2], k)
arr_val= np.memmap(fname_va,'float32','r+',shape=arr_shape_va)

Split into train, test, validate

In [None]:
train_val, test = train_test_split(data_dicts, test_size=0.15, random_state=42)

# 0.176 ≈ 15% of the full data
train, val = train_test_split(train_val, test_size=0.176, random_state=42)

print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

Define transforms 

In [None]:
# want to train with patches
train_transforms = Compose([
    SpatialPadd(keys = ("image","label"), spatial_size = dims_tuple), #ensures all data is around the same size
    Rand3DElasticd(keys = ("image","label"), sigma_range = (0.5,1), magnitude_range = (0.1, 0.4), prob=0.4, shear_range=(0.1, -0.05, 0.0, 0.0, 0.0, 0.0), scale_range=0.5, padding_mode= "zeros"),
    RandFlipd(keys = ("image","label"), prob = 0.5, spatial_axis=1),
    RandFlipd(keys = ("image","label"), prob = 0.5, spatial_axis=0),
    RandFlipd(keys = ("image","label"), prob = 0.5, spatial_axis=2),
    RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_center=True, random_size=False)
])

# want to validate and test with whole images 
val_transforms = Compose([
    SpatialPadd(keys = ("image","label"),spatial_size = dims_tuple)
])


Set up datasets and data loader with monai

In [None]:

train_ds = Dataset(data=train, transform=train_transforms)
val_ds = Dataset(data=val, transform=val_transforms)
test_ds = Dataset(data=test, transform=val_transforms)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=32, pin_memory=pin_memory)

# val and test on whole brain data 
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)

Model definition 

In [None]:
# calculate channels and strides based on given depth 

# question about this section - does this match the unet? 
channels = []
for i in range(depth):
    channels.append(filter)
    filter *=2
print("channels: ", channels)

strides = []
for i in range(depth - 1):
    strides.append(2)
strides += 1
print("strides: ", strides)

# define model 
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=channels,
    strides=strides,
    num_res_units=2,
    dropout=0.2,
    norm='BATCH'
).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

Define loss functions 

In [None]:
training_steps = int(np_tr / batch_size) # number of training steps per epoch
validation_steps = int(np_va / batch_size)
learning_rate = float(lr)
betas = (0.5, 0.999)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, betas=betas)
patience = 22 # epochs it will take for training to terminate if no improvement
early_stopping = EarlyStopping(patience=patience, verbose=True, path = f'{output_dir}/checkpoint.pt')
max_epochs = 800

loss = torch.nn.L1Loss().to(device)

train_losses = [float('inf')]
val_losses = [float('inf')]
test_losses = [float('inf')]

Train model

In [None]:
start = time.time()

for epoch in range(max_epochs):
    print(f"\nEpoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0

    progress_bar(
        index=epoch+1, # displays what step we are of current epoch, our epoch number, training  loss
        count = max_epochs, 
        desc= f"epoch {epoch + 1}, training mae loss: {train_losses[-1]:.4f}, validation mae metric: {val_losses[-1]:.4f}",
        newline = True) # progress bar to display current stage in training
    
    # training
    for batch in train_loader:
        # image is gad image, label is nogad image
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        optimizer.zero_grad()
        degad_images = model(gad_images)
        train_loss = loss(degad_images, nogad_images)
        train_loss.backward()

        optimizer.step()
        avg_train_loss += train_loss.item() 
        avg_train_loss /= training_steps
        train_losses.append(avg_train_loss) # append total epoch loss divided by the number of training steps in epoch to loss list
        model.eval()
    
    # validation 
    with torch.no_grad(): #we do not update weights/biases in validation training, only used to assess current state of model
        avg_val_loss = 0 # will hold sum of all validation losses in epoch and then average
        for batch in val_loader: # iterating through dataloader
            gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
            degad_images = model(gad_images)
            
            val_loss = loss(degad_images, nogad_images)
            avg_val_loss += val_loss 
        avg_val_loss = avg_val_loss.item()/validation_steps #producing average val loss for this epoch
        val_losses.append(avg_val_loss) 
        early_stopping(avg_val_loss, model) # early stopping keeps track of last best model

    if early_stopping.early_stop: # stops early if validation loss has not improved for {patience} number of epochs
        print("Early stopping") 
        break

end = time.time()
time = end - start
print("time for training and validation: ", time)

    

Plot training metrics 

In [None]:
with open (f'{output_dir}/model_stats.txt', 'w') as file:  
    file.write(f'Training time: {time}\n') 
    file.write(f'Number of trainable parameters: {trainable_params}\n')
    file.write(f'Training loss: {train_losses[-patience]} \nValidation loss: {early_stopping.val_loss_min}')

    plt.figure(figsize=(12,5))
    plt.plot(list(range(len(train_losses))), train_losses, label="Training Loss")
    plt.plot(list(range(len(val_losses))),val_losses , label="Validation Loss")
    plt.grid(True, "both", "both")
    plt.legend()
    plt.savefig(f'{output_dir}/lossfunction.png')

Test model with sliding window 

In [None]:
model.load_state_dict(torch.load(f'{output_dir}/checkpoint.pt'))
model.eval()

output_dir_test = Path(output_dir) / "test"
output_dir_test.mkdir(parents=True, exist_ok=True)

with torch.no_grad():
    for batch in test_loader:
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        degad_images = sliding_window_inference(gad_images, patch_size, 1, model)

        loss_val = loss(degad_images, nogad_images)

        test_loss += loss_val.item()

        # to save the output files 
        for i in range(degad_images.shape[0]):
            gad_path = batch["image_meta_dict"]["filename_or_obj"][i]
            gad_nib = nib.load(gad_path)
            sub = Path(gad_path).name.split("_")[0] 
            degad_name = f"{sub}_acq-degad_T1w.nii.gz"
            
            degad_nib = nib.Nifti1Image(
                degad_images[i, 0].cpu().numpy() * 100, 
                affine=gad_nib.affine,
                header=gad_nib.header
            )

            save_path = output_dir_test / degad_name
            nib.save(degad_nib, str(save_path))

print(f"Test Loss: {test_loss / len(test_loader):.4f}")