# Apply pre-trained priors for MRI image reconstruction
**Authors**: [Guanxiong Luo](mailto:guanxiong.luo@med.uni-goettingen.de), [Nick Scholand](mailto:nick.scholand@med.uni-goettingen.de), [Christian Holme](mailto:christian.holme@med.uni-goettingen.de)

**Will take around 30 mins to go through this tutorial.**

**Have fun with it! If you have any questions, don't hesitate to drop us a line.**

## 1. Set up the environment
### a. Download and Compile `bart`
If you are running this notebook in the environment that has bart already installed, there is no need run this part.

In [None]:
%%bash

# Install BARTs dependencies
apt-get install -y make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev &> /dev/null

# Download BART version
[ -d /content/bart ] && rm -r /content/bart
git clone https://github.com/mrirecon/bart/ bart
[ -d "bart" ] && echo "BART was downloaded successfully."

cd bart
make &> /dev/null

After compilation of BART we need to set the required environmental variable: `TOOLBOX_PATH`

In [None]:
%env TOOLBOX_PATH=/content/bart

Additionally, we add the compiled `bart` executable to our `PATH` variable

In [None]:
import os
os.environ['PATH'] = os.environ['TOOLBOX_PATH'] + ":" + os.environ['PATH']

Check if bart toolbox is working properly

In [None]:
%%bash
bart version

### b. Install `spreco`
Download the package spreco and install it with the `pip` command.

In [None]:
%%bash
pip uninstall tensorflow-gpu
pip install tensorflow-gpu==2.4.1
git clone https://github.com/mrirecon/spreco.git
cd spreco
pip install .

In [None]:
%%bash
curl https://zenodo.org/record/6521188/files/pre-trained.tar?download=1 --output pre-trained.tar
curl https://zenodo.org/record/6521188/files/full_kspace.npz?download=1 --output full_kspace.npz
mkdir spreco/data
tar xf pre-trained.tar -C spreco/data
mv full_kspace.npz spreco/data

## 2. Execute reconstruction
### a. Import modules 

`ops` contains simple functionalities for building the forward operator for the k-space measurement $\mathbf{y}=\mathcal{A}\mathbf{x}+\epsilon$.

`sde` contains the class for training reverse transitions $p_\theta(\mathbf{x}_i|\mathbf{x}_{i+1})$.

`posterior_sampler` contains the class for simulating samples from $p({\mathbf{x}|\mathbf{y}})$ given the learned transitions and the measured k-space.

`utils` contains utilities for the calling to bart, loading configuration, converting complex arrays to float arrays and so on.

In [None]:
from spreco.common import ops, sampling_pattern, utils
from spreco.model.sde import sde, posterior_sampler

import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
from functools import partial

### b. Load the configuration for reconstruction

The configuration file for reconstruction specifies the following options.

1. which model is used to construct transition kernel?
2. which sampling pattern is used?
3. how many samples will be drawn from the posterior?
4. where to store the results?
5. the values of K, N, $\lambda$ in the proposed algorithm?
6. whether to use burn-in phase, at which time point to be burned?
7. where is the k-space data?

Example for the recon config file.

```yaml
cal: 20         # calibration region
fx: 1.5         # if possion sampling is used, acceleration along x direction
fy: 1.5         # if possion sampling is used, acceleration along y direction
poisson: true   # possion sampling pattern
sampling_rate: 0.2  # if possion samples is not used, use Gaussian sampling pattern instead
s_stepsize: 25  # $\lambda$ in Algorithm 1
st: 30          # N=100-st in Algorithm 1
c_steps: 5      # K in Algorithm 1
nr_samples: 10  # how many samples will be drawn
burn_in: false
burn_t: 0.5     # at which time point to be burn
disable_z: True
target_snr: 1

model_folder: xxxx
model_name: xxx
ksp_path: xxx    # kspace location
gpu_id: '3'
```

check your `recon_config` file, especially the params for the location of kspace and model, and load the configuration with `utils`, 

In [None]:
config_path  ='/content/spreco/scripts/recon_config.yaml'
config       = utils.load_config(config_path)
model_config = utils.load_config(config['model_folder']+'/config.yaml')

model_path   = os.path.join(config['model_folder'], config['model_name'])
np.random.seed(model_config['seed'])
np.random.seed(model_config['seed'])

### c. Prepare the undersampled k-space data

