In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time, datetime

import nibabel as nib
from sklearn.model_selection import train_test_split
from scipy.stats import ttest_ind
from tabulate import tabulate

In [None]:
####################
#### file paths ####
####################

## INPUT FILES
# participants.csv
data_table_path = "/<path to>/participants.csv"
# roster_CN.csv (from 0_identify_converters_cn_to_mci_or_ad)
roster_cn_path = "/<path to>/roster_CN.csv"

## OUTPUT FILE PATH
output_path = "/<path>"


In [None]:
# set random state seed
r = 81
numpy_seed = 427
np.random.seed(numpy_seed)

In [None]:
df = pd.read_csv(data_table_path)

### clean the data table

In [None]:
print("images")
print(len(df[df['GROUP'] == 'MCI']))
print(len(df[df['GROUP'] == 'EMCI']))
print(len(df[df['GROUP'] == 'LMCI']))
print(len(df[df['GROUP'] == 'SMC']))
print("\nsubjects")
print(len(df[df['GROUP'] == 'MCI']['SUBJECT'].unique()))
print(len(df[df['GROUP'] == 'EMCI']['SUBJECT'].unique()))
print(len(df[df['GROUP'] == 'LMCI']['SUBJECT'].unique()))
print(len(df[df['GROUP'] == 'SMC']['SUBJECT'].unique()))

In [None]:
# remove all MCI subjects
df = df[df['GROUP'] != 'MCI']
df = df[df['GROUP'] != 'EMCI']
df = df[df['GROUP'] != 'LMCI']
df = df[df['GROUP'] != 'SMC']
df = df[df['GROUP'] != 'Patient']

### remove CN subjects which converted during the study

In [None]:
roster_cn = pd.read_csv(roster_cn_path)
roster_cn.head(n=5)

In [None]:
print(len(df[df['GROUP']=='CN']['SUBJECT'].unique()))
print(len(df[df['GROUP']=='AD']['SUBJECT'].unique()))

In [None]:
converters = list(roster_cn[roster_cn['CONVERSION']==True]['PTID'])
df = df[df.apply(lambda row: row['SUBJECT'] not in converters, axis=1)]

In [None]:
print(len(df[df['GROUP']=='CN']['SUBJECT'].unique()))
print(len(df[df['GROUP']=='AD']['SUBJECT'].unique()))

### add column for stratification during data split

