In [None]:
!pip install nibabel
!pip install antspyx
!pip install SimpleITK
!pip install positional-encodings[tensorflow]

In [None]:
import os
import csv
import ants
import time
import json
import numpy as np
import nibabel as nib
import nibabel.processing

from scipy import ndimage
from skimage import morphology

import SimpleITK as sitk
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.ndimage import gaussian_gradient_magnitude

In [None]:
def image_registration(img, template):
    registration = ants.registration(fixed=template, moving=img, aff_sampling=36, type_of_transform='Translation')
    transformed_ants = ants.apply_transforms(fixed=template, moving=img, interpolator='bSpline', transformlist=registration['fwdtransforms'])
    transformed_img = ants.to_nibabel(transformed_ants)
    return transformed_img

def postprocessing_segmentation(segmentation_mask, voxel_threshold=100, smoothing_sigma=1, expansion_distance=10):
    cleaned_mask = morphology.remove_small_objects(np.asarray(segmentation_mask > 0.5, dtype=np.int64), min_size=voxel_threshold, connectivity=1)
    smoothed_mask = ndimage.gaussian_filter(cleaned_mask.astype(float), sigma=smoothing_sigma)
    smoothed_mask = smoothed_mask > 0.5
    final_segmentation = ndimage.binary_dilation(smoothed_mask, iterations=expansion_distance)
    return final_segmentation

In [None]:
source_dir = "/ngochuynh/f/Dataset/ADNI/"
atlas_file = "temp/BN_Atlas_246_1mm.nii"
status_type = "CN_MCI"

In [None]:
atlas = nib.load(os.path.join(source_dir, atlas_file))
atlas_data = atlas.get_fdata()

In [None]:
img0 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_000.0_MRI_2006-01-20.nii"))
img12 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_012.3_MRI_2007-01-24.nii"))
img24 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_024.8_MRI_2008-02-04.nii"))
# img36 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_011_S_0021_Month_037.0_MRI_2008-10-23.nii"))
img48 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_048.7_MRI_2010-01-21.nii"))
img60 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_061.0_MRI_2011-01-24.nii"))
img72 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_073.2_MRI_2012-01-25.nii"))
img84 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_085.8_MRI_2013-02-05.nii"))
img96 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_097.7_MRI_2014-01-28.nii"))
img108 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_123_S_0106_Month_110.1_MRI_2015-02-04.nii"))
# img120 = nib.load(os.path.join(source_dir, f"temp/{status_type}/transformed_011_S_0021_Month_122.5_MRI_2015-11-03.nii"))

In [None]:
img0_ants = ants.from_nibabel(img0)
img12_ants = ants.from_nibabel(img12)
img24_ants = ants.from_nibabel(img24)
# img36_ants = ants.from_nibabel(img36)
img48_ants = ants.from_nibabel(img48)
img60_ants = ants.from_nibabel(img60)
img72_ants = ants.from_nibabel(img72)
img84_ants = ants.from_nibabel(img84)
img96_ants = ants.from_nibabel(img96)
img108_ants = ants.from_nibabel(img108)
# img120_ants = ants.from_nibabel(img120)
atlas_ants = ants.from_nibabel(atlas)

In [None]:
coreg_atlas0 = image_registration(atlas_ants, img0_ants)
nib.save(coreg_atlas0, os.path.join(source_dir, f'temp/{status_type}/atlas0_coreg.nii'))

In [None]:
extracted_region = 'Hippocampus'
# subregions = region_dict[extracted_region]
# subregion_ids = [41, 42]
subregion_ids = [215, 216, 217, 218]
# subregion_ids = [490, 317, 386, 488, 489, 487, 352, 455, 381]

In [None]:
rounded_image0 = np.round(coreg_atlas0.get_fdata()).astype(int)
extracted_mask0 = np.isin(rounded_image0, subregion_ids)
mask_img0 = nib.Nifti1Image(extracted_mask0, img0.affine, header=img0.header)
nib.save(mask_img0, os.path.join(source_dir, f'temp/{status_type}/mask0_{extracted_region}.nii'))

In [None]:
final_segmentation = postprocessing_segmentation(mask_img0.get_fdata(), voxel_threshold=75, smoothing_sigma=1, expansion_distance=3)

