## Libraries

In [13]:
%pip install pynrrd numpy torch torchvision monai tensorflow wandb nibabel

Defaulting to user installation because normal site-packages is not writeable
Collecting nibabel
  Downloading nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Collecting wrapt<1.15,>=1.11.0 (from tensorflow)
  Downloading wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting importlib-resources>=5.12 (from nibabel)
  Downloading importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)
Downloading nibabel-5.3.2-py3-none-any.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importlib_resources-6.5.2-py3-none-any.whl (37 kB)
Downloading wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (77 kB)
Installing collected packages: wrapt, importlib-resources, nibabel
[0m[31mERROR: pip's dependency resolver does not currently take into account all th

In [1]:
import monai
import os
import csv
import numpy as np
import nrrd
import torch
import PIL
import IPython.display
from tqdm import tqdm
import matplotlib.pyplot as plt
from monai.transforms import (
    LoadImage,
    LoadImaged)
from monai.inferers import sliding_window_inference

2025-03-31 15:51:34.484841: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-31 15:51:34.522603: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-31 15:51:34.522623: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-31 15:51:34.522641: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-31 15:51:34.530173: I tensorflow/core/platform/cpu_feature_g

In [5]:
def build_dict_ASOCA(data_path, mode="train"):
    # test if mode is correct
    if mode not in ["train", "validation", "test"]:
        raise ValueError(f"Please choose a mode in ['train', 'validation', 'test']. Current mode is {mode}.")

    # create empty dictionary
    dicts = list()

    for clazz in ["Diseased", "Normal"]:
        if mode == "train":
            for index in range(1,17):
                image_path = os.path.join(data_path, clazz, "CTCA", f"{clazz}_{index}.nrrd")
                mask_path = os.path.join(data_path, clazz, "Annotations", f"{clazz}_{index}.nrrd")
                dicts.append({"img": image_path, "mask": mask_path})
        if mode == "validation":
            for index in range(17,21):
                image_path = os.path.join(data_path, clazz, "CTCA", f"{clazz}_{index}.nrrd")
                mask_path = os.path.join(data_path, clazz, "Annotations", f"{clazz}_{index}.nrrd")
                dicts.append({"img": image_path, "mask": mask_path})
        if mode == "test":
            if clazz == "Diseased":
                for index in range(10,20):
                    image_path = os.path.join(data_path, clazz, f"Testset_Disease", f"{index}.nrrd")
                    dicts.append({"img": image_path})
            else:
                for index in range(10):
                    image_path = os.path.join(data_path, clazz, f"Testset_{clazz}", f"{index}.nrrd")
                    dicts.append({"img": image_path})
    return dicts

In [2]:
def build_dict_ASOCA_secret(data_path):
    # create empty dictionary
    dicts = list()

    for index in range(1,21):
        image_path = os.path.join(data_path, f"{index}.img.nii")
        dicts.append({"img": image_path})
    return dicts

In [6]:
# Some examples of image sizes and voxel dimensions

loader = LoadImage(dtype=np.float32, image_only=True)
image = loader(build_dict_ASOCA("ASOCA", mode="train")[0]["img"])
# print(f"input: {train_data_dicts[0]['image']}")
print(f"image shape: {image.shape}")
print(f"image voxel dim: {image.pixdim}")
image = loader(build_dict_ASOCA("ASOCA", mode="validation")[2]["img"])
# print(f"input: {train_data_dicts[0]['image']}")
print(f"image shape: {image.shape}")
print(f"image voxel dim: {image.pixdim}")

image shape: torch.Size([512, 512, 224])
image voxel dim: tensor([0.3867, 0.3867, 0.6250], dtype=torch.float64)
image shape: torch.Size([512, 512, 168])
image voxel dim: tensor([0.4316, 0.4316, 0.6250], dtype=torch.float64)


In [3]:
loader = LoadImage(dtype=np.float32, image_only=True, reader="NibabelReader")
image = loader(build_dict_ASOCA_secret("forDLMIA")[2]["img"])
print(f"image shape: {image.shape}")
print(f"image voxel dim: {image.pixdim}")

image shape: torch.Size([512, 512, 275])
image voxel dim: tensor([0.3613, 0.3613, 0.5000], dtype=torch.float64)


In [13]:
# adjust cache_rate based on how much memory you have. test_dataset is only used for visualization / at end, so dont need to cache it
# Look at the values for pixdim

# Training transforms
train_transform = monai.transforms.Compose([
    # Load the image with monai's image loader:
    monai.transforms.LoadImaged(keys=("img", "mask"), image_only=False),
    # Add channel since transforms expect a channel dimension:
    monai.transforms.EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    # Create uniform voxel spacing:
    monai.transforms.Spacingd(keys=["img", "mask"], pixdim=(0.4, 0.4, 0.5), mode=("bilinear", "nearest")),
    # Scale the intensities:
    monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
    # Random flip and rotate:
    monai.transforms.RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=1),
    monai.transforms.RandRotated(keys=['img', 'mask'], range_x=np.pi/4, prob=0.5, mode=['bilinear', 'nearest']),
    # Crop to 128×128×128
    monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=[256,256,128], random_size=False)
])