In [None]:
df['STRAT'] = (df['AGE'] // 5) * 5 # round age down to 5-year steps
df['STRAT'] = df['SEX'] + df['STRAT'].astype('str')

In [None]:
df[['SUBJECT', 'GROUP', 'SEX', 'AGE', 'STRAT']]

### ensure balanced classes

In [None]:
def demographics(df):
    df_temp = df[['SUBJECT', 'SEX', 'AGE']]
    df_temp = df.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
    
    nall    = len(df_temp)
    nfemale = len(df_temp[df_temp['SEX'] == 'F'])
    nmale   = len(df_temp[df_temp['SEX'] == 'M'])
    meanall    = np.mean(df_temp['AGE'])
    stdall     = np.std( df_temp['AGE'])
    meanfemale = np.mean(df_temp[df_temp['SEX'] == 'F']['AGE'])
    stdfemale  = np.std( df_temp[df_temp['SEX'] == 'F']['AGE'])
    meanmale   = np.mean(df_temp[df_temp['SEX'] == 'M']['AGE'])
    stdmale    = np.std( df_temp[df_temp['SEX'] == 'M']['AGE'])
    
    print("{:4} female subjects, age {:0.1f} +- {:0.1f} years".format(nfemale, meanfemale, stdfemale))
    print("{:4}   male subjects, age {:0.1f} +- {:0.1f} years".format(nmale, meanmale, stdmale))
    print("                          ------------")
    print("        all subjects: age {:0.1f} +- {:0.1f} years".format(meanall, stdall))
    
    

def demographics_plot(df, dx="all"):
    print("##############################\ndemographics of {} subjects".format(dx))
    demographics(df)
    
    df_temp = df[['SUBJECT', 'SEX', 'AGE']]
    df_temp = df.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
    
    #nall    = len(df_temp)
    #nfemale = len(df_temp[df_temp['SEX'] == 'F'])
    #nmale   = len(df_temp[df_temp['SEX'] == 'M'])
    #meanall    = np.mean(df_temp['AGE'])
    #stdall     = np.std( df_temp['AGE'])
    #meanfemale = np.mean(df_temp[df_temp['SEX'] == 'F']['AGE'])
    #stdfemale  = np.std( df_temp[df_temp['SEX'] == 'F']['AGE'])
    #meanmale   = np.mean(df_temp[df_temp['SEX'] == 'M']['AGE'])
    #stdmale    = np.std( df_temp[df_temp['SEX'] == 'M']['AGE'])
    
    #print("{:4} female subjects, age {:0.1f} +- {:0.1f} years".format(nfemale, meanfemale, stdfemale))
    #print("{:4}   male subjects, age {:0.1f} +- {:0.1f} years".format(nmale, meanmale, stdmale))
    #print("                          ------------")
    #print("        all subjects: age {:0.1f} +- {:0.1f} years".format(meanall, stdall))
    
    # age histograms
    # note: use age at first scan if multiple scans per patient
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey='row', figsize=(12, 5))
    
    ax1.hist(df_temp[df_temp['SEX'] == 'F']['AGE'], bins = range(50, 105, 5))
    ax1.set_title('female subjects')
    ax1.set_xlabel('age at baseline')
    ax1.set_xlim(50, 100)
    ax1.grid(which='both')
    
    ax2.hist(df_temp[df_temp['SEX'] == 'M']['AGE'], bins = range(50, 105, 5))
    ax2.set_title('male subjects')
    ax2.set_xlabel('age at baseline')
    ax2.set_xlim(50, 100)
    ax2.grid(which='both')
    
    plt.show()

In [None]:
df = df[df['AGE'] > 0] # remove subjects with no given age
subjects_AD = df[df['GROUP'] == 'AD']['SUBJECT'].unique()
subjects_CN = df[df['GROUP'] == 'CN']['SUBJECT'].unique()
subjects_CN = [p for p in subjects_CN if p not in subjects_AD]

print("AD subjects: ", len(subjects_AD))
print("CN subjects: ", len(subjects_CN))

In [None]:
print("CN:  ", ttest_subject_age(df[df['GROUP'] == 'CN']))
print("AD:  ", ttest_subject_age(df[df['GROUP'] == 'AD']))

In [None]:
demographics_plot(df[df['GROUP'] == 'AD'], 'AD')
demographics_plot(df[df['GROUP'] == 'CN'], 'CN')

#### check: avg scans / subject

In [None]:
scans_male = df[df['SEX'] == 'M']
scans_female = df[df['SEX'] == 'F']

print('male:', len(scans_male)/len(scans_male.drop_duplicates(subset='SUBJECT', keep='first')), 'scans/subject')
print('female:', len(scans_female)/len(scans_female.drop_duplicates(subset='SUBJECT', keep='first')), 'scans/subject')

In [None]:
from collections import Counter

images_per_subject = Counter(list(df['SUBJECT']))
image_counts = Counter(dict(images_per_subject).values())
image_counts_sorted = sorted([(k, image_counts[k]) for k in image_counts.keys()], key = lambda x:x[0], reverse=True)

for k, v in image_counts_sorted:
    print("{:2} images: {:3} subjects".format(k, v))


In [None]:
# for each (age, dx)-group, choose the same number of male and female subjects
# (i.e. drop excessive subjects from the table)
# example: 30 female subjects, 25 male subjects -> drop 5 female subjects

df_copy = df.copy()
df_copy = df_copy.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
subjects_to_drop = []

