# Fourier Compressed Sensing

In this notebook, we will create an L1 wavelet regularized reconstruction.

## Setup

Let us import relevant packages and load a brain dataset.

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

In [None]:
img = np.load('data/brain_img.npy')

fig, ax = plt.subplots()
ax.imshow(np.abs(img), cmap='gray')
ax.set_title('Ground Truth')

In [None]:
ksp = np.load('data/brain_ksp.npy')

fig, ax = plt.subplots()
ax.imshow(np.abs(ksp)**0.1, cmap='gray')
ax.set_title('Log k-space magnitude')

# Native Reconstruction

Simply do an inverse FFT.

In [None]:
img_naive = np.fft.ifftshift(ksp)
img_naive = np.fft.ifftn(img_naive, norm='ortho')
img_naive = np.fft.fftshift(img_naive)

fig, ax = plt.subplots()
ax.imshow(np.abs(img_naive), cmap='gray')
ax.set_title('Naive Reconstruction')

# L1 Wavelet Reconstruction

The L1 wavelet regularization reconstruction solves the following problem:
$$\min_x \frac{1}{2} \| S F W x - y \|_2^2 + \lambda \| x \|_1$$
where $S$ is the sampling operator, $F$ is the Fourier transform operator, $W$ is the inverse wavelet transform operator, $x$ is the wavelet coefficients and $y$ is the acquired k-space measurements.

The reconstruction image can be generated from $W x$.

We will create an L1 wavelet reconstruction by doing:

- Create linear operators $S, F, W$
- Create the soft-threshold function

## Linear operators (Linop)

In the following, we will create the neccessary linear operators, $F$, $S$, and $W$.

First we will create a generic Linop class such that we can perform forward, adjoint and compose.

    A(x)
    A.H(x)
    C = A * B

In [None]:
class Linop(object):
    def __call__(self, x):
        return self._forward(x)
    
    def H(self, x):
        return self._adjoint(x)
    
    def __mul__(self, B):
        return Compose(self, B)
    
    def _forward(self, x):
        raise NotImplementedError

    def _adjoint(self, x):
        raise NotImplementedError

In [None]:
class Compose(Linop):
    def __init__(self, A, B):
        self.A = A
        self.B = B
        
    def _forward(self, x):
        return self.A(self.B(x))
    
    def _adjoint(self, x):
        return self.B.H(self.A.H(x))

### $F$ Linop
To start, we will create an FFT linear opeartor class.  It takes the input array shape and application axes as arguments. We will also apply its adjoint, which is the inverse FFT, to the k-space array `ksp` to verify if it's working.

In [None]:
class FFTLinop(Linop):
    
    def _forward(self, x):
        y = np.fft.ifftshift(x)
        y = np.fft.fftn(y, norm='ortho')
        y = np.fft.fftshift(y)
        return y

    def _adjoint(self, x):
        y = np.fft.ifftshift(x)
        y = np.fft.ifftn(y, norm='ortho')
        y = np.fft.fftshift(y)
        return y

F = FFTLinop()

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(F.H(ksp)), cmap='gray')
ax.set_title('Naive Reconstruction using Linop')

### $S$ Linop

Given the sampling mask, this operator simply multiplies the input array with the mask. We will estimate the sampling mask from the non-zero entries of the k-space arrays.

In [None]:
mask = np.abs(ksp) > 0

fig, ax = plt.subplots()
ax.imshow(mask, cmap='gray')
ax.set_title('Sampling Mask')

This can also be made into a linear operator

In [None]:
class SamplingLinop(Linop):

    def _forward(self, x):
        return mask * x

    _adjoint = _forward
    
S = SamplingLinop()

### $W$ Linop

We will perform wavelet transforms using the `pywt` library. By default, the wavelet decomposition function `wavdecn` outputs a dictionary with wavelet coefficients from each subband. We will concatenate that into an array using the function `coeffs_to_array`.

In [None]:
import pywt

coeff = pywt.wavedecn(img, 'db4', mode='periodic', level=3)
coeff, coeff_slices = pywt.coeffs_to_array(coeff)

fig, ax = plt.subplots()
ax.imshow(np.abs(coeff)**0.3, cmap='gray')
ax.set_title('Wavelet Coefficients')

In [None]:
class InverseWaveletLinop(Linop):

    def _forward(self, x):
        y = pywt.array_to_coeffs(x, coeff_slices)
        y = pywt.waverecn(y, 'db4', mode='periodic')
        return y

    def _adjoint(self, x):
        coeff = pywt.wavedecn(x, 'db4', mode='periodic', level=3)
        coeff, coeff_slices = pywt.coeffs_to_array(coeff)
        return coeff
    
W = InverseWaveletLinop()

In [None]:
A = S * F * W

# Soft-threshold

