In [None]:
import os
import numpy as np
from glob import glob

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import ( 
    Input, Conv3D, SpatialDropout3D, BatchNormalization, 
    GlobalAveragePooling3D, Dense
)

from tensorflow.keras.optimizers import Adam

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/MAID/CV/data/data.zip

Archive:  /content/drive/MyDrive/MAID/CV/data/data.zip
  inflating: data/sub-100307/func/sub-100307_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-100408/func/sub-100408_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101006/func/sub-101006_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101107/func/sub-101107_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101309/func/sub-101309_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101410/func/sub-101410_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101915/func/sub-101915_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-102008/func/sub-102008_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-102311/func/sub-102311_task-WM_run-LR_space-MNI152N

In [None]:
data_path = 'data/'
n_states_training = 4

In [None]:
subjects = np.sort(
    np.unique(
        [
        int(p.split('sub-')[1])
        for p in os.listdir(data_path)
        if p.startswith('sub-')
        ]
    )
)

In [None]:
subjects

array([100307, 100408, 101006, 101107, 101309, 101410, 101915, 102008,
       102311, 102816, 103111, 103414, 103515, 103818, 104820, 105014,
       105115, 105216, 106016, 106319])

In [None]:
subjects_training = np.random.choice(
    subjects,
    int(subjects.size*3/4.),
    replace=False
)

subjects_validation = np.array(
    [
        s for s in subjects
        if s not in subjects_training
    ]
)

subjects_training, subjects_validation

(array([106016, 100408, 105014, 103111, 101309, 101410, 106319, 105216,
        102311, 103414, 101107, 104820, 103515, 101006, 102816]),
 array([100307, 101915, 102008, 103818, 105115]))

In [None]:
train_files = []
for subject in subjects_training:
    train_files.extend(glob('data/sub-' + str(subject) + '/func/*.tfrecords'))

train_files

['data/sub-106016/func/sub-106016_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-100408/func/sub-100408_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-105014/func/sub-105014_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-103111/func/sub-103111_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-101309/func/sub-101309_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-101410/func/sub-101410_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-106319/func/sub-106319_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-105216/func/sub-105216_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-102311/func/sub-102311_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-103414/func/sub-103414_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-101107/fu

In [None]:
validation_files = []
for subject in subjects_validation:
    validation_files.extend(glob('data/sub-' + str(subject) + '/func/*.tfrecords'))

validation_files

['data/sub-100307/func/sub-100307_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-101915/func/sub-101915_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-102008/func/sub-102008_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-103818/func/sub-103818_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords',
 'data/sub-105115/func/sub-105115_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords']

In [None]:
def parse_func_tfr(
    example_proto,
    nx, ny, nz,
    n_onehot=None,
    onehot_idx=None,
    only_parse_XY=False,
    transpose_xyz=False,
    add_channel_dim=False):
    """Parse TFR-data
    Args:
        example_proto: Single example from TFR-file
        nx, ny, nz: Integers indicating the x-/y-/z-dimensions
            of the fMRI data stored in the TFR-files
        n_onehot: Total number of states across tasks
        onehot_idx: idx that is returned from state-onehot;
            e.g., if state-onehot encoding has 20 values in total,
            but we only want to train with values 5-10,
            onehot_idx can be set to np.arange(4,10) 
    Returns:
        Parsed data stored in TFR-files. Specifically, the:
        volume: Ndarray of fMRI volume activations
        task_id: Integer ID of the HCP task 
        subject_id: Integer ID of the subject
        run_id: Integer ID of the run
        tr: TR of fmri volume (float)
        state: Integer cognitive state of the volume
        state_onehot: One-hot encoding of the states
        only_parse_XY: Bool indicating whether only volume 
            and y onehot encoding should be returned,
            as needed for integration with keras. If False,
            volume, task_id, subject_id, run_id, volume_idx,
            label, label_onehot are returned
    """
    features = {'volume': tf.io.FixedLenFeature([nx*ny*nz], tf.float32),
                'task_id': tf.io.FixedLenFeature([1], tf.int64),
                'subject_id': tf.io.FixedLenFeature([1], tf.int64),
                'run_id': tf.io.FixedLenFeature([1], tf.int64),
                'tr': tf.io.FixedLenFeature([1], tf.float32),
                'state': tf.io.FixedLenFeature([1], tf.int64),
                'onehot': tf.io.FixedLenFeature([n_onehot], tf.int64)}
    parsed_features = tf.io.parse_single_example(example_proto, features)
    if onehot_idx is None:
        onehot_idx = np.arange(n_onehot)
    volume = tf.cast(tf.reshape(parsed_features["volume"], [nx, ny, nz]), tf.float32)
    if add_channel_dim:
        volume = tf.cast(tf.reshape(parsed_features["volume"], [nx, ny, nz, 1]), tf.float32)
        if transpose_xyz:
            volume = tf.transpose(volume, perm=[2, 1, 0, 3])
    elif transpose_xyz:
        volume = tf.transpose(volume, perm=[2, 1, 0])
    volume = tf.where(tf.math.is_nan(volume), tf.zeros_like(volume), volume)
    volume = tf.where(tf.math.is_inf(volume), tf.ones_like(volume)*1e4, volume)
    onehot = tf.cast(tf.gather(parsed_features["onehot"], onehot_idx), tf.int64)
    if only_parse_XY:
        return (volume, onehot)
    else:
        return {"volume": volume,
                "onehot": onehot,
                "task_id": parsed_features["task_id"],
                "subject_id": parsed_features["subject_id"],
                "run_id": parsed_features["run_id"],
                "tr": parsed_features["tr"],
                "state": parsed_features["state"]}

def make_dataset(
    files,
    n_onehot,
    batch_size,
    nx=91, ny=109, nz=91,
    onehot_idx=None,
    repeat=True,
    shuffle=True,
    only_parse_XY=False,
    n_workers=4,
    shuffle_buffer_size=500,
    scope_name='train',
    transpose_xyz=False,
    add_channel_dim=False):
    """Make iteratable dataset from TFR files."""
    if onehot_idx is None:
        onehot_idx = np.arange(n_onehot)
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(lambda x: parse_func_tfr(x,
        nx=nx, ny=ny, nz=nx,
        n_onehot=n_onehot,
        onehot_idx=onehot_idx,
        only_parse_XY=only_parse_XY,
        transpose_xyz=transpose_xyz,
        add_channel_dim=add_channel_dim), n_workers)
    # dataset = dataset.apply(tf.data.experimental.ignore_errors())
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size)
    if repeat:
        dataset = dataset.repeat()
    return dataset

