# Train priors for MRI image reconstruction
**Authors**: [Guanxiong Luo](mailto:guanxiong.luo@med.uni-goettingen.de), [Nick Scholand](mailto:nick.scholand@med.uni-goettingen.de), [Christian Holme](mailto:christian.holme@med.uni-goettingen.de)

**Have fun with it! If you have any questions, don't hesitate to drop us a line.**

## 1. Install `spreco`

Download the package spreco and install it with the `pip` command.

In [None]:
%%bash
git clone https://github.com/mrirecon/spreco.git
cd spreco
pip install .

## 2. Import modules

1. `parts` create a function dict which contains functions for loading data
1. `pipe` create a data pipe based on `parts`
1. `trainer` train the prior according to the configuration file

In [3]:
from spreco.common import utils,pipe
from spreco.common.options import parts
from spreco.workbench.trainer import trainer
import os

### 1. Explanation of the config file for training

The configuration file consists of four parts: model, saving, data and gpu.
Prepare training files and create the data pipe that matches the input of the selected network, and specify the location of them in the file config.yaml.

    ```yaml
    # model
    model: 'NCSN'
    batch_size: 10
    input_shape: [256, 256, 2]
    data_chns: 'CPLX'   # complex input 
    lr: 0.0001          # learning rate
    begin_sigma: 0.3    # sigma_max
    end_sigma: 0.01     # sigma_min
    anneal_power: 2.
    nr_levels: 10       # N
    affine_x: True
    nonlinearity: 'elu' # activation function
    nr_filters: 64      # base number for the number of filters

    # saving
    seed: 1234          # random seed
    max_keep: 100
    max_epochs: 1000
    save_interval: 50   # take snapshot of model per 50
    saved_name: test_brain
    log_folder:         # location for saving models, and training logs

    # data
    train_data_path: /home/ague/data/gluo/dataset/brain_mat/train
    test_data_path: /home/ague/data/gluo/dataset/brain_mat/test
    pattern: "nyu_AXFLAIR_*.npz"    # all the files matching this name pattern will be loaded for training or testing.
    num_prepare: 10
    print_loss: True
    parts: 
        # specify the components that are used to constructed a data pipe
        # load the data with key 'rss' into numpy complex array, then squeeze the array, then normalize it with its maximum magnitude, then represent the complex image (width,height,1) with the float array (width,height,2), then crop the float array into the specified shape
        - [{func: 'npz_loader', key: 'rss'}, {func: 'squeeze'}, {func: 'normalize_with_max'}, {func: 'slice_image', shape: [256, 256, 2]}] 
        # function to generate noise indices
        - [{func: 'randint', nr_levels: 10}]

    # gpu
    nr_gpu: 2       # number of gpus
    gpu_id: '1,2'   # PCI_BUS_ID
    ```

## 3. Prepare training files
Uncompress the zip file, then specify the path of data

1. `train_data_path`: 
2. `test_data_path`: 
3. `pattern`:

`log_folder` is the path where intermediate snapshots of the model and model specs are stored.

In [None]:
%%bash
unzip ./data/brain_mnist.zip

In [4]:
config_path='/home/gluo/spreco/scripts/train_config.yaml'
config =  utils.load_config(config_path)

train_files = utils.find_files(config['train_data_path'], config['pattern'])
test_files  = utils.find_files(config['test_data_path'], config['pattern'])

parts_funcs = parts.parse(config['parts'])

###
train_pipe = pipe.create_pipe(parts_funcs,
                    files=train_files,
                    batch_size=config['batch_size']*config['nr_gpu'],
                    shape_info=[config['input_shape'], [1]], names=['inputs', 'h'])

test_pipe  = pipe.create_pipe(parts_funcs, test_files,
                            batch_size = config['batch_size']*config['nr_gpu'],
                            shape_info=[config['input_shape'], [1]], names=['inputs', 'h'])

## Start the trainer

1. create the trainer with the given data pipes and configurations
2. tracking the training time

In [None]:
go = trainer(train_pipe, test_pipe, config)
utils.log_to(os.path.join(go.log_path, 'training files'), train_files, prefix="#")
utils.log_to(os.path.join(go.log_path, 'config.yaml'), [utils.get_timestamp(), "The training is starting"], prefix="#")
go.train()
utils.log_to(os.path.join(go.log_path, 'config.yaml'), [utils.get_timestamp(), "The training is ending"], prefix="#")
utils.color_print('TRAINING FINISHED')