# HoloML in Stan
### Brian Ward, Bob Carpenter, and David Barmherzig
#### June 13, 2022

This case study is a re-implementation of the algorithm described in *David A. Barmherzig and Ju Sun, "Towards practical holographic coherent diffraction imaging via maximum likelihood estimation," Opt. Express 30, 6886-6906 (2022)* in as a Stan model.

### Motivating Problem 

The experimental setup is as follows:
A biomolecule is placed some distance away from a known reference pattern. A radiation source, usually an X-Ray, is diffracted by both the specimen and the reference, and the resulting photon flux is measured by a far-field detector. This photon flux is approximately the squared magnitude of the Fourier transform of the electric field causing the diffraction. Inverting this to recover an image of the specimen is a problem usually known as *phase retrevial*. 

The number of photons each detector recieves on average is $N_p$. It is desirable that retrevial be performed when this value is small (> 10) due to the damage the radiation causes the biomolecule under observation. Furthermore, to prevent damage to the detectors, the lowest frequencies are removed by a *beamstop*. 

In more general terms, the problem of holographic coherent diffraction imaging is to recover an image from 1) the magnitude of the Fourier transform of that image concatenated with a reference image, and 2) the known reference image. 

## Stan Model Code

The Stan model code is as follows. This is a direct translation of the log probability described in the paper, with the following notes:

- The FFT described in the paper is an "oversampled" FFT. This corresponds to embedding the image in a larger array of zeros. 
- A prior is added to impose an L2 penalty on adjacent pixels. This is not strictly necessary, and for low values of `sigma` induces a Gaussian blur.

In [None]:
# use https://pypi.org/project/cmdstanjupyter/ to display the model inline
%load_ext cmdstanjupyter

In [None]:
%%stan HoloML_model
functions {
  /**
   * Return M1 x M2 matrix of 1 values with blocks in corners set to
   * 0, where the upper left is (r x r), the upper right is (r x r-1),
   * the lower left is (r-1 x r), and the lower right is (r-1 x r-1).
   * This corresponds to zeroing out the lowest-frequency portions of
   * an FFT.
   * @param M1 number of rows
   * @param M2 number of cols
   * @param r block dimension
   * @return matrix of 1 values with 0-padded corners
   */
  matrix pad_corners(int M1, int M2, int r) {
    matrix[M1, M2] B_cal = rep_matrix(1, M1, M2);
    if (r == 0) {
      return B_cal;
    }
    // upper left
    B_cal[1 : r, 1 : r] = rep_matrix(0, r, r);
    // upper right
    B_cal[1 : r, M2 - r + 2 : M2] = rep_matrix(0, r, r - 1);
    // lower left
    B_cal[M1 - r + 2 : M1, 1 : r] = rep_matrix(0, r - 1, r);
    // lower right
    B_cal[M1 - r + 2 : M1, M2 - r + 2 : M2] = rep_matrix(0, r - 1, r - 1);
    return B_cal;
  }

  /**
   * Return result of separating X and R with a matrix of 0s and then
   * 0 padding to right and below.  That is, assuming X and R are the
   * same shape and 0 is a matrix of zeros of the same shape, the
   * result is
   *
   *  [X, 0, R]  | 0
   *  --------------
   *      0      | 0
   *
   * @param X X matrix
   * @param R R matrix
   * @return 0-padded [X, 0, R] matrix
   */
  matrix pad(matrix X, matrix R, int d, int M1, int M2) {
    matrix[M1, M2] y = rep_matrix(0, M1, M2);
    int N = rows(X);
    y[1 : N, 1 : N] = X;
    y[1 : N,  N + d + 1 : 2 * N + d] = R;
    return y;
  }
}

data {
  int<lower=0> N;                     // image dimension
  matrix<lower=0, upper=1>[N, N] R;   // registration image
  int<lower=0, upper=N> d;            // separation between sample and registration image
  int<lower=N> M1;                    // rows of padded matrices
  int<lower=2 * N + d> M2;            // cols of padded matrices
  int<lower=0, upper=M1> r;           // beamstop radius. replaces omega1, omega2 in paper

  real<lower=0> N_p;                  // avg number of photons per pixel
  array[M1, M2] int<lower=0> Y_tilde; // observed number of photons

  real<lower=0> sigma;                // standard deviation of pixel prior
}

