# Train a neural network for MRI representation learning using progressive auto-encoder

In this notebook, we will use Nobrainer to train a model for brain MRI representation learning. Brain MRI representation learning using autoencoder structures is a useful unsupervised task for medical data compression or downstream supervised tasks. In particular, Nobrainer allows the training of an encoder network on its own to allow a projection of data into a predefined manifold. This can be used in the evaluation of performances of GAN models. 

In the following cells we will:
1. Get sample T1-weighted MR scans as features
2. Convert the data to TFRecords format
3. Instantiate a progressive convolutional neural network for encoder and decoder
4. Create a Dataset of the features
5. Instantiate a trainer and choose a loss function to use
6. Define whether the decoder network is fixed
7. When working with a fixed decoder, download pre-trained decoders
8. Train on part of the data in two phases (transition and resolution)
9. Repeat steps 4-8 for each growing resolution


# Google Colaboratory

If you are using Colab, please switch your runtime to GPU. To do this, select `Runtime > Change runtime type` in the top menu. Then select GPU under `Hardware accelerator`. A GPU greatly speeds up training.

In [None]:
!pip install --no-cache-dir nilearn nobrainer

# Imports

In [None]:
import glob
import nobrainer
import tensorflow as tf

# Get sample features and labels

We use 9 pairs of volumes for training and 1 pair of volunes for evaluation. Many more volumes would be required to train a model for any useful purpose.

In [None]:
csv_of_filepaths = nobrainer.utils.get_data()
filepaths = nobrainer.io.read_csv(csv_of_filepaths)

train_paths = filepaths[:9]


# Convert medical images to TFRecords

Remember how many full volumes are in the TFRecords files. This will be necessary to know how many steps are in on training epoch. The default training method needs to know this number, because Datasets don't always know how many items they contain.



In [None]:
!mkdir -p data

In [None]:
resolution_batch_size_map = {8: 1, 16: 1, 32: 1, 64: 1, 128: 1, 256: 1} 
resolutions = sorted(list(resolution_batch_size_map.keys()))

In [None]:
nobrainer.tfrecord.write(
    features_labels=train_paths,
    filename_template='data/data-train_shard-{shard:03d}.tfrec',
    examples_per_shard=3, # change for larger dataset
    multi_resolution=True,
    resolutions=resolutions)

# Set Hyperparameters

In [None]:
latent_size = 1024
e_fmap_base = 2048
d_fmap_base = 2048
# latent_size = 1024 uncomment when sufficient compute is available
# g_fmap_base = 4096 uncomment when sufficient compute is available
# d_fmap_base = 4096 uncomment when sufficient compute is available
num_parallel_calls = 4
iterations = int(10)
# iterations = int(300e3) uncomment when sufficient compute is available
lr = 1e-4

# Creating Logging Directories

In [None]:
from pathlib import Path

save_dir = 'pae'

save_dir = Path(save_dir)
generated_dir = save_dir.joinpath('generated')
model_dir = save_dir.joinpath('saved_models')
log_dir = save_dir.joinpath('logs')

save_dir.mkdir(exist_ok=True)
generated_dir.mkdir(exist_ok=True)
model_dir.mkdir(exist_ok=True)

# Instantiate a neural network 

In [None]:
encoder, decoder = nobrainer.models.progressiveae(latent_size, e_fmap_base=e_fmap_base, d_fmap_base=d_fmap_base)

# Set pretrained decoder neural network paths

In [None]:
fixed = False

if fixed:
    path = './mypaths/saved_models/'                     # if fixed=True, specify the folder in which *.h5 files are stored
    model_paths = iter(sorted(glob.glob(path+'/*.h5')))

# Training an autoencoder progressively for each resolution

In [None]:
from nobrainer import training
for resolution in resolutions:

    # create a train dataset with features for resolution
    dataset_train = nobrainer.dataset.get_dataset(
        file_pattern="data/*res-%03d*.tfrec"%(resolution),
        batch_size=resolution_batch_size_map[resolution],
        num_parallel_calls=num_parallel_calls,
        volume_shape=(resolution, resolution, resolution),
        n_classes=1, # dummy labels as this is unsupervised training
        scalar_label=True,
        normalizer=None
    )


    # grow the networks by one (2^x) resolution
    encoder.add_resolution()

    if fixed:
        decoder = tf.keras.models.load_model(next(model_paths))
    else:
        decoder.add_resolution()

    # instantiate a progressive training helper
    progressive_ae_trainer = training.ProgressiveAETrainer(
        encoder=encoder,
        decoder=decoder,
        fixed = fixed,)

    # compile with optimizers and loss function of choice
    progressive_ae_trainer.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8),
        loss_fn=tf.keras.losses.MeanSquaredError(),
        )

    steps_per_epoch = iterations//resolution_batch_size_map[resolution]
    # save_best_only is set to False as it is an adversarial loss
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(str(model_dir), save_weights_only=True, save_best_only=False, save_freq=10)

    # Train at resolution
    print('Resolution : {}'.format(resolution),flush=True)

    print('Transition phase')
    progressive_ae_trainer.fit(
        dataset_train,
        phase='transition',
        resolution=resolution,
        steps_per_epoch=steps_per_epoch, # necessary for repeat dataset
        callbacks=[model_checkpoint_callback])

    print('Resolution phase')
    progressive_ae_trainer.fit(
        dataset_train,
        phase='resolution',
        resolution=resolution,
        steps_per_epoch=steps_per_epoch,
        callbacks=[model_checkpoint_callback])

    #save the final weights
    #print('Saving')
    #encoder.save(str(model_dir.joinpath('encoder_res_{}'.format(resolution))))
    #if not fixed: decoder.save(str(model_dir.joinpath('decoder_res_{}'.format(resolution))))