In [None]:
!pip install antspyx
!pip install nilearn
!pip install nibabel
!pip install SimpleITK

### Import libraries

In [1]:
import os
import ants
import time
import numpy as np
import nibabel as nib
import SimpleITK as sitk
from nilearn import image, plotting
from nilearn.image import resample_to_img, new_img_like
import subprocess

In [2]:
import matplotlib.pyplot as plt

### Preprocessing MRI image

#### 1. Pre-defined functions

In [3]:
def Adativehistogram_equalization(ref, alpha, beta):
    histogramEqualization = sitk.AdaptiveHistogramEqualizationImageFilter()
    histogramEqualization.SetAlpha(alpha)
    histogramEqualization.SetBeta(beta)
    equalized_volume = histogramEqualization.Execute(ref)
    return equalized_volume

def SmoothingRecursiveGaussian(image, sigma):
    gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
    gaussian.SetSigma(sigma)
    blur_image = gaussian.Execute(image)
    return blur_image

def normalise_zero_one(image):
    """Image normalisation. Normalises image to fit [0, 1] range."""

    image = image.astype(np.float32)

    minimum = np.min(image)
    maximum = np.max(image)

    if maximum > minimum:
        ret = (image - minimum) / (maximum - minimum)
    else:
        ret = image * 0.
    return ret

def skull_strip_fsl(input_path, output_path, frac="0.5"):
    bet_command = ["bet", input_path, output_path,"-f", frac, "-g", "0"]
    subprocess.call(bet_command)

def registration_to_mni(img, template):
    registration = ants.registration(fixed=template, moving=img, type_of_transform='Affine')
    transformed_ants = ants.apply_transforms(fixed=template, moving=img, transformlist=registration['fwdtransforms'])
    transformed_img = ants.to_nibabel(transformed_ants)
    return transformed_img