for group in ['AD', 'CN']:
    for age in range(50, 100, 5):
        df_temp = df_copy[df_copy['GROUP'] == group]
        df_temp = df_temp[df_temp['AGE'] >= age]
        df_temp = df_temp[df_temp['AGE'] < age + 5]
        
        if age < 60 or age > 89:
            print(age, "to", age+4, ": dropping all")
            subjects_to_drop.extend(df_temp['SUBJECT'])
            continue
        
        df_temp_f = df_temp[df_temp['SEX'] == 'F']
        df_temp_m = df_temp[df_temp['SEX'] == 'M']
        
        print("-----------------------------")
        print(group, age, "to", age + 4)
        print(len(df_temp_f), "female", len(df_temp_m), "male")
        # more female than male
        if len(df_temp_f) > len(df_temp_m):
            diff = len(df_temp_f) - len(df_temp_m)
            drop = np.random.choice(df_temp_f['SUBJECT'], diff, replace=False)
            print("dropping", diff, "female subjects:")
            print(drop)
            subjects_to_drop.extend(drop)
            
        # more male than female
        elif len(df_temp_m) > len(df_temp_f):
            diff = len(df_temp_m) - len(df_temp_f)
            drop = np.random.choice(df_temp_m['SUBJECT'], diff, replace=False)
            print("dropping", diff, "male subjects:")
            print(drop)
            subjects_to_drop.extend(drop)
            
        # else (same number), no subjects are dropped

In [None]:
df = df[df.apply(lambda row: not row['SUBJECT'] in subjects_to_drop, axis=1)]
demographics_plot(df[df['GROUP'] == 'AD'], 'AD')
demographics_plot(df[df['GROUP'] == 'CN'], 'CN')

In [None]:
def ttest_subject_age(df):
    df = df.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
    ages_f = df[df['SEX'] == 'F']['AGE']
    ages_m = df[df['SEX'] == 'M']['AGE']
    return ttest_ind(ages_f, ages_m)

def ttest_image_age(df):
    ages_f = df[df['SEX'] == 'F']['AGE']
    ages_m = df[df['SEX'] == 'M']['AGE']
    return ttest_ind(ages_f, ages_m)

In [None]:
print("all: ", ttest_subject_age(df))
print("CN:  ", ttest_subject_age(df[df['GROUP'] == 'CN']))
print("AD:  ", ttest_subject_age(df[df['GROUP'] == 'AD']))

In [None]:
df.to_csv(os.path.join(output_path, 'np{}_r{}_all.csv'.format(numpy_seed, r)))

## create train-test split

In [None]:
def print_df_stats(df, df_train, df_val, df_test):
    """Print some statistics about the patients and images in a dataset."""
    headers = ['Images', '-> AD', '-> CN', 'Patients', '-> AD', '-> CN']

    def get_stats(df):
        df_ad = df[df['GROUP'] == 'AD']
        df_cn = df[df['GROUP'] == 'CN']
        return [len(df), len(df_ad), len(df_cn), len(df['SUBJECT'].unique()), len(df_ad['SUBJECT'].unique()), len(df_cn['SUBJECT'].unique())]

    stats = []
    stats.append(['All'] + get_stats(df))
    stats.append(['Train'] + get_stats(df_train))
    stats.append(['Val'] + get_stats(df_val))
    stats.append(['Test'] + get_stats(df_test))

    print(tabulate(stats, headers=headers))
    print()

In [None]:
# patient-wise train-test-split
subjects_AD = df[df['GROUP'] == 'AD'][['SUBJECT', 'AGE', 'STRAT']].sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
subjects_CN = df[df['GROUP'] == 'CN'][['SUBJECT', 'AGE', 'STRAT']].sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')

test_size_ad = int(0.15*len(subjects_AD))
test_size_cn = int(0.15*len(subjects_CN))

subjects_AD_train, subjects_AD_test = train_test_split(subjects_AD, test_size=test_size_ad, stratify=subjects_AD[['STRAT']], random_state=r)
subjects_AD_train, subjects_AD_val = train_test_split(subjects_AD_train, test_size=test_size_ad, stratify=subjects_AD_train[['STRAT']], random_state=r)
subjects_CN_train, subjects_CN_test = train_test_split(subjects_CN, test_size=test_size_cn, stratify=subjects_CN[['STRAT']], random_state=r)
subjects_CN_train, subjects_CN_val = train_test_split(subjects_CN_train, test_size=test_size_cn, stratify=subjects_CN_train[['STRAT']], random_state=r)

