##### Import packages

In [None]:
import os
import torch 
import pickle
import nibabel

import h5py
import numpy as np
import pandas as pd
    
from collections import OrderedDict, defaultdict

In [None]:
####################
#### file paths ####
####################

## INPUT FILE PATHS
# data_base_path as in the previous notebooks, lrp_base_path as created by 4_calculate_LRP_heatmaps
data_base_path = '/path/to/data'
lrp_base_path = '/path/to/lrp'
splits = ['split_0', 'split_1', 'split_2']

In [None]:
def load_nifti(file_path, mask=None, z_factor=None, remove_nan=True):
    """Load a 3D array from a NIFTI file."""
    img = nibabel.load(file_path)
    struct_arr = np.array(img.get_data())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nibabel.Nifti1Image(struct_arr, np.eye(4))
    nibabel.save(img, file_path)


### average heatmaps over all runs from all splits

In [None]:
mri_shape = (182, 218, 182)
mean_map_AD_m, mean_map_HC_m = np.zeros(mri_shape), np.zeros(mri_shape)
mean_map_AD_f, mean_map_HC_f = np.zeros(mri_shape), np.zeros(mri_shape)
counts = defaultdict(int)

for split in splits:
    path = '{}/{}'.format(lrp_base_path, split)
    
    subjects_csv = '{}/{}_test.csv'.format(data_base_path, split[:-4])
    subjects = pd.read_csv(subjects_csv)  
    
    for run in range(5):
        print('calculating mean maps, processing split {}, run {}'.format(split, run))
        for r in subjects.iterrows():
            print("subject {}/{}        ".format(r[0]+1, len(subjects)), end="\r")
            row = r[1]
            subject = row['SUBJECT']
            case = row['GROUP'] + row['SEX']
            
            counts[case] += 1
            if case == 'ADF':
                mean_map_AD_f += load_nifti(os.path.join(path, '{}/{}.nii'.format(run, subject)))
            elif case == 'ADM':
                mean_map_AD_m += load_nifti(os.path.join(path, '{}/{}.nii'.format(run, subject)))
            elif case == 'CNF':
                mean_map_HC_f += load_nifti(os.path.join(path, '{}/{}.nii'.format(run, subject)))
            elif case == 'CNM':
                mean_map_HC_m += load_nifti(os.path.join(path, '{}/{}.nii'.format(run, subject)))

print("\n", counts)   
mean_map_AD_f /= counts['ADF']
mean_map_AD_m /= counts['ADM']
mean_map_HC_f /= counts['CNF']
mean_map_HC_m /= counts['CNM']

save_nifti('{}/LRP_AD_f.nii'.format(lrp_base_path), mean_map_AD_f)
save_nifti('{}/LRP_AD_m.nii'.format(lrp_base_path), mean_map_AD_m)
save_nifti('{}/LRP_HC_f.nii'.format(lrp_base_path), mean_map_HC_f)
save_nifti('{}/LRP_HC_m.nii'.format(lrp_base_path), mean_map_HC_m)