In [1]:
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

import nibabel
import nibabel.processing
import os
from skimage.filters import threshold_otsu
import cc3d
import shutil
import pickle

In [None]:
import sys

In [None]:
print (sys.version)
print (sys.version_info)

In [2]:
#USE SMALL GPU#
use_gpu = 1 
# The largest memory size GPU is always the first one (0) as they are sorted by size!
gpus=tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[use_gpu], 'GPU')

In [3]:
#Min-max scaling between 0 and 1
def normalize(volume):
    """Normalize the volume"""
    min = volume.min()
    max = volume.max()
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

In [4]:
#This is similar to Ding's preprocess, except CCA is performed in 3D and the largest label is the one we keep

def pp(original_image):
    
    input_img=input_img = nibabel.load("{0}".format(original_image))
    resampled_img = nibabel.processing.conform(input_img, out_shape=(100,100,90), voxel_size=(2.0, 2.0, 2.0))
    
    img = resampled_img.get_fdata()
    
    thresh = threshold_otsu(img)
    bw_img1 = np.copy(img)
    bw_img1[bw_img1 < thresh] = 0
    bw_img1[bw_img1 >= thresh] = 255
    
    input_CCA=bw_img1.astype('int32')
    connectivity = 6
    labels_out, N = cc3d.connected_components(input_CCA, return_N=True)
    
    def mask_largest_label (labels_out, N):
        print("This function returns the largest blob of a CCA processed image as a binary mask")
        print("")
        def separate_labels(label_ID, label_matrix):
            mask=1*(label_matrix == label_ID)
            return mask
        labellist=[]
        for j in range(1, N+1):
            a=separate_labels(j, labels_out)
            labellist.append(a)
        print("The image has {0} labels".format(len(labellist)))
        z=labellist[0]
        print("The shape of the labels is: {0}".format(z.shape))
        sizelist=[]
        for counter,element in enumerate (labellist):
            a=labellist[counter].sum()
            sizelist.append(a)
        print("Label sizes: {0}".format(sizelist))
        sizelist=np.asarray(sizelist)
        a=sizelist.argmax()
        print("The largest label index is: {0}".format(a))
        mask=labellist[a]
        print("The largest label is now a binary mask with shape {0}, size {1}, max value {2} and min value {3}".format((mask.shape),(mask.sum()),(mask.max()),(mask.min())))
        return mask

    mask=mask_largest_label(labels_out, N)
    
    pimg=np.multiply(img,mask)
    
    return pimg

In [5]:
#Wrapper function for normalize and preprocess

def process_scan(path):
    """Read and normalize volume"""
    # Read and pp scan
    volume = pp(path)
    # Normalize
    volume = normalize(volume)
    return volume


In [None]:
#This is a binary version of the split cell in Ding_preprocess

In [None]:
#Define base destination folders

train="/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/split/train"
test="/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/split/test"
val="/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/split/val"


#Loop for AD

print("\n AD \n")

os.chdir("/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/ad")
v=os.listdir()
v.sort()

val_samples=round(0.1*len(v))
test_samples=round(0.1*len(v))

for i,j in enumerate(v[0:val_samples]):
    print("Este es de val",i,j)
    shutil.copy(j,"{0}/ad".format(val))
    print("Copied")

ec=0
for i,j in enumerate(v[val_samples-1:]):
    if v[val_samples-1:][i][0:10]==v[val_samples-1:][i+1][0:10]:
        print("Extra de val",i+val_samples,v[val_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples-1:][i+1]),"{0}/ad".format(val))
        print("Copied")
        ec=ec+1
    else:
        break
        
for i,j in enumerate(v[val_samples+ec:val_samples+ec+test_samples]):
    print("Este es test",i+val_samples+ec,j)
    shutil.copy(j,"{0}/ad".format(test))
    print("Copied")

eec=0
for i,j in enumerate(v[val_samples+ec+test_samples-1:]):
    if v[val_samples+ec+test_samples-1:][i][0:10]==v[val_samples+ec+test_samples-1:][i+1][0:10]:
        print("Extra test",i+val_samples+ec+test_samples,v[val_samples+ec+test_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples+ec+test_samples-1:][i+1]),"{0}/ad".format(test))
        eec=eec+1
    else:
        break
for i,j in enumerate(v[val_samples+ec+test_samples+eec:]):
    print("Este es train",i+val_samples+ec+test_samples+eec,j)
    shutil.copy(j,"{0}/ad".format(train))
    print("Copied")
    
"""""
#Loop for MCI

print("\n MCI \n")

os.chdir("/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/mci")
v=os.listdir()
v.sort()

val_samples=round(0.1*len(v))
test_samples=round(0.1*len(v))

for i,j in enumerate(v[0:val_samples]):
    print("Este es de val",i,j)
    shutil.copy(j,"{0}/mci".format(val))
    print("Copied")

ec=0
for i,j in enumerate(v[val_samples-1:]):
    if v[val_samples-1:][i][0:10]==v[val_samples-1:][i+1][0:10]:
        print("Extra de val",i+val_samples,v[val_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples-1:][i+1]),"{0}/mci".format(val))
        print("Copied")
        ec=ec+1
    else:
        break
        
for i,j in enumerate(v[val_samples+ec:val_samples+ec+test_samples]):
    print("Este es test",i+val_samples+ec,j)
    shutil.copy(j,"{0}/mci".format(test))
    print("Copied")

eec=0
for i,j in enumerate(v[val_samples+ec+test_samples-1:]):
    if v[val_samples+ec+test_samples-1:][i][0:10]==v[val_samples+ec+test_samples-1:][i+1][0:10]:
        print("Extra test",i+val_samples+ec+test_samples,v[val_samples+ec+test_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples+ec+test_samples-1:][i+1]),"{0}/mci".format(test))
        eec=eec+1
    else:
        break
for i,j in enumerate(v[val_samples+ec+test_samples+eec:]):
    print("Este es train",i+val_samples+ec+test_samples+eec,j)
    shutil.copy(j,"{0}/mci".format(train))
    print("Copied")
    
    
    """
    
