In [41]:
import nibabel as nib 
import numpy as np
import logging
import matplotlib.pyplot as plt
import os
import pandas as pd
from typing import Callable

In [42]:
import nibabel as nib 
import numpy as np
import logging
import os
from typing import Callable
# TODO: Could be refactor as TFRecord class
def nii_reader(path:str, default_shape:tuple=(256, 256, 166), ignore_shape:bool=True):
    # TODO: Could be extended to multiple formats
    # TODO: Could return objects of Nibel etc.
    # TODO: Optimalize reading, test different methods 
    # https://simpleitk.readthedocs.io
    # https://nipy.org/nibael/
    # https://nilearn.github.io/
    assert os.path.isfile(path)
    img = nib.load(path)
    if img.shape != default_shape:
        logging.warning(f'Unexpected shape {img.shape}, default shape {default_shape}, file {path}')
        if ignore_shape: return None 
    
    return np.squeeze(np.array(img.get_fdata()))

def nii_dir_generator(input_dir:str,
                      fname2label: Callable[[str], str]=None,
                      image_ext:str="nii",
                      default_shape:tuple=(256, 256, 166),
                      ignore_shape:bool=False):
    for (dirpath, dirnames, filenames) in os.walk(input_dir):
        for f in filenames:
            if f.endswith(image_ext):
                f_path = os.path.join(dirpath, f)
                logging.info(f'Read nii file from {f_path}')
                img = nii_reader(f_path, default_shape=default_shape, ignore_shape=ignore_shape)
                if fname2label:
                    yield fname2label(f), img
                else:
                    yield f, img

In [43]:
import re
# ADNI tools
# TODO: Separate
def parse_adni_img_id(adni_img_name):
    assert adni_img_name.startswith("ADNI_")
    adni_id = re.findall(r".*_I([0-9]+)\.nii", adni_img_name)
    if len(adni_id) != 1:
        logging.error("Unknown subject ID: {}".format(adni_img_name))
    return adni_id[0]

def parse_adni_usr_id(adni_img_name):
    assert adni_img_name.startswith("ADNI_")
    adni_id = re.findall(r"ADNI_([0-9]+_S_[0-9]+)_", adni_img_name)
    if len(adni_id) != 1:
        logging.error("Unknown subject ID: {}".format(adni_img_name))
    return adni_id[0]

def get_adni_group(adni_img_name, adni_desc_df): 
    img_id = parse_adni_img_id(adni_img_name)
    adni_ids = adni_desc_df.loc[adni_desc_df['Image Data ID']== 45108, 'Group'].values
    assert adni_ids.shape[0] == 1
    return adni_ids[0]
    

In [44]:
# TODO: Inplace method much more faster 
# TODO: Aware of diffrent scanners 
# MARK: Sklearn does not work od 3D data
# TODO: Registration etc... https://mirtk.github.io
# https://nilearn.github.io
def normalize(data, feature_range=(0,1), method="MinMax", min_data=None, max_data=None, copy=True):
    min_out, max_out = feature_range
    if method == "MinMax":
        if not min_data: np.min(data)
        if not max_data: np.max(data)
            
        scale = (max_out - min_out) / (max_data - min_data)
        if copy: 
            return data * scale + min_out - min_data * scale
        else: 
            data *= scale
            data += min_out - min_data * scale
    else:
        raise Exception("Unknown method {}".format(method))

In [45]:
# Training
import tensorflow as tf
from tensorflow.keras import layers

def get_baseline():
    img_inputs = layers.Input((256, 256, 166, 1))
    conv0 = layers.Conv3D(16,
                          3,
                          strides=(2,2,2),
                          activation='relu')(img_inputs)
    conv1 = layers.Conv3D(32,
                          3,
                          strides=(2,2,2),
                          activation='relu')(conv0)
    conv2 = layers.Conv3D(64,
                          3, 
                          strides=(2,2,2),
                          activation='relu')(conv1)
    conv3 = layers.Conv3D(128,
                          3, 
                          strides=(2,2,2),
                          activation='relu')(conv2)
    conv4 = layers.Conv3D(256,
                          3, 
                          strides=(2,2,2),
                          activation='relu')(conv3)
    flatten = layers.Flatten()(conv4)

    output = layers.Dense(2, activation='softmax')(flatten)

    return(tf.keras.Model(inputs=img_inputs, outputs=output, name='3D_Dense'))

