##### Import packages

In [None]:
import os

import pickle
import nibabel
import h5py
import numpy as np
import pandas as pd

import torch
from torch import nn
from innvestigator import InnvestigateModel
from nmm_mask_areas import all_areas

from collections import OrderedDict

In [None]:
####################
#### file paths ####
####################

## INPUT FILE PATHS
# this notebook assumes that both paths given below contain subfolders for the different data splits,
# with the folder names given in the list 'splits'
#   - data_base_path/[split] should contain the h5 files (only holdout needed here)
#   - data_base_path should also contain [split]_test.csv (from 1_create_dataset_splits_stratified)
#   - models_base_path/[split] should contain the trained models (from 2_train_models_multiGPU)
data_base_path = '/path/to/data'
models_base_path = '/path/to/models'
splits = [splits = ['split_0', 'split_1', 'split_2']]
# the NMM mask, which must be rescaled to the input image dimensions
nmm_mask_path_scaled = '/path/to/nmm_mask_rescaled.nii'

## OUTPUT
# base path for the LRP heatmaps
# the files will be created as '[lrp_path]/[split]/[repeat]/[subject id].nii'
lrp_path = '/path/to/lrp'



## Load model

In [None]:
class ClassificationModel3D(nn.Module):
    def __init__(self, dropout=0.4, dropout2=0.4):
        nn.Module.__init__(self)
        self.Conv_1 = nn.Conv3d(1, 8, 3)
        self.Conv_1_bn = nn.BatchNorm3d(8)
        self.Conv_1_mp = nn.MaxPool3d(2)
        self.Conv_2 = nn.Conv3d(8, 16, 3)
        self.Conv_2_bn = nn.BatchNorm3d(16)
        self.Conv_2_mp = nn.MaxPool3d(3)
        self.Conv_3 = nn.Conv3d(16, 32, 3)
        self.Conv_3_bn = nn.BatchNorm3d(32)
        self.Conv_3_mp = nn.MaxPool3d(2)
        self.Conv_4 = nn.Conv3d(32, 64, 3)
        self.Conv_4_bn = nn.BatchNorm3d(64)
        self.Conv_4_mp = nn.MaxPool3d(3)
        self.dense_1 = nn.Linear(2304, 128)
        self.dense_2 = nn.Linear(128, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout2)

    def forward(self, x):
        x = self.relu(self.Conv_1_bn(self.Conv_1(x)))
        x = self.Conv_1_mp(x)
        x = self.relu(self.Conv_2_bn(self.Conv_2(x)))
        x = self.Conv_2_mp(x)
        x = self.relu(self.Conv_3_bn(self.Conv_3(x)))
        x = self.Conv_3_mp(x)
        x = self.relu(self.Conv_4_bn(self.Conv_4(x)))
        x = self.Conv_4_mp(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.relu(self.dense_1(x))
        x = self.dropout2(x)
        x = self.dense_2(x)
        
        return x
    

## Load ADNI Data

In [None]:
def min_max_normalization(subset):
    for i in range(len(subset)):
        subset[i] -= np.min(subset[i])
        subset[i] /= np.max(subset[i])
    return subset
    
def load_data(skip_train=True, skip_val=True, skip_test=False, dtype=np.float32):
    """ Load hdf5 files and extract columns. """
    X_train, y_train, X_val, y_val, X_holdout, y_holdout = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
    # train
    if not skip_train:
        train_h5_ = h5py.File(train_h5, 'r')
        X_train, y_train = train_h5_['X'], train_h5_['y']
        X_train = np.expand_dims(np.array(X_train, dtype=dtype), 1)
        X_train = min_max_normalization(X_train)
        y_train = np.array(y_train)
        print("Total training set length: {}".format(len(y_train)))
        print("Number of healthy controls: {}".format(len(np.array(y_train)[np.array(y_train)==0.])))
        print("Number of AD patients: {}".format(len(np.array(y_train)[np.array(y_train)==1.])))
    if not skip_val:
        # val
        val_h5_ = h5py.File(val_h5, 'r')
        X_val, y_val = val_h5_['X'], val_h5_['y']
        X_val = np.expand_dims(np.array(X_val, dtype=dtype), 1)
        X_val = min_max_normalization(X_val)
        y_val = np.array(y_val)
        print("Total validation set length: {}".format(len(y_val)))
    if not skip_test:
        # test
        holdout_h5_ = h5py.File(holdout_h5, 'r')
        X_holdout, y_holdout = holdout_h5_['X'], holdout_h5_['y']
        X_holdout = np.expand_dims(np.array(X_holdout, dtype=dtype), 1)
        X_holdout = min_max_normalization(X_holdout)
        y_holdout = np.array(y_holdout)
        print("Total test set length: {}".format(len(y_holdout)))
   
    return X_train, y_train, X_val, y_val, X_holdout, y_holdout


def load_nifti(file_path, mask=None, z_factor=None, remove_nan=True):
    """Load a 3D array from a NIFTI file."""
    img = nibabel.load(file_path)
    struct_arr = np.array(img.get_data())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nibabel.Nifti1Image(struct_arr, np.eye(4))
    nibabel.save(img, file_path)


## Load mask

In [None]:
nmm_mask = load_nifti(nmm_mask_path_scaled)

mri_shape = (182, 218, 182)

# all_areas holds the area name and a tuple with the minimum 
# idx in the NMM mask and the maximum idx in the NMM mask belonging to that area
area_masks = {k: None for k in all_areas.keys()}
for name, (min_idx, max_idx) in all_areas.items():
    area_mask = np.zeros(mri_shape)
    area_mask[np.logical_and(nmm_mask>=min_idx, nmm_mask<=max_idx)] = 1
    area_masks[name] = area_mask

## Evaluate LRP on dataset

In [None]:
def run_LRP(net, image_tensor):
    return net.innvestigate(in_tensor=image_tensor, rel_for_class=1)

## Calculate heatmaps

In [None]:
for split in splits:
    print("#############################")
    print("##### split {}".format(split))
    print("#############################\n")
    
    #######################################
    ##### load test data and subject info
    #######################################
    holdout_h5 = '{}/{}/ADNI_3T_AD_CN_holdout.h5'.format(data_base_path, split)
    _, _, _, _, X_holdout, y_holdout = load_data()
    
    subjects_csv = '{}/{}_test.csv'.format(data_base_path, split[:-4])
    subjects = pd.read_csv(subjects_csv)  
    
    for repeat in range(5):
        print("starting split {}, repeat {}".format(split, repeat))
        ##################
        ##### load model
        ##################
        model_path = '{}/{}/trial_{}_BEST_ITERATION.h5'.format(models_base_path, split, repeat)
        
        device = 0
        net = ClassificationModel3D()
        net.cuda(device)

        state_dict = torch.load(model_path, map_location='cpu')
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove "module." prefix (due to nn.DataParallel)
            new_state_dict[name] = v
        
        net.load_state_dict(new_state_dict)
        net.eval()
        net = torch.nn.Sequential(net, torch.nn.Softmax(dim=1))
        inn_model = InnvestigateModel(net, lrp_exponent=1,
                                          method="b-rule",
                                          beta=0, epsilon=1e-6).cuda(device)
        inn_model.eval();
        
        #######################
        ##### reset variables
        #######################
        cases = ["AD", "HC", "TP", "TN", "FP", "FN"]
        mean_maps_LRP = {case: np.zeros(mri_shape) for case in cases}
        rs_per_area_LRP = {case: {k: [] for k in all_areas.keys()} for case in cases}
        counts = {case: 0 for case in cases}
        area_sizes = {k: 0 for k in all_areas.keys()}
        
        ##########################
        ##### calculate heatmaps
        ##########################
        heatmaps_path = '{}/{}/{}'.format(lrp_path, split, repeat)
        os.makedirs(heatmaps_path, exist_ok=True)
        num_samples = len(X_holdout)
        ad_score_list = []

        for i, (image, label) in enumerate(zip(X_holdout, y_holdout)):
            image_tensor = torch.Tensor(image[None]).cuda(device)   
            AD_score, LRP_map = run_LRP(inn_model, image_tensor)
            AD_score = AD_score[0][1].detach().cpu().numpy()
            LRP_map = LRP_map.detach().numpy().squeeze()
            ad_score_list.append(AD_score)
    
            # save individual heatmap
            subject = subjects.at[i, 'SUBJECT']
            save_nifti(os.path.join(heatmaps_path, "{}.nii".format(subject)), LRP_map)
    
            true_case = "AD" if label else "HC"
            if AD_score.round() and label:
                case = "TP"
            elif AD_score.round() and not label:
                case = "FP"
            elif not AD_score.round() and label:
                case = "FN"
            elif not AD_score.round() and not label:
                case = "TN"
    
            mean_maps_LRP[case] += LRP_map
            counts[case] += 1
            mean_maps_LRP[true_case] += LRP_map
            counts[true_case] += 1
    
            for name, (min_idx, max_idx) in all_areas.items():
                area_mask = area_masks[name]
                summed_LRP = (LRP_map * area_mask).sum()
        
                # Keep index in test set for identification
                rs_per_area_LRP[case][name].append((i, summed_LRP))
                rs_per_area_LRP[true_case][name].append((i, summed_LRP))
        
                if i < 1:
                    area_size = area_mask.sum()
                    area_sizes.update({name:area_size})
    
            print("Completed {0:3.2f}%  \r".format(100*(i+1)/num_samples), end="")
        
        #####################
        ##### save heatmaps
        #####################
        print("now saving heatmaps, case counts:")
        print(counts)
        for case in cases:
            mean_maps_LRP[case] /= counts[case]
            save_nifti(os.path.join(heatmaps_path, "LRP_{case}.nii".format(case=case)),
                       mean_maps_LRP[case])
            with open(os.path.join(heatmaps_path, "LRP_area_evdcs_{case}.pkl".format(case=case)), 'wb') as file:
                pickle.dump(rs_per_area_LRP[case], file)

        with open(os.path.join(heatmaps_path, "area_sizes.pkl"), 'wb') as file:
            pickle.dump(area_sizes, file)

        np.savetxt(os.path.join(heatmaps_path, "ad_scores.txt"), ad_score_list)
        
        print("done with split {}, repeat {}\n\n".format(split, repeat))
        