In [None]:
mask_final = nib.Nifti1Image(final_segmentation, img0.affine, header=img0.header)
nib.save(mask_final, os.path.join(source_dir, f'temp/{status_type}/mask_{extracted_region}.nii'))

In [None]:
segment0 = np.multiply(img0.get_fdata(), final_segmentation)
segment12 = np.multiply(img12.get_fdata(), final_segmentation)
segment24 = np.multiply(img24.get_fdata(), final_segmentation)
# segment36 = np.multiply(img36.get_fdata(), final_segmentation)
segment48 = np.multiply(img48.get_fdata(), final_segmentation)
segment60 = np.multiply(img60.get_fdata(), final_segmentation)
segment72 = np.multiply(img72.get_fdata(), final_segmentation)
segment84 = np.multiply(img84.get_fdata(), final_segmentation)
segment96 = np.multiply(img96.get_fdata(), final_segmentation)
segment108 = np.multiply(img108.get_fdata(), final_segmentation)
# segment120 = np.multiply(img120.get_fdata(), final_segmentation)

In [None]:
segment0_img = nib.Nifti1Image(segment0, img0.affine, header=img0.header)
nib.save(segment0_img, os.path.join(source_dir, f'temp/{status_type}/seg0_{extracted_region}.nii'))

# segment12_img = nib.Nifti1Image(segment12, img12.affine, header=img12.header)
# nib.save(segment12_img, os.path.join(source_dir, f'temp/{status_type}/seg12_{extracted_region}.nii'))

# segment24_img = nib.Nifti1Image(segment24, img24.affine, header=img24.header)
# nib.save(segment24_img, os.path.join(source_dir, f'temp/{status_type}/seg24_{extracted_region}.nii'))

# segment36_img = nib.Nifti1Image(segment36, img36.affine, header=img36.header)
# nib.save(segment36_img, os.path.join(source_dir, f'temp/{status_type}/seg36_{extracted_region}.nii'))

# segment48_img = nib.Nifti1Image(segment48, img48.affine, header=img48.header)
# nib.save(segment48_img, os.path.join(source_dir, f'temp/{status_type}/seg48_{extracted_region}.nii'))

# segment60_img = nib.Nifti1Image(segment60, img60.affine, header=img60.header)
# nib.save(segment60_img, os.path.join(source_dir, f'temp/{status_type}/seg60_{extracted_region}.nii'))

In [None]:
def compute_gradient_map(image_visit1, image_visit2, output_path, elevation_angle=30, azimuthal_angle=45, roll_angle=0):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1 = np.gradient(norm_image_visit1)
    gradient2 = np.gradient(norm_image_visit2)
    saliency_map1 = np.sqrt(np.sum(np.square(gradient1), axis=0))
    saliency_map2 = np.sqrt(np.sum(np.square(gradient2), axis=0))
    
    gradient_difference = np.abs(saliency_map1 - saliency_map2)
    normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    colormap = cm.get_cmap('jet')
    
    visualize_3d_saliency_grid(gradient_difference, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle)
    
def compute_gradient_v2_map(image_visit1, image_visit2, output_path, direction, elevation_angle=30, azimuthal_angle=45, roll_angle=0):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1_x = np.gradient(norm_image_visit1, axis=0)
    gradient1_y = np.gradient(norm_image_visit1, axis=1)
    gradient1_z = np.gradient(norm_image_visit1, axis=2)
    gradient2_x = np.gradient(norm_image_visit2, axis=0)
    gradient2_y = np.gradient(norm_image_visit2, axis=1)
    gradient2_z = np.gradient(norm_image_visit2, axis=2)
    if direction=='xy':
        saliency_map1 = np.arctan2(gradient1_y, gradient1_x)
        saliency_map2 = np.arctan2(gradient2_y, gradient2_x)
    elif direction=='xz':
        saliency_map1 = np.arctan2(gradient1_z, gradient1_x)
        saliency_map2 = np.arctan2(gradient2_z, gradient2_x)
    elif direction=='yz':
        saliency_map1 = np.arctan2(gradient1_z, gradient1_y)
        saliency_map2 = np.arctan2(gradient2_z, gradient2_y)
    
    gradient_difference = np.abs(saliency_map2 - saliency_map1)
    #normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    colormap = cm.get_cmap('jet')
    
    visualize_3d_saliency_grid(gradient_difference, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle)
    