Given a function $g(x)$, a proximal operator is called on a scalar ($\alpha$) and an array $x$ to compute:
$$\text{prox}_{\alpha g} (y) = \text{argmin}_x \frac{1}{2} || x - y ||_2^2 + \alpha g(x)$$

Here, our function $g(x) = \lambda \| x \|_1$ is a scaled L1-norm function. The proximal operator becomes the soft-threshold function:

$$\text{prox}_{\alpha g} (y) = (|y| - \alpha \lambda)_+ \frac{y}{|y|}$$

In [None]:
def soft_thresh(y, lamda):
    mag = np.abs(y) - lamda
    mag *= mag > 0
    
    sign = np.divide(y, np.abs(y), where=y != 0)
    return mag * sign

With the soft-thresholding function, we can then soft-thresh the wavelet coefficients.

In [None]:
lamda = 0.05
coeff_thresh = soft_thresh(coeff, lamda)

fig, ax = plt.subplots()
ax.imshow(np.abs(coeff_thresh)**0.3, cmap='gray')
ax.set_title('Soft-thresholded Wavelet Coefficients')

In [None]:
img_thresh = W(coeff_thresh)

fig, ax = plt.subplots()
ax.imshow(np.abs(img_thresh), cmap='gray')
ax.set_title('WavThresh Denoised Image')

# Gradient Descent

We will do gradient descent. Let us first define the parameters and gradient function

In [None]:
max_iter = 100
alpha = 1

def gradf(x):
    return A.H(A(x) - ksp)

Then we do the following iterative reconstruction

In [None]:
coeff_hat = np.zeros_like(coeff)

for it in range(max_iter):
    coeff_hat -= alpha * gradf(coeff_hat)
    coeff_hat = soft_thresh(coeff_hat, lamda * alpha)
    
img_hat = W(coeff_hat)

I'd like to highlight a very convenient package `tqdm`. `tqdm` adds a progress bar very easily. You can add it in the for loop `tqdm(range(max_iter))`.

In [None]:
from tqdm.auto import tqdm

coeff_hat = np.zeros_like(coeff)

for it in tqdm(range(max_iter)):
    coeff_hat -= alpha * gradf(coeff_hat)
    coeff_hat = soft_thresh(coeff_hat, lamda * alpha)
    
img_hat = W(coeff_hat)

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(img), cmap='gray', vmax=1)
ax.set_title('Ground Truth')

fig, ax = plt.subplots()
ax.imshow(np.abs(img_naive), cmap='gray', vmax=1)
ax.set_title('Naive Reconstruction')

fig, ax = plt.subplots()
ax.imshow(np.abs(img_hat), cmap='gray', vmax=1)
ax.set_title('L1 Wavelet Regularized Reconstruction')

# Wrapup
I'd like to wrap up with a package that I've been developing, SigPy. SigPy is a package for signal processing with iterative methods. It is built to operate directly on NumPy arrays on CPU and CuPy arrays on GPU. 

SigPy provides three supporting classes for building iterative reconstruction [App](https://sigpy.readthedocs.io/en/latest/generated/sigpy.app.App.html#sigpy.app.App)'s: 

- [Linop](https://sigpy.readthedocs.io/en/latest/generated/sigpy.linop.Linop.html#sigpy.linop.Linop) for linear operator
- [Prox](https://sigpy.readthedocs.io/en/latest/generated/sigpy.prox.Prox.html#sigpy.prox.Prox) for proximal operator
- [Alg](https://sigpy.readthedocs.io/en/latest/generated/sigpy.alg.Alg.html#sigpy.alg.Alg) for iterative algorithms

![architecture](https://sigpy.readthedocs.io/en/latest/_images/architecture.pdf)

In [None]:
import sigpy as sp

class L1WaveletRecon(sp.app.App):
    def __init__(self, ksp, lamda, max_iter):
        mask = ksp != 0
        
        F = sp.linop.FFT(ksp.shape)
        S = sp.linop.Multiply(ksp.shape, mask)
        self.W = sp.linop.InverseWavelet(ksp.shape)
        A = S * F * self.W
        
        proxg = sp.prox.L1Reg(A.ishape, lamda)
        
        self.coeff = np.zeros(A.ishape, np.complex)
        alpha = 1
        def gradf(x):
            return A.H * (A * x - ksp)

        alg = sp.alg.GradientMethod(gradf, self.coeff, alpha, proxg=proxg, 
                                    max_iter=max_iter)
        super().__init__(alg)
        
    def _output(self):
        return self.W(self.coeff)

In [None]:
img_hat = L1WaveletRecon(ksp, lamda, max_iter).run()

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(img_hat), cmap='gray', vmax=1)
ax.set_title('L1 Wavelet Regularized Reconstruction using SigPy')