In [None]:
n_onehot=20
n_workers = 4
batch_size = 16
shuffle_buffer_size = 50
input_shape = (91, 109, 91)
onehot_idx = np.array([16, 17, 18, 19])

blocks = 5
epochs = 50
n_classes = 4
training_steps = 100
validation_steps = 100

learning_rate = 0.0001

In [None]:
train_dataset = make_dataset(
    files=train_files,
    batch_size=batch_size,
    nx=input_shape[0],
    ny=input_shape[1],
    nz=input_shape[2],
    shuffle=True,
    only_parse_XY=True,
    transpose_xyz=True,
    add_channel_dim=True,
    repeat=True,
    n_onehot=n_onehot,
    onehot_idx=onehot_idx,
    shuffle_buffer_size=shuffle_buffer_size,
    n_workers=n_workers
)

In [None]:
validation_dataset = make_dataset(
    files=validation_files,
    batch_size=batch_size,
    nx=input_shape[0],
    ny=input_shape[1],
    nz=input_shape[2],
    shuffle=True,
    only_parse_XY=True,
    transpose_xyz=True,
    add_channel_dim=True,
    repeat=True,
    n_onehot=n_onehot,
    onehot_idx=onehot_idx,
    shuffle_buffer_size=shuffle_buffer_size,
    n_workers=n_workers
)

In [None]:
def conv_block(x, filters, kernel_size, strides, dropout_rate):
    x = Conv3D(
        filters=filters, 
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        data_format='channels_last',
        activation='relu'
    )(x)
    x = SpatialDropout3D(rate=dropout_rate)(x)
    x = BatchNormalization()(x)

    return x

In [None]:
# create model

inputs = Input(shape=[*input_shape, 1])

x = conv_block(
    inputs, 
    filters=2**3, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**3, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**3, 
    kernel_size=3,
    strides=2,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**3, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**4, 
    kernel_size=3,
    strides=2,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**4, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**5, 
    kernel_size=3,
    strides=2,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**5, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**6, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**7, 
    kernel_size=3,
    strides=2,
    dropout_rate=0.2
)

x = conv_block(
    x, 
    filters=2**7, 
    kernel_size=3,
    strides=1,
    dropout_rate=0.2
)

x = Conv3D(
    filters=n_classes, 
    kernel_size=1,
    strides=1,
    padding='same',
    data_format='channels_last',
    activation=None
)(x)

x = SpatialDropout3D(rate=0.2)(x)
x = GlobalAveragePooling3D()(x)
x = Dense(n_classes, activation='softmax')(x)

model = Model(inputs, x)

In [None]:
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 91, 109, 91, 1)]  0         
                                                                 
 conv3d_12 (Conv3D)          (None, 91, 109, 91, 8)    224       
                                                                 
 spatial_dropout3d_12 (Spati  (None, 91, 109, 91, 8)   0         
 alDropout3D)                                                    
                                                                 
 batch_normalization_11 (Bat  (None, 91, 109, 91, 8)   32        
 chNormalization)                                                
                                                                 
 conv3d_13 (Conv3D)          (None, 91, 109, 91, 8)    1736      
                                                                 
 spatial_dropout3d_13 (Spati  (None, 91, 109, 91, 8)   0   

In [None]:
stored_model = tf.keras.models.load_model('/content/drive/MyDrive/MAID/CV/data/model-3D_DeepLight_desc-pretrained_model.hdf5')

In [None]:
stored_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv3d (Conv3D)             (None, 91, 109, 91, 8)    224       
                                                                 
 spatial_dropout3d (SpatialD  (None, 91, 109, 91, 8)   0         
 ropout3D)                                                       
                                                                 
 batch_normalization (BatchN  (None, 91, 109, 91, 8)   32        
 ormalization)                                                   
                                                                 
 conv3d_1 (Conv3D)           (None, 91, 109, 91, 8)    1736      
                                                                 
 spatial_dropout3d_1 (Spatia  (None, 91, 109, 91, 8)   0         
 lDropout3D)                                                     
                                                        

In [None]:
for model_layer, stored_model_layer in zip(model.layers[1:], stored_model.layers):
    try:
        # print(model_layer.name, stored_model_layer.name)
        model_layer.set_weights(stored_model_layer.get_weights())
    except:
        print('shapes dont match for', model_layer.name)

shapes dont match for conv3d_21
shapes dont match for batch_normalization_20
shapes dont match for conv3d_22
shapes dont match for conv3d_23
shapes dont match for global_average_pooling3d_1
shapes dont match for dense_1


In [None]:
model.compile(
    optimizer=Adam(learning_rate),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    train_dataset,
    epochs=epochs,
    steps_per_epoch=training_steps,
    validation_data=validation_dataset,
    validation_steps=validation_steps,
    verbose=True,
    use_multiprocessing=True,
    workers=n_workers
)

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40