1. load the fully sampled k-space
2. generate the undersampling mask with `bart` or `sampling_pattern`
3. compute coil sensitivities with `ecalib`
4. build the operator for the k-space measurement.
5. compute the ground truth from fully sampled k-space and zero-filled from undersampled k-space

In [None]:
def prepare_simu(config, mask=None):
        
    kspace = np.squeeze(np.load(config['ksp_path'])['kspace'])

    nx, ny, _ = kspace.shape
    coilsen = np.squeeze(utils.bart(1, 'ecalib -m1 -r20 -c0.001', kspace[np.newaxis, ...]))
    img_shape = [nx, ny]
    std_coils = ops.mifft2(kspace, img_shape)

    rss = np.sum(np.multiply(std_coils, np.squeeze(np.conj(coilsen))), axis=2)

    if mask is None:
        if not config['poisson']:
            mask = sampling_pattern.gen_mask_2D(nx, ny, center_r = config['cal'], undersampling = config['sampling_rate'])
        else:
            mask = utils.bart(1, 'poisson -Y %d -Z %d -y %f -z %f -s 1234 -v -C %d'%(nx, ny, config['fx'], config['fy'], config['cal']))
            mask = np.squeeze(mask)

    und_ksp = kspace*abs(mask[..., np.newaxis])

    coilsen = np.squeeze(utils.bart(1, 'ecalib -m1 -r20 -c0.001', kspace[np.newaxis, ...]))
    coilsen = np.squeeze(coilsen)
    x_ = ops.AT_cart(und_ksp, coilsen, mask, img_shape)

    return x_, mask, coilsen, (nx, ny), rss, und_ksp

zero_filled, mask, coilsen, shape, rss, und_ksp = prepare_simu(config)
zero_filled = utils.float2cplx(utils.normalize_with_max(zero_filled)) # [-1, 1]
l1_recon    = utils.bart(1, 'pics -l1 -r 0.01', und_ksp[:,:,np.newaxis,:], coilsen[:,:,np.newaxis,:])

grad_params = {'coilsen': coilsen[np.newaxis, ...], 'mask': mask[np.newaxis, ...], 'shape': shape, 'center': False}
AHA         = partial(ops.AHA, **grad_params)

### d. Run the sampler for $p(\mathbf{x}|\mathbf{y})$

1. create two placeholders for the image $\mathbf{x}$ and noise indices $i$
2. instantiate the neural network
3. restore the pre-trained model
4. run the sampler with the learned transitions and the measured k-space.
5. save results

In [None]:
## network
x          = tf.placeholder(tf.float32, shape=[None]+model_config['input_shape']) 
t          = tf.placeholder(tf.float32, shape=[None]) 
ins_sde    = sde(model_config)
_          = ins_sde.net.forward(x, t)
all_params = tf.trainable_variables()
saver      = tf.train.Saver()
sess       = tf.Session()

sess.run(tf.global_variables_initializer())
saver.restore(sess, os.path.join(config['model_folder'], config['model_name']))

ins_sampler = posterior_sampler(ins_sde, 
                            steps      = config['c_steps'],
                            target_snr = config['target_snr'],
                            nr_samples = config['nr_samples'],
                            burn_in    = config['burn_in'],
                            burn_t     = config['burn_t'],
                            ode        = False if 'ode' not in config.keys() else config['ode'],
                            ext_iter   = 0 if 'ext_iter' not in config.keys() else config['ext_iter'], 
                            disable_z        = False if 'disable_z' not in config.keys() else config['disable_z'],
                            use_pixelcnn     = False if 'use_pixelcnn' not in config.keys() else config['use_pixelcnn'])

images = ins_sampler.conditional_ancestral_sampler(x, t, sess, AHA, zero_filled[np.newaxis, ...], config['s_stepsize'], st=config['st'])

if config['burn_in']:
    idx = int(ins_sampler.sde.N*config['burn_t']*config['c_steps']) - config['c_steps']
    images = utils.float2cplx(np.array(images[-idx:]))
else:
    images = np.array(images)
    images = images[...,0]+1.0j*images[...,1]

## 3. Results

1. normalize reconstruction 
2. plot the curves to track PSNR and SSIM over iterations
3. compare reconstruction
4. plot image over intermediate distributions



### a. Normalize reconstruction

