In [None]:
!pip install antspyx
!pip install nibabel
!pip install tf-models-official
!pip install tensorflow-addons

In [None]:
import os
import csv
import glob
import ants
import time
import json
import numpy as np
import pandas as pd
import nibabel as nib
import nibabel.processing

from scipy import ndimage
from skimage import morphology

import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.ndimage import gaussian_gradient_magnitude

In [None]:
from IPython.display import display, HTML

In [None]:
def image_registration(img, template):
    registration = ants.registration(fixed=template, moving=img, aff_sampling=36, type_of_transform='Translation')
    transformed_ants = ants.apply_transforms(fixed=template, moving=img, interpolator='bSpline', transformlist=registration['fwdtransforms'])
    transformed_img = ants.to_nibabel(transformed_ants)
    return transformed_img

def postprocessing_segmentation(segmentation_mask, voxel_threshold=100, smoothing_sigma=1, expansion_distance=10):
    cleaned_mask = morphology.remove_small_objects(np.asarray(segmentation_mask > 0.5, dtype=np.bool_), min_size=voxel_threshold, connectivity=1)
    smoothed_mask = ndimage.gaussian_filter(cleaned_mask.astype(float), sigma=smoothing_sigma)
    smoothed_mask = smoothed_mask > 0.5
    final_segmentation = ndimage.binary_dilation(smoothed_mask, iterations=expansion_distance)
    return final_segmentation

def compute_gradient_map(image_visit1, image_visit2):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1 = np.gradient(norm_image_visit1)
    gradient2 = np.gradient(norm_image_visit2)
    saliency_map1 = np.sqrt(np.sum(np.square(gradient1), axis=0))
    saliency_map2 = np.sqrt(np.sum(np.square(gradient2), axis=0))
    
    gradient_difference = np.abs(saliency_map1 - saliency_map2)
    #normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    return gradient_difference
    
def compute_gradient_v2_map(image_visit1, image_visit2):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1_x = np.gradient(norm_image_visit1, axis=0)
    gradient1_y = np.gradient(norm_image_visit1, axis=1)
    gradient1_z = np.gradient(norm_image_visit1, axis=2)
    gradient2_x = np.gradient(norm_image_visit2, axis=0)
    gradient2_y = np.gradient(norm_image_visit2, axis=1)
    gradient2_z = np.gradient(norm_image_visit2, axis=2)
    
    saliency_xy_1 = np.arctan2(gradient1_y, gradient1_x)
    saliency_xy_2 = np.arctan2(gradient2_y, gradient2_x)
    saliency_xz_1 = np.arctan2(gradient1_z, gradient1_x)
    saliency_xz_2 = np.arctan2(gradient2_z, gradient2_x)
    saliency_yz_1 = np.arctan2(gradient1_z, gradient1_y)
    saliency_yz_2 = np.arctan2(gradient2_z, gradient2_y)
    
    saliency_map1 = np.maximum(np.maximum(saliency_xy_1, saliency_xz_1), saliency_yz_1)
    saliency_map2 = np.maximum(np.maximum(saliency_xy_2, saliency_xz_2), saliency_yz_2)
    
    gradient_difference = np.abs(saliency_map2 - saliency_map1)
    #normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    return gradient_difference
    
def compute_gradient_v3_map(image_visit1, image_visit2):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1_x = np.gradient(norm_image_visit1, axis=0)
    gradient1_y = np.gradient(norm_image_visit1, axis=1)
    gradient1_z = np.gradient(norm_image_visit1, axis=2)
    gradient2_x = np.gradient(norm_image_visit2, axis=0)
    gradient2_y = np.gradient(norm_image_visit2, axis=1)
    gradient2_z = np.gradient(norm_image_visit2, axis=2)
    
    saliency_x1 = np.arctan2(np.sqrt(gradient1_y**2 + gradient1_z**2), gradient1_x)
    saliency_x2 = np.arctan2(np.sqrt(gradient2_y**2 + gradient2_z**2), gradient2_x)
    saliency_y1 = np.arctan2(np.sqrt(gradient1_x**2 + gradient1_z**2), gradient1_y)
    saliency_y2 = np.arctan2(np.sqrt(gradient2_x**2 + gradient2_z**2), gradient2_y)
    saliency_z1 = np.arctan2(np.sqrt(gradient1_x**2 + gradient1_y**2), gradient1_z)
    saliency_z2 = np.arctan2(np.sqrt(gradient2_x**2 + gradient2_y**2), gradient2_z)
    
    saliency_map1 = np.maximum(np.maximum(saliency_x1, saliency_y1), saliency_z1)
    saliency_map2 = np.maximum(np.maximum(saliency_x2, saliency_y2), saliency_z2)
    
    gradient_difference = np.abs(saliency_map2 - saliency_map1)
    #normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    return gradient_difference

