# HoloML in Stan
### Brian Ward, Bob Carpenter, and David Barmherzig
#### June 13, 2022

This case study is a re-implementation of the algorithm described in *David A. Barmherzig and Ju Sun, "Towards practical holographic coherent diffraction imaging via maximum likelihood estimation," Opt. Express 30, 6886-6906 (2022)* in as a Stan model.

### Motivating Problem 

The experimental setup is as follows:
A biomolecule is placed some distance away from a known reference pattern. A radiation source, usually an X-Ray, is diffracted by both the specimen and the reference, and the resulting photon flux is measured by a far-field detector. This photon flux is approximately the squared magnitude of the Fourier transform of the electric field causing the diffraction. Inverting this to recover an image of the specimen is a problem usually known as *phase retrevial*. 

The number of photons each detector recieves on average is $N_p$. It is desirable that retrevial be performed when this value is small (> 10) due to the damage the radiation causes the biomolecule under observation. Furthermore, to prevent damage to the detectors, the lowest frequencies are removed by a *beamstop* which obscures further data. 

In more general terms, the problem of holographic coherent diffraction imaging is to recover an image from 
1) the magnitude of the Fourier transform of that image concatenated with a reference image 
2) that reference image. 

## Simulating Data

We simulate data from the generative model directly. This corresponds to the work done in Barmherzig and Sun, and is based on MATLAB code provided by Barmherzig.

#### Imports and helper code

In [None]:
import numpy as np
from scipy import stats

import matplotlib.image as mpimg
import matplotlib.pyplot as plt

def rgb2gray(rgb):
    """Convert a nxmx3 RGB array to a grayscale nxm array.

    This function uses the same internal coefficients as MATLAB:
    https://www.mathworks.com/help/matlab/ref/rgb2gray.html
    """
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

#### Simulation parameters. 

To match the figures in the paper (in particular, Figure 9), we use an image of size 256x256, $N_p = 1$ (meaning each detector only recieves, on average, one photon), and a beamstop of size 25x25 (corresponding to a radius of 13), and a seperation `d` equal to the size of the image.

In [None]:
N = 256
d = N
N_p = 1
r = 13

M1 = 2 * N 
M2 = 2 * (2 * N + d)

We can then load the "Truth" image used for these simulations. This is a picture of a giant virus known as a mimivirus.

In [None]:
X_True = rgb2gray(mpimg.imread('mimivirus.png'))
N = X_True.shape[0]
plt.imshow(X_True, cmap='gray', vmin=0, vmax=1)

Additionally, we load in the reference array. 

