In [1]:
!pip install torch monai nibabel scikit-learn matplotlib

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


Imports

In [1]:
import torch 
import glob 
import os
import nibabel as nib
from sklearn.model_selection import train_test_split

import monai
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Rand3DElasticd,
    ScaleIntensityd,
    SpatialPadd,
    CenterSpatialCropd,
    RandFlipd,
    ToTensord,
    GridPatchd,
    MapTransform
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet

import time
# from pytorchtools import EarlyStopping
import numpy as np

from monai.utils import progress_bar

import matplotlib.pyplot as plt
from monai.inferers import sliding_window_inference
from pathlib import Path



Model variables

In [2]:
# To test
# patch_size = (16, 32)
# batch_size = (32, 64, 128)
# lr = (0.0001, 0.001, 0.01)
# filter_num = (16, 32, 64)
# depth = (3, 4)
# num_conv = (2, 3)
# loss_func = "mae"

image_size = 256
patch_size = 64
batch_size = 16
lr = 0.0001
filter = 16
depth = 4
loss_func = "mae"

create output directory 

In [3]:
output_dir = f"output/image-{image_size}_patch-{patch_size}_batch-{batch_size}_LR-{lr}_filter-{filter}_depth-{depth}_loss-{loss_func}/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

Resources

In [4]:
# use gpu if available 
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  return torch._C._cuda_getDeviceCount() > 0


Move preprocessed data to one input folder 

In [5]:
# creates a dictionary of pairs of image file paths (str)

input_dir = os.path.expanduser("~/graham/scratch/degad_preprocessed_data")

work_dir = os.path.join(input_dir, "work")
subject_dirs = glob.glob(os.path.join(work_dir, "sub-*"))

subjects = []
for directory in subject_dirs:
    if os.path.isdir(directory): 
        subjects.append(directory)

data_dicts = []
for sub in subjects:   
    gad_images = glob.glob(os.path.join(sub, "ses-pre", "normalize", "*acq-gad*_T1w.nii.gz"))
    print("Found gad images:", gad_images)
    
    nogad_images = glob.glob(os.path.join(sub, "ses-pre", "normalize", "*acq-nongad*_T1w.nii.gz"))
    print("Found nogad images:", nogad_images)
    
    if gad_images and nogad_images:
        data_dicts.append({"image": gad_images[0], "label": nogad_images[0], "image_filepath": gad_images[0]})

print("Loaded", len(data_dicts), "paired samples.")


Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-nongad_run-01_desc-normalized_zscore_T1w.nii.gz']
Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P004/ses-pre/normalize/sub-P004_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P004/ses-pre/normalize/sub-P004_ses-pre_acq-nongad_run-01_desc-normalized_zscore_T1w.nii.gz']
Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P005/ses-pre/normalize/sub-P005_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P005/ses-pre/n

Split into train, test, validate

In [6]:
# create split of the image path strings

# 70% train, 15% val, 15% test 

train_val, test = train_test_split(data_dicts, test_size=0.15, random_state=42)

# 0.176 ≈ 15% of the full data
train, val = train_test_split(train_val, test_size=0.176, random_state=42)

print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

Train: 42, Val: 9, Test: 9


In [7]:
class SaveImagePath(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)
        
    def __call__(self, data):
        # Storing the file path separately in the 'image_filepath' key
        data['image_filepath'] = data['image']
        return data

Define transforms 

In [8]:
# using transformations from original code 

# set size of image to patch size (patch_size, patch_size, patch_size)
dims_tuple_image = (image_size,)*3
print("dims_tuple: ", dims_tuple_image)
dims_tuple_patch = (patch_size,)*3
print("dims_tuple: ", dims_tuple_patch)

# train tranforms 
train_transforms = Compose([
    LoadImaged(
        keys=["image", "label"], 
    ),  # load image from the file path 
    EnsureChannelFirstd(keys=["image", "label"]), # ensure this is [C, H, W, (D)]
    ScaleIntensityd(keys=["image"]), # scales the intensity from 0-1
    Rand3DElasticd(
        keys = ("image","label"), 
        sigma_range = (0.5,1), 
        magnitude_range = (0.1, 0.4), 
        prob=0.4, 
        shear_range=(0.1, -0.05, 0.0, 0.0, 0.0, 0.0),
        scale_range=0.5, padding_mode= "zeros"
    ),
    RandFlipd(keys = ("image","label"), prob = 0.5, spatial_axis=(0,1,2)),
    SpatialPadd(keys = ("image","label"), spatial_size=dims_tuple_image), #ensure all images are (1,256,256,256) if too small
    CenterSpatialCropd(keys=("image", "label"), roi_size=dims_tuple_image), # ensure all images are (1,256,256,256) if too big
    GridPatchd(
        keys=("image", "label"),
        patch_size=(dims_tuple_patch),
        offset=(0, 0, 0),
        stride=(dims_tuple_patch) # Non-overlapping
    ),
    ToTensord(keys=["image", "label"])
])