def compute_gradient_v3_map(image_visit1, image_visit2, output_path, direction, elevation_angle=30, azimuthal_angle=45, roll_angle=0):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    gradient1_x = np.gradient(norm_image_visit1, axis=0)
    gradient1_y = np.gradient(norm_image_visit1, axis=1)
    gradient1_z = np.gradient(norm_image_visit1, axis=2)
    gradient2_x = np.gradient(norm_image_visit2, axis=0)
    gradient2_y = np.gradient(norm_image_visit2, axis=1)
    gradient2_z = np.gradient(norm_image_visit2, axis=2)
    if direction=='x':
        saliency_map1 = np.arctan2(np.sqrt(gradient1_y**2 + gradient1_z**2), gradient1_x)
        saliency_map2 = np.arctan2(np.sqrt(gradient2_y**2 + gradient2_z**2), gradient2_x)
    elif direction=='y':
        saliency_map1 = np.arctan2(np.sqrt(gradient1_x**2 + gradient1_z**2), gradient1_y)
        saliency_map2 = np.arctan2(np.sqrt(gradient2_x**2 + gradient2_z**2), gradient2_y)
    elif direction=='z':
        saliency_map1 = np.arctan2(np.sqrt(gradient1_x**2 + gradient1_y**2), gradient1_z)
        saliency_map2 = np.arctan2(np.sqrt(gradient2_x**2 + gradient2_y**2), gradient2_z)
    
    gradient_difference = np.abs(saliency_map2 - saliency_map1)
    #normalized_difference = (gradient_difference - np.min(gradient_difference)) / (np.max(gradient_difference) - np.min(gradient_difference))
    
    colormap = cm.get_cmap('jet')
    
    visualize_3d_saliency_grid(gradient_difference, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle)

def compute_saliency_map(image_visit1, image_visit2, output_path, elevation_angle=30, azimuthal_angle=45, roll_angle=0):
    norm_image_visit1 = (image_visit1 - np.min(image_visit1)) / (np.max(image_visit1) - np.min(image_visit1))
    norm_image_visit2 = (image_visit2 - np.min(image_visit2)) / (np.max(image_visit2) - np.min(image_visit2))
    
    
    # Compute the absolute difference between the two images
    difference_image = np.abs(norm_image_visit2 - norm_image_visit1)
    
    # Normalize the difference image to the range [0, 1]
    normalized_difference = (difference_image - np.min(difference_image)) / (np.max(difference_image) - np.min(difference_image))
    
    colormap = cm.get_cmap('jet')

    # Visualize the saliency map in 3D
    visualize_3d_saliency(difference_image, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle)
    
def visualize_3d_saliency(saliency_map, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Extract the indices of non-zero values to represent salient regions
    salient_indices = np.transpose(np.nonzero(saliency_map))

    # Create a scatter plot for salient regions using the chosen colormap
    sc = ax.scatter(salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2],
                    c=saliency_map[salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2]],
                    cmap=colormap, marker='o', alpha=0.5#, vmin=0, vmax=1
                   )

#     ax.set_xlabel('X')
#     ax.set_ylabel('Y')
#     ax.set_zlabel('Z')

    # Set up the angle of the 3D view
    ax.view_init(elevation_angle, azimuthal_angle, roll_angle)
    plt.axis('off')

    # Add colorbar
#     cbar = plt.colorbar(sc)
#     cbar.set_label('Normalized Saliency')

#     plt.title('3D Saliency Map Visualization')
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.show()

def visualize_3d_saliency_grid(saliency_map, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle):
    fig, axes = plt.subplots(nrows=2, ncols=2, subplot_kw={'projection': '3d'})
    salient_indices = np.transpose(np.nonzero(saliency_map))
    for i, ax in enumerate(axes.flat):
        sc = ax.scatter(salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2],
                        c=saliency_map[salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2]],
                        cmap=colormap, marker='o', alpha=0.5#, vmin=0, vmax=1
                        )
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.view_init(elevation_angle[i], azimuthal_angle[i], roll_angle[i])
    
    fig.colorbar(sc, ax=axes.ravel().tolist(), shrink=0.8, pad=0.1)
    fig.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.show()

