In [1]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import nibabel as nib
from sklearn.model_selection import train_test_split
from scipy.ndimage.interpolation import zoom
from tabulate import tabulate

sys.path.insert(0,"/analysis/ritter/AD/Budding_Spectral_Analysis/code/nitorch/")
from nitorch.data import load_nifti

In [2]:
working_dir = "/analysis/ritter/data/ADNI/ADNI_BIDS_skull_stripped"
## This project does not use downsized images 
z_factor = 0.525

In [3]:
data_MCI = pd.read_csv('/analysis/ritter/projects/AD/Budding_Spectral_Analysis/data/ADNI_mci_total.csv', delimiter = ',')
data_MCI_conv = pd.read_csv('/analysis/ritter/projects/AD/Budding_Spectral_Analysis/data/ADNI_mci_conv.csv', delimiter = ',')

In [4]:
data_MCI_conv.head()

In [5]:
data_MCI.head()

Unnamed: 0,SubjectID,Phase,Sex,Weight,Research Group,Visit,Archive Date,Study Date,Age,MMSE Total Score,...,FAQ Total Score,NPI-Q Total Score,Preprocessing,Image ID,session,path,TOTAL11,TOTALMOD,DXCONV,Conversion
0,068_S_0401,ADNI 1,M,112.9,MCI,ADNI Screening,4/02/2009,4/25/2006,64.0,24.0,...,,,N3 + Scaled,140333,screen,068_S_0401/screen/anat/sub-068_S_0401_ses-scre...,7.67,11.67,0.0,0
1,123_S_1300,ADNI 1,F,48.1,MCI,ADNI1/GO Month 12,4/10/2008,3/26/2008,74.6,27.0,...,0.0,2.0,N3 + Scaled,102132,month12,123_S_1300/month12/anat/sub-123_S_1300_ses-mon...,5.67,6.67,0.0,0
2,100_S_0296,ADNI 1,M,70.4,MCI,ADNI Screening,12/13/2006,4/03/2006,79.3,26.0,...,,,N3 + Scaled,33121,screen,100_S_0296/screen/anat/sub-100_S_0296_ses-scre...,15.0,24.0,0.0,0
3,018_S_0142,ADNI 1,M,72.6,MCI,ADNI1/GO Month 24,3/20/2008,3/03/2008,81.4,30.0,...,0.0,0.0,N3 + Scaled,99058,month24,018_S_0142/month24/anat/sub-018_S_0142_ses-mon...,7.33,13.33,0.0,0
4,068_S_0478,ADNI 1,M,80.6,MCI,ADNI Screening,12/12/2008,5/16/2006,56.2,29.0,...,,,N3 + Scaled_2,130209,screen,068_S_0478/screen/anat/sub-068_S_0478_ses-scre...,4.67,7.67,0.0,0


In [11]:
# load images in matrix
def create_dataset(dataset, to_zoom = None, z_factor = 1, mask=None):
    data_matrix = [] 
    labels = [] 
    for idx, row in dataset.iterrows():
        path = os.path.join(working_dir, row["path"])
        struct_arr = np.NAN
        #try:
        scan = nib.load(path)
        struct_arr = scan.get_data().astype(np.float32)
        if mask is not None:
            struct_arr *= mask
        if to_zoom is not None:
            struct_arr = zoom(struct_arr, z_factor)
        data_matrix.append(struct_arr)
        labels.append((row["Conversion"] == 1) *1)
#         else: 
#             labels.append((row["Research Group"] == "MCI") *-1)
        
#         what to do with labelling MCI? not training on this so doesn't need specific label? 
#         labels.append((row["Research Group"] == "MCI") *2)
        #except:
        #    print("Couldnt find file: {}. Skipping".format(path))           
    return np.array(data_matrix), np.array(labels)

In [13]:
import time, datetime

In [17]:
print("Starting at " + time.ctime())
start = time.time()
X_mci, y_mci = create_dataset(data_MCI)
end = time.time()
print("Runtime: " + str(datetime.timedelta(seconds=(end-start))))

Starting at Thu Sep 26 16:36:38 2019
Runtime: 0:05:42.534885


In [18]:
print("Starting at " + time.ctime())
start = time.time()
X_mci_conv, y_mci_conv = create_dataset(data_MCI_conv)
end = time.time()
print("Runtime: " + str(datetime.timedelta(seconds=(end-start))))

In [19]:
print(y_mci_conv)

In [20]:
fig, axs = plt.subplots(1, 2, figsize = (15, 5))
axs[0].imshow(X_mci_conv[0][:,:,60], cmap='gray')
axs[1].imshow(X_mci_conv[1][:,:,60], cmap='gray')

In [21]:
X_mci.shape

(1220, 182, 218, 182)

In [22]:
X_mci_conv.shape

In [23]:
import h5py

In [27]:
h5 = h5py.File('/analysis/ritter/projects/AD/Budding_Spectral_Analysis/data/ADNI_mci_all.h5', 'w')
h5.create_dataset('X', data=X_mci, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=y_mci, compression='gzip', compression_opts=9)
h5.close()

In [28]:
h5 = h5py.File('/analysis/ritter/projects/AD/Budding_Spectral_Analysis/data/ADNI_mci_all.h5', 'w')
h5.create_dataset('X', data=X_mci_conv, compression='gzip', compression_opts=9)
h5.create_dataset('y', data=y_mci_conv, compression='gzip', compression_opts=9)
h5.close()