In [1]:
!pip install torch monai nibabel scikit-learn matplotlib

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


Imports

In [2]:
import torch 
import glob 
import os
import nibabel as nib
from sklearn.model_selection import train_test_split
from nibabel.processing import resample_from_to

import monai
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Rand3DElasticd,
    ScaleIntensityd,
    SpatialPadd,
    CenterSpatialCropd,
    RandFlipd,
    ToTensord,
    MapTransform,
    Resize
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet

import time
import numpy as np

from monai.utils import progress_bar

import matplotlib.pyplot as plt
from monai.inferers import sliding_window_inference
from pathlib import Path

Model variables

In [3]:
# To test
# patch_size = (16, 32)
# batch_size = (32, 64, 128)
# lr = (0.0001, 0.001, 0.01)
# filter_num = (16, 32, 64)
# depth = (3, 4)
# num_conv = (2, 3)
# loss_func = "mae"

image_size = 256
batch_size = 1
lr = 0.0001
filter = 64
depth = 4
loss_func = "mae"

create output directory 

In [4]:

home_dir = Path.home()
output_dir = f"/localscratch/output_whole_images/image-{image_size}_batch-{batch_size}_LR-{lr}_filter-{filter}_depth-{depth}_loss-{loss_func}/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

Resources

In [5]:
# use gpu if available 
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  return torch._C._cuda_getDeviceCount() > 0


Move preprocessed data to one input folder 

In [6]:
# creates a dictionary of pairs of image file paths (str)

input_dir = os.path.expanduser("~/graham/scratch/degad_preprocessed_data")

work_dir = os.path.join(input_dir, "work")
subject_dirs = glob.glob(os.path.join(work_dir, "sub-*"))

subjects = []
for directory in subject_dirs:
    if os.path.isdir(directory): 
        subjects.append(directory)

data_dicts = []
for sub in subjects:   
    gad_images = glob.glob(os.path.join(sub, "ses-pre", "normalize", "*acq-gad*_T1w.nii.gz"))
    print("Found gad images:", gad_images)
    
    nogad_images = glob.glob(os.path.join(sub, "ses-pre", "normalize", "*acq-nongad*_T1w.nii.gz"))
    print("Found nogad images:", nogad_images)
    
    if gad_images and nogad_images:
        data_dicts.append({"image": gad_images[0], "label": nogad_images[0], "image_filepath": gad_images[0]})

print("Loaded", len(data_dicts), "paired samples.")


Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P003/ses-pre/normalize/sub-P003_ses-pre_acq-nongad_run-01_desc-normalized_zscore_T1w.nii.gz']
Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P004/ses-pre/normalize/sub-P004_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P004/ses-pre/normalize/sub-P004_ses-pre_acq-nongad_run-01_desc-normalized_zscore_T1w.nii.gz']
Found gad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P005/ses-pre/normalize/sub-P005_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz']
Found nogad images: ['/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P005/ses-pre/n

Split into train, test, validate

In [7]:
# create split of the image path strings

# 70% train, 15% val, 15% test 

train_val, test = train_test_split(data_dicts, test_size=0.15, random_state=42)

# 0.176 ≈ 15% of the full data
train, val = train_test_split(train_val, test_size=0.176, random_state=42)

print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

Train: 109, Val: 24, Test: 24


In [8]:
class SaveImagePath(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)
        
    def __call__(self, data):
        # Storing the file path separately in the 'image_filepath' key
        data['image_filepath'] = data['image']
        return data




Define transforms 

In [11]:
# using transformations from original code 

# set size of image to patch size (patch_size, patch_size, patch_size)
dims_tuple = (image_size,)*3
print("dims_tuple: ", dims_tuple)

# train tranforms 
train_transforms = Compose([
    LoadImaged(
        keys=["image", "label"], 
    ),  # load image from the file path 
    EnsureChannelFirstd(keys=["image", "label"]), # ensure this is [C, H, W, (D)]
    ScaleIntensityd(keys=["image"]), # scales the intensity from 0-1
    Rand3DElasticd(
        keys = ("image","label"), 
        sigma_range = (0.5,1), 
        magnitude_range = (0.1, 0.4), 
        prob=0.4, 
        shear_range=(0.1, -0.05, 0.0, 0.0, 0.0, 0.0),
        scale_range=0.5, padding_mode= "zeros"
    ),
    RandFlipd(keys = ("image","label"), prob = 0.5, spatial_axis=(0,1,2)),
    SpatialPadd(keys = ("image","label"), spatial_size=dims_tuple), #ensure all images are (1,256,256,256) if too small
    CenterSpatialCropd(keys=("image", "label"), roi_size=dims_tuple), # ensure all images are (1,256,256,256) if too big
    ToTensord(keys=["image", "label"])
])

# view size of image and label for training
sample_train = train_transforms(train[0])
print("Test image shape:", sample_train["image"].shape)
print("Test label shape:", sample_train["label"].shape)

# want to validate and test with whole images 
val_transforms = Compose([
    SaveImagePath(keys=["image"]),
    LoadImaged(
        keys=["image", "label"]
    ),  # load image
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityd(keys=["image"]),
    SpatialPadd(keys = ("image","label"), spatial_size=dims_tuple), #ensure all images are (1,256,256,256) if too small
    CenterSpatialCropd(keys=("image", "label"), roi_size=dims_tuple), # ensure all images are (1,256,256,256) if too big
    ToTensord(keys=["image", "label"])
])

sample_val = val_transforms(val[0])
print("Val image shape:", sample_val["image"].shape)
print("Val label shape:", sample_val["label"].shape)

sample_test = val_transforms(test[0])
print("Test image shape:", sample_test["image"].shape)
print("Test label shape:", sample_test["label"].shape)
print("Image file path:", sample_test["image_filepath"])


dims_tuple:  (256, 256, 256)
Test image shape: torch.Size([1, 256, 256, 256])
Test label shape: torch.Size([1, 256, 256, 256])
Val image shape: torch.Size([1, 256, 256, 256])
Val label shape: torch.Size([1, 256, 256, 256])
Test image shape: torch.Size([1, 256, 256, 256])
Test label shape: torch.Size([1, 256, 256, 256])
Image file path: /home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P140/ses-pre/normalize/sub-P140_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz


Set up datasets and data loader with monai

In [12]:

train_ds = Dataset(data=train, transform=train_transforms)
val_ds = Dataset(data=val, transform=val_transforms)
test_ds = Dataset(data=test, transform=val_transforms)

# training, validating, testing of whole data so use a batch size of 1
train_loader = DataLoader(train_ds, batch_size=len(train), shuffle=True, num_workers=2, pin_memory=pin_memory)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=2, pin_memory=pin_memory)

