<a href="https://colab.research.google.com/github/faizankshaikh/evaluating-deeplight-transfer/blob/master/experiments/train_deeplight3D_TF2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Prepare notebook

In [1]:
#@title 1.1 Import required libs and modules

!pip install -q wandb

[K     |████████████████████████████████| 1.8 MB 5.0 MB/s 
[K     |████████████████████████████████| 145 kB 50.1 MB/s 
[K     |████████████████████████████████| 181 kB 48.8 MB/s 
[K     |████████████████████████████████| 63 kB 1.2 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
#@title 1.1 Import required libs and modules (contd)

import os
import wandb
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (
    BatchNormalization,
    Conv3D,
    Dense,
    GlobalAveragePooling3D,
    Input,
    SpatialDropout3D,
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

In [22]:
#@title 1.1 Import required libs and modules

wandb.login()
wandb.init(project="DeepLight")



VBox(children=(Label(value='10.429 MB of 10.429 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
accuracy,▁▁▁▁▂▃▃▄▄▄▄▄▅▅▅▆▅▆▆▆▆▆▇▇▇▇█▇▇████
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
loss,█▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▄▃▂▃▃▂▂▂▂▂▁▁▁▁▁▁▁
val_accuracy,▁▃▄▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████
val_loss,█▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
accuracy,0.55757
best_epoch,31.0
best_val_loss,0.8875
epoch,32.0
loss,0.98971
val_accuracy,0.63003
val_loss,0.88901


In [4]:
#@title 1.2 Connect google drive

from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


# 2. Setup dataset for training

In [5]:
#@title 2.1 Download preprocessed data

!unzip /content/drive/MyDrive/MAID/CV/data/data.zip

Archive:  /content/drive/MyDrive/MAID/CV/data/data.zip
  inflating: data/sub-100307/func/sub-100307_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-100408/func/sub-100408_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101006/func/sub-101006_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101107/func/sub-101107_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101309/func/sub-101309_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101410/func/sub-101410_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-101915/func/sub-101915_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-102008/func/sub-102008_task-WM_run-LR_space-MNI152NLin6Asym_res-2_desc-tfr.tfrecords  
  inflating: data/sub-102311/func/sub-102311_task-WM_run-LR_space-MNI152N

In [6]:
#@title 2.2 Define data path

data_path = "data/"

In [7]:
#@title 2.3 Create training and validation dataset

subjects = np.sort(
    np.unique(
        [int(p.split("sub-")[1]) for p in os.listdir(data_path) if p.startswith("sub-")]
    )
)

subjects_training = np.random.choice(
    subjects, (subjects.size * 3) // 4, replace=False
)

subjects_validation = np.array([s for s in subjects if s not in subjects_training])

print("Training subjects: ", list(subjects_training))

print("Validation subjects: ", list(subjects_validation))

Training subjects:  [101410, 102008, 101915, 105115, 106016, 103111, 100307, 103414, 100408, 101006, 101309, 105216, 101107, 102816, 106319]
Validation subjects:  [102311, 103515, 103818, 104820, 105014]


In [8]:
#@title 2.3 Create training and validation dataset (contd)

train_files = []
for subject in subjects_training:
    train_files.extend(glob("data/sub-" + str(subject) + "/func/*.tfrecords"))

validation_files = []
for subject in subjects_validation:
    validation_files.extend(glob("data/sub-" + str(subject) + "/func/*.tfrecords"))

In [9]:
#@title 2.3 Create training and validation dataset (contd)

# functions to read preprocessed data 
# taken from https://github.com/athms/evaluating-deeplight-transfer


def parse_func_tfr(
    example_proto,
    nx,
    ny,
    nz,
    n_onehot=None,
    onehot_idx=None,
    only_parse_XY=False,
    transpose_xyz=False,
    add_channel_dim=False,
):
    """Parse TFR-data
    Args:
        example_proto: Single example from TFR-file
        nx, ny, nz: Integers indicating the x-/y-/z-dimensions
            of the fMRI data stored in the TFR-files
        n_onehot: Total number of states across tasks
        onehot_idx: idx that is returned from state-onehot;
            e.g., if state-onehot encoding has 20 values in total,
            but we only want to train with values 5-10,
            onehot_idx can be set to np.arange(4,10) 
    Returns:
        Parsed data stored in TFR-files. Specifically, the:
        volume: Ndarray of fMRI volume activations
        task_id: Integer ID of the HCP task 
        subject_id: Integer ID of the subject
        run_id: Integer ID of the run
        tr: TR of fmri volume (float)
        state: Integer cognitive state of the volume
        state_onehot: One-hot encoding of the states
        only_parse_XY: Bool indicating whether only volume 
            and y onehot encoding should be returned,
            as needed for integration with keras. If False,
            volume, task_id, subject_id, run_id, volume_idx,
            label, label_onehot are returned
    """
    features = {
        "volume": tf.io.FixedLenFeature([nx * ny * nz], tf.float32),
        "task_id": tf.io.FixedLenFeature([1], tf.int64),
        "subject_id": tf.io.FixedLenFeature([1], tf.int64),
        "run_id": tf.io.FixedLenFeature([1], tf.int64),
        "tr": tf.io.FixedLenFeature([1], tf.float32),
        "state": tf.io.FixedLenFeature([1], tf.int64),
        "onehot": tf.io.FixedLenFeature([n_onehot], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    if onehot_idx is None:
        onehot_idx = np.arange(n_onehot)
    volume = tf.cast(tf.reshape(parsed_features["volume"], [nx, ny, nz]), tf.float32)
    if add_channel_dim:
        volume = tf.cast(
            tf.reshape(parsed_features["volume"], [nx, ny, nz, 1]), tf.float32
        )
        if transpose_xyz:
            volume = tf.transpose(volume, perm=[2, 1, 0, 3])
    elif transpose_xyz:
        volume = tf.transpose(volume, perm=[2, 1, 0])
    volume = tf.where(tf.math.is_nan(volume), tf.zeros_like(volume), volume)
    volume = tf.where(tf.math.is_inf(volume), tf.ones_like(volume) * 1e4, volume)
    onehot = tf.cast(tf.gather(parsed_features["onehot"], onehot_idx), tf.int64)
    if only_parse_XY:
        return (volume, onehot)
    else:
        return {
            "volume": volume,
            "onehot": onehot,
            "task_id": parsed_features["task_id"],
            "subject_id": parsed_features["subject_id"],
            "run_id": parsed_features["run_id"],
            "tr": parsed_features["tr"],
            "state": parsed_features["state"],
        }


def make_dataset(
    files,
    n_onehot,
    batch_size,
    nx=91,
    ny=109,
    nz=91,
    onehot_idx=None,
    repeat=True,
    shuffle=True,
    only_parse_XY=False,
    n_workers=4,
    shuffle_buffer_size=500,
    scope_name="train",
    transpose_xyz=False,
    add_channel_dim=False,
):
    """Make iteratable dataset from TFR files."""
    if onehot_idx is None:
        onehot_idx = np.arange(n_onehot)
    dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
    dataset = dataset.map(
        lambda x: parse_func_tfr(
            x,
            nx=nx,
            ny=ny,
            nz=nx,
            n_onehot=n_onehot,
            onehot_idx=onehot_idx,
            only_parse_XY=only_parse_XY,
            transpose_xyz=transpose_xyz,
            add_channel_dim=add_channel_dim,
        ),
        n_workers,
    )
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size)
    if repeat:
        dataset = dataset.repeat()
    return dataset

In [10]:
#@title 2.3 Create training and validation dataset (contd)

# define parameters
n_onehot = 20
n_workers = 4
batch_size = 20
training_steps = 100
validation_steps = 100
shuffle_buffer_size = 50
input_shape = (91, 109, 91)
onehot_idx = np.array([16, 17, 18, 19])

In [11]:
#@title 2.3 Create training and validation dataset (contd)

train_dataset = make_dataset(
    files=train_files,
    batch_size=batch_size,
    nx=input_shape[0],
    ny=input_shape[1],
    nz=input_shape[2],
    shuffle=True,
    only_parse_XY=True,
    transpose_xyz=True,
    add_channel_dim=True,
    repeat=True,
    n_onehot=n_onehot,
    onehot_idx=onehot_idx,
    shuffle_buffer_size=shuffle_buffer_size,
    n_workers=n_workers,
)

validation_dataset = make_dataset(
    files=validation_files,
    batch_size=batch_size,
    nx=input_shape[0],
    ny=input_shape[1],
    nz=input_shape[2],
    shuffle=True,
    only_parse_XY=True,
    transpose_xyz=True,
    add_channel_dim=True,
    repeat=True,
    n_onehot=n_onehot,
    onehot_idx=onehot_idx,
    shuffle_buffer_size=shuffle_buffer_size,
    n_workers=n_workers,
)

# 3. Train DL model

In [23]:
#@title 3.1 Define hyperparameters

epochs = 50
n_classes = 4
learning_rate = 0.001

In [24]:
#@title 3.2 Create model

def conv_block(x, filters, kernel_size, strides, dropout_rate):
    x = Conv3D(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding="same",
        data_format="channels_last",
        activation="relu",
    )(x)
    x = SpatialDropout3D(rate=dropout_rate)(x)
    x = BatchNormalization()(x)

    return x

In [25]:
#@title 3.2 Create model (contd)

inputs = Input(shape=[*input_shape, 1])

x = conv_block(inputs, filters=2 ** 3, kernel_size=3, strides=1, dropout_rate=0.2)

x = conv_block(x, filters=2 ** 3, kernel_size=3, strides=1, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 3, kernel_size=3, strides=2, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 3, kernel_size=3, strides=1, dropout_rate=0.2)

x = conv_block(x, filters=2 ** 4, kernel_size=3, strides=2, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 4, kernel_size=3, strides=1, dropout_rate=0.2)

x = conv_block(x, filters=2 ** 5, kernel_size=3, strides=2, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 5, kernel_size=3, strides=1, dropout_rate=0.2)

x = conv_block(x, filters=2 ** 6, kernel_size=3, strides=2, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 6, kernel_size=3, strides=1, dropout_rate=0.2)

x = conv_block(x, filters=2 ** 7, kernel_size=3, strides=2, dropout_rate=0.2)
x = conv_block(x, filters=2 ** 7, kernel_size=3, strides=1, dropout_rate=0.2)

x = Conv3D(
    filters=n_classes,
    kernel_size=1,
    strides=1,
    padding="same",
    data_format="channels_last",
    activation=None,
)(x)
x = SpatialDropout3D(rate=0.2)(x)
x = GlobalAveragePooling3D()(x)
x = Dense(n_classes, activation="softmax")(x)

model = Model(inputs, x)

In [26]:
#@title 3.2 Create model (contd)

model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 91, 109, 91, 1)]  0         
                                                                 
 conv3d_13 (Conv3D)          (None, 91, 109, 91, 8)    224       
                                                                 
 spatial_dropout3d_13 (Spati  (None, 91, 109, 91, 8)   0         
 alDropout3D)                                                    
                                                                 
 batch_normalization_12 (Bat  (None, 91, 109, 91, 8)   32        
 chNormalization)                                                
                                                                 
 conv3d_14 (Conv3D)          (None, 91, 109, 91, 8)    1736      
                                                                 
 spatial_dropout3d_14 (Spati  (None, 91, 109, 91, 8)   0   

In [27]:
#@title 3.3 Load pretrained model

pretrained_model = tf.keras.models.load_model(
    "/content/drive/MyDrive/MAID/CV/data/model-3D_DeepLight_desc-pretrained_model.hdf5"
)

In [28]:
#@title 3.3 Load pretrained model (contd)

pretrained_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv3d (Conv3D)             (None, 91, 109, 91, 8)    224       
                                                                 
 spatial_dropout3d (SpatialD  (None, 91, 109, 91, 8)   0         
 ropout3D)                                                       
                                                                 
 batch_normalization (BatchN  (None, 91, 109, 91, 8)   32        
 ormalization)                                                   
                                                                 
 conv3d_1 (Conv3D)           (None, 91, 109, 91, 8)    1736      
                                                                 
 spatial_dropout3d_1 (Spatia  (None, 91, 109, 91, 8)   0         
 lDropout3D)                                                     
                                                        

In [29]:
#@title 3.4 Transfer weights from pretrained

for model_layer, pretrained_model_layer in zip(model.layers[1:], pretrained_model.layers):
    try:
        model_layer.set_weights(pretrained_model_layer.get_weights())
    except:
        print("shapes dont match for", model_layer.name)

shapes dont match for conv3d_25
shapes dont match for dense_1


In [30]:
#@title 3.4 Train model

model.compile(
    optimizer=Adam(learning_rate), loss="categorical_crossentropy", metrics=["accuracy"]
)

In [31]:
#@title 3.4 Train model (contd)

estop = EarlyStopping("val_loss", patience=10, restore_best_weights=True, verbose=0)
wb = wandb.keras.WandbCallback(monitor="val_loss")

In [32]:
#@title 3.4 Train model (contd)

history = model.fit(
    train_dataset,
    epochs=epochs,
    steps_per_epoch=training_steps,
    validation_data=validation_dataset,
    validation_steps=validation_steps,
    verbose=True,
    use_multiprocessing=True,
    workers=n_workers,
    callbacks=[estop, wb]
)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