In [None]:
def visualize_all_saliency(list_visits, saliency_type='diff', direction=None):
    for visit in list_visits:
        if visit==12:
            target_visit = segment12
        elif visit==24:
            target_visit = segment24
        elif visit==36:
            target_visit = segment36
        elif visit==48:
            target_visit = segment48
        elif visit==60:
            target_visit = segment60
        elif visit==72:
            target_visit = segment72
        elif visit==84:
            target_visit = segment84
        elif visit==96:
            target_visit = segment96
        elif visit==108:
            target_visit = segment108
        if saliency_type=='diff':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_45.png')
            compute_saliency_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=45, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_90.png')
            compute_saliency_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=90, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_135.png')
            compute_saliency_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=135, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_210_90.png')
            compute_saliency_map(segment0, target_visit, output_path, elevation_angle=210, azimuthal_angle=90, roll_angle=0)
        elif saliency_type=='grad':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_45.png')
            compute_gradient_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=45, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_90.png')
            compute_gradient_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=90, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_30_135.png')
            compute_gradient_map(segment0, target_visit, output_path, elevation_angle=30, azimuthal_angle=135, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}_210_90.png')
            compute_gradient_map(segment0, target_visit, output_path, elevation_angle=210, azimuthal_angle=90, roll_angle=0)
        elif saliency_type=='grad_v2':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_45.png')
            compute_gradient_v2_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=45, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_90.png')
            compute_gradient_v2_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=90, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_135.png')
            compute_gradient_v2_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=135, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_210_90.png')
            compute_gradient_v2_map(segment0, target_visit, output_path, direction, elevation_angle=210, azimuthal_angle=90, roll_angle=0)
        elif saliency_type=='grad_v3':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_45.png')
            compute_gradient_v3_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=45, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_90.png')
            compute_gradient_v3_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=90, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_30_135.png')
            compute_gradient_v3_map(segment0, target_visit, output_path, direction, elevation_angle=30, azimuthal_angle=135, roll_angle=0)
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}_210_90.png')
            compute_gradient_v3_map(segment0, target_visit, output_path, direction, elevation_angle=210, azimuthal_angle=90, roll_angle=0)

In [None]:
def visualize_all_saliency_grid(list_visits, saliency_type='diff', direction=None):
    for visit in list_visits:
        if visit==12:
            target_visit = segment12
        elif visit==24:
            target_visit = segment24
        elif visit==36:
            target_visit = segment36
        elif visit==48:
            target_visit = segment48
        elif visit==60:
            target_visit = segment60
        elif visit==72:
            target_visit = segment72
        elif visit==84:
            target_visit = segment84
        elif visit==96:
            target_visit = segment96
        elif visit==108:
            target_visit = segment108
        elif visit==120:
            target_visit = segment120
        if saliency_type=='diff':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}.png')
            compute_saliency_map(segment0, target_visit, output_path, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])
        elif saliency_type=='grad':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_0_{visit}_{extracted_region}.png')
            compute_gradient_map(segment0, target_visit, output_path, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])
        elif saliency_type=='grad_v2':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}.png')
            compute_gradient_v2_map(segment0, target_visit, output_path, direction, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])
        elif saliency_type=='grad_v3':
            output_path = os.path.join(source_dir, f'temp/{status_type}/saliency_{saliency_type}_{direction}_0_{visit}_{extracted_region}.png')
            compute_gradient_v3_map(segment0, target_visit, output_path, direction, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])

In [None]:
list_visits = [12,24,48,60,72,84,96,108]
visualize_all_saliency(list_visits, 'diff')
# visualize_all_saliency_grid(list_visits, 'grad')
# visualize_all_saliency_grid(list_visits, 'grad_v2', direction='xy')
# visualize_all_saliency_grid(list_visits, 'grad_v2', direction='xz')
# visualize_all_saliency_grid(list_visits, 'grad_v2', direction='yz')
# visualize_all_saliency_grid(list_visits, 'grad_v3', direction='x')
# visualize_all_saliency_grid(list_visits, 'grad_v3', direction='y')
# visualize_all_saliency_grid(list_visits, 'grad_v3', direction='z')

In [None]:
import os
import glob
import shutil
import pandas as pd

In [None]:
csv_files = [file for file in os.listdir('./data') if file.endswith('0.csv')]

In [None]:
csv_files