This pattern is known as a "uniformly redundant array". It has been shown to be an optimal reference image for this kind of work, but other references (including none at all) could be used.
The code to generate it is omitted, various options such as [cappy](https://github.com/bpops/cappy)

In [None]:
R = np.loadtxt('URA.csv', delimiter=",", dtype=int)
plt.imshow(R, cmap='gray')

We can then create the specimen-reference hybrid image. In the true experiment, this is done by placing the specimen some distance `d` away from the reference, with opaque material between.

In [None]:
X0R = np.concatenate([X_True, np.zeros((N,d)), R], axis=1)
plt.imshow(X0R, cmap='gray')

With this hybrid object, we can simulate the instrument operation by taking the absolute value squared of the 2-dimensional oversampled FFT. 

The oversampled FFT (denoted $\mathcal{F}$ in the paper) corresponds to padding the image in both dimensions with zeros until it is a desired size. For our case, we define `M1` and `M2` to be two times the size of our hybrid image, so the resulting FFT is twice oversampled.

In [None]:
Y = np.abs(np.fft.fft2(X0R, s=(M1, M2))) ** 2
plt.imshow(np.fft.fftshift(np.log(1 + Y)), cmap="viridis")

We can then simulate the poisson-shot noise model in the physical diffraction problem 

In [None]:
rate = N_p/ Y.mean()
Y_tilde = stats.poisson.rvs(rate * Y)
plt.imshow(np.fft.fftshift(np.log(1 + Y_tilde)), cmap="viridis")

Finally, we need to remove the low frequency content of the data. This is caused in the physical experiment by the inclusion of a "beamstop", which protects the instrument used by preventing the strongest parts of the beam from directly shining on the detectors.

The beamstop is represented in the model by $\mathcal{B}$, a matrix of 0s and 1s.

In [None]:
B_cal = np.ones((M1,M2), dtype=int)
B_cal[M1 // 2 - r + 1: M1 // 2 + r, M2 // 2 - r + 1: M2 // 2 + r] = 0
B_cal = np.fft.ifftshift(B_cal)
# Sanity check
assert (M1 * M2 - B_cal.sum()) == (( 2 * r - 1)**2)
plt.imshow(np.fft.fftshift(B_cal), cmap='gray')

After removing these elements from the simulated data, we have the final data which is used in our model

In [None]:
Y_tilde *= B_cal
plt.imshow(np.fft.fftshift(np.log(1 + Y_tilde)), cmap="viridis")

## Stan Model Code

The Stan model code is as follows. This is a direct translation of the log density of the forward model described in the paper and used above, with the following notes:

- The FFT described in the paper is an "oversampled" FFT. This corresponds to embedding the image in a larger array of zeros. We write an overload of the `fft2` function which implements this behavior.
- A prior is added to impose an L2 penalty on adjacent pixels. This is not strictly necessary, and for low values of `sigma` induces a Gaussian blur.

In [None]:
# use https://pypi.org/project/cmdstanjupyter/ to display the model inline
%load_ext cmdstanjupyter

In [None]:
%%stan HoloML_model
/**
 * Preliminary sketch of model (equations 2.2, 2.3) from:
 *
 * David A. Barmherzig and Ju Sun. 2022. Towards practical holographic
 * coherent diffraction imaging via maximum likelihood estimation.
 * arXiv 2105.11512v2.
 */

functions {
  /**
   * Return M1 x M2 matrix of 1 values with blocks in corners set to
   * 0, where the upper left is (r x r), the upper right is (r x r-1),
   * the lower left is (r-1 x r), and the lower right is (r-1 x r-1).
   * This corresponds to zeroing out the lowest-frequency portions of
   * an FFT.
   * @param M1 number of rows
   * @param M2 number of cols
   * @param r block dimension
   * @return matrix of 1 values with 0-padded corners
   */
  matrix pad_corners(int M1, int M2, int r) {
    matrix[M1, M2] B_cal = rep_matrix(1, M1, M2);
    if (r == 0) {
      return B_cal;
    }
    // upper left
    B_cal[1 : r, 1 : r] = rep_matrix(0, r, r);
    // upper right
    B_cal[1 : r, M2 - r + 2 : M2] = rep_matrix(0, r, r - 1);
    // lower left
    B_cal[M1 - r + 2 : M1, 1 : r] = rep_matrix(0, r - 1, r);
    // lower right
    B_cal[M1 - r + 2 : M1, M2 - r + 2 : M2] = rep_matrix(0, r - 1, r - 1);
    return B_cal;
  }

  /**
   * Return the matrix corresponding to the fast Fourier
   * transform of Z after it is padded with zeros to size
   * N by M
   * When N by M is larger than the dimensions of Z,
   * this computes an oversampled FFT.
   *
   * @param Z matrix of values
   * @param N number of rows desired (must be >= rows(Z))
   * @param M number of columns desired (must be >= cols(Z))
   * @return the FFT of Z padded with zeros
   */
  complex_matrix fft2(complex_matrix Z, int N, int M){
    int r = rows(Z);
    int c = cols(Z);
    if (r > N){
      reject("N must be at least rows(Z)");
    }
    if (c > M){
      reject("M must be at least cols(Z)");
    }

    complex_matrix[N, M] pad = rep_matrix(0, N, M);
    pad[1:r, 1:c] = Z;

    return fft2(pad);
  }

}

data {
  int<lower=0> N;                     // image dimension
  matrix<lower=0, upper=1>[N, N] R;   // registration image
  int<lower=0, upper=N> d;            // separation between sample and registration image
  int<lower=N> M1;                    // rows of padded matrices
  int<lower=2 * N + d> M2;            // cols of padded matrices
  int<lower=0, upper=M1> r;           // beamstop radius. replaces omega1, omega2 in paper

  real<lower=0> N_p;                  // avg number of photons per pixel
  array[M1, M2] int<lower=0> Y_tilde; // observed number of photons

  real<lower=0> sigma;                // standard deviation of pixel prior.
}

transformed data {
  matrix[M1, M2] B_cal = pad_corners(M1, M2, r);
  matrix[d, N] separation = rep_matrix(0, d, N);
}

parameters {
  matrix<lower=0, upper=1>[N, N] X;
}

model {
  // prior - penalizing L2 on adjacent pixels
  for (i in 1 : rows(X) - 1) {
    X[i] ~ normal(X[i + 1], sigma);
  }
  for (j in 1 : cols(X) - 1) {
    X[ : , j] ~ normal(X[ : , j + 1], sigma);
  }

  // likelihood

  // object representing specimen and reference together
  matrix[N, 2*N + d] X0R = append_col(X, append_col(separation, R));
  // observed signal - squared magnitude of the (oversampled) FFT
  matrix[M1, M2] Y = abs(fft2(X0R, M1, M2)) .^ 2;

  real N_p_over_Y_bar = N_p / mean(Y);
  matrix[M1, M2] lambda = N_p_over_Y_bar * Y;

  for (m1 in 1 : M1) {
    for (m2 in 1 : M2) {
      if (B_cal[m1, m2]) {
        Y_tilde[m1, m2] ~ poisson(lambda[m1, m2]);
      }
    }
  }
}

## Optimization

Now that we have our simulated data and our generative model, we can attempt to recover the image. 

The Stan model needs all of the same information the generative model did, except it is supplied with $\tilde{Y}$ instead of the true image $X$, plus a scale parameter for the prior, $\sigma$. Smaller values of $\sigma$ (approaching 0) lead to increasing amounts of blur in the resulting image.

In [None]:
sigma = 1 # prior smoothing
data = {
    "N": N,
    "R": R,
    "d": N,
    "M1": M1,
    "M2": M2,
    "Y_tilde": Y_tilde,
    "r": r,
    "N_p": N_p,
    "sigma": sigma
}

Here we use optimization via L-BFGS, as opposed to the conjugate gradient approach in the paper. This should take a few (1-3) minutes, depending on the machine you are running on.

It is also possible to sample the model using the No-U-Turn Sampler (NUTS), but at this image size it can take a few hours to do so.

In [None]:
%%time
fit = HoloML_model.optimize(data, inits=1)

We can then plot the recovered image alongside the original

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(X_True, cmap="gray", vmin=0, vmax=1)
ax1.set_title("True Image")
ax1.set_axis_off()
ax2.imshow(fit.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
ax2.set_title("Recovered Image")
ax2.set_axis_off()

fig.suptitle(f"{N}x{N} - N_p: {N_p} - r: {r} - d: {d} - sigma: {sigma}", y=0.9)
plt.tight_layout()