subjects_train = np.concatenate([subjects_AD_train, subjects_CN_train])
subjects_val = np.concatenate([subjects_AD_val, subjects_CN_val])
subjects_test = np.concatenate([subjects_AD_test, subjects_CN_test])

# compile train and val dfs based on subjects
df_train = df[df.apply(lambda row: row['SUBJECT'] in subjects_train, axis=1)]
df_val = df[df.apply(lambda row: row['SUBJECT'] in subjects_val, axis=1)]
df_test = df[df.apply(lambda row: row['SUBJECT'] in subjects_test, axis=1)]

# keep only baseline image for test set
#df_val = df_val.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
df_test = df_test.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')

print_df_stats(df, df_train, df_val, df_test)

In [None]:
print("all:   "); demographics(df)
print("\n\ntrain: "); demographics(df_train)
print("\n\nval:   "); demographics(df_val)
print("\n\ntest:  "); demographics(df_test)

In [None]:
df_test_m = df_test[df_test['SEX'] == 'M']
df_test_f = df_test[df_test['SEX'] == 'F']

print("test images: {} male, {} female".format(len(df_test_m), len(df_test_f)))

In [None]:
def demographics_images(df):
    df_temp = df[['SUBJECT', 'SEX', 'AGE']]
    #df_temp = df.sort_values(by='AGE').drop_duplicates(subset='SUBJECT', keep='first')
    
    nall    = len(df_temp)
    nfemale = len(df_temp[df_temp['SEX'] == 'F'])
    nmale   = len(df_temp[df_temp['SEX'] == 'M'])
    meanall    = np.mean(df_temp['AGE'])
    stdall     = np.std( df_temp['AGE'])
    meanfemale = np.mean(df_temp[df_temp['SEX'] == 'F']['AGE'])
    stdfemale  = np.std( df_temp[df_temp['SEX'] == 'F']['AGE'])
    meanmale   = np.mean(df_temp[df_temp['SEX'] == 'M']['AGE'])
    stdmale    = np.std( df_temp[df_temp['SEX'] == 'M']['AGE'])
    
    print("{:4} female images, age {:0.1f} +- {:0.1f} years".format(nfemale, meanfemale, stdfemale))
    print("{:4}   male images, age {:0.1f} +- {:0.1f} years".format(nmale, meanmale, stdmale))
    print("                          ------------")
    print("        all images: age {:0.1f} +- {:0.1f} years".format(meanall, stdall))

In [None]:
print("all:   "); demographics_images(df)
print("\n\ntrain: "); demographics_images(df_train)
print("\n\nval:   "); demographics_images(df_val)
print("\n\ntest:  "); demographics_images(df_test)

In [None]:
ttest_image_age(df_test)

In [None]:
# as with subjects, drop images from train and val sets 
# to balance number of images from each sex
def balance_images(df):
    df_copy = df.copy()
    images_to_drop = []

    for group in ['AD', 'CN']:
        for age in range(50, 100, 5):
            df_temp = df_copy[df_copy['GROUP'] == group]
            df_temp = df_temp[df_temp['AGE'] >= age]
            df_temp = df_temp[df_temp['AGE'] < age + 5]
        
            if age < 60 or age > 89:
                print(age, "to", age+4, ": dropping all")
                images_to_drop.extend(df_temp['IMAGEUID'])
                continue
        
            df_temp_f = df_temp[df_temp['SEX'] == 'F']
            df_temp_m = df_temp[df_temp['SEX'] == 'M']
        
            print(group, age, "to", age + 4, end=";    ")
            print(len(df_temp_f), "female", len(df_temp_m), "male", end=";    ")
            # more female than male
            if len(df_temp_f) > len(df_temp_m):
                diff = len(df_temp_f) - len(df_temp_m)
                drop = np.random.choice(df_temp_f['IMAGEUID'], diff, replace=False)
                print("dropping", diff, "female images")
                #print(drop)
                images_to_drop.extend(drop)
            
            # more male than female
            elif len(df_temp_m) > len(df_temp_f):
                diff = len(df_temp_m) - len(df_temp_f)
                drop = np.random.choice(df_temp_m['IMAGEUID'], diff, replace=False)
                print("dropping", diff, "male images")
                #print(drop)
                images_to_drop.extend(drop)
            
            # else (same number), no subjects are dropped
    return df[df.apply(lambda row: not row['IMAGEUID'] in images_to_drop, axis=1)]