def compute_saliency_map(image_visit1, image_visit2):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    # Compute the absolute difference between the two images
    difference_image = np.abs(norm_image_visit2 - norm_image_visit1)
    
    # Normalize the difference image to the range [0, 1]
    #normalized_difference = (difference_image - np.min(difference_image)) / (np.max(difference_image) - np.min(difference_image))
    
    return difference_image

def crop_region_image(image):
    x, y, z = np.where(image != 0)
    min_x, max_x = min(x), max(x)
    min_y, max_y = min(y), max(y)
    min_z, max_z = min(z), max(z)
    crop_image = image[min_x:max_x+1, min_y:max_y+1, min_z:max_z+1]
    return crop_image, [min_x, max_x, min_y, max_y, min_z, max_z]

def save_image(image, ref, dest_path, filename):
    img = nib.Nifti1Image(image, ref.affine, header=ref.header)
    try:
         if not os.path.exists(dest_path):
            os.makedirs(dest_path)
    except OSError as err:
         print(err)
    nib.save(img, os.path.join(dest_path, filename))

In [None]:
source_dir = "/ngochuynh/f/Dataset/ADNI/"
atlas_file = "temp/BN_Atlas_246_1mm.nii"
extracted_region = 'Hippocampus'
subregion_ids = [215, 216, 217, 218]

In [None]:
atlas = nib.load(os.path.join(source_dir, atlas_file))
atlas_data = atlas.get_fdata()
atlas_ants = ants.from_nibabel(atlas)

In [None]:
list_patietn_id = glob.glob(os.path.join(source_dir, 'ADNI_renamed', '*'), recursive = True)

In [None]:
df_coordinate = pd.DataFrame(columns=['PatientID', 'Visit', 'Type', 'Orientation', 'Min_X', 'Max_X', 'Min_Y', 'Max_Y', 'Min_Z', 'Max_Z', 'BL_Path', 'CURR_Path'])
dfs = []

In [None]:
sublist = list_patietn_id[1500:]

In [None]:
total = len(sublist)
progress = 0
display(HTML("""
        <div style="width: 100%; background-color: #f0f0f0; border-radius: 5px; padding: 3px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);">
            <div id="progress-bar" style="width: 0%; height: 20px; background-color: #4CAF50; border-radius: 5px; transition: width 0.3s ease;"></div>
        </div>
    """))

for pid_path in sublist:
    pid = os.path.basename(pid_path)
    img_paths = glob.glob(os.path.join(pid_path,'preprocessed','MRI','smooth_*.nii'))
    if img_paths:
        images = {}
        image_paths = {}
        img_orientations = {}
        visits = []

        for img_path in img_paths:
            base_name = img_path.split("/")[-1]
            visit_number = float(base_name.split("_")[5])
            rounded_visit = round(visit_number / 6) * 6
            image_paths[rounded_visit] = img_path
            images[rounded_visit] = nib.load(img_path)
            img_orientations[rounded_visit] = ants.from_nibabel(images[rounded_visit]).get_orientation()
            if rounded_visit not in visits:
                visits.append(rounded_visit)
        visits = sorted(visits)

        img0_ants = ants.from_nibabel(images[visits[0]])
        coreg_atlas0 = image_registration(atlas_ants, img0_ants)
        dest_dir = os.path.join(source_dir,'ADNI_saliency',pid)
        try:
             if not os.path.exists(dest_dir):
                os.makedirs(dest_dir)
        except OSError as err:
             print(err)
        nib.save(coreg_atlas0, os.path.join(dest_dir,'atlas0_coreg.nii'))

        rounded_image0 = np.round(coreg_atlas0.get_fdata()).astype(int)
        extracted_mask0 = np.isin(rounded_image0, subregion_ids)
        mask_img0 = nib.Nifti1Image(extracted_mask0, images[visits[0]].affine, header=images[visits[0]].header)
        final_segmentation = postprocessing_segmentation(mask_img0.get_fdata(), voxel_threshold=75, smoothing_sigma=1, expansion_distance=3)
        mask_final = nib.Nifti1Image(final_segmentation, images[visits[0]].affine, header=images[visits[0]].header)
        try:
             if not os.path.exists(os.path.join(dest_dir,extracted_region)):
                os.makedirs(os.path.join(dest_dir,extracted_region))
        except OSError as err:
             print(err)
        nib.save(mask_final, os.path.join(dest_dir,extracted_region,'region_mask.nii'))

        for i in range(len(visits)-1):
            segment0 = np.multiply(images[visits[i]].get_fdata(), final_segmentation)
            path0 = image_paths[visits[i]]
            for j in range(i+1, len(visits)):
