# Train priors for MRI image reconstruction
**Authors**: [Guanxiong Luo](mailto:guanxiong.luo@med.uni-goettingen.de), [Nick Scholand](mailto:nick.scholand@med.uni-goettingen.de), [Christian Holme](mailto:christian.holme@med.uni-goettingen.de)

**Have fun with it! If you have any questions, don't hesitate to drop us a line.**

## 1. Install `spreco`

Download the package spreco and install it with the `pip` command.

In [None]:
%%bash
#pip uninstall tensorflow-gpu
#pip install tensorflow-gpu==2.4.1
git clone https://github.com/mrirecon/spreco.git
cd spreco
pip install .

## 2. Import modules
1. `pipe` create a dataloader
2. `trainer` train the prior according to the configuration file

In [3]:
from spreco.common import utils,pipe
from spreco.workbench.trainer import trainer
import os

#### Explanation of the config file for training

The configuration file consists of four parts: model, saving, data and gpu.


    ```yaml
    # model
    model: 'NCSN'
    batch_size: 2
    input_shape: [256, 256, 2]
    data_chns: 'CPLX'   # complex input 
    lr: 0.0001          # learning rate
    begin_sigma: 0.3    # sigma_max
    end_sigma: 0.01     # sigma_min
    anneal_power: 2.
    nr_levels: 10       # N
    affine_x: True
    nonlinearity: 'elu' # activation function
    nr_filters: 64      # base number for the number of filters

    # saving
    seed: 1234          # random seed
    max_keep: 100
    max_epochs: 1000
    save_interval: 50   # take snapshot of model per 50
    saved_name: test_brain
    log_folder: /content/logs     # location for saving models, and training logs

    # data
    train_data_path: /content/spreco/data/brain_mnist/train
    test_data_path: /content/spreco/data/brain_mnist/test
    pattern: "*.npz"    # all the files matching this name pattern will be loaded for training or testing.
    num_prepare: 10
    print_loss: True
    
    # gpu
    nr_gpu: 2       # number of gpus
    gpu_id: '1,2'   # PCI_BUS_ID
    ```

## 3. Prepare training files
Uncompress the zip file, then specify the path of data

1. `train_data_path`: 
2. `test_data_path`: 
3. `pattern`:

In [1]:
%%bash
# download the data
curl https://zenodo.org/record/6521188/files/brain_mnist.tar?download=1 --output brain_mnist.tar
mkdir spreco/data
tar xf brain_mnist.tar -C spreco/data

Process is terminated.


Please check the paths of config file, training data and logging folder.

In [4]:
config_path='/content/spreco/scripts/train_config.yaml'
config =  utils.load_config(config_path)

train_files = utils.find_files(config['train_data_path'], config['pattern'])
test_files  = utils.find_files(config['test_data_path'], config['pattern'])

### create dataloader 

import numpy as np

def npz_loader(x):
    return utils.npz_loader(x, 'rss')

def squeeze(x):
    return np.squeeze(x)

def normalize(x):
    return utils.normalize_with_max(x)

def slice_image(x):
    return utils.slice_image(x, [256, 256, 2])

def randint(x, dtype='int32'):
    # x is a dummy arg
    return np.random.randint(0, config['nr_levels'], (1), dtype=dtype)

parts_funcs = [[npz_loader, squeeze, normalize, slice_image], [randint]]

train_pipe = pipe.create_pipe(parts_funcs,
                    files=train_files,
                    batch_size=config['batch_size']*config['nr_gpu'],
                    shape_info=[config['input_shape'], [1]], names=['inputs', 'h'])

test_pipe  = pipe.create_pipe(parts_funcs, test_files,
                            batch_size = config['batch_size']*config['nr_gpu'],
                            shape_info=[config['input_shape'], [1]], names=['inputs', 'h'])

## Start the trainer

1. create the trainer with the given data pipes and configurations
2. tracking the training time

In [None]:
go = trainer(train_pipe, test_pipe, config)
utils.log_to(os.path.join(go.log_path, 'training files'), train_files, prefix="#")
utils.log_to(os.path.join(go.log_path, 'config.yaml'), [utils.get_timestamp(), "The training is starting"], prefix="#")
go.train()
utils.log_to(os.path.join(go.log_path, 'config.yaml'), [utils.get_timestamp(), "The training is ending"], prefix="#")
utils.color_print('TRAINING FINISHED')