In [0]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [0]:
import tensorflow as tf

from tensorflow.keras import Sequential

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv3D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import MaxPool3D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import AveragePooling3D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten

In [0]:
data_path = 'data'
labels = pd.read_csv(os.path.join(data_path, 'adni_demographic_master_kaggle.csv'))

In [0]:
def load_datasets(type: str):
    if type not in ['train', 'test', 'valid']: raise Exception('Unsupported dataset type')
    train_valid_test = 0 if type == 'train' else 1 if type == 'valid' else 2
    i = 1
    dataset = np.load(os.path.join(data_path, f'img_array_{type}_6k_{i}.npy'))
    while True:
        try:
            i += 1
            dataset = np.vstack((dataset, np.load(os.path.join(data_path, f'img_array_{type}_6k_{i}.npy'))))
        except FileNotFoundError:
            print(f'Loaded all {type} datasets')
            break
    # dataset = np.expand_dims(dataset, axis=1)
    dataset = np.reshape(dataset, (-1, 62, 96, 96, 1))
    for i in range(dataset.shape[0]):
        dataset[i] = dataset[i]/np.amax(dataset[i])
    return dataset, np.eye(3)[labels[labels.train_valid_test == train_valid_test].diagnosis - 1]

In [0]:
test_data, test_labels = load_datasets('test')
train_data, train_labels = load_datasets('train')
valid_data, valid_data = load_datasets('valid')

Loaded all valid datasets


In [0]:
def create_model():
    kernel_regularizer = tf.keras.regularizers.l2(0.001)
    he_normal = tf.keras.initializers.he_normal(seed=0)
    lecun_normal = tf.keras.initializers.lecun_normal(seed=0)

    input_3d = (62, 96, 96, 1)
    pool_3d = (2, 2, 2)

    model = Sequential()
    
    model.add(tf.keras.layers.InputLayer(input_shape=input_3d))
    model.add(Conv3D(filters=8,
                     kernel_size=3,
                     kernel_regularizer=kernel_regularizer,
                     kernel_initializer=he_normal))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool3D(pool_size=pool_3d, name='pool1'))

    model.add(Conv3D(filters=8,
                     kernel_size=3,
                     kernel_regularizer=kernel_regularizer,
                     kernel_initializer=he_normal))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool3D(pool_size=pool_3d, name='pool2'))

    model.add(Conv3D(filters=8,
                     kernel_size=3,
                     kernel_regularizer=kernel_regularizer,
                     kernel_initializer=he_normal))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool3D(pool_size=pool_3d, name='pool3'))

    model.add(Flatten())
    model.add(Dense(1024, name='dense1', kernel_regularizer=kernel_regularizer, kernel_initializer=he_normal))
    model.add(LeakyReLU())
    model.add(Dropout(0.5, name='dropout1'))

    model.add(Dense(512, activation='relu', name='dense2', kernel_regularizer=kernel_regularizer, kernel_initializer=he_normal))
    model.add(LeakyReLU())
    model.add(Dropout(0.5, name='dropout2'))

    model.add(Dense(3, activation='softmax', name='softmax', kernel_initializer=lecun_normal))

    model.compile(loss='categorical_crossentropy',
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.keras.metrics.CategoricalAccuracy()])

    return model

In [0]:
reducelr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_categorical_accuracy',
                                                         factor=0.2,
                                                         patience=5,
                                                         min_delta=0.01,
                                                         verbose=1)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(os.path.join('weights', 'best_weights.h5'),
                                                         monitor='val_categorical_accuracy',
                                                         verbose=1,
                                                         save_best_only=True,
                                                         mode='max')

callbacks_list = [checkpoint_callback, reducelr_callback]

In [0]:
model = create_model()
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d (Conv3D)              (None, 60, 94, 94, 8)     224       
_________________________________________________________________
batch_normalization (BatchNo (None, 60, 94, 94, 8)     32        
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 60, 94, 94, 8)     0         
_________________________________________________________________
pool1 (MaxPooling3D)         (None, 30, 47, 47, 8)     0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 28, 45, 45, 8)     1736      
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 45, 45, 8)     32        
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 28, 45, 45, 8)     0

In [0]:
history = model.fit(x=train_data, y=train_labels, validation_data=(valid_data, valid_labels), epochs=25, callbacks=callbacks_list, verbose=1)

Epoch 1/25
Epoch 00001: val_categorical_accuracy improved from -inf to 0.21839, saving model to weights/best_weights.h5
Epoch 2/25
Epoch 00002: val_categorical_accuracy improved from 0.21839 to 0.37931, saving model to weights/best_weights.h5
Epoch 3/25
Epoch 00003: val_categorical_accuracy did not improve from 0.37931
Epoch 4/25
Epoch 00004: val_categorical_accuracy improved from 0.37931 to 0.40230, saving model to weights/best_weights.h5
Epoch 5/25





Epoch 00005: val_categorical_accuracy did not improve from 0.40230
Epoch 6/25





Epoch 00006: val_categorical_accuracy improved from 0.40230 to 0.41379, saving model to weights/best_weights.h5
Epoch 7/25





Epoch 00007: val_categorical_accuracy did not improve from 0.41379
Epoch 8/25





Epoch 00008: val_categorical_accuracy did not improve from 0.41379
Epoch 9/25





Epoch 00009: val_categorical_accuracy did not improve from 0.41379
Epoch 10/25





Epoch 00010: val_categorical_accuracy did not improve from 0.41379
Epoch 11/25





Epoch 00011: val_categorical_accuracy did not improve from 0.41379

Epoch 00011: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 12/25





Epoch 00012: val_categorical_accuracy did not improve from 0.41379
Epoch 13/25





Epoch 00013: val_categorical_accuracy did not improve from 0.41379
Epoch 14/25





Epoch 00014: val_categorical_accuracy did not improve from 0.41379
Epoch 15/25





Epoch 00015: val_categorical_accuracy did not improve from 0.41379
Epoch 16/25





Epoch 00016: val_categorical_accuracy did not improve from 0.41379

Epoch 00016: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 17/25





Epoch 00017: val_categorical_accuracy did not improve from 0.41379
Epoch 18/25





Epoch 00018: val_categorical_accuracy did not improve from 0.41379
Epoch 19/25





Epoch 00019: val_categorical_accuracy did not improve from 0.41379
Epoch 20/25





Epoch 00020: val_categorical_accuracy did not improve from 0.41379
Epoch 21/25





Epoch 00021: val_categorical_accuracy did not improve from 0.41379

Epoch 00021: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 22/25





Epoch 00022: val_categorical_accuracy did not improve from 0.41379
Epoch 23/25





Epoch 00023: val_categorical_accuracy did not improve from 0.41379
Epoch 24/25





Epoch 00024: val_categorical_accuracy did not improve from 0.41379
Epoch 25/25





Epoch 00025: val_categorical_accuracy did not improve from 0.41379