#Loop for Control

print("\n Control \n")

os.chdir("/local_mount/space/celer/1/users/notebooks/moises/pdata/newdata/control")
v=os.listdir()
v.sort()

val_samples=round(0.1*len(v))
test_samples=round(0.1*len(v))

for i,j in enumerate(v[0:val_samples]):
    print("Este es de val",i,j)
    shutil.copy(j,"{0}/control".format(val))
    print("Copied")

ec=0
for i,j in enumerate(v[val_samples-1:]):
    if v[val_samples-1:][i][0:10]==v[val_samples-1:][i+1][0:10]:
        print("Extra de val",i+val_samples,v[val_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples-1:][i+1]),"{0}/control".format(val))
        print("Copied")
        ec=ec+1
    else:
        break
        
for i,j in enumerate(v[val_samples+ec:val_samples+ec+test_samples]):
    print("Este es test",i+val_samples+ec,j)
    shutil.copy(j,"{0}/control".format(test))
    print("Copied")

eec=0
for i,j in enumerate(v[val_samples+ec+test_samples-1:]):
    if v[val_samples+ec+test_samples-1:][i][0:10]==v[val_samples+ec+test_samples-1:][i+1][0:10]:
        print("Extra test",i+val_samples+ec+test_samples,v[val_samples+ec+test_samples-1:][i+1])
        shutil.copy("{0}".format(v[val_samples+ec+test_samples-1:][i+1]),"{0}/control".format(test))
        eec=eec+1
    else:
        break
for i,j in enumerate(v[val_samples+ec+test_samples+eec:]):
    print("Este es train",i+val_samples+ec+test_samples+eec,j)
    shutil.copy(j,"{0}/control".format(train))
    print("Copied")
    

In [None]:
#Then we create a list of all the paths to specific images (you can see different versions below)

In [10]:
cd ..

/local_mount/space/celer/1/users/notebooks/moises


In [None]:
train_control_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/train/control", x)
    for x in os.listdir("pdata/newdata/split/train/control")
]

val_control_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/val/control", x)
    for x in os.listdir("pdata/newdata/split/val/control")
]

test_control_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/test/control", x)
    for x in os.listdir("pdata/newdata/split/test/control")
]

train_ad_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/train/ad", x)
    for x in os.listdir("pdata/newdata/split/train/ad")
]

val_ad_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/val/ad", x)
    for x in os.listdir("pdata/newdata/split/val/ad")
]

test_ad_scan_paths = [
    os.path.join(os.getcwd(), "pdata/newdata/split/test/ad", x)
    for x in os.listdir("pdata/newdata/split/test/ad")
]

In [None]:
# Read and process the scans (binary AD/control))

train_ad_scans = np.array([process_scan(path) for path in train_ad_scan_paths])
val_ad_scans = np.array([process_scan(path) for path in val_ad_scan_paths])
test_ad_scans = np.array([process_scan(path) for path in test_ad_scan_paths])

train_control_scans = np.array([process_scan(path) for path in train_control_scan_paths])
val_control_scans = np.array([process_scan(path) for path in val_control_scan_paths])
test_control_scans = np.array([process_scan(path) for path in test_control_scan_paths])


# # Labeling samples according to folder architecture
train_ad_labels = np.array([1 for _ in range(len(train_ad_scans))])
val_ad_labels = np.array([1 for _ in range(len(val_ad_scans))])
test_ad_labels = np.array([1 for _ in range(len(test_ad_scans))])

train_control_labels = np.array([0 for _ in range(len(train_control_scans))])
val_control_labels = np.array([0 for _ in range(len(val_control_scans))])
test_control_labels = np.array([0 for _ in range(len(test_control_scans))])

#Sets
x_train = np.concatenate((train_ad_scans,train_control_scans),axis=0)
y_train = np.concatenate((train_ad_labels,train_control_labels), axis=0)

x_val = np.concatenate((val_ad_scans,val_control_scans),axis=0)
y_val = np.concatenate((val_ad_labels,val_control_labels), axis=0)

x_test = np.concatenate((test_ad_scans,test_control_scans),axis=0)
y_test = np.concatenate((test_ad_labels,test_control_labels), axis=0)

print(
    "Number of samples in train and validation are %d and %d."
    % (x_train.shape[0], x_val.shape[0])
)


In [13]:
#We can double check the amount of samples
print(
    "Number of samples in train and validation are %d and %d."
    % (x_train.shape[0], x_val.shape[0])
)

Number of samples in train and validation are 1681 and 217.


In [14]:
# We create a tensorflow dataset object with all this preprocessing already done
gtrain_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
gvalidation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
gtest_loader= tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [None]:
#Then we save the dataset to a folder (beware of the element_spec, which is needed to reload the dataset later on).
#The name of the dataset needs to be substituted in the ### spots below

In [16]:
tf.data.experimental.save(gtrain_loader, "./datasets/###TRAIN###")

tf.data.experimental.save(gvalidation_loader, "./datasets/###VAL###")

tf.data.experimental.save(gtest_loader, "./datasets/###TEST###")

In [17]:
#We get the element_spec and then save it to a pickle object
a=gtrain_loader.element_spec

In [19]:
with open('###Dataset###.pickle', 'wb') as f:
    pickle.dump(a, f)