# Train a neural network for mri generation using progressive growing

In this notebook, we will use Nobrainer to train a model for brain MRI generation. Brain MRI generation is a useful task in synthetically creating neuroimaging data. We will use a Generative Adversarial Network to model the generation and use a progressive growing training method for high quality generation at higher resolutions.

In the following cells, we will:

1. Get sample T1-weighted MR scans as features.
2. Convert the data to TFRecords format.
3. Instantiate a progressive convolutional neural network for generator and discriminator.
4. Create a Dataset of the features.
5. Instantiate a trainer and choose a loss function to use.
6. Train on part of the data in two phases (transition and resolution).
7. Repeat steps 4-6 for each growing resolution.
8. Generate some images using trained model

## Google Colaboratory

If you are using Colab, please switch your runtime to GPU. To do this, select `Runtime > Change runtime type` in the top menu. Then select GPU under `Hardware accelerator`. A GPU greatly speeds up training.

In [None]:
!pip install --no-cache-dir nilearn https://github.com/neuronets/nobrainer/archive/refs/heads/enh/api.zip

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow as tf

In [None]:
import nobrainer

# Get sample features and labels

We use 9 pairs of volumes for training and 1 pair of volumes for evaluation. Many more volumes would be required to train a model for any useful purpose.

In [None]:
csv_of_filepaths = nobrainer.utils.get_data()
filepaths = nobrainer.io.read_csv(csv_of_filepaths)

train_paths = filepaths[:9]

# Convert medical images to TFRecords

Remember how many full volumes are in the TFRecords files. This will be necessary to know how many steps are in on training epoch. The default training method needs to know this number, because Datasets don't always know how many items they contain.

In [None]:
from nobrainer.dataset import write_multi_resolution

In [None]:
datasets = write_multi_resolution(train_paths, 
                                  tfrecdir="data/generate",
                                  n_processes=None)

In [None]:
datasets

The datasets will look like the following. One can adjust the batch size depending on compute power.

```python
datasets = {8: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-008.tfrec',
  'batch_size': 1,
  'normalizer': None},
 16: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-016.tfrec',
  'batch_size': 1,
  'normalizer': None},
 32: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-032.tfrec',
  'batch_size': 1,
  'normalizer': None},
 64: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-064.tfrec',
  'batch_size': 1,
  'normalizer': None},
 128: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-128.tfrec',
  'batch_size': 1,
  'normalizer': None},
 256: {'file_pattern': '/home/jovyan/temp/nobrainer/guide/data/*res-256.tfrec',
  'batch_size': 1,
  'normalizer': None}}
```

For example:

```python
datasets[8]["batch_size"] = 32
datasets[16]["batch_size"] = 16
datasets[32]["batch_size"] = 8
datasets[64]["batch_size"] = 4
datasets[128]["batch_size"] = 1
datasets[256]["batch_size"] = 1
```

In [None]:
from nobrainer.processing.generation import ProgressiveGeneration
gen = ProgressiveGeneration()

In [None]:
gen.fit(datasets)

In [None]:
from nilearn import plotting

img = gen.generate()
plotting.plot_anat(anat_img=img, draw_cross=False)

We can warm start the training, but for the moment it will only retrain using the final resolution of the data or higher. 

In [None]:
gen.fit(datasets, warm_start=True)

In [None]:
from nilearn import plotting

img = gen.generate()
plotting.plot_anat(anat_img=img, draw_cross=False)