# Validation transforms
val_transform = monai.transforms.Compose([
    # Load the image with monai's image loader:
    monai.transforms.LoadImaged(keys=("img", "mask"), image_only=False),
    # Add channel since transforms expect a channel dimension:
    monai.transforms.EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    # Create uniform voxel spacing:
    monai.transforms.Spacingd(keys=["img", "mask"], pixdim=(0.4, 0.4, 0.5), mode=("bilinear", "nearest")),
    # Scale the intensities:
    monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
    #monai.transforms.DivisiblePadd(keys=['img', 'mask'], k=64)
])

# Test transforms
test_transform = monai.transforms.Compose([
    # Load the image with monai's image loader:
    monai.transforms.LoadImaged(keys=("img", "mask"), image_only=False),
    # Add channel since transforms expect a channel dimension:
    monai.transforms.EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    # Create uniform voxel spacing:
    monai.transforms.Spacingd(keys=["img", "mask"], pixdim=(0.4, 0.4, 0.5), mode=("bilinear", "nearest")),
    # Scale the intensities:
    monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
])


train_dataset = monai.data.CacheDataset(build_dict_ASOCA("ASOCA", mode="train"), transform=train_transform, num_workers = 8)
validation_dataset = monai.data.CacheDataset(build_dict_ASOCA("ASOCA", mode="validation"), transform=val_transform, num_workers =8)
#test_dataset = monai.data.CacheDataset(build_dict_ASOCA("ASOCA", mode="test"), transform=test_transform, cache_rate=0)

Loading dataset: 100%|██████████| 32/32 [01:26<00:00,  2.70s/it]
Loading dataset: 100%|██████████| 8/8 [00:23<00:00,  2.97s/it]


In [11]:
# Secret test transforms
secret_test_transform = monai.transforms.Compose([
    # Load the image with monai's image loader:
    monai.transforms.LoadImaged(keys=("img"), image_only=True),
    # Add channel since transforms expect a channel dimension:
    monai.transforms.EnsureChannelFirstd(keys=['img'], channel_dim="no_channel"),
    # Create uniform voxel spacing:
    monai.transforms.Spacingd(keys=["img"], pixdim=(0.4, 0.4, 0.5), mode=("bilinear")),
    # Scale the intensities:
    monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
])

secret_test_dataset = monai.data.CacheDataset(build_dict_ASOCA_secret("forDLMIA"), transform=secret_test_transform, cache_rate=0, num_workers=8)

In [14]:
# Check if the image size and voxel dimension have correctly changed
sample_dict = train_dataset[20]
print("Size of the image", sample_dict["img"].shape)
print("Voxel dim of the image", sample_dict["img"].pixdim)
print("Size of the mask", sample_dict["mask"].shape)
print("Voxel dim of the mask", sample_dict["mask"].pixdim)

Size of the image torch.Size([1, 256, 256, 128])
Voxel dim of the image tensor([0.4000, 0.4000, 0.5000], dtype=torch.float64)
Size of the mask torch.Size([1, 256, 256, 128])
Voxel dim of the mask tensor([0.4000, 0.4000, 0.5000], dtype=torch.float64)


In [12]:
sample_dict = secret_test_dataset[5]
print("Size of the image", sample_dict["img"].shape)
print("Voxel dim of the image", sample_dict["img"].pixdim)

