In [18]:
import pandas as pd
import SimpleITK as sitk

from aidream_registration import constants
from aidream_registration.dataloaders import AtlasImagingNiftiLoader
from aidream_registration.utils.cohort_utils import get_perfusion_patients
from sklearn.utils import resample

from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.model_selection import train_test_split


import numpy as np
import ants


In [2]:
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [3]:
list_patients = get_perfusion_patients()
print("Number of patients:", len(list_patients))


Number of patients: 186


In [4]:
# Load the atlas imaging nifti loader :
atlas_loader = AtlasImagingNiftiLoader(source_mri="PIPELINE_SS")


In [17]:
# First, load the hammersmith atlas, and create a mask :
path_hammersmith = constants.DIR_DATA / "hammersmith" / "T1w_ICBM_skullstripped.nii.gz"
ants_hammersmith = ants.image_read(str(path_hammersmith))

sitk_hammersmith = sitk.ReadImage(str(path_hammersmith))
mask_hammersmith = ants_hammersmith > 0
print(f"Hammersmith mask volume: {mask_hammersmith.sum() / 1e3:.2f} cm3")


Hammersmith mask volume: 1886.57 cm3


In [6]:
# Load all segmentations in a dictionary :
dict_segmentations = {
    path_seg.name.removesuffix(".nii.gz"): ants.image_read(str(path_seg)) > 0
    for path_seg in (constants.DIR_DATA / "hammersmith").glob("*.nii.gz")
    if path_seg.name != "T1w_ICBM_skullstripped.nii.gz"
}

print("Number of segmentations:", len(dict_segmentations))

for segmentation, mask_segmentation in dict_segmentations.items():
    print(fr"{segmentation} volume: {mask_segmentation.sum() / 1e3:.2f} cm3")


Number of segmentations: 21
Brainstem volume: 26.94 cm3
Corpus_callosum volume: 21.05 cm3
L.frontal_lobe volume: 225.46 cm3
L.Grey_nuclei volume: 18.22 cm3
L.Limbic_lobe volume: 24.96 cm3
L.Occipital_lobe volume: 75.31 cm3
L.Parietal_lobe volume: 142.35 cm3
L.Post_fossea volume: 90.63 cm3
L.Temporal_lobe volume: 115.72 cm3
L.Ventricles volume: 6.72 cm3
L_insula volume: 16.62 cm3
R.frontal_lobe volume: 227.50 cm3
R.Grey_nuclei volume: 18.20 cm3
R.Limbic_lobe volume: 24.28 cm3
R.Occipital_lobe volume: 76.46 cm3
R.Parietal_lobe volume: 141.59 cm3
R.Post_fossea volume: 88.57 cm3
R.Temporal_lobe volume: 121.76 cm3
R.Ventricles volume: 6.11 cm3
R_insula volume: 16.48 cm3
third_ventricle volume: 0.63 cm3


In [7]:
# Create mask_1 for the ventricle voxels :
mask_1 = dict_segmentations["L.Ventricles"] + dict_segmentations["R.Ventricles"] + dict_segmentations["third_ventricle"]
mask_1 = mask_1 > 0

print(f"Ventricles volume: {mask_1.sum() / 1e3:.2f} cm3")


Ventricles volume: 13.47 cm3


In [100]:
# Create mask82 for the segmented voxels that are not ventricles :
mask_2 = None
for seg, mask_segmentation in dict_segmentations.items():

    if seg not in ["L.Ventricles", "R.Ventricles", "third_ventricle"]:
        if mask_2 is None:
            mask_2 = mask_segmentation
        else:
            mask_2 += mask_segmentation

mask_2 = mask_2 > 0

print(f"Brain volume: {mask_2.sum() / 1e3:.2f} cm3")

mask_2.to_file(str(constants.DIR_DATA / "hammersmith" / "custom" / "mask_2.nii.gz"))


Brain volume: 1472.11 cm3


In [34]:
# Apply region growing to the mask_1 to complete the ventricles :
import SimpleITK as sitk

left_ventricle = sitk.ReadImage(str(constants.DIR_DATA / "hammersmith" / "L.Ventricles.nii.gz"))
right_ventricle = sitk.ReadImage(str(constants.DIR_DATA / "hammersmith" / "R.Ventricles.nii.gz"))
third_ventricle = sitk.ReadImage(str(constants.DIR_DATA / "hammersmith" / "third_ventricle.nii.gz"))

# Convert the masks to an integer type (e.g., 8-bit unsigned integer)
left_ventricle = sitk.Cast(left_ventricle, sitk.sitkUInt8)
right_ventricle = sitk.Cast(right_ventricle, sitk.sitkUInt8)
third_ventricle = sitk.Cast(third_ventricle, sitk.sitkUInt8)


In [127]:
def clean_mask(mask_ventricle):
    
    mask_ventricle = sitk.Cast(mask_ventricle > 0, sitk.sitkUInt8)
    cleaned_mask = sitk.RelabelComponent(sitk.ConnectedComponent(mask_ventricle), sortByObjectSize=True) == 1
    
    return cleaned_mask


In [131]:
cleaned_left = clean_mask(left_ventricle)
sitk.WriteImage(cleaned_left, str(constants.DIR_DATA / "hammersmith" / "NEW" / "L.Ventricles_cleaned.nii.gz"))


In [132]:
import random
def select_seed(mask_ventricle):
    
    mask_array = sitk.GetArrayFromImage(mask_ventricle)    
    # Get the indices of the mask 
    indices = np.argwhere(mask_array > 0)            
    seed = random.choice(indices)
    # reordering the seed
    seed = (int(seed[2]), int(seed[1]), int(seed[0]))
    return seed


In [147]:
def grow_region(mask_ventricle, n_seeds=100, dilate=4):    
    
    print("Dilating mask...")
    dilated_mask_ventricle = sitk.BinaryDilate(mask_ventricle, [dilate] * mask_ventricle.GetDimension())
    
    print("Computing thresholds...")
    
    mask_array = sitk.GetArrayFromImage(mask_ventricle)
    hammersmith_array = sitk.GetArrayFromImage(sitk_hammersmith)
    intensities = hammersmith_array[mask_array > 0]
    
    mean_intensity = np.mean(intensities)
    std_intensity = np.std(intensities)
    
    print(
        f"Mean intensity: {mean_intensity:.2f}",
        f"Standard deviation: {std_intensity:.2f}",
    )
    lower_threshold = mean_intensity - 2 * std_intensity
    upper_threshold = mean_intensity + 2 * std_intensity
    
    print(
        f"Lower threshold: {lower_threshold:.2f}",
        f"Upper threshold: {upper_threshold:.2f}"
    )
    
    # print("Selecting seeds...")
    seeds = tuple(map(int, mask_array))
    
    print("Growing region...")
    region_grown = sitk.ConnectedThreshold(sitk_hammersmith, seedList=seeds, lower=lower_threshold, upper=upper_threshold)
    
    return region_grown
    

In [148]:
region_grown_left = grow_region(cleaned_left, n_seeds=1000, dilate=4)
sitk.WriteImage(region_grown_left, str(constants.DIR_DATA / "hammersmith" / "NEW" / "L.Ventricles_region_grown.nii.gz"))
        

Dilating mask...
Computing thresholds...
Mean intensity: 29.68 Standard deviation: 6.53
Lower threshold: 16.62 Upper threshold: 42.74


TypeError: only length-1 arrays can be converted to Python scalars

In [152]:
mask_array = sitk.GetArrayFromImage(left_ventricle)

In [153]:
mask

array([1, 1, 1, ..., 1, 1, 1], dtype=uint8)