##### Import packages

In [None]:
%matplotlib inline

In [None]:
import os

import h5py
import numpy as np
import pandas as pd
import torch 
import pickle
import nibabel

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from copy import deepcopy

from collections import OrderedDict, defaultdict

In [None]:
plt.style.use('default')

plt.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': False,
    'pgf.rcfonts': True,
})

In [None]:
####################
#### file paths ####
####################

## INPUT FILE PATHS
# as in the previous notebooks
data_base_path = '/path/to/data'
models_base_path = '/path/to/models'
lrp_base_path = '/path/to/lrp'

mni_brain_path = '/path/to/MNI152_T1_1mm_brain.nii'

## Plotting individual heatmaps

In [None]:
mri_shape = (182, 218, 182)

In [None]:
def load_nifti(file_path, mask=None, z_factor=None, remove_nan=True):
    """Load a 3D array from a NIFTI file."""
    img = nibabel.load(file_path)
    struct_arr = np.array(img.get_data())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nibabel.Nifti1Image(struct_arr, np.eye(4))
    nibabel.save(img, file_path)

In [None]:
def plot_idv_brain(heat_map, brain_img, ref_scale, fig=None, ax=None, contour_areas=[],
                  x_idx=slice(0, mri_shape[0]), y_idx=slice(0, mri_shape[1]), z_idx=slice(0, mri_shape[2]),
                  vmin=90, vmax=99.5, set_nan=True, cmap=None, c=None):

    if fig is None or ax is None:
        fig, ax = plt.subplots(1, figsize=(12, 12))
    
    img = deepcopy(heat_map)
    if set_nan:
        img[nmm_mask==0]=np.nan
    if cmap is None:
        cmap = mcolors.LinearSegmentedColormap.from_list(name='alphared',
                                                  colors=[(1, 0, 0, 0),
                                                         "darkred", "red", "darkorange", "orange", "yellow"],
                                                  N=5000)

    if brain_img is not None:
        ax.imshow(brain_img[x_idx, y_idx, z_idx].T, cmap="Greys", origin='lower')

    vmin, vmax = np.percentile(ref_scale, vmin), np.percentile(ref_scale, vmax)
    im = ax.imshow(img[x_idx, y_idx, z_idx].T, cmap=cmap, 
               vmin=vmin, vmax=vmax, interpolation="gaussian", origin='lower')
    
   
    ax.axis('off')    

    return fig, ax, im

In [None]:
mni_brain = load_nifti(mni_brain_path)

## select patients

In [None]:
# split and trial from which to select patients for individual heatmap comparisons
split = '02_np914_r47_bal'
trial = 0

In [None]:
raw_pred = pickle.load(open('{}/{}/raw_pred.pkl'.format(models_base_path, split), 'rb'))
raw_pred = raw_pred[trial]

subjects = pd.read_csv('{}/{}_test.csv'.format(data_base_path, split[:-4]))

In [None]:
subjects['RAWPRED'] = raw_pred
subjects = subjects[['SUBJECT', 'GROUP', 'SEX', 'AGE', 'T1', 'RAWPRED']]
subjects_sorted = subjects.sort_values(by=['SUBJECT'])

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(subjects_sorted)

## female vs male, same split

In [None]:
subject_young_male = '009_S_5037'
subject_old_male = '014_S_4615'
subject_young_female = '082_S_6690'
subject_old_female = '027_S_0404'

In [None]:
subjects_plots = [subject_young_male, subject_old_male, subject_young_female, subject_old_female]

df_filtered = subjects[subjects.apply(lambda row: row['SUBJECT'] in subjects_plots, axis=1)]
display(df_filtered)

lrp_path = '{}/{}/{}'.format(lrp_base_path, split, trial)
mapAD_avg = load_nifti(os.path.join(lrp_path, 'LRP_AD.nii'))

patient_ym_heatmap = load_nifti(os.path.join(lrp_path, subject_young_male + '.nii'))
patient_ym_brain = load_nifti(subjects[subjects['SUBJECT'] == subject_young_male]['T1'].iat[0])

patient_om_heatmap = load_nifti(os.path.join(lrp_path, subject_old_male + '.nii'))
patient_om_brain = load_nifti(subjects[subjects['SUBJECT'] == subject_old_male]['T1'].iat[0])

patient_yf_heatmap = load_nifti(os.path.join(lrp_path, subject_young_female + '.nii'))
patient_yf_brain = load_nifti(subjects[subjects['SUBJECT'] == subject_young_female]['T1'].iat[0])

patient_of_heatmap = load_nifti(os.path.join(lrp_path, subject_old_female + '.nii'))
patient_of_brain = load_nifti(subjects[subjects['SUBJECT'] == subject_old_female]['T1'].iat[0])

###############
###############

fig, axes = plt.subplots(2, 4, figsize=(24, 10), sharey=False, sharex=False)
vmin, vmax = 90, 99.5