In [12]:
for batch in test_loader:
    print("Image shape:", batch["image"].shape)
    print("Label shape:", batch["label"].shape)
    break


KeyboardInterrupt: 

In [67]:
# def visualize_batch(loader, num_samples=6, title="Batch Samples"):

#     plt.figure(figsize=(12, 2 * num_samples))
    
#     for i, batch in enumerate(loader):
#         images, labels = batch["image"], batch["label"]
        
#         # convert to numpy arrays (remove channel dimension)
#         images = images.numpy()
#         labels = labels.numpy()

#         batch_size = images.shape[0]
#         max_to_show = min(num_samples, batch_size)

#         for idx in range(max_to_show):
#             img = np.squeeze(images[idx])  # shape: (H, W, D) or (D, H, W)
#             lbl = np.squeeze(labels[idx])

#             # pick a slice along the last dimension
#             slice_index = img.shape[-1] // 2

#             # Show image
#             plt.subplot(max_to_show, 2, idx * 2 + 1)
#             plt.imshow(img[..., slice_index], cmap="gray")
#             plt.title(f"Image {idx+1}")
#             plt.axis("off")

#             # Show label
#             plt.subplot(max_to_show, 2, idx * 2 + 2)
#             plt.imshow(lbl[..., slice_index], cmap="gray")
#             plt.title(f"Label {idx+1}")
#             plt.axis("off")

#     plt.suptitle(title)
#     plt.tight_layout()
#     plt.show()

# visualize_batch(train_loader, num_samples=6, title="Train Batch Samples")

Model definition 

In [13]:
# calculate channels and strides based on given parameters
channels = []
for i in range(depth):
    channels.append(filter)
    filter *=2
print("channels: ", channels)
strides = []
for i in range(depth-1):
    strides.append(2)
print("strides: ", strides)

# define model 
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=channels,
    strides=strides,
    num_res_units=2,
    dropout=0.2,
    norm='BATCH'
).apply(monai.networks.normal_init).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)


channels:  [64, 128, 256, 512]
strides:  [2, 2, 2]


Define lr, optimization, epochs, loss

In [14]:
learning_rate = float(lr)

# common defaults
betas = (0.5, 0.999)

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, betas=betas)

best_model_path = f"{output_dir}/best_model.pt"
max_epochs = 1
patience = 22

loss = torch.nn.L1Loss().to(device)

train_losses = [float('inf')]
val_losses = [float('inf')]
best_val_loss = float('inf')
test_loss = 0

Train model

In [22]:
start = time.time()