In [4]:
def visualize_nifti(img):
    # Load the NIFTI image
    img_data = img.get_fdata()

    # Get the middle slices for each view
    axial_slice = img_data[img_data.shape[0] // 2, :, :]
    coronal_slice = img_data[:, img_data.shape[1] // 2, :]
    sagittal_slice = img_data[:, :, img_data.shape[2] // 2]

    # Create a figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot axial view
    axes[0].imshow(axial_slice, cmap='gray')
    axes[0].set_title('Sagittal View')
    axes[0].axis('off')

    # Plot coronal view
    axes[1].imshow(coronal_slice, cmap='gray')
    axes[1].set_title('Coronal View')
    axes[1].axis('off')

    # Plot sagittal view
    axes[2].imshow(sagittal_slice, cmap='gray')
    axes[2].set_title('Axial View')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

#### 2. Test on an image

In [5]:
source_dir = "/ngochuynh/f/Dataset/ADNI/"
mri_image = nib.load(os.path.join(source_dir,'ADNI_renamed','021_S_2100','021_S_2100_Month_049.4_MRI_2014-10-20.nii'))
mni_template = nib.load(os.path.join(source_dir,'atlas','MNI152_T1_1mm.nii'))
dest_path = os.path.join(source_dir,'transformed.nii')

In [None]:
visualize_nifti(mri_image)

In [None]:
visualize_nifti(mni_template)

In [None]:
# Get the image data as a NumPy array
img_data = mri_image.get_fdata()
# Flip the sagittal view 180 degrees
flipped_img_data = np.flip(img_data, axis=0)
flipped_img = nib.Nifti1Image(flipped_img_data, mri_image.affine)

In [None]:
visualize_nifti(flipped_img)

In [None]:
#mri_ants = ants.from_nibabel(mri_image)
#mni_ants = ants.from_nibabel(mni_template)

In [None]:
#mri_ants.get_orientation()

In [None]:
#reorient_mri_img = mri_ants.reorient_image2(orientation='SAR')

In [None]:
#reorient_mri_img.get_orientation()

In [6]:
# Convert MRI image to ANTs format
mri_ants = ants.from_nibabel(mri_image)
#mri_ants = ants.from_nibabel(flipped_img)
mni_ants = ants.from_nibabel(mni_template)

# Perform registration to MRI using ANTs
transformed_mri_img = registration_to_mni(mri_ants, mni_ants)
nib.save(transformed_mri_img, dest_path)

In [7]:
starttime = time.time()
# Convert MRI image to ANTs format
mri_ants = ants.from_nibabel(mri_image)
#mri_ants = ants.from_nibabel(flipped_img)
mni_ants = ants.from_nibabel(mni_template)

# Perform registration to MRI using ANTs
transformed_mri_img = registration_to_mni(mri_ants, mni_ants)
nib.save(transformed_mri_img, 'test/transformed.nii')

# Skull stripping
skullstripping_path = 'test/stripped.nii'
skull_strip_fsl('test/transformed.nii', skullstripping_path, frac="0.5")
stripped_mri_img = nib.load(f'{skullstripping_path}.gz')

# Intensity normalization
normalized_mri_img = normalise_zero_one(stripped_mri_img.get_fdata())
normalized_mri_vol = sitk.GetImageFromArray(normalized_mri_img)
nib.save(nib.Nifti1Image(normalized_mri_img, transformed_mri_img.affine), 'test/normalized.nii')

# Histogram equalization
alpha = 0.8
beta  = 0.8
equalized_mri_vol = Adativehistogram_equalization(normalized_mri_vol, alpha, beta)
nib.save(nib.Nifti1Image(sitk.GetArrayFromImage(equalized_mri_vol), transformed_mri_img.affine), 'test/equalized.nii')

# Gaussian filter smoothing
fwhm = 2
voxelsize = 2
sigma = fwhm / (np.sqrt(8 * np.log(2)) * voxelsize)
smoothed_mri_vol = SmoothingRecursiveGaussian(equalized_mri_vol, sigma=sigma)

# Transfer the orientation information from the input image to the processed output image
smoothed_mri_nib = nib.Nifti1Image(sitk.GetArrayFromImage(smoothed_mri_vol), transformed_mri_img.affine)

# Write image
nib.save(smoothed_mri_nib, dest_path)
print(f'Preprocessing time: {time.time()-starttime:.3f} sec')

FileNotFoundError: [Errno 2] No such file or directory: 'bet'

In [None]:
visualize_nifti(transformed_mri_img)

In [None]:
dri_fa = nib.load('test/109_S_4499_Month_000.0_DTI_FA_2012-03-08.nii')
dri_ma = nib.load('test/109_S_4499_Month_000.0_DTI_MD_2012-03-08.nii')
pet_fdg = nib.load('test/109_S_4499_Month_000.4_PET_FDG_2012-03-19_raw.nii')
pet_AV45 = nib.load('test/109_S_4499_Month_000.6_PET_AV45_2012-03-26.nii')

In [None]:
dri_fa_ants = ants.from_nibabel(dri_fa)
dri_ma_ants = ants.from_nibabel(dri_ma)
pet_fdg_ants = ants.from_nibabel(pet_fdg)
pet_AV45_ants = ants.from_nibabel(pet_AV45)

In [None]:
transformed_mri_ants = ants.from_nibabel(transformed_mri_img)

In [None]:
transformed_dri_fa_img = registration_to_mni(dri_fa_ants, transformed_mri_ants)
nib.save(transformed_dri_fa_img, 'test/transformed_109_S_4499_Month_000.0_DTI_FA_2012-03-08.nii')
transformed_dri_ma_img = registration_to_mni(dri_ma_ants, transformed_mri_ants)
nib.save(transformed_dri_ma_img, 'test/transformed_109_S_4499_Month_000.0_DTI_MA_2012-03-08.nii')
transformed_pet_fdg_img = registration_to_mni(pet_fdg_ants, transformed_mri_ants)
nib.save(transformed_pet_fdg_img, 'test/transformed_109_S_4499_Month_000.0_PET_FDG_2012-03-08.nii')
transformed_pet_AV45_img = registration_to_mni(pet_AV45_ants, transformed_mri_ants)
nib.save(transformed_pet_AV45_img, 'test/transformed_109_S_4499_Month_000.0_PET_AV45_2012-03-08.nii')

#### 3. Preprocess all images

In [None]:
import os
import glob

In [None]:
def smallest_distance(numbers, reference):
    min_distance = float('inf')
    
    for num in numbers:
        if num is not None:
            distance = abs(num - reference)
            min_distance = min(min_distance, distance)
    
    return min_distance

In [None]:
def structured_image_dic(scrdir, files):
    rounded_months_dict = {}
    for file_name in files:
        parts = file_name.split("_")
        if len(parts) == 7 or len(parts) == 8:
            month = parts[4]
            image_type = parts[5]
            if image_type=='PET' or image_type=='DTI':
                image_type = parts[5] + '_' + parts[6]
            month = float(month)
            rounded_month = round(month / 6) * 6
            if rounded_month in rounded_months_dict:
                current_months = rounded_months_dict[rounded_month]['month']
                current_diff = smallest_distance(current_months, rounded_month)
                new_diff = abs(month - rounded_month)
                
                if image_type == "MRI":
                    if (new_diff < current_diff) or (rounded_months_dict[rounded_month]['image_paths'][0] is None):
                        rounded_months_dict[rounded_month]['month'][0] = month
                        rounded_months_dict[rounded_month]['image_paths'][0] = os.path.join(scrdir, file_name)
            else:
                rounded_months_dict[rounded_month] = {'month': [None, None, None, None, None], 'image_paths': [None, None, None, None, None]}
                if image_type == "MRI":
                    rounded_months_dict[rounded_month]['month'][0] = month
                    rounded_months_dict[rounded_month]['image_paths'][0] = os.path.join(scrdir, file_name)

            if image_type == "PET_FDG":
                rounded_months_dict[rounded_month]['month'][1] = month
                rounded_months_dict[rounded_month]['image_paths'][1] = os.path.join(scrdir, file_name)
            elif image_type == "PET_AV45":
                rounded_months_dict[rounded_month]['month'][2] = month
                rounded_months_dict[rounded_month]['image_paths'][2] = os.path.join(scrdir, file_name)
            elif image_type == "DTI_FA":
                rounded_months_dict[rounded_month]['month'][3] = month
                rounded_months_dict[rounded_month]['image_paths'][3] = os.path.join(scrdir, file_name)
            elif image_type == "DTI_MD":
                rounded_months_dict[rounded_month]['month'][4] = month
                rounded_months_dict[rounded_month]['image_paths'][4] = os.path.join(scrdir, file_name)
            
    return rounded_months_dict

In [None]:
image_dir = '/ngochuynh/f/Dataset/ADNI/ADNI_renamed'
list_ids  = glob.glob(image_dir+'/*')

In [None]:
mni_template = nib.load('/ngochuynh/f/Dataset/ADNI/atlas/MNI152_T1_1mm.nii')
mni_ants = ants.from_nibabel(mni_template)

In [None]:
test_files = list_ids

In [None]:
extra_list = []

In [None]:
for file_path in test_files:
    ptid  = os.path.basename(file_path)
    files =glob.glob(os.path.join(file_path, "*.nii"))
    filter_files = [f for f in files if "DTI_FA" in f or "DTI_MD" in f]
    preprocessed_dir = os.path.join(file_path, 'preprocessed')
    os.makedirs(preprocessed_dir, exist_ok=True)
    for f in filter_files:
        starttime = time.time()
        bn = os.path.basename(f)
        transformed_path = os.path.join(preprocessed_dir, 'DTI', 'transformed_'+bn)
        if not os.path.exists(transformed_path):
            image = nib.load(f)
            if (len(image.shape)==3 or image.shape[3]==1):
                image_ants = ants.from_nibabel(image)
                transformed_img = registration_to_mni(image_ants, mni_ants)
                transformed_path = os.path.join(preprocessed_dir, 'DTI', 'transformed_'+bn)
                nib.save(transformed_img, transformed_path)
                print(f'Processed: {ptid} ({image_ants.get_orientation()}) - {f} - {time.time()-starttime:.3f} sec')
print("Finished !!")

In [None]:
for file_path in test_files:
    ptid  = os.path.basename(file_path)
    files = os.listdir(file_path)
    structured_files = structured_image_dic(file_path, files)
    
    structured_files = 
    
    preprocessed_dir = os.path.join(file_path, 'preprocessed')
    os.makedirs(preprocessed_dir, exist_ok=True)
    for m, tp in structured_files.items():
        for i, img in enumerate(tp['image_paths']):
            starttime = time.time()
            if img is not None:
                bn = os.path.basename(img)
                if i==0:
                    # Convert MRI image to ANTs format
                    mri_image = nib.load(img)
                    mri_ants = ants.from_nibabel(mri_image)
                    if mri_ants.get_orientation() in ['RAI', 'RPI']:
                        # Get the image data as a NumPy array
                        img_data = mri_image.get_fdata()
                        # Flip the sagittal view 180 degrees
                        flipped_img_data = np.flip(img_data, axis=0)
                        flipped_img = nib.Nifti1Image(flipped_img_data, mri_image.affine)
                        mri_ants = ants.from_nibabel(flipped_img)
                    # Perform registration to MRI using ANTs
                    transformed_path = os.path.join(preprocessed_dir, 'transformed_'+bn)
                    if not os.path.exists(transformed_path):
                        transformed_mri_img = registration_to_mni(mri_ants, mni_ants)
                        nib.save(transformed_mri_img, transformed_path)
                        # Skull stripping
                        stripped_path = os.path.join(preprocessed_dir, 'stripped_'+bn)
                        skull_strip_fsl(transformed_path, stripped_path, frac="0.5")
                        stripped_mri_img = nib.load(f'{stripped_path}.gz')
                        # Intensity normalization
                        normalized_mri_img = normalise_zero_one(stripped_mri_img.get_fdata())
                        normalized_mri_vol = sitk.GetImageFromArray(normalized_mri_img)
                        normalized_path   = os.path.join(preprocessed_dir, 'normalized_'+bn)
                        nib.save(nib.Nifti1Image(normalized_mri_img, transformed_mri_img.affine), normalized_path)
                        # Histogram equalization
                        alpha = 0.8
                        beta  = 0.8
                        equalized_mri_vol = Adativehistogram_equalization(normalized_mri_vol, alpha, beta)
                        equalized_path    = os.path.join(preprocessed_dir, 'equalized_'+bn)
                        nib.save(nib.Nifti1Image(sitk.GetArrayFromImage(equalized_mri_vol), transformed_mri_img.affine), equalized_path)
                        # Gaussian filter smoothing
                        fwhm = 2
                        voxelsize = 1
                        sigma = fwhm / (np.sqrt(8 * np.log(2)) * voxelsize)
                        smoothed_mri_vol = SmoothingRecursiveGaussian(equalized_mri_vol, sigma=sigma)
                        smooth_path      = os.path.join(preprocessed_dir, 'smooth_'+bn)
                        nib.save(nib.Nifti1Image(sitk.GetArrayFromImage(smoothed_mri_vol), transformed_mri_img.affine), smooth_path)

                        extra_list.append({'ptid':ptid, 'month':m, 'img_type':i, 'orient':mri_ants.get_orientation()})
                        print(f'Processed: {ptid} ({mri_ants.get_orientation()}) - month {m} - image {i} - {time.time()-starttime:.3f} sec')
                else:
                    image = nib.load(img)
                    if (len(image.shape)==3 or image.shape[3]==1):
                        if i==3 or i==4:
                            image_ants = ants.from_nibabel(image)
                            transformed_img = registration_to_mni(image_ants, mni_ants)
                            nib.save(transformed_img, transformed_path)

                            extra_list.append({'ptid':ptid, 'month':m, 'img_type':i, 'orient':image_ants.get_orientation()})
                            print(f'Processed: {ptid} ({image_ants.get_orientation()}) - month {m} - image {i} - {time.time()-starttime:.3f} sec')

                #print(f'Processed: {ptid} - month {m} - image {i} - {time.time()-starttime:.3f} sec')
print("Finished !!")

In [None]:
image = nib.load("/ngochuynh/f/Dataset/ADNI/ADNI_renamed/003_S_4288/003_S_4288_Month_025.9_DTI_MD_2013-12-05.nii")

In [None]:
mni_template.shape

In [None]:
import json

In [None]:
with open('extra_list_v2', 'w') as fp:
    fp.write(json.dumps(extra_list, indent=4))

In [None]:
with open('extra_list_v2', 'r') as of:
    json_object = json.load(of)