In [None]:
merged_data = pd.DataFrame()
for file in csv_files:
    file_path = os.path.join('./data', file)
    df = pd.read_csv(file_path)
    #df_filtered = df[(df['Min_X'] < 40) & (df['Type'] == 'DG1')]
    merged_data = pd.concat([merged_data, df], ignore_index=True)

In [None]:
merged_data.to_csv('./data/ADNI_Hippocampus_merged.csv', header=True, index=False)

In [None]:
source_dir = '/ngochuynh/f/Dataset/ADNI/ADNI_saliency'

In [None]:
wrong_data = pd.read_csv('./data/wrong_coreg.csv')
data = pd.read_csv('./data/ADNI_Hippocampus_merged.csv')

In [None]:
def delete_folder(ptid, parent_dir):
    folder_path = os.path.join(parent_dir, ptid)
    if os.path.exists(folder_path):
        # Delete the folder and its contents
        shutil.rmtree(folder_path)

In [None]:
def delete_folder_v2(ptid, visit, parent_dir):
    folder_path = os.path.join(parent_dir, ptid)
    sample_dir = glob.glob(os.path.join(folder_path, 'Hippocampus', 'sample*'))
    for sample in sample_dir:
        visit_dir = glob.glob(os.path.join(sample, '*'))
        for vis in visit_dir:
            vis_name = os.path.basename(vis)
            vis_1 = int(vis_name.split('_')[0])
            vis_2 = int(vis_name.split('_')[1])
            if (vis_1==visit) or (vis_2==visit):
                shutil.rmtree(vis)
        items = os.listdir(sample)
        subfolders = [item for item in items if os.path.isdir(os.path.join(sample, item))]
        if not subfolders:
            shutil.rmtree(sample)

In [None]:
for index, row in wrong_data.iterrows():
    if row['Visit'] != 999:
        delete_folder_v2(row['PTID'], row['Visit'], source_dir)

In [None]:
for index, row in wrong_data.iterrows():
    if row['Visit'] == 999:
        ptid_to_drop = row['PTID']
        # Drop rows where PatientID is equal to ptid_to_drop
        data = data[data['PatientID'] != ptid_to_drop]

In [None]:
delete_folder_v2('002_S_4473', 48, source_dir)

In [None]:
df = pd.read_csv('./data/ADNIMERGE_modified_labels.csv', low_memory=False)
df['SAL_PATHS'] = None
source_dir = '/ngochuynh/f/Dataset/ADNI/ADNI_saliency'

In [None]:
df_new = pd.DataFrame(columns=list(df.columns))

In [None]:
items = os.listdir(source_dir)
subfolders = [item for item in items if os.path.isdir(os.path.join(source_dir, item))]

In [None]:
dfs = []
for ptid in subfolders:
    pre_num = int(ptid.split("_")[0])
    if pre_num <= 126:
        items = os.listdir(os.path.join(source_dir,ptid,"Hippocampus"))
        subdirs = [item for item in items if os.path.isdir(os.path.join(source_dir,ptid,"Hippocampus",item))]
        for sample in subdirs:
            visits = os.listdir(os.path.join(source_dir,ptid,"Hippocampus",sample))
            bl_visit = visits[0].split("_")[0]
            df_temp = df[(df["PTID"]==ptid) & (df["M"]==int(bl_visit))]
            df_temp.reset_index(drop=True, inplace=True)
            list_visits = [os.path.join(source_dir,ptid,"Hippocampus",sample, v) for v in visits]
            df_temp['SAL_PATHS'] = [list_visits]
            df_new = pd.concat([df_new, df_temp])

In [None]:
df_sorted = df_new.sort_values(by=['RID', 'M'], ascending=[True, True])

In [None]:
df_sorted = df_sorted[(df_sorted['VISMISS'] == False) & (~df_sorted['PROGRESS_STATE_3Y'].isna())]
df_sorted.reset_index(drop=True, inplace=True)

In [None]:
df_sorted.to_csv('./data/ADNI_saliency.csv', header=True, index=False)

In [None]:
df_sorted = df_sorted[df_sorted['RID'].notna()]

In [None]:
check_columns = ['PROGRESS_STATE_5Y']

