Notebook for splitting the data set into separate train, validation and test sets. We create 10 random splits by sampling from the whole data set anew. For this we store a range of random states (20-30) which we can re-use to get the same splittings elsewhere.

In [1]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time, datetime
import h5py

import nibabel as nib
from sklearn.model_selection import train_test_split
from scipy.ndimage.interpolation import zoom

sys.path.insert(0,"/analysis/fabiane/phd/nitorch/")
from nitorch.data import load_nifti
from tabulate import tabulate
from data_split import print_df_stats, create_dataset

In [None]:
# main configurations
settings = {
    "data_path": "/ritter/share/projects/Methods/LRP/data/rieke-copy1/2Node_trial0/beta0/",
    "ADNI_DIR": "/ritter/share/data/ADNI/ADNI_2Yr/ADNI_2Yr_15T_quick_preprocessed/",
    "train_h5": "/ritter/share/data/ADNI_HDF5/Splits_Eitel/10xrandom_splits/train_AD_CN_2Yr15T_plus_UniqueScreening_quickprep_(96, 114, 96)_random_state",
    "val_h5": "/ritter/share/data/ADNI_HDF5/Splits_Eitel/10xrandom_splits/val_AD_CN_2Yr15T_plus_UniqueScreening_quickprep_(96, 114, 96)_random_state",
    "holdout_h5": "/ritter/share/data/ADNI_HDF5/Splits_Eitel/10xrandom_splits/holdout_AD_CN_2Yr15T_plus_UniqueScreening_quickprep_(96, 114, 96)_random_state",
    "binary_brain_mask": "binary_brain_mask.nii.gz",
    "data_table" : "/analysis/fabiane/other_code/johannes/cnn-interpretability/data/ADNI/ADNI_tables/customized/DxByImgClean_CompleteAnnual2YearVisitList_1_5T.csv",
    "z_factor" : 0.5
}

In [2]:
# save the data sets to disk?
save = True

# Binary brain mask used to cut out the skull.
mask = load_nifti(settings["binary_brain_mask"])

# set random state seed
random_states = np.arange(20, 30) # original is 43

## Clean the data table

In [3]:
# load the table
df = pd.read_csv(settings["data_table"])

In [4]:
# Sometimes pre-processing fails and we removed
# failed pre-processing in 067_S_0077/Screening 

failed_idx = list(df.loc[(df["PTID"]=="067_S_0077") & (df["Visit"] == "Screening")].index)
df = df.drop(index=failed_idx)

# remove all MCI subjects
df = df[df['DX'] != 'MCI']

## Build subsets and save to disk

In [6]:
for i, r in enumerate(random_states):
    print(f"Iteration {i}")
    # Patient-wise train-test-split.
    # Select a number of subjects for each class, put all their images in the test set 
    # and all other images in the train set. This is the split that is used in the paper to produce the heatmaps.
    test_subjects_per_class = 30
    val_subjects_per_class = 18

    subjects_AD = df[df['DX'] == 'Dementia']['PTID'].unique()
    subjects_CN = df[df['DX'] == 'CN']['PTID'].unique()
    subjects_CN = [p for p in subjects_CN if p not in subjects_AD]  # subjects that have both a CN and an AD scan should belong to the AD group

    subjects_AD_train, subjects_AD_test = train_test_split(subjects_AD, test_size=test_subjects_per_class, random_state=r)
    subjects_AD_train, subjects_AD_val = train_test_split(subjects_AD_train, test_size=val_subjects_per_class, random_state=r)
    subjects_CN_train, subjects_CN_test = train_test_split(subjects_CN, test_size=test_subjects_per_class, random_state=r)
    subjects_CN_train, subjects_CN_val = train_test_split(subjects_CN_train, test_size=val_subjects_per_class, random_state=r)

    subjects_train = np.concatenate([subjects_AD_train, subjects_CN_train])
    subjects_val = np.concatenate([subjects_AD_val, subjects_CN_val])
    subjects_test = np.concatenate([subjects_AD_test, subjects_CN_test])

    # Compile train and val dfs based on subjects.
    df_train = df[df.apply(lambda row: row['PTID'] in subjects_train, axis=1)]
    df_val = df[df.apply(lambda row: row['PTID'] in subjects_val, axis=1)]
    df_test = df[df.apply(lambda row: row['PTID'] in subjects_test, axis=1)]

    print_df_stats(df, df_train, df_val, df_test)
    
    print("Starting at " + time.ctime())
    start = time.time()

    print("Train dataset..")
    train_dataset, train_labels = create_dataset(df_train, z_factor=settings["z_factor"], settings=settings, mask=mask)
    print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))
    print("Validation dataset..")
    val_dataset, val_labels = create_dataset(df_val, z_factor=settings["z_factor"], settings=settings, mask=mask)
    print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))
    print("Holdout dataset..")
    holdout_dataset, holdout_labels = create_dataset(df_test, z_factor=settings["z_factor"], settings=settings, mask=mask)

    end = time.time()
    print("Runtime: " + str(datetime.timedelta(seconds=(end-start))))

    print(train_dataset.shape)
    print(val_dataset.shape)
    print(holdout_dataset.shape)

    if save:
        print("Storing data sets")
        h5 = h5py.File(settings["train_h5"]+str(r)+".h5", 'w')
        h5.create_dataset('X', data=train_dataset, compression='gzip', compression_opts=9)
        h5.create_dataset('y', data=train_labels, compression='gzip', compression_opts=9)
        h5.close()
        
        h5 = h5py.File(settings["val_h5"]+str(r)+".h5", 'w')
        h5.create_dataset('X', data=val_dataset, compression='gzip', compression_opts=9)
        h5.create_dataset('y', data=val_labels, compression='gzip', compression_opts=9)
        h5.close()
        
        h5 = h5py.File(settings["holdout_h5"]+str(r)+".h5", 'w')
        h5.create_dataset('X', data=holdout_dataset, compression='gzip', compression_opts=9)
        h5.create_dataset('y', data=holdout_labels, compression='gzip', compression_opts=9)
        h5.close()
        
quit()

Iteration 0
         Images    -> AD    -> CN    Patients    -> AD    -> CN
-----  --------  -------  -------  ----------  -------  -------
All         969      475      494         344      193      151
Train       699      361      338         248      145      103
Val         107       49       58          36       18       18
Test        163       65       98          60       30       30

Starting at Thu Jan  7 14:24:54 2021
Train dataset..
Time elapsed: 0:24:09.617346
Validation dataset..
Time elapsed: 0:27:57.935494
Holdout dataset..
Runtime: 0:33:48.587732
(699, 96, 114, 96)
(107, 96, 114, 96)
(163, 96, 114, 96)
Storing data sets
Iteration 1
         Images    -> AD    -> CN    Patients    -> AD    -> CN
-----  --------  -------  -------  ----------  -------  -------
All         969      475      494         344      193      151
Train       698      358      340         248      145      103
Val          95       40       55          36       18       18
Test        176       