# view size of image and label for training
sample_train = train_transforms(train[0])
print("Test image shape:", sample_train["image"].shape)
print("Test label shape:", sample_train["label"].shape)

# want to validate and test with whole images 
val_transforms = Compose([
    SaveImagePath(keys=["image"]),
    LoadImaged(
        keys=["image", "label"]
    ),  # load image
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityd(keys=["image"]),
    SpatialPadd(keys = ("image","label"),spatial_size=dims_tuple_image), # ensure data is the same size
    CenterSpatialCropd(keys=("image", "label"), roi_size=dims_tuple_image), # ensure all images are (1,256,256,256) if too big
    ToTensord(keys=["image", "label"])
])

sample_val = val_transforms(val[0])
print("Val image shape:", sample_val["image"].shape)
print("Val label shape:", sample_val["label"].shape)

sample_test = val_transforms(test[0])
print("Test image shape:", sample_test["image"].shape)
print("Test label shape:", sample_test["label"].shape)
print("Image file path:", sample_test["image_filepath"])


dims_tuple:  (256, 256, 256)
dims_tuple:  (64, 64, 64)
Test image shape: torch.Size([64, 1, 64, 64, 64])
Test label shape: torch.Size([64, 1, 64, 64, 64])
Val image shape: torch.Size([1, 256, 256, 256])
Val label shape: torch.Size([1, 256, 256, 256])
Test image shape: torch.Size([1, 256, 256, 256])
Test label shape: torch.Size([1, 256, 256, 256])
Image file path: /home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz


Set up datasets and data loader with monai

In [9]:

train_ds = Dataset(data=train, transform=train_transforms)
val_ds = Dataset(data=val, transform=val_transforms)
test_ds = Dataset(data=test, transform=val_transforms)

# validating, testing of whole data so use a batch size of 1
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=pin_memory)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)

In [None]:
for batch in train_loader:
    print("Image shape:", batch["image"].shape)
    print("Label shape:", batch["label"].shape)
    break


In [67]:
# def visualize_batch(loader, num_samples=6, title="Batch Samples"):

#     plt.figure(figsize=(12, 2 * num_samples))
    
#     for i, batch in enumerate(loader):
#         images, labels = batch["image"], batch["label"]
        
#         # convert to numpy arrays (remove channel dimension)
#         images = images.numpy()
#         labels = labels.numpy()

#         batch_size = images.shape[0]
#         max_to_show = min(num_samples, batch_size)

#         for idx in range(max_to_show):
#             img = np.squeeze(images[idx])  # shape: (H, W, D) or (D, H, W)
#             lbl = np.squeeze(labels[idx])

#             # pick a slice along the last dimension
#             slice_index = img.shape[-1] // 2

#             # Show image
#             plt.subplot(max_to_show, 2, idx * 2 + 1)
#             plt.imshow(img[..., slice_index], cmap="gray")
#             plt.title(f"Image {idx+1}")
#             plt.axis("off")

#             # Show label
#             plt.subplot(max_to_show, 2, idx * 2 + 2)
#             plt.imshow(lbl[..., slice_index], cmap="gray")
#             plt.title(f"Label {idx+1}")
#             plt.axis("off")

#     plt.suptitle(title)
#     plt.tight_layout()
#     plt.show()

# visualize_batch(train_loader, num_samples=6, title="Train Batch Samples")

Model definition 

In [11]:
# calculate channels and strides based on given parameters

filter = 32

channels = []
for i in range(depth):
    channels.append(filter)
    filter *=2
print("channels: ", channels)
strides = []
for i in range(depth-1):
    strides.append(2)
print("strides: ", strides)

# define model 
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=channels,
    strides=strides,
    num_res_units=2,
    dropout=0.2,
    norm='BATCH'
).apply(monai.networks.normal_init).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)


channels:  [32, 64, 128, 256]
strides:  [2, 2, 2]


Define lr, optimization, epochs, loss

In [12]:
learning_rate = float(lr)

# common defaults
betas = (0.5, 0.999)

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, betas=betas)

# patience = 22 # epochs it will take for training to terminate if no improvement
# early_stopping = EarlyStopping(patience=patience, verbose=True, path = f'{output_dir}/checkpoint.pt')
best_model_path = f"{output_dir}/best_model.pt"
max_epochs = 1

loss = torch.nn.L1Loss().to(device)

train_losses = [float('inf')]
val_losses = [float('inf')]
best_val_loss = float('inf')
test_loss = 0

Train model

In [14]:
start = time.time()