In [None]:
def determine_conv_state_group(group):
    values_list = group[check_columns].values.flatten().tolist()
    values_list = [value for value in values_list if pd.notna(value)]
    #print(values_list)
    if all(visit == 'CN_CN' or visit == 'MCI_MCI' for visit in values_list):
        return 0
    elif any(visit in ['CN_MCI', 'CN_Dementia', 'MCI_Dementia'] for visit in values_list):
        return 1
    elif all(visit == 'Dementia_Dementia' for visit in values_list):
        return 2
    else:
        return -1

In [None]:
def determine_conv_state(row):
    visit = row[check_columns].values
    if visit == 'CN_CN' or visit == 'MCI_MCI':
        return 0
    elif visit in ['CN_MCI', 'CN_Dementia', 'MCI_Dementia']:
        return 1
    elif visit == 'Dementia_Dementia':
        return 2
    else:
        return -1

In [None]:
df_sorted['CONV_STATE_5Y'] = df_sorted.apply(determine_conv_state, axis=1)

In [None]:
df_sorted['CONV_STATE'] = df_sorted.groupby('RID', group_keys=False).apply(determine_conv_state)

In [None]:
conv_state_list = []
for rid in df_sorted['RID'].unique():
    group = df_sorted[df_sorted['RID'] == rid]
    values_list = group[check_columns].values.flatten().tolist()
    values_list = [value for value in values_list if pd.notna(value)]
    
    if all(visit == 'CN_CN' or visit == 'MCI_MCI' for visit in values_list):
        conv_state_list.extend([0] * len(group))
    elif any(visit in ['CN_MCI', 'CN_Dementia', 'MCI_Dementia'] for visit in values_list):
        conv_state_list.extend([1] * len(group))
    elif all(visit == 'Dementia_Dementia' for visit in values_list):
        conv_state_list.extend([2] * len(group))
    else:
        conv_state_list.extend([None] * len(group))

In [None]:
df_sorted['CONV_STATE_5Y'].value_counts()

In [None]:
import os
import glob
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm.notebook import tqdm

from IPython.display import display

In [None]:
def visualize_slice(image_path, mask_path, slice_index, dimension='axial', output_path=None):
    # Load the image and mask
    image = nib.load(image_path).get_fdata()
    mask = nib.load(mask_path).get_fdata()

    # Select the slice based on the chosen dimension
    if dimension == 'axial':
        slice_data = image[:, :, slice_index]
        mask_slice = mask[:, :, slice_index]
    elif dimension == 'coronal':
        slice_data = image[:, slice_index, :]
        mask_slice = mask[:, slice_index, :]
    elif dimension == 'sagittal':
        slice_data = image[slice_index, :, :]
        mask_slice = mask[slice_index, :, :]
    else:
        raise ValueError("Invalid dimension. Use 'axial', 'coronal', or 'sagittal'.")
    
    cmap = ListedColormap(['black', 'red'], name='overlay_cmap')

    # Plot the image
    plt.imshow(slice_data, cmap='gray')

    # Overlay the mask with transparency
    plt.imshow(mask_slice, cmap=cmap, alpha=0.3)

    # Add title and show the plot
    #plt.title(f"{dimension.capitalize()} Slice with Mask Overlay")
    #plt.colorbar()
    plt.axis("off")
    plt.savefig(output_path, format='png', bbox_inches='tight', pad_inches=0, transparent=True, dpi=300)
    #plt.show()
    plt.close()

In [None]:
image_path = "/ngochuynh/f/Dataset/ADNI/smooth_002_S_0295_Month_006.6_MRI_2006-11-02.nii"
mask_path = "/ngochuynh/f/Dataset/ADNI/region_mask.nii"
slice_index = 105
dimension = "coronal"
output_path = "/ngochuynh/f/Dataset/ADNI/overlay.png"

In [None]:
visualize_slice(image_path, mask_path, slice_index, dimension, output_path)

In [None]:
source_dir = "/ngochuynh/f/Dataset/ADNI/"
all_items = os.listdir(os.path.join(source_dir, "ADNI_saliency"))
subfolders = [item for item in all_items if (os.path.isdir(os.path.join(source_dir, "ADNI_saliency", item))) and (int(item.split('_')[0])>136)]

In [None]:
slice_index = 105
dimension = "coronal"