# young female
fig, ax, im = plot_idv_brain(patient_yf_heatmap, patient_yf_brain, ref_scale=mapAD_avg,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,0]);


fig, ax, im = plot_idv_brain(patient_yf_heatmap, patient_yf_brain, ref_scale=mapAD_avg,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,1]);

# old female
fig, ax, im = plot_idv_brain(patient_of_heatmap, patient_of_brain, ref_scale=mapAD_avg,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,0]);


fig, ax, im = plot_idv_brain(patient_of_heatmap, patient_of_brain, ref_scale=mapAD_avg,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,1]);

# young male
fig, ax, im = plot_idv_brain(patient_ym_heatmap, patient_ym_brain, ref_scale=mapAD_avg,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,2]);


fig, ax, im = plot_idv_brain(patient_ym_heatmap, patient_ym_brain, ref_scale=mapAD_avg,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,3]);

# old male
fig, ax, im = plot_idv_brain(patient_om_heatmap, patient_om_brain, ref_scale=mapAD_avg,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,2]);


fig, ax, im = plot_idv_brain(patient_om_heatmap, patient_om_brain, ref_scale=mapAD_avg,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,3]);

# labels
axes[0,0].text(-15, 45, "younger", rotation="vertical", fontsize=28)
axes[1,0].text(-15, 55, "older", rotation="vertical", fontsize=28)

axes[0,0].text(145, 145, "female", fontsize=28)
axes[0,2].text(150, 145, "male", fontsize=28)

# colorbar
fig.tight_layout()

fig.subplots_adjust(top=0.95, right=0.87, hspace=0.02, wspace=0.02)
cbar_ax = fig.add_axes([0.88, 0.1, 0.02, 0.8])
cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)

vmin_val, vmax_val = np.percentile(mapAD_avg, vmin), np.percentile(mapAD_avg, vmax)
cbar.set_ticks([vmin_val, vmax_val])
cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
                        fontsize=20)
cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=22, labelpad=-30)

#plt.show()
plt.savefig('heatmaps_single_subjects.pdf', bbox_inches="tight")

## average male/female AD/HC

In [None]:
ad_m_heatmap = load_nifti('{}/LRP_AD_m.nii'.format(lrp_base_path))
ad_f_heatmap = load_nifti('{}/LRP_AD_f.nii'.format(lrp_base_path))
hc_m_heatmap = load_nifti('{}/LRP_HC_m.nii'.format(lrp_base_path))
hc_f_heatmap = load_nifti('{}/LRP_HC_f.nii'.format(lrp_base_path))


###############
###############

fig, axes = plt.subplots(2, 4, figsize=(24, 10), sharey=False, sharex=False)
vmin, vmax = 90, 99.5

# ad female
fig, ax, im = plot_idv_brain(ad_f_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,0]);


fig, ax, im = plot_idv_brain(ad_f_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,1]);

# ad male
fig, ax, im = plot_idv_brain(ad_m_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,0]);


fig, ax, im = plot_idv_brain(ad_m_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,1]);

# hc female
fig, ax, im = plot_idv_brain(hc_f_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,2]);


fig, ax, im = plot_idv_brain(hc_f_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[0,3]);

# hc male
fig, ax, im = plot_idv_brain(hc_m_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=slice(0, 166), y_idx=120, z_idx=slice(16, 154),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,2]);


fig, ax, im = plot_idv_brain(hc_m_heatmap, mni_brain, ref_scale=ad_f_heatmap,
                             x_idx=85, y_idx=slice(16, 206), z_idx=slice(0, 160),
                             set_nan=False, vmin=vmin, vmax=vmax, fig=fig, ax=axes[1,3]);

# labels
axes[0,0].text(-15, 45, "female", rotation="vertical", fontsize=36)
axes[1,0].text(-15, 55, "male", rotation="vertical", fontsize=36)

axes[0,0].text(155, 145, "AD", fontsize=32)
axes[0,2].text(157, 145, "HC", fontsize=32)

# colorbar
fig.tight_layout()

fig.subplots_adjust(top=0.95, right=0.87, hspace=0.02, wspace=0.02)
cbar_ax = fig.add_axes([0.88, 0.1, 0.02, 0.8])
cbar = fig.colorbar(im, shrink=0.5, ticks=[vmin, vmax], cax=cbar_ax)

vmin_val, vmax_val = np.percentile(ad_f_heatmap, vmin), np.percentile(ad_f_heatmap, vmax)
cbar.set_ticks([vmin_val, vmax_val])
cbar.ax.set_yticklabels(['{0:.1f}%'.format(vmin), '{0:.1f}%'.format(vmax)],
                        fontsize=20)
cbar.set_label('Percentile of average AD patient values', rotation=270, fontsize=22, labelpad=-30)

#plt.show()
plt.savefig('heatmaps_average.pdf', bbox_inches="tight")