for epoch in range(max_epochs):
    print(f"\nEpoch {epoch + 1}/{max_epochs}")
    model.train()

    train_loss_display = f"{train_losses[-1]:.4f}" if train_losses else "N/A"
    val_loss_display = f"{val_losses[-1]:.4f}" if val_losses else "N/A"

    progress_bar(
        index=epoch + 1,
        count=max_epochs,
        desc=f"epoch {epoch + 1}, training mae loss: {train_loss_display}, validation mae metric: {val_loss_display}",
        newline=True
    )

    # training
    avg_train_loss = 0

    print("--------Training--------")
    for batch in train_loader:
        print(f"batch: {batch}")
        # image is gad image, label is nogad image
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        B, P, C, D, H, W = gad_images.shape
        optimizer.zero_grad() #resets optimizer to 0
        total_loss = 0

        for patch in range(P):
            print(f"patch: {patch}")
            patch_gad = gad_images[:, patch] 
            patch_nogad = nogad_images[:, patch]
            degad_images = model(patch_gad)
            
            patch_loss = loss(degad_images, patch_nogad)
            total_loss += patch_loss

        total_loss /= P
        total_loss.backward() # computes gradients for each parameter based on loss
        optimizer.step() # updates the model weights using the gradient
        avg_train_loss += total_loss.item() 
    
    avg_train_loss /= len(train_loader) # average loss per current epoch 
    train_losses.append(avg_train_loss) # append total epoch loss divided by the number of training steps in epoch to loss list
    print(f"Epoch {epoch + 1} - Average Training Loss: {avg_train_loss:.4f}")
    model.eval()
    
    print("--------Validation--------")
    # validation 
    with torch.no_grad(): #we do not update weights/biases in validation training, only used to assess current state of model
        avg_val_loss = 0 # will hold sum of all validation losses in epoch and then average
        for batch in val_loader: # iterating through dataloader
            gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
            degad_images = model(gad_images)
            degad_images = degad_images[:, :, :255, :255, :255]

            val_loss = loss(degad_images, nogad_images)
            avg_val_loss += val_loss.item() 
        
        avg_val_loss /= len(val_loader) #producing average val loss for this epoch
        val_losses.append(avg_val_loss) 

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")
        # early_stopping(avg_val_loss, model) # early stopping keeps track of last best model

    # if early_stopping.early_stop: # stops early if validation loss has not improved for {patience} number of epochs
    #     print("Early stopping, saving model") 
    #     break


end = time.time()
total_time = end - start
print("time for training and validation: ", total_time)


Epoch 1/1

--------Training--------
batch: {'image': metatensor([[[[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            ...,
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00]],

           [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0000e+00, 0.0000e+00],
            [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
             0.0

: 

Plot training metrics 

In [85]:
# from original code 

with open (f'{output_dir}/model_stats.txt', 'w') as file:  
    file.write(f'Training time: {time}\n') 
    file.write(f'Number of trainable parameters: {trainable_params}\n')
    # file.write(f'Training loss: {train_losses[-patience]} \nValidation loss: {early_stopping.val_loss_min}')
    
plt.figure(figsize=(12,5))
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(list(range(len(train_losses))), train_losses, label="Training Loss")
plt.plot(list(range(len(val_losses))),val_losses , label="Validation Loss")
plt.grid(True, "both", "both")
plt.title("Training and Validation Loss")
plt.legend()
plt.savefig(f'{output_dir}/lossfunction.png')
plt.close()

Test model with sliding window 

In [140]:
model.load_state_dict(torch.load(f'{output_dir}/checkpoint.pt'))
model.eval()

output_dir_test = Path(output_dir) / "test"
output_dir_test.mkdir(parents=True, exist_ok=True)

with torch.no_grad():
    for i, batch in enumerate(test_loader):      
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        gad_paths = batch["image_filepath"]
        degad_images = sliding_window_inference(gad_images, image_size, 1, model)
        degad_images = degad_images[:, :, :255, :255, :255]

        loss_value = loss(degad_images, nogad_images)

        test_loss += loss_value.item()

        # to save the output files 
        # shape[0] gives number of images 
        for j in range(degad_images.shape[0]):
            gad_path = gad_paths[j] # test dictionary image file name
            print(gad_path)
            gad_nib = nib.load(gad_path)
            sub = Path(gad_path).name.split("_")[0]
            degad_name = f"{sub}_acq-degad_T1w.nii.gz"
            
            degad_nib = nib.Nifti1Image(
                degad_images[j, 0].detach().numpy()*100, 
                affine=gad_nib.affine,
                header=gad_nib.header
            )

            os.makedirs(f'{output_dir_test}/bids/{sub}/ses-pre/anat', exist_ok=True) # save in bids format
            output_path = f'{output_dir_test}/bids/{sub}/ses-pre/anat/{degad_name}'
            nib.save(degad_nib, output_path)
    
print(f"Test Loss: {test_loss / len(test_loader):.4f}")

/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P008/ses-pre/normalize/sub-P008_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P043/ses-pre/normalize/sub-P043_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P054/ses-pre/normalize/sub-P054_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P018/ses-pre/normalize/sub-P018_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P063/ses-pre/normalize/sub-P063_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P03