# Building an L1 Wavelet Reconstruction

In this notebook, we will create an L1 wavelet regularized reconstruction.

## Setup

Let us import relevant packages and load a brain dataset.

In [None]:
%matplotlib notebook
import numpy as np
import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import matplotlib.pyplot as plt

In [None]:
img = np.load('data/brain_img.npy')

fig, ax = plt.subplots()
ax.imshow(np.abs(img), cmap='gray')
ax.set_title('Ground Truth')

In [None]:
ksp = np.load('data/brain_ksp.npy')

fig, ax = plt.subplots()
ax.imshow(np.abs(ksp)**0.1, cmap='gray')
ax.set_title('Log k-space magnitude')

# Native Reconstruction

Simply do an inverse FFT.

In [None]:
img_naive = np.fft.ifftshift(ksp)
img_naive = np.fft.ifftn(img_naive, norm='ortho')
img_naive = np.fft.fftshift(img_naive)

fig, ax = plt.subplots()
ax.imshow(np.abs(img_naive), cmap='gray')
ax.set_title('Naive Reconstruction')

# L1 Wavelet Reconstruction

The L1 wavelet regularization reconstruction solves the following problem:
$$\min_x \frac{1}{2} \| S F W^H x - y \|_2^2 + \lambda \| x \|_1$$
where $S$ is the sampling operator, $F$ is the Fourier transform operator, $W^H$ is the inverse wavelet transform operator, $x$ is the wavelet coefficients and $y$ is the acquired k-space measurements.

The reconstruction image can be generated from $W^H x$.

We will create an L1 wavelet reconstruction by doing:

- Create linear operators $P, F, W$
- Create the soft-threshold function

## Linear operators (Linop)

In the following, we will create the neccessary linear opeartors, $F$, $S$, $P$, and $W$.

### $F$ Linop
To start, we will create an FFT linear opeartor class.  It takes the input array shape and application axes as arguments. We will also apply its adjoint, which is the inverse FFT, to the k-space array `ksp` to verify if it's working.

In [None]:
class FFTLinop(object):
    
    def __call__(self, x):
        y = np.fft.ifftshift(x)
        y = np.fft.fftn(y, norm='ortho')
        y = np.fft.fftshift(y)
        return y

    def H(self, x):
        y = np.fft.ifftshift(x)
        y = np.fft.ifftn(y, norm='ortho')
        y = np.fft.fftshift(y)
        return y

F = FFTLinop()

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(F.H(ksp)), cmap='gray')

### $S$ Linop

Given the sampling mask, this operator simply multiplies the input array with the mask. We will estimate the sampling mask from the non-zero entries of the k-space arrays.

In [None]:
mask = np.abs(ksp) > 0

fig, ax = plt.subplots()
ax.imshow(mask, cmap='gray')

This can also be made into a linear operator

In [None]:
class SamplingLinop(object):

    def __call__(self, x):
        return mask * x

    H = __call__
    
S = SamplingLinop()

### $W$ Linop

We will create a wavelet transform operator using the [Wavelet](https://sigpy.readthedocs.io/en/latest/generated/sigpy.linop.Wavelet.html#sigpy.linop.Wavelet) Linop. It takes the input array shape as input. By defualt, it uses the Daubechies-4 wavelet transform. Let us apply it to an image to see if the result makes sense.

In [None]:
import pywt

coeff = pywt.wavedecn(img, 'db4', mode='periodic', level=3)
coeff, coeff_slices = pywt.coeffs_to_array(coeff)

fig, ax = plt.subplots()
ax.imshow(np.abs(coeff)**0.1, cmap='gray')

In [None]:
class WaveletLinop(object):

    def __call__(self, x):
        coeff = pywt.wavedecn(x, 'db4', mode='periodic', level=3)
        coeff, coeff_slices = pywt.coeffs_to_array(coeff)
        return coeff

    def H(self, x):
        y = pywt.array_to_coeffs(x, coeff_slices)
        y = pywt.waverecn(y, 'db4', mode='periodic')
        return y
    
W = WaveletLinop()

# Soft-threshold

Proximal operators are abstracted in the class [Prox](https://sigpy.readthedocs.io/en/latest/generated/sigpy.prox.Prox.html#sigpy.prox.Prox). Given a function $g(x)$, a proximal operator is called on a scalar ($\alpha$) and an array $x$ to compute:
$$\text{prox}_{\alpha g} (y) = \text{argmin}_x \frac{1}{2} || x - y ||_2^2 + \alpha g(x)$$

Here, our function $g(x) = \lambda \| x \|_1$ is a scaled L1-norm function. We can use the [L1Reg](https://sigpy.readthedocs.io/en/latest/generated/sigpy.prox.L1Reg.html#sigpy.prox.L1Reg) Prox, which performs a soft-thresholding operation. It takes the input array shape and the regularization parameter $\lambda$ as arguments.

We will define the L1 regularization proximal operator and apply it on the wavelet coefficients.

In [None]:
def soft_thresh(y, lamda):
    mag = np.abs(y) - lamda
    mag *= mag > 0
    
    sign = np.divide(y, np.abs(y), where=y != 0)
    return mag * sign

In [None]:
lamda = 0.01
coeff_thresh = soft_thresh(coeff, lamda)

fig, ax = plt.subplots()
ax.imshow(np.abs(coeff_thresh)**0.1, cmap='gray')

In [None]:
img_thresh = W.H(coeff_thresh)

fig, ax = plt.subplots()
ax.imshow(np.abs(img_thresh), cmap='gray')

# Gradient Descent

In [None]:
max_iter = 100
alpha = 1

def gradf(x):
    return W(F.H(S.H(S(F(W.H(x))) - ksp)))

In [None]:
coeff_hat = np.zeros_like(coeff)

for it in range(max_iter):
    coeff_hat -= alpha * gradf(coeff_hat)
    coeff_hat = soft_thresh(coeff_hat, lamda * alpha)
    
img_hat = W.H(coeff_hat)

In [None]:
from tqdm.auto import tqdm

coeff_hat = np.zeros_like(coeff)

for it in tqdm(range(max_iter)):
    coeff_hat -= alpha * gradf(coeff_hat)
    coeff_hat = soft_thresh(coeff_hat, lamda * alpha)
    
img_hat = W.H(coeff_hat)

In [None]:
fig, (ax0, ax1, ax2) = plt.subplots(3, figsize=(5, 15))
ax0.imshow(np.abs(img), cmap='gray')
ax1.imshow(np.abs(img_naive), cmap='gray')
ax2.imshow(np.abs(img_hat), cmap='gray')