In [49]:
# CONFIG
# TODO: Move to config file
# IMG
# ADNI 
ADNI_DF = pd.read_csv("ADNI1_Complete_1Yr_1.5T_11_21_2019.csv")
# READING
IMG_PATH = 'data'
IMG_EXT = 'nii'
IMG_SHAPE = (256, 256, 166)
IMG_IGNORE_BAD_SHAPE = True
FNAME_TO_LABEL = lambda x: get_adni_group(x, ADNI_DF)
# NORMALIZATION
NORM_METHOD = 'MinMax'
NORM_RANGE = (0, 1)
# AUGMENTATION 

# CLASS BALANCING

# TRAINING
T_BATCH_SIZE = 2
T_EPOCHS = 10
T_LOGS = 'logs'
T_CHECKPOINT = 'checkpoints'

In [56]:
from logging import info, warning, error
# READ PHASE
info(f'Reading from {IMG_PATH}')

labels = []
images = []
img_generator = nii_dir_generator(input_dir=IMG_PATH,
                                  fname2label=FNAME_TO_LABEL,
                                  image_ext=IMG_EXT,
                                  default_shape=IMG_SHAPE,
                                  ignore_shape=IMG_IGNORE_BAD_SHAPE)
for fname, img in img_generator:
    labels.append(fname)
    images.append(img)
    
images = np.array(images)
labels = np.array(labels)

info('Reading finished')

# NORMALIZATION PHASE
voxel_mean = np.mean(images)
voxel_std = np.std(images)
voxel_max = np.max(images)
voxel_min = np.min(images)

normalize(images, 
          feature_range=(0,1),
          method=NORM_METHOD, 
          min_data=voxel_min,
          max_data=voxel_max, 
          copy=False)

info('Normalization finished')

# DATA AUGMENTATION PHASE
# TODO: Implement

#CLASS BALANCING PHASE 
# TODO: Implement 

#PREPARATION PHASE

assert images.shape[-1] != 1

images = images.reshape((*images.shape,1)).astype('float32')
labels = labels == 'CN'

train_x = images[2:]
train_y = labels[2:]

test_x = images[:2]
test_y = labels[:2]

val_x = test_x
val_y = test_y

info('Preparation finished')

# TRAINING PHASE
# TODO: USE Straka logging name trick 
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=T_LOGS),
             tf.keras.callbacks.ModelCheckpoint(filepath=T_CHECKPOINT, 
                                                verbose=1)
            ]
model = get_baseline()
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=tf.optimizers.Adam(),
              metrics=['accuracy'])
history = model.fit(train_x, train_y,
                    batch_size=T_BATCH_SIZE,
                    epochs=T_EPOCHS,
                    validation_data=(val_x, val_y),
                    callbacks=callbacks)
# EVALUATE PHASE 
print(f'Test')
test_scores = model.evaluate(test_x, test_y, batch_size=T_BATCH_SIZE)
print(f'Test loss: {test_scores[0]}')
print(f'Test accuracy: {test_scores[1]}')

Train on 6 samples, validate on 2 samples
Epoch 1/10
Epoch 00001: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 2/10
Epoch 00002: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 3/10
Epoch 00003: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 4/10
Epoch 00004: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 5/10
Epoch 00005: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 6/10
Epoch 00006: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 7/10
Epoch 00007: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 8/10
Epoch 00008: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 9/10
Epoch 00009: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Epoch 10/10
Epoch 00010: saving model to checkpoints
INFO:tensorflow:Assets written to: checkpoints/assets


INFO:tensorflow:Assets written to: checkpoints/assets


Test
Test loss: 0.0
Test accuracy: 1.0


assets	saved_model.pb	variables