In [None]:
with tqdm(total=len(subfolders)) as pbar:
    for fname in subfolders:
        mask_path = os.path.join(source_dir,"ADNI_saliency",fname,"Hippocampus/region_mask.nii")
        list_images = glob.glob(os.path.join(source_dir,"ADNI_renamed",fname,"preprocessed/MRI/smooth_*.nii"))
        for i in range(len(list_images)):
            image_path = list_images[i]
            basename = os.path.basename(image_path)
            output_path = os.path.join(source_dir,"ADNI_saliency",fname,f"Hippocampus/{dimension}_{slice_index}_{basename}_overlay.png")
            visualize_slice(image_path, mask_path, slice_index, dimension, output_path)
        pbar.update(1)

In [None]:
import tensorflow as tf
import layers as CLayers

In [None]:
input_shape = (4, 85, 50, 45, 4)
inputs = tf.random.uniform(input_shape)
scn_layer = CLayers.SCN3D(num_channels=[16, 23, 64], kernel_size=3, pool_size=2, dropout=0.2)
scn_feat = scn_layer(inputs)

In [None]:
aspp_layer = CLayers.SpatialPyramidPooling3D(16, [6,12,18,24], dropout=0.2)
aspp_feat = aspp_layer(scn_feat)

In [None]:
aspp_feat = tf.reshape(aspp_feat, [-1, tf.shape(aspp_feat)[1]*tf.shape(aspp_feat)[2]*tf.shape(aspp_feat)[3], tf.shape(aspp_feat)[4]])
aspp_feat = tf.transpose(aspp_feat, perm=[0, 2, 1])

In [None]:
tcn_layer = CLayers.TemporalConvNet(num_channels=[512, 256, 128], kernel_size=3, dropout=0.2)
tcn_feat = tcn_layer(aspp_feat)

In [None]:
tcn_feat

In [None]:
import tensorflow as tf
import numpy as np
from positional_encodings.tf_encodings import TFPositionalEncoding3D, TFSummer

In [None]:
p_enc_3d = TFPositionalEncoding3D(10)
z = tf.convert_to_tensor(np.random.uniform(size=(4,8,6,2,1)), dtype=tf.float32)
pe = p_enc_3d(z)
pe_z = z + pe

In [None]:
pe_z[0,...,0]

In [None]:
from scipy import ndimage
import time

In [None]:
input_shape = (30, 85, 50, 45, 9)
inputs = {'AbsDiff':tf.random.uniform(input_shape)}

In [None]:
model_config = dict(
    input_name='AbsDiff',
    num_classes=3,
    num_filters=[8,16,32,64],
    bap_filters=32,
    fc_units=[64, 32],
    kernel_size=3,
    pool_size=2,
    dropout=0.35,
)
model = CLayers.UnimodelS1_CNN_Attention(**model_config)

In [None]:
feat, att_map = model(inputs)

In [None]:
def random_select_and_resize(input_tensor, feature_map):
    batch, height, width, depth, c = input_tensor.shape
    _, feat_height, feat_width, feat_depth, f = feature_map.shape
    
    if f < c:
        raise ValueError("Number of channels in feature map must be greater than or equal to c.")
    
    selected_slices = np.random.choice(f, c, replace=False)
    selected_feature = feature_map[:, :, :, :, selected_slices]
    
    depth_factor = depth / feat_depth
    width_factor = width / feat_width
    height_factor = height / feat_height
    
    resized_feature = np.zeros_like(input_tensor)

    for i in range(batch):
        resized_feature[i,:,:,:,:] = ndimage.zoom(selected_feature[i, :, :, :, :], (height_factor, width_factor, depth_factor, 1), order=2)
    
    return tf.convert_to_tensor(resized_feature, dtype=tf.float32)

In [None]:
array = np.random.uniform(size=att_map.shape)

In [None]:
start_time = time.time()
resized_array_2 = ndimage.zoom(array, (1, 85/10, 50/6, 45/5, 1), order=2)
print(time.time()-start_time, ' sec')

In [None]:
start_time = time.time()
resized_array_1 = random_select_and_resize(inputs['AbsDiff'].numpy(), att_map.numpy())
print(time.time()-start_time, ' sec')

In [None]:
resized_array.shape

In [None]:
resized_array_1