transformed data {
  matrix[M1, M2] B_cal = pad_corners(M1, M2, r);
}

parameters {
  matrix<lower=0, upper=1>[N, N] X;
}

model {
  matrix[M1, M2] X0R_pad = pad(X, R, d, M1, M2);
  matrix[M1, M2] Y = abs(fft2(X0R_pad)) .^ 2;
  real Y_bar = mean(Y);

  // prior
  for (i in 1 : rows(X) - 1) {
    X[i] ~ normal(X[i + 1], sigma);
  }
  for (j in 1 : cols(X) - 1) {
    X[ : , j] ~ normal(X[ : , j + 1], sigma);
  }

  // likelihood
  real N_p_over_Y_bar = N_p / Y_bar;
  matrix[M1, M2] lambda = N_p_over_Y_bar * Y;

  for (m1 in 1 : M1) {
    for (m2 in 1 : M2) {
      if (B_cal[m1, m2] == 1) {
        Y_tilde[m1, m2] ~ poisson(lambda[m1, m2]);
      }
    }
  }
}

## Simulating Data

We simulate data from the generative model directly. This corresponds to the work done in Barmherzig and Sun, and is based on MATLAB code provided by Barmherzig.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from math import floor
from scipy import stats

def rgb2gray(rgb):
    """Convert a nxmx3 RGB array to a grayscale nxm array.

    This function uses the same internal coefficients as MATLAB:
    https://www.mathworks.com/help/matlab/ref/rgb2gray.html
    """
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

In [None]:
# Simulation parameters:

N = 256 # image size
M1 = 2 * N 
M2 = 6 * N

N_p = 1
r = 12 

In [None]:
# load in a specimin image
X_True = rgb2gray(mpimg.imread('mimivirus.png'))
N = X_True.shape[0]
plt.imshow(X_True, cmap='gray', vmin=0, vmax=1)

In [None]:
# Load the reference, a "uniformly redundant array"
# Code to generate this omitted; various options such as https://github.com/bpops/cappy
R = np.loadtxt('URA.csv', delimiter=",", dtype=int)
plt.imshow(R, cmap='gray')

In [None]:
# Create specimin-reference hybrid image
X0R = np.concatenate([X_True, np.zeros((N,N)), R], axis=1)
plt.imshow(X0R, cmap='gray')

In [None]:
# Generate data
Y = np.abs(np.fft.fft2(X0R, s=(M1, M2))) ** 2
plt.imshow(np.fft.fftshift(np.log(1 + Y)), cmap="viridis")

In [None]:
# Noise model
rate = N_p/ Y.mean()
Y_tilde = stats.poisson.rvs(rate * Y)
plt.imshow(np.fft.fftshift(np.log(1 + Y_tilde)), cmap="viridis")

In [None]:
# Beamstop occlusion
B_cal = np.ones((M1,M2), dtype=int)
B_cal[M1 // 2 - r: M1 // 2 + r + 1, M2 // 2 - r: M2 // 2 + r + 1] = 0
B_cal = np.fft.ifftshift(B_cal)
plt.imshow(np.fft.fftshift(B_cal), cmap='gray')

In [None]:
Y_tilde *= B_cal
plt.imshow(np.fft.fftshift(np.log(1 + Y_tilde)), cmap="viridis")

This image shows the final input data of the simulated retrievals with noise and a beamstop occluding the lowest frequencies. 

## Optimization

Now that we have our simulated data and our generative model, we can attempt to recover the image. 

Here we use optimization via L-BFGS, as opposed to the conjugate gradient approach in the paper. 
It is also possible to sample the model using the No-U-Turn Sampler (NUTS), but at this image size it can take a few hours to do so.

In [None]:
data = {
    "N": N,
    "R": R,
    "d": N,
    "M1": M1,
    "M2": M2,
    "Y_tilde": Y_tilde,
    "r": r,
    "N_p": N_p,
    "sigma": 1, # prior smoothing
}

In [None]:
%%time
fit = HoloML_model.optimize(data, inits=1)

In [None]:
plt.imshow(fit.stan_variable("X"), cmap='gray', vmin=0, vmax=1)