Size of the image torch.Size([1, 453, 453, 275])
Voxel dim of the image tensor([0.4000, 0.4000, 0.5000], dtype=torch.float64)


In [28]:
# Visualization options
color_actual = [0, 0.5, 0]
color_predicted = [0, 0, 0.5]

def RGB_mask(mask, color):
    result = np.zeros((*mask.shape,3))
    for i in range(3):
        result[...,i]=color[i]*mask
    return result

def RGB_image(image):
    result = image-image.min() # [a,b] -> [0, b-a]
    result = result/result.max() # [0, b-a] -> [0,1] -> [0,255]
    return np.repeat(np.reshape(result[:,:,:], [image.shape[0],image.shape[1],image.shape[2],1]),3,axis=3)

def visualize_3d_gif(image=None, actual_mask=None, predicted_mask=None, name_gif="array.gif"):
    if image is not None:
        image = image.numpy().squeeze()
        rgb_image = RGB_image(image)
        result = rgb_image
    for i in range(2):
        mask = [actual_mask, predicted_mask][i]
        if mask is not None:
            mask = mask.numpy().squeeze()
            rgb_mask = RGB_mask(mask, [color_actual, color_predicted][i])
            if result is not None:
                result += rgb_mask
            else:
                result = rgb_mask

    result = result/np.max(result)*255
    result = result.astype(np.uint8)
    images = [PIL.Image.fromarray(result[:,:,index,:]) for index in range(image.shape[2])]
    images[0].save(name_gif, save_all=True, append_images=images[1:],loop=0)

def visualize_3d_masks(actual_mask=None, predicted_mask=None):
    fig=plt.figure()
    ax = fig.add_subplot(111, projection="3d")

    for i in range(2):
        mask = [actual_mask, predicted_mask][i]
        if mask is not None:
            mask = mask.numpy().squeeze()
            pos = np.where(mask==1)
            ax.scatter(pos[0],pos[1],pos[2], color=[color_actual, color_predicted][i])

    # multiple figures at different angles?
    #ax.view_init(45, 0)
    plt.show()    

def visualize_histogram(image, mask):
    image = image.numpy().squeeze()
    mask = mask.numpy().squeeze()
    fig, axs = plt.subplots(1, 2, tight_layout=True) 
    axs[0].hist(image[mask==0], bins=20)
    axs[0].set_title("Background class")
    axs[1].hist(image[mask==1], bins=20)
    axs[1].set_title("Vessel class")   
    plt.show()

In [None]:
# visualize_3d_gif(image=validation_dataset[0]["img"], actual_mask=validation_dataset[0]["mask"], name_gif="val_array.gif")
# display(IPython.display.Image(data=open("val_array.gif",'rb').read(), format='png'))
# visualize_3d_masks(actual_mask=validation_dataset[0]["mask"])
visualize_histogram(validation_dataset[0]["img"],validation_dataset[0]["mask"])

In [31]:
def visualize_histogram_secret(original, secret):
    original = original.numpy().squeeze().flatten()
    secret = secret.numpy().squeeze().flatten()
    fig, axs = plt.subplots(1, 2, tight_layout=True) 
    axs[0].hist(original, bins=20)
    axs[0].set_title("Sample from original dataset")
    axs[1].hist(secret, bins=20)
    axs[1].set_title("Sample from secret test dataset")   
    plt.show()

In [None]:
# visualize_3d_gif(image=secret_test_dataset[0]["img"])
# display(IPython.display.Image(data=open("array.gif",'rb').read(), format='png'))
visualize_histogram_secret(validation_dataset[0]["img"], secret_test_dataset[0]["img"])

In [11]:
import wandb
wandb.login()

train_loader = monai.data.DataLoader(train_dataset, batch_size=4, num_workers=8,pin_memory=torch.cuda.is_available())
validation_loader = monai.data.DataLoader(validation_dataset, batch_size=1, num_workers=8,pin_memory=torch.cuda.is_available())
device = torch.device("cuda:1")
model = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)

In [None]:
run = wandb.init(
    project='ASOCAproject',
    name='Test_crop256_v3',
    config={
        'loss function': str(loss_function), 
        'lr': optimizer.param_groups[0]["lr"],
        'batch_size': train_loader.batch_size,
    }
)
# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!
# For example you should add information on your model...

