# Total Variation Regularized Reconstruction

This notebook reproduce experiments with total variation regularized reconstruction described in [Accelerating Non-Cartesian MRI Reconstruction Convergence using k-space Preconditioning](https://arxiv.org/abs/1902.09657).

In [1]:
%matplotlib notebook

import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import matplotlib.pyplot as plt
import numpy as np

try:
    import mkl
    mkl.set_num_threads(1)
except:
    pass

## Set parameters and load dataset

In [2]:
max_iter = 30
max_cg_iter = 5
lamda = 0.001

ksp_file = 'data/liver/ksp.npy'
coord_file = 'data/liver/coord.npy'

# Choose computing device.
# Device(-1) specifies CPU, while others specify GPUs.
# GPU requires installing cupy.
try:
    device = sp.Device(0)
except:
    device = sp.Device(-1)

xp = device.xp
device.use()

# Load datasets.
ksp = xp.load(ksp_file)
coord = xp.load(coord_file)

## Estimate sensitivity maps using JSENSE

Here we use [JSENSE](https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21245) to estimate sensitivity maps.

In [3]:
mps = mr.app.JsenseRecon(ksp, coord=coord, device=device).run()

HBox(children=(IntProgress(value=0, description='JsenseRecon', max=10, style=ProgressStyle(description_width='…




## ADMM

In [4]:
admm_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, max_iter=max_iter // max_cg_iter,
        solver='ADMM', max_cg_iter=max_cg_iter, device=device, save_objective_values=True)
admm_img = admm_app.run()

pl.ImagePlot(admm_img)

HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=6, style=ProgressStyle(description_…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf536c64a8>

## ADMM with circulant preconditioner

In [5]:
rho = 1
circ_precond = mr.circulant_precond(mps, coord=coord, device=device, lamda=rho)

img_shape = mps.shape[1:]
G = sp.linop.FiniteDifference(img_shape)
g = G.H * G * sp.dirac(img_shape)
g = sp.fft(g)
g = sp.to_device(g, device=device)
circ_precond = 1 / (1 / circ_precond + lamda * g)

img_shape = mps.shape[1:]
D = sp.linop.Multiply(img_shape, circ_precond)
P = sp.linop.IFFT(img_shape) * D * sp.linop.FFT(img_shape)

admm_cp_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, max_iter=max_iter // max_cg_iter,
        P=P, rho=rho,
        solver='ADMM', max_cg_iter=max_cg_iter, device=device, save_objective_values=True)
admm_cp_img = admm_cp_app.run()

pl.ImagePlot(admm_cp_img)

HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=6, style=ProgressStyle(description_…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf4b3a4240>

## Primal dual hybrid gradient reconstruction

In [6]:
pdhg_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, max_iter=max_iter,
        solver='PrimalDualHybridGradient', device=device, save_objective_values=True)
pdhg_img = pdhg_app.run()

pl.ImagePlot(pdhg_img)

HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=30, style=ProgressStyle(description…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf536a2630>

## PDHG with dcf

In [7]:
# Compute preconditioner
precond_dcf = mr.pipe_menon_dcf(coord, device=device)
precond_dcf = xp.tile(precond_dcf, [len(mps)] + [1] * (mps.ndim - 1))
img_shape = mps.shape[1:]
G = sp.linop.FiniteDifference(img_shape)
max_eig_G = sp.app.MaxEig(G.H * G).run()
sigma2 = xp.ones([sp.prod(img_shape) * len(img_shape)],
                 dtype=ksp.dtype) / max_eig_G
sigma = xp.concatenate([precond_dcf.ravel(), sigma2.ravel()])

pdhg_dcf_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, sigma=sigma, max_iter=max_iter,
        solver='PrimalDualHybridGradient', device=device, save_objective_values=True)
pdhg_dcf_img = pdhg_dcf_app.run()

pl.ImagePlot(pdhg_dcf_img)

HBox(children=(IntProgress(value=0, description='PipeMenonDCF', max=30, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=30, style=ProgressStyle(description…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf4b29ba20>

## PDHG with single-channel precond.

In [8]:
# Compute preconditioner
ones = np.ones_like(mps)
ones /= len(mps)**0.5
precond_sc = mr.kspace_precond(ones, coord=coord, device=device)
img_shape = mps.shape[1:]
max_eig_G = sp.app.MaxEig(G.H * G).run()
sigma2 = xp.ones([sp.prod(img_shape) * len(img_shape)],
                 dtype=ksp.dtype) / max_eig_G
sigma = xp.concatenate([precond_sc.ravel(), sigma2.ravel()]) / 2

pdhg_sc_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, sigma=sigma, max_iter=max_iter,
        solver='PrimalDualHybridGradient', device=device, save_objective_values=True)
pdhg_sc_img = pdhg_sc_app.run()

pl.ImagePlot(pdhg_sc_img)

HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=30, style=ProgressStyle(description…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf4b213518>

## PDHG with multi-channel precond.

In [9]:
# Compute preconditioner
precond_mc = mr.kspace_precond(mps, coord=coord, device=device)
img_shape = mps.shape[1:]
max_eig_G = sp.app.MaxEig(G.H * G).run()
sigma2 = xp.ones([sp.prod(img_shape) * len(img_shape)],
                 dtype=ksp.dtype) / max_eig_G
sigma = xp.concatenate([precond_mc.ravel(), sigma2.ravel()])

pdhg_mc_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, sigma=sigma, max_iter=max_iter,
        solver='PrimalDualHybridGradient', device=device, save_objective_values=True)
pdhg_mc_img = pdhg_mc_app.run()

pl.ImagePlot(pdhg_mc_img)

HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='MaxEig', max=30, style=ProgressStyle(description_width='initi…




HBox(children=(IntProgress(value=0, description='TotalVariationRecon', max=30, style=ProgressStyle(description…




<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7fbf4b1a1898>

## Convergence curves

In [10]:
plt.figure(figsize=(8, 3))
plt.semilogy(admm_app.time, admm_app.objective_values,
               marker='v', color='C1')
plt.semilogy(admm_cp_app.time, admm_cp_app.objective_values,
               marker='^', color='C2')
plt.semilogy(pdhg_app.time, pdhg_app.objective_values,
               marker='+', color='C3')
plt.semilogy(pdhg_dcf_app.time, pdhg_dcf_app.objective_values,
               marker='s', color='C4')
plt.semilogy(pdhg_sc_app.time, pdhg_sc_app.objective_values,
               marker='*', color='C5')
plt.semilogy(pdhg_mc_app.time, pdhg_mc_app.objective_values,
               marker='x', color='C6')
plt.legend(['ADMM',
            'ADMM w/ circulant precond.',
            'PDHG',
            'PDHG w/ density comp.',
            'PDHG w/ SC k-space precond.',
            'PDHG w/ MC k-space precond.'])
plt.ylabel('Objective Value [a.u.]')
plt.xlabel('Time [s]')
plt.title(r"Total Variation Regularized Reconstruction")
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>