In [None]:
# normalize the results
mag_rets = np.abs(images)/np.max(abs(images), axis=(2,3))[..., np.newaxis, np.newaxis]
mag_rss         = abs(rss)
normalized_zero_filled=abs(zero_filled/np.linalg.norm(zero_filled))
normalized_l1_recon=abs(l1_recon/np.linalg.norm(l1_recon))
normalized_rss  = mag_rss/np.linalg.norm(mag_rss)
normalized_rets = mag_rets/np.linalg.norm(mag_rets, axis=(2,3), keepdims=True)

total_steps = mag_rets.shape[0]
step_size = 1

normalized_expectation_1  = normalized_rets[:total_steps:step_size,0,...]
normalized_expectation_2  = np.mean(normalized_rets[:total_steps:step_size,0:2,...], axis=1)
normalized_expectation_4  = np.mean(normalized_rets[:total_steps:step_size,1:4,...], axis=1)
normalized_expectation_8  = np.mean(normalized_rets[:total_steps:step_size,1:8,...], axis=1)
normalized_expectation_10 = np.mean(normalized_rets[:total_steps:step_size], axis=1)

# calculate psnrs and ssims

psnrs=[]
ssims=[]

for i in range(int(total_steps/step_size)):
    tmp_psnr=[]
    tmp_ssim=[]
    tmp_psnr.append(utils.psnr(normalized_expectation_1[i], normalized_rss))
    tmp_ssim.append(utils.ssim(normalized_expectation_1[i], normalized_rss))

    tmp_psnr.append(utils.psnr(normalized_expectation_2[i], normalized_rss))
    tmp_ssim.append(utils.ssim(normalized_expectation_2[i], normalized_rss))

    tmp_psnr.append(utils.psnr(normalized_expectation_4[i], normalized_rss))
    tmp_ssim.append(utils.ssim(normalized_expectation_4[i], normalized_rss))

    tmp_psnr.append(utils.psnr(normalized_expectation_8[i], normalized_rss))
    tmp_ssim.append(utils.ssim(normalized_expectation_8[i], normalized_rss))

    tmp_psnr.append(utils.psnr(normalized_expectation_10[i], normalized_rss))
    tmp_ssim.append(utils.ssim(normalized_expectation_10[i], normalized_rss))
    psnrs.append(tmp_psnr)
    ssims.append(tmp_ssim)

psnrs = np.array(psnrs)
ssims = np.array(ssims)

### b. Plot the curves to track PSNR and SSIM over iterations

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = False
import math

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5.7), gridspec_kw={'width_ratios': [1, 1]})
fontsize = 17
ticksize = 12
for i,c in enumerate([1,2,4,8,10]):
    ax1.plot(psnrs[:,i], label= "{} sample".format(c) if c==1 else "{} samples".format(c))
    ax2.plot(ssims[:,i], label= "{} sample".format(c) if c==1 else "{} samples".format(c))

ax1.set_xlabel('iteration', fontsize=fontsize)
ax1.set_ylabel('PSNR', fontsize=fontsize)
ax1.tick_params(labelsize=ticksize) 
ax1.legend(loc='lower right')

ax2.set_xlabel('iteration', fontsize=fontsize)
ax2.set_ylabel('SSIM', fontsize=fontsize)
ax2.tick_params(labelsize=ticksize) 
ax2.legend(loc='lower right')
plt.tight_layout()

### c. Compare reconstruction methods

In [None]:

def subplot(ax, img, title, cmap, interpolation, vmin, vmax):
    ax.imshow(img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax)
    ax.set_title(title)
    ax.axis('off')

rss_max = np.max(normalized_rss)
plot_params = {'cmap': 'gray', 'interpolation': 'none', 'vmin': 0, 'vmax': rss_max}
axplot      = partial(subplot, **plot_params)

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 7.6), gridspec_kw={'width_ratios': [1, 1, 1, 1]})
axplot(ax1, normalized_zero_filled, title='zero filled')
axplot(ax2, normalized_l1_recon, title='l1-ESPIRiT in wavelet domain')
axplot(ax3, normalized_expectation_10[-1], '{x}_{MMSE}')
axplot(ax4, normalized_rss, 'truth')
plt.tight_layout()

### d. Create grid of samples and generative the gif for iterations.

In [None]:
samples = normalized_rets[::35,...]

fig, axss = plt.subplots(10, 10, figsize=(15, 15), gridspec_kw={'width_ratios': [1 for _ in range(10)]})
for i in range(10):
    for j in range(10):
        if i==0:
            strs='x_%d'%j
        else:
            strs=''
        axplot(axss[i,j], samples[i,j], title=strs)
plt.tight_layout(pad=.1)