run_id = run.id # We remember here the run ID to be able to write the evaluation metrics

def log_to_wandb(epoch, train_loss, val_loss, batch_data):
    """ Function that logs ongoing training variables to W&B """
    # This part is not yet working
    # def log_to_wandb(epoch, train_loss, val_loss, batch_data, outputs):
    # Create list of images that have segmentation masks for model output and ground truth
    #log_imgs = [wandb.Image(img, masks=wandb_masks(mask_output, mask_gt)) for img, mask_output,
    #            mask_gt in zip(batch_data['img'], outputs, batch_data['mask'])]

    # Send epoch, losses and images to W&B
    #wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'results': log_imgs})
    wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss})

for epoch in tqdm(range(1000)):
    # training
    model.train()
    epoch_loss=0
    step=0
    for batch in tqdm(train_loader, desc="Training Step", leave=False, ncols=100):  # Nested tqdm for training steps
        print(step)
        step += 1
        optimizer.zero_grad()
        inputs = batch["img"].to(device)
        labels = batch["mask"].to(device)
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() 
    train_loss = epoch_loss / step
    
    # validation
    step = 0
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(validation_loader, desc="Validation Step", leave=False, ncols=100):  # Nested tqdm for training steps
            print(step)
            step += 1
            model.eval()
            inputs = batch["img"].to(device)
            labels = batch["mask"].to(device)
            outputs = sliding_window_inference(inputs, (256, 256, 128), 4, model)
            #outputs = model(inputs)
            loss = loss_function(outputs, labels)
            val_loss+= loss.item()
        val_loss = val_loss / step
    
    log_to_wandb(epoch, train_loss, val_loss, batch)
    print(f"Epoch {epoch+1}, Train loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}")
    
    torch.save(model.state_dict(),os.path.join(f"Trained_crop256_v3", f"trainedUNet_epoch{epoch}.pt"))
run.finish()

In [12]:
# Option to load a previous trained model, check path!
model.load_state_dict(torch.load(r'trainedUNet_epoch543.pt'))

RuntimeError: Attempting to deserialize object on CUDA device 7 but torch.cuda.device_count() is 2. Please use torch.load with map_location to map your storages to an existing device.

In [None]:
def visual_evaluation(sample, model):
    """
    Allow the visual inspection of one sample by plotting the X-ray image, the ground truth (green)
    and the segmentation map produced by the network (red).
    
    Args:
        sample (Dict[str, torch.Tensor]): sample composed of an X-ray ('img') and a mask ('mask').
        model (torch.nn.Module): trained model to evaluate.
    """
    model.eval()
    inferer = monai.inferers.SlidingWindowInferer(roi_size=[256, 256, 128])
    discrete_transform = monai.transforms.AsDiscrete(logit_thresh=0.5, threshold_values=True)
    Softmax = torch.nn.Softmax()
    with torch.no_grad():
        print(sample['img'].shape)
        #output = discrete_transform(Softmax(model(sample['img'].to('cuda:1'))).cpu()).squeeze()
        output = discrete_transform(Softmax(inferer(sample['img'].to('cuda:6'), network=model).cpu())).squeeze()
        output = np.squeeze(output[1, :, :, :])
        print(output.shape)
    
    fig, ax =plt.subplots(1,2, subplot_kw={"projection":"3d"})

    actual_mask = np.squeeze(sample['mask'])
    pos = np.where(actual_mask==1)
    print(len(pos[0]))
    ax[0].scatter(pos[0],pos[1],pos[2], color=color_actual)
    
    pos2 = np.where(output == 1)
    print(len(pos2[0]))
    ax[1].scatter(pos2[0],pos2[1],pos2[2], color=color_predicted)
   
    plt.show()

In [None]:
for sample in validation_loader:
    visual_evaluation(sample, model.to('cuda:6'))

In [2]:
# Check GPU memory
!nvidia-smi

Mon Mar 31 15:15:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A16                     Off |   00000000:1B:00.0 Off |                    0 |
|  0%   41C    P0             35W /   62W |    1677MiB /  15356MiB |     81%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A16                     Off |   00