# MLEM Experiments

## Prerequisites

To run this notebook, you need the [Operator Discretization Library (ODL)](https://github.com/odlgroup/odl), which itself depends on the [Astra toolbox](https://github.com/astra-toolbox/astra-toolbox). You also need the plotting library [Holoviews](http://holoviews.org/).

In [None]:
import odl
import scipy

In [None]:
import numpy as np

In [None]:
import holoviews as hv
hv.extension('bokeh')

## Plotting Functions

Some general plotting functions to compute `odl` images with `holoviews`.

In [None]:
import holoviews as hv
import xarray as xr

def coords_from_space(space):
    coords = [axis.points().squeeze(axis=-1) for axis in space.partition.byaxis]
    return coords

def sanitize_label(axis_label):
    """
    This function only exists to remove dollar signs.
    """
    sanitized = axis_label.replace('$', '').replace(r"\varphi", 'φ')
    return sanitized

def sanitize_labels(axis_labels):
    return [sanitize_label(label) for label in axis_labels]

def xarray_from_element(element):
    coords = coords_from_space(element.space)
    xarr = xr.DataArray(element, coords=coords, dims=sanitize_labels(element.space.axis_labels), name='Intensity')
    return xarr

def show_image(element):
    xarr = xarray_from_element(element)
    im = hv.Image(xarr)
    im.opts(cmap='bone')
    return im

## Operator

The forward operator that we will use for the experiments.

In [None]:
# Function for getting a CT operator which given image resolution
def get_ray_trafo(resolution=256):
    """
    Gives full size ray transform.
    The data space seems too big in practice.
    """
    reco_space = odl.uniform_discr(
        min_pt=[-20, -20], max_pt=[20, 20], shape=[resolution, resolution], dtype='float32')

    # Make a parallel beam geometry with flat detector
    angle_partition = odl.uniform_partition(0, np.pi, 90)

    # Detector: uniformly sampled, n = 512, min = -30, max = 30
    detector_partition = odl.uniform_partition(-20, 20, 128)

    geometry = odl.tomo.Parallel2dGeometry(angle_partition, detector_partition)

    # Ray transform (= forward projection).
    ray_trafo = odl.tomo.RayTransform(reco_space, geometry)
    return ray_trafo

## Primal function (divergence)

Non normalised divergence between two nonnegative vectors $u$ and $v$ is
\\[
δ(u||v) = \sum_i v_i - u_i + u_i \log(u_i/v_i)
\\]
It is related to the standard divergence by
\\[
δ(u||v) = M(u) \Bigl[ f - 1 - \log (f) + D(\bar{u}||\bar{v}) \Bigr] \qquad f := \frac{M(v)}{M(u)}
\\]

In [None]:
def divergence(x,y):
    xlogy = scipy.special.xlogy(x, x) - scipy.special.xlogy(x, y)
    Mx = np.sum(x)
    res = np.sum(y - x + xlogy)
    if not np.all(np.isfinite(res)):
        print(scipy.special.xlogy(x, y))
    return res/Mx

## MLEM


A careful implementation of the MLEM algorithm.

In [None]:
def xovery(x, y, eps=1e-50):
    mask = np.abs(x) > eps
    res = np.zeros_like(x)
    res[mask] = x[mask] / y[mask]
    return res

In [None]:
from tqdm import tqdm

In [None]:
def generate_mlem(op, x, data, niter=1000, eps=1e-20):
    sensitivity = op.adjoint(op.range.one())
    for i in tqdm(range(niter)):
        #x_ = (x/sensitivity)*op.adjoint((data/np.maximum(op(x),eps)))
        y = op(x)
        yield (x,y)
        x_ = (x/sensitivity)*op.adjoint(xovery(data.asarray(),y.asarray(), eps))
        x = x_

## Phantom

Some phantoms for the experiments.

### Torus

In [None]:
op = get_ray_trafo()

In [None]:
def torus(x):
    r2 = x[0]**2 + x[1]**2
    r = np.sqrt(r2)
    return np.exp(-(r-15)**2) + 0

In [None]:
phantom = op.domain.element(torus)

In [None]:
hv.Image(xarray_from_element(phantom)).opts(hv.opts.Image(colorbar=True, width=400))

### DeRenzo

In [None]:
phantom = odl.phantom.emission.derenzo_sources(op.domain)

In [None]:
hv.Image(xarray_from_element(phantom)).opts(hv.opts.Image(colorbar=True, width=400))

## Data

In [None]:
# Data (sinogram)
data = op(phantom)

## Numerical tests

### Noisy data

In [None]:
def get_noisy_data(data, level):
    alpha = level*odl.phantom.noise.poisson_noise(data/level)
    return alpha

### File saving utilities

Some auxilliary functions to save the resulting figures in relevant subfolders.

In [None]:
from pathlib import Path

In [None]:
def get_destination_template(level, data_dir_name='mlem_data'):
    formatted_level = '{:.1e}'.format(level)
    data_dir = Path() / data_dir_name
    data_dir.mkdir(exist_ok=True)
    dir_destination = data_dir / 'level_{}'.format(formatted_level)
    print(dir_destination)
    dir_destination.mkdir(exist_ok=True)
    data_destination = dir_destination / '{}.{}'
    return data_destination

In [None]:
def save_curve(curve, fname):
    hv.save(curve.opts(toolbar=None), backend='bokeh', filename=fname, fmt='png')

In [None]:
def save_fig(func, reco, fname):
    fig, ax = plt.subplots()
    func(reco, ax)
    ax.set_aspect(1)
    fig.tight_layout()
    fig.savefig(fname, dpi=300, bbox_inches='tight')

### MLEM

In [None]:
def run(alpha, niter=2**8):
    rs = list(generate_mlem(op, op.domain.one(), alpha, niter=niter))
    recos, ys = list(zip(*rs))
    return recos, ys

### Divergences

The divergence to data figure.

In [None]:
def get_div_curve(alpha, ys):
    divs = np.array([divergence(alpha, y) for y in ys[1:]])
    print ('{:.2e}'.format(divs[-1]))
    div_curve = hv.Curve(divs).redim.label(x='k', y='divergence to data').opts(hv.opts.Curve(width=250, height=200)).opts(hv.opts.Curve(ylim=(0,None), line_width=4, show_grid=True)) * hv.HLine(divs[-1]).opts(hv.opts.HLine(alpha=.3, color='Black'))
    return div_curve

### Quantiles

The 95% percentile figure.

In [None]:
def get_quantile_curve(recos):
    quantiles = np.array([np.quantile(reco, .95) for reco in recos])
    quant_curve = hv.Curve(quantiles[1:]).opts(hv.opts.Curve(logy=True, line_width=4, show_grid=True)).redim.label(x='k', y='95% percentile')
    quant_curve.opts(width=250, height=200)
    return quant_curve

## Reconstructions

### Smooth

A smoothed out version of the reconstruction.

In [None]:
import scipy as sp

In [None]:
import matplotlib.pyplot as plt

In [None]:
import colorcet as cc

In [None]:
def plot_filter_img(reco, ax, sigma=3):
    filtered = op.domain.element(sp.ndimage.filters.gaussian_filter(reco, sigma=sigma))
    xarray_from_element(filtered).T.plot(cmap=cc.cm.gray, ax=ax)


### With clim

A limited dynamic range version of the reconstruction.

In [None]:
def plot_clim_img(reco, ax):
    xarray_from_element(reco).T.plot(cmap=cc.cm.gray, vmax=1, ax=ax)

## Dual certificates

We compute dual certificates which make sure that the minimum is strictly positive. This means that the noisy data is outside the image of the oeprator, and therefore entails sparsity.

Since we need dual feasibility for $\lambda$, i.e., $A^T\lambda \geq 0$, add small constant to $\lambda$ to get a dual feasible variable. 

The choice scaling = 1 below ensures $A^T\lambda \geq 0$, but start from scaling = 0.5 which can be enough, if not, increase it.

In [None]:
from numpy import inf

In [None]:
def dual(alpha, dual):
    if np.min(op.adjoint(dual)) < 0:
        raise ValueError("Unfeasible variable, increase scaling")
    return np.sum(scipy.special.xlogy(alpha, 1-dual))

In [None]:
def certify(reco, alpha, scaling=.5):
    # Compute candidate for dual certificate
    lamda = (op.range.one()-xovery(alpha.asarray(),op(recos[-1]).asarray())).asarray()

    np.min(op.adjoint(lamda)/np.sqrt(np.sum(lamda**2)));
    sparsity = op.adjoint(lamda);
    mini = np.min(sparsity)
    mini2 = np.min(op.adjoint(op.range.one()))
    lamda_modified = lamda - scaling*(mini/mini2)

    # Certification: if not dual feasible, increase scaling above. 
    # When dual feasibility is achieved, without certification, try and increase number of ML-EM iterations
    d = dual(alpha, lamda_modified)
        
    if d > 0: 
        return True
    else:
        raise ValueError("bad luck, this is not a dual certificate")

## Batch runs

This allows to generate all the relevant figures in one function.

In [None]:
def save_all(level):
    destination_template = get_destination_template(level)

    np.random.seed(10)

    alpha = get_noisy_data(data, level)

    recos, ys = run(alpha)

    div_curve = get_div_curve(alpha, ys)

    save_curve(div_curve, fname=destination_template.as_posix().format('div', 'png'))

    quant_curve = get_quantile_curve(recos)

    save_curve(quant_curve, fname=destination_template.as_posix().format('quant', 'png'))

    save_fig(plot_filter_img, recos[-1], fname=destination_template.as_posix().format('smooth', 'png'))

    save_fig(plot_clim_img, recos[-1], fname=destination_template.as_posix().format('clim', 'png'))
    
    return alpha, recos, ys

Here we generate all the figures for various noise levels.

In [None]:
for level in [10**int(k) for k in np.arange(-2,3)]:
    alpha, recos, ys = save_all(level)
    try:
        c = certify(recos[-1], alpha=alpha, scaling=1)
        if c:
            print("Sparsity is certified!")
    except ValueError as ve:
        print(ve)

# Debugging

In [None]:
alpha, recos, ys = save_all(level=1e2)

In [None]:
get_div_curve(alpha, ys)