In [None]:
print("training set")
df_train = balance_images(df_train)
print("train:")
demographics_images(df_train)
print(ttest_image_age(df_train))
print("\n\n\nvalidation set")
df_val = balance_images(df_val)
print("val:")
demographics_images(df_val)
print(ttest_image_age(df_val))

In [None]:
print_df_stats(df, df_train, df_val, df_test)

### save test dataset info

In [None]:
df_test.to_csv(os.path.join(output_path, 'np{}_r{}_test.csv'.format(numpy_seed, r)))

## create datasets

In [None]:
# load images in matrix
def create_dataset(dataset):
    data_matrix = [] 
    labels = [] 
    for idx, row in dataset.iterrows():
        path = row["T1"]
        struct_arr = np.NAN
        scan = nib.load(path)
        struct_arr = scan.get_data().astype(np.float32)
        data_matrix.append(struct_arr)
        labels.append((row['GROUP'] == 'AD') *1)      
    return np.array(data_matrix), np.array(labels)

In [None]:
print("Starting at " + time.ctime())
start = time.time()

print("Train dataset..")
train_dataset, train_labels = create_dataset(df_train)
print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))

print("Validation dataset..")
val_dataset, val_labels = create_dataset(df_val)
print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))

print("Holdout dataset..")
holdout_dataset, holdout_labels = create_dataset(df_test)
print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))

print("Holdout dataset (male)..")
holdout_m_dataset, holdout_m_labels = create_dataset(df_test_m)
print("Time elapsed: " + str(datetime.timedelta(seconds=(time.time()-start))))

print("Holdout dataset (female)..")
holdout_f_dataset, holdout_f_labels = create_dataset(df_test_f)

end = time.time()
print("Runtime: " + str(datetime.timedelta(seconds=(end-start))))


In [None]:
print(train_dataset.shape)
print(val_dataset.shape)
print(holdout_dataset.shape)
print(holdout_m_dataset.shape)
print(holdout_f_dataset.shape)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(12, 8))
plt.imshow(train_dataset[-1][:,:,115], cmap='gray')
plt.show()

In [None]:
plt.figure(figsize=(12, 8))
plt.imshow(holdout_m_dataset[6][:,:,115], cmap='gray')
plt.show()

In [None]:
import h5py
fpath = os.path.join(output_path, 'np{}_r{}_bal'.format(numpy_seed, r))
os.mkdir(fpath)

In [None]:
h5 = h5py.File(os.path.join(fpath, 'ADNI_3T_AD_CN_train.h5'), 'w')
h5.create_dataset('X', data=train_dataset, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=train_labels, compression='gzip', compression_opts=9)
h5.close()

In [None]:
h5 = h5py.File(os.path.join(fpath, 'ADNI_3T_AD_CN_val.h5'), 'w')
h5.create_dataset('X', data=val_dataset, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=val_labels, compression='gzip', compression_opts=9)
h5.close()

In [None]:
h5 = h5py.File(os.path.join(fpath, 'ADNI_3T_AD_CN_holdout.h5'), 'w')
h5.create_dataset('X', data=holdout_dataset, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=holdout_labels, compression='gzip', compression_opts=9)
h5.close()

In [None]:
h5 = h5py.File(os.path.join(fpath, 'ADNI_3T_AD_CN_holdout_m.h5'), 'w')
h5.create_dataset('X', data=holdout_m_dataset, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=holdout_m_labels, compression='gzip', compression_opts=9)
h5.close()

In [None]:
h5 = h5py.File(os.path.join(fpath, 'ADNI_3T_AD_CN_holdout_f.h5'), 'w')
h5.create_dataset('X', data=holdout_f_dataset, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=holdout_f_labels, compression='gzip', compression_opts=9)
h5.close()