In [None]:
def normalize_minmax(tensor):
    norm_tensor = tf.math.divide_no_nan (
       tf.subtract(
          tensor, 
          tf.reduce_min(tensor)
       ), 
       tf.subtract(
          tf.reduce_max(tensor), 
          tf.reduce_min(tensor)
       )
    )
    return norm_tensor

In [None]:
norm_tensor = normalize_minmax(resized_array_1)

In [None]:
norm_tensor

In [None]:
import models
import tensorflow as tf

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
model = models.ProGAN(input_shape=(189,216,189,1), latent_shape=(1024+128,), dicrim_shape=(84,48,42,1))

In [None]:
model._make_model(16,16,16,16,0.3,0.3,1024)

In [None]:
x  = tf.random.uniform((4,189,216,189,1))
pe = tf.random.uniform((4,5,128))
y  = tf.random.uniform((4,5,84,48,42,1))
m  = tf.constant([[1,1,1,0,0],[1,1,1,1,0],[1,1,0,0,0],[1,0,0,0,0]], dtype=tf.float32)

In [None]:
flat_m = tf.cast(tf.reshape(m, shape=[-1]), dtype=tf.int32)

feat_e1 = model.E1(x, training=False)
mean_e1, logvar_e1 = model.encode(feat_e1)
z_e1 = model.reparameterize(mean_e1, logvar_e1)
latent_e1 = tf.concat((z_e1, pe[:,0,:]), axis=-1)

out_de = model.De(latent_e1, training=False)

out_gens  = []
out_reals = []
out_dis  = []
is_first = True
for i in range(5):   
    if i==0:
        out_gen = model.G1(latent_e1, training=False)
    else:
        feat_y = y[:,i-1,:,:,:,:]
        feat_e2 = model.E2(feat_y, training=False)
        mean_e2, logvar_e2 = model.encode(feat_e2)
        z_e2 = model.reparameterize(mean_e2, logvar_e2)
        latent_e2 = tf.concat((z_e2, pe[:,i,:]), axis=-1)
        out_gen = model.G1(latent_e2, training=False)
        out_reals.append(feat_y)
    out_gens.append(out_gen)
out_reals.append( y[:,-1,:,:,:,:])
out_reals = tf.concat(out_reals, axis=0)
out_gens  = tf.concat(out_gens, axis=0)

out_reals = tf.boolean_mask(out_reals, flat_m, axis=0)
out_gens  = tf.boolean_mask(out_gens, flat_m, axis=0)

real_logits = model.Di(out_reals)
fake_logits = model.Di(out_gens)

In [None]:
import os
import ast
import pandas as pd
import glob
import re

In [None]:
df = pd.read_csv('./data/ADNI_saliency.csv', low_memory=False)

In [None]:
source_dir = '/ngochuynh/f/Dataset/ADNI/ADNI_renamed'

In [None]:
for i in range(len(df)):
    sal_paths = df.loc[i,'SAL_PATHS']
    sp = ast.literal_eval(sal_paths)
    new_sp = []
    for p in sp:
        parts = p.split('/')
        idx   = parts.index('ADNI_saliency')
        new_p = '/'.join(parts[idx:])
        new_sp.append(new_p)
    df.at[i,'SAL_PATHS'] = new_sp
    
    ptid = df.loc[i, 'PTID']
    mri_paths = glob.glob(os.path.join(source_dir, ptid, 'preprocessed/MRI/smooth_*.nii'))
    src_visit = df.loc[i, 'M']
    
    closest_path = min(mri_paths, key=lambda path: abs(float(re.search(r'Month_(\d+\.\d+)', path).group(1)) - src_visit))
    parts = closest_path.split('/')
    idx   = parts.index('ADNI_renamed')
    new_path = '/'.join(parts[idx:])
    df.at[i,'MRI_IMG'] = new_path

In [None]:
df.to_csv('./data/ADNI_saliency.csv', header=True, index=False)

In [None]:
import dataloader as DL

In [None]:
ds_config = dict(
    source_dir="/ngochuynh/f/Dataset/ADNI",
    filepath="./data/ADNI_saliency.csv",
    list_rids=[4,5,6,8,14,15,16,19,21,23],
    input_name="AbsDiff",  
)

In [None]:
ds = DL.InputFunctionS2(**ds_config, shuffle=True)

In [None]:
for feat, label in ds():
    break

In [None]:
label['mask_visits']