for epoch in range(max_epochs):
    print(f"\nEpoch {epoch + 1}/{max_epochs}")
    model.train()

    train_loss_display = f"{train_losses[-1]:.4f}" if train_losses else "N/A"
    val_loss_display = f"{val_losses[-1]:.4f}" if val_losses else "N/A"

    progress_bar(
        index=epoch + 1,
        count=max_epochs,
        desc=f"epoch {epoch + 1}, training mae loss: {train_loss_display}, validation mae metric: {val_loss_display}",
        newline=True
    )

    # training
    avg_train_loss = 0

    print("--------Training--------")
    for batch in train_loader:
        # image is gad image, label is nogad image
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        optimizer.zero_grad() #resets optimizer to 0
        degad_images = model(gad_images)
        degad_images = degad_images[:, :, :image_size, :image_size, :image_size]
        
        train_loss = loss(degad_images, nogad_images)
        train_loss.backward() # computes gradients for each parameter based on loss
        optimizer.step() # updates the model weights using the gradient
        avg_train_loss += train_loss.item() 
   
    avg_train_loss /= len(train_loader) # average loss per current epoch 
    train_losses.append(avg_train_loss) # append total epoch loss divided by the number of training steps in epoch to loss list
    model.eval()
    
    print("--------Validation--------")
    # validation 
    with torch.no_grad(): #we do not update weights/biases in validation training, only used to assess current state of model
        avg_val_loss = 0 # will hold sum of all validation losses in epoch and then average
        for batch in val_loader: # iterating through dataloader
            gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
            degad_images = model(gad_images)
            degad_images = degad_images[:, :, :image_size, :image_size, :image_size]

            val_loss = loss(degad_images, nogad_images)
            avg_val_loss += val_loss.item() 
        
        avg_val_loss /= len(val_loader)  # Average validation loss for the epoch
        val_losses.append(avg_val_loss) 

        if avg_val_loss < best_val_loss:
            print(f"Validation loss improved from {best_val_loss:.4f} to {avg_val_loss:.4f}. Saving model.")
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            print(f"No improvement in validation loss. Best remains {best_val_loss:.4f}.")

        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch+1} due to no improvement in validation loss.")
            break
        
end = time.time()
total_time = end - start
print("time for training and validation: ", total_time)


Epoch 1/1

--------Training--------


: 

Plot training metrics 

In [None]:
# from original code 

with open (f'{output_dir}/model_stats.txt', 'w') as file:  
    file.write(f'Training time: {total_time:.2f} seconds\n') 
    file.write(f'Number of trainable parameters: {trainable_params}\n')

    if len(train_losses) > patience:
        file.write(f'Training loss (epoch {-patience}): {train_losses[-patience]:.4f}\n')
    else:
        file.write(f'Training loss (last epoch): {train_losses[-1]:.4f}\n')

    file.write(f'Validation loss (best): {best_val_loss:.4f}\n')
    
epochs = list(range(1, len(train_losses) + 1))
plt.figure(figsize=(12, 5))
plt.plot(epochs, train_losses, label="Training Loss")
plt.plot(epochs, val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(f'{output_dir}/lossfunction.png')
plt.close()

Test model with sliding window 

In [15]:
test_loss = 0 
home_dir = Path.home()

model.load_state_dict(torch.load(f'{home_dir}/graham/scratch/mri_degad/output_whole_images_new_data/image-256_batch-1_LR-0.0005_filter-64_depth-4_loss-ssim/best_model.pt', map_location=torch.device('cpu')))
model.eval()

output_dir_test = Path(output_dir) / "test"
output_dir_test.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
    for i, batch in enumerate(test_loader):      
        gad_images, nogad_images = batch["image"].to(device), batch["label"].to(device)
        gad_paths = batch["image_filepath"]
        
        degad_images = sliding_window_inference(gad_images, image_size, 1, model)
        loss_value = loss(degad_images, gad_images)
        test_loss += loss_value.item()

        # to save the output files 
        # shape[0] gives number of images 
        for j in range(degad_images.shape[0]):
            gad_path = gad_paths[j] # test dictionary image file name
            print(gad_path)
            gad_nib = nib.load(gad_path)
            sub = Path(gad_path).name.split("_")[0]
            
            degad_name = f"{sub}_acq-degad_T1w.nii.gz" 
            nogad_name = f"{sub}_acq-nogad_T1w.nii.gz" 
            gad_name = f"{sub}_acq-gad_T1w.nii.gz"           
            
            # convert predicted output to NumPy
            degad_np = degad_images[j, 0].detach().cpu().numpy()
            nogad_np = nogad_images[j, 0].detach().cpu().numpy()
            gad_np = gad_images[j, 0].detach().cpu().numpy()

            # use original gad_affine on the image 
            gad_affine = gad_nib.affine if gad_nib is not None else np.eye(4)

            # create nifti images 
            degad_nib = nib.Nifti1Image(degad_np, gad_affine)
            nogad_nib = nib.Nifti1Image(nogad_np, gad_affine)
            gad_nib = nib.Nifti1Image(gad_np, gad_affine)

            # output directory in BIDS format
            subject_output_dir = f'{output_dir_test}/bids/{sub}/ses-pre/anat'
            os.makedirs(subject_output_dir, exist_ok=True)

            # save all images
            nib.save(degad_nib, os.path.join(subject_output_dir, degad_name))
            nib.save(nogad_nib, os.path.join(subject_output_dir, nogad_name))
            nib.save(gad_nib, os.path.join(subject_output_dir, gad_name)) 
    
    
print(f"Test Loss: {test_loss / len(test_loader):.4f}")

/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P140/ses-pre/normalize/sub-P140_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P054/ses-pre/normalize/sub-P054_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P147/ses-pre/normalize/sub-P147_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P152/ses-pre/normalize/sub-P152_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P125/ses-pre/normalize/sub-P125_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P035/ses-pre/normalize/sub-P035_ses-pre_acq-gad_run-01_desc-normalize_minmax_T1w.nii.gz
/home/UWO/msnyde26/graham/scratch/degad_preprocessed_data/work/sub-P10