#                 if visits[j] - visits[i] > 36:
#                     break
                pathnext = image_paths[visits[j]]
                segmentnext = np.multiply(images[visits[j]].get_fdata(), final_segmentation)
                salmap_ad = compute_saliency_map(segment0, segmentnext)
                salmap_mg = compute_gradient_map(segment0, segmentnext)
                salmap_dg_1 = compute_gradient_v2_map(segment0, segmentnext)
                salmap_dg_2 = compute_gradient_v3_map(segment0, segmentnext)

                crop_salmap_ad, coor_salmap_ad = crop_region_image(salmap_ad)
                crop_salmap_mg, coor_salmap_mg = crop_region_image(salmap_mg)
                crop_salmap_dg_1, coor_salmap_dg_1 = crop_region_image(salmap_dg_1)
                crop_salmap_dg_2, coor_salmap_dg_2 = crop_region_image(salmap_dg_2)

                data1 = {'PatientID': [pid], 'Visit': [f'{visits[i]}_{visits[j]}'], 'Type': ['AD'], 'Orientation': [img_orientations[visits[j]]],
                         'Min_X': [coor_salmap_ad[0]], 'Max_X': [coor_salmap_ad[1]],
                         'Min_Y': [coor_salmap_ad[2]], 'Max_Y': [coor_salmap_ad[3]],
                         'Min_Z': [coor_salmap_ad[4]], 'Max_Z': [coor_salmap_ad[5]],
                         'BL_Path': [path0], 'CURR_Path': [pathnext]}
                data2 = {'PatientID': [pid], 'Visit': [f'{visits[i]}_{visits[j]}'], 'Type': ['MG'], 'Orientation': [img_orientations[visits[j]]],
                         'Min_X': [coor_salmap_mg[0]], 'Max_X': [coor_salmap_mg[1]],
                         'Min_Y': [coor_salmap_mg[2]], 'Max_Y': [coor_salmap_mg[3]],
                         'Min_Z': [coor_salmap_mg[4]], 'Max_Z': [coor_salmap_mg[5]],
                         'BL_Path': [path0], 'CURR_Path': [pathnext]}
                data3 = {'PatientID': [pid], 'Visit': [f'{visits[i]}_{visits[j]}'], 'Type': ['DG1'], 'Orientation': [img_orientations[visits[j]]],
                         'Min_X': [coor_salmap_dg_1[0]], 'Max_X': [coor_salmap_dg_1[1]],
                         'Min_Y': [coor_salmap_dg_1[2]], 'Max_Y': [coor_salmap_dg_1[3]],
                         'Min_Z': [coor_salmap_dg_1[4]], 'Max_Z': [coor_salmap_dg_1[5]],
                         'BL_Path': [path0], 'CURR_Path': [pathnext]}
                data4 = {'PatientID': [pid], 'Visit': [f'{visits[i]}_{visits[j]}'], 'Type': ['DG2'], 'Orientation': [img_orientations[visits[j]]],
                         'Min_X': [coor_salmap_dg_2[0]], 'Max_X': [coor_salmap_dg_2[1]],
                         'Min_Y': [coor_salmap_dg_2[2]], 'Max_Y': [coor_salmap_dg_2[3]],
                         'Min_Z': [coor_salmap_dg_2[4]], 'Max_Z': [coor_salmap_dg_2[5]],
                         'BL_Path': [path0], 'CURR_Path': [pathnext]}

                dfs.append(pd.DataFrame(data1))
                dfs.append(pd.DataFrame(data2))
                dfs.append(pd.DataFrame(data3))
                dfs.append(pd.DataFrame(data4))

                save_image(salmap_ad, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'salmap_ad.nii')
                save_image(salmap_mg, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'salmap_mg.nii')
                save_image(salmap_dg_1, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'salmap_dg_1.nii')
                save_image(salmap_dg_2, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'salmap_dg_2.nii')
                save_image(crop_salmap_ad, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'crop_salmap_ad.nii')
                save_image(crop_salmap_mg, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'crop_salmap_mg.nii')
                save_image(crop_salmap_dg_1, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'crop_salmap_dg_1.nii')
                save_image(crop_salmap_dg_2, images[visits[j]], os.path.join(dest_dir, extracted_region, f'sample_{i:02d}', f'{visits[i]:03d}_{visits[j]:03d}'), f'crop_salmap_dg_2.nii')
    
    progress += 1
    percentage = (progress / total) * 100
    display(HTML("<script>document.getElementById('progress-bar').style.width='{}%';</script>".format(percentage)))

In [None]:
df_coordinate = pd.concat(dfs, ignore_index=True)

In [None]:
df_coordinate.to_csv(f'./data/ADNI_{extracted_region}_coordinate_2000.csv', header=True, index=False)