# HoloML in Stan
### Brian Ward, Bob Carpenter, and David Barmherzig
#### June 13, 2022

This case study is a re-implementation of the algorithm described in [Barmherzig and Sun 2022] as a Stan model.

## Introduction

The HoloML technique is an approach to solving a specific kind of inverse problem inherent to imaging nanoscale specimens with X-ray crystallography. 

To solve this problem in Stan, we are able to write down the forward scientific model given by Barmherzig and Sun, including the Poisson photon distribution and censored data inherent to the physical problem.

### Experimental setup 

In traditional coherent diffraction imaging (CDI), a radiation source, typically an X-ray, is directed at a biomolecule or other specimen of interest and diffracted. The resulting photon flux is measured by a far-field detector. This photon flux is approximately the squared magnitude of the Fourier transform of the electric field causing the diffraction. Inverting this to recover an image of the specimen is a problem usually known as *phase retrevial*. The phase retrieval problem is highly challenging and often lacks a unique solution [Barnett et al. 2020].

Holographic coherent diffraction imaging (HCDI) is a variant in which the specimen is placed some distance away from a known reference object, and the data observed is the diffracted result of both the specimen and the reference. This additional reference information makes the problem identifiable. 

<img src="./figure 1.jpg" width=400 />

**TODO: Cite image or replace, ask David**

The idealized version of HCDI is formulated as 

- Given a reference $R$, data $Y = | \mathcal{F}( X + R ) | ^2$
- Recover the source image $X$

However, the real-world set up of these experiments introduces two additional difficulties. Data is measured from a limited number of photons, where each detector recieves photons based on a Poisson distribution (referred to in the paper as *Poisson-shot noise*). The expected number of photons each detector receives is denoted $N_p$. It is desirable that retrevial be performed when this value is small (< 10) due to the damage the radiation causes the biomolecule under observation. Secondly, to prevent damage to the detectors, the lowest frequencies are removed by a *beamstop* which censors low-frequency observations. 

The model presented here is able to recover reasonable images even under a regime featuring low photon counts and a beamstop.

## Simulating Data

We simulate data from the generative model directly. This corresponds to the work done in Barmherzig and Sun, and is based on MATLAB code provided by Barmherzig.

#### Imports and helper code

Generating the data requires a few standard Python numerical libraries such as scipy and numpy. Matplotlib is also used to simplify loading in the source image and displaying results.

In [None]:
import numpy as np
from scipy import stats

import matplotlib as mpl
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

# disable axes drawing, since we are showing images
mpl.rc('axes.spines', top=False, bottom=False, left=False, right=False)
mpl.rc('axes', facecolor='white')
mpl.rc("xtick", bottom=False, labelbottom=False)
mpl.rc("ytick", left=False, labelleft=False)
mpl.rc("figure", autolayout=True)

def rgb2gray(rgb):
    """Convert a nxmx3 RGB array to a grayscale nxm array.

    This function uses the same internal coefficients as MATLAB:
    https://www.mathworks.com/help/matlab/ref/rgb2gray.html
    """
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

#### Simulation parameters. 

To match the figures in the paper (in particular, Figure 9), we use an image of size 256x256, $N_p = 1$ (meaning each detector is expected to receive one photon), and a beamstop of size 25x25 (corresponding to a radius of 13), and a seperation `d` equal to the size of the image.


In [None]:
N = 256
d = N
N_p = 1
r = 13

M1 = 2 * N 
M2 = 2 * (2 * N + d)

We can then load the source image used for these simulations. This is a picture of a [giant virus](https://en.wikipedia.org/wiki/Giant_virus) known as a mimivirus.

In this model, the pixels of $X$ grayscale values represented on the interval [0, 1]. A conversion is done here from the standard RGBA encoding using the above `rgb2gray` function.

In [None]:
X_src = rgb2gray(mpimg.imread('mimivirus.png'))
plt.imshow(X_src, cmap='gray', vmin=0, vmax=1)

Additionally, we load in the reference object. 

The pattern used here is known as a *uniformly redundant array* (URA) [Fenimore and Cannon 1978]. It has been shown to be an optimal reference image for this kind of work, but other references (including none at all) could be used.

The code used to generate this grid is omitted from this case study. Various options such as [cappy](https://github.com/bpops/cappy) exist to generate these patterns in Python.

In [None]:
R = np.loadtxt('URA.csv', delimiter=",", dtype=int)
plt.imshow(R, cmap='gray')

We can then create the specimen-reference hybrid image by concatenating the $X$ image, a matrix of zeros, and the reference $R$. In the true experiment, this is done by placing the specimen some distance `d` away from the reference, with opaque material between. Traditionally, this distance is the same as the size of the specimen.

In [None]:
X0R = np.concatenate([X_src, np.zeros((N,d)), R], axis=1)
plt.imshow(X0R, cmap='gray')

We can simulate the diffraction of an X-ray by taking the absolute value squared of the 2-dimensional oversampled FFT of this hybrid object. 

The oversampled FFT (denoted $\mathcal{F}$ in the paper) corresponds to padding the image in both dimensions with zeros until it is a desired size. For our case, we define the size of the padded image, `M1` by `M2`, to be two times the size of our hybrid image, so the resulting FFT is twice oversampled. This is the oversampling ratio traditionally used for this problem, however Barmherzig and Sun also showed that this model can operate with less oversampling as well.

In [None]:
Y = np.abs(np.fft.fft2(X0R, s=(M1, M2))) ** 2
plt.imshow(np.fft.fftshift(np.log1p(Y)), cmap="viridis")

We can then simulate the data retrieval process with a Poisson random number generator.

We fix the seed here to ensure the same fake data is generated each time.

In [None]:
rate = N_p / Y.mean()
Y_tilde = stats.poisson.rvs(rate * Y, random_state=1234)
plt.imshow(np.fft.fftshift(np.log1p(Y_tilde)), cmap="viridis")

Finally, we need to remove the low frequency content of the data. This is caused in the physical experiment by the inclusion of a "beamstop", which protects the instrument used by preventing the strongest parts of the beam from directly shining on the detectors.

In our model, the beamstop is represented in the model by $\mathcal{B}$, a matrix of 0s and 1s.

In [None]:
B_cal = np.ones((M1,M2), dtype=int)
B_cal[M1 // 2 - r + 1: M1 // 2 + r, M2 // 2 - r + 1: M2 // 2 + r] = 0
B_cal = np.fft.ifftshift(B_cal)
# Sanity check
assert (M1 * M2 - B_cal.sum()) == (( 2 * r - 1)**2)
plt.imshow(np.fft.fftshift(B_cal), cmap="gray", vmin=0, vmax=1.25)

We can then use this matrix $\mathcal{B}$ as a mask on our simulated data. After removing these elements from the simulated data, we have the final input which is used in our model

In [None]:
Y_tilde *= B_cal
plt.imshow(np.fft.fftshift(np.log(1 + Y_tilde)), cmap="viridis")

## Stan Model Code

The Stan model code is a direct translation of the log density of the forward model described in the paper and used above, with the following notes:

- The FFT described in the paper is an "oversampled" FFT. This corresponds to embedding the image in a larger array of zeros. We write an overload of the `fft2` function which implements this behavior.
- A prior is added to impose an L2 penalty on adjacent pixels. This prior induces a Gaussian blur on the result, and it is not strictly necessary for running the model.

In [None]:
# use https://pypi.org/project/cmdstanjupyter/ to display the model inline
%load_ext cmdstanjupyter

### TODO

More walk through of code, comment on more efficient construction, etc.

In [None]:
%%stan HoloML_model
/**
 * Preliminary sketch of model (equations 2.2, 2.3) from:
 *
 * David A. Barmherzig and Ju Sun. 2022. Towards practical holographic
 * coherent diffraction imaging via maximum likelihood estimation.
 * arXiv 2105.11512v2.
 */

functions {
  /**
   * Return M1 x M2 matrix of 1 values with blocks in corners set to
   * 0, where the upper left is (r x r), the upper right is (r x r-1),
   * the lower left is (r-1 x r), and the lower right is (r-1 x r-1).
   * This corresponds to zeroing out the lowest-frequency portions of
   * an FFT.
   * @param M1 number of rows
   * @param M2 number of cols
   * @param r block dimension
   * @return matrix of 1 values with 0-padded corners
   */
  matrix pad_corners(int M1, int M2, int r) {
    matrix[M1, M2] B_cal = rep_matrix(1, M1, M2);
    if (r == 0) {
      return B_cal;
    }
    // upper left
    B_cal[1 : r, 1 : r] = rep_matrix(0, r, r);
    // upper right
    B_cal[1 : r, M2 - r + 2 : M2] = rep_matrix(0, r, r - 1);
    // lower left
    B_cal[M1 - r + 2 : M1, 1 : r] = rep_matrix(0, r - 1, r);
    // lower right
    B_cal[M1 - r + 2 : M1, M2 - r + 2 : M2] = rep_matrix(0, r - 1, r - 1);
    return B_cal;
  }
  
  /**
   * Return the matrix corresponding to the fast Fourier
   * transform of Z after it is padded with zeros to size
   * N by M
   * When N by M is larger than the dimensions of Z,
   * this computes an oversampled FFT.
   *
   * @param Z matrix of values
   * @param N number of rows desired (must be >= rows(Z))
   * @param M number of columns desired (must be >= cols(Z))
   * @return the FFT of Z padded with zeros
   */
  complex_matrix fft2(complex_matrix Z, int N, int M) {
    int r = rows(Z);
    int c = cols(Z);
    if (r > N) {
      reject("N must be at least rows(Z)");
    }
    if (c > M) {
      reject("M must be at least cols(Z)");
    }
    
    complex_matrix[N, M] pad = rep_matrix(0, N, M);
    pad[1 : r, 1 : c] = Z;
    
    return fft2(pad);
  }
}
data {
  int<lower=0> N; // image dimension
  matrix<lower=0, upper=1>[N, N] R; // registration image
  int<lower=0, upper=N> d; // separation between sample and registration image
  int<lower=N> M1; // rows of padded matrices
  int<lower=2 * N + d> M2; // cols of padded matrices
  int<lower=0, upper=M1> r; // beamstop radius. replaces omega1, omega2 in paper
  
  real<lower=0> N_p; // avg number of photons per pixel
  array[M1, M2] int<lower=0> Y_tilde; // observed number of photons
  
  real<lower=0> sigma; // standard deviation of pixel prior.
}
transformed data {
  matrix[M1, M2] B_cal = pad_corners(M1, M2, r);
  matrix[d, N] separation = rep_matrix(0, d, N);
}
parameters {
  matrix<lower=0, upper=1>[N, N] X;
}
model {
  // prior - penalizing L2 on adjacent pixels
  for (i in 1 : rows(X) - 1) {
    X[i] ~ normal(X[i + 1], sigma);
  }
  for (j in 1 : cols(X) - 1) {
    X[ : , j] ~ normal(X[ : , j + 1], sigma);
  }
  
  // likelihood
  
  // object representing specimen and reference together
  matrix[N, 2 * N + d] X0R = append_col(X, append_col(separation, R));
  // observed signal - squared magnitude of the (oversampled) FFT
  matrix[M1, M2] Y = abs(fft2(X0R, M1, M2)) .^ 2;
  
  real N_p_over_Y_bar = N_p / mean(Y);
  matrix[M1, M2] lambda = N_p_over_Y_bar * Y;
  
  for (m1 in 1 : M1) {
    for (m2 in 1 : M2) {
      if (B_cal[m1, m2]) {
        Y_tilde[m1, m2] ~ poisson(lambda[m1, m2]);
      }
    }
  }
}

### Digression: Efficiency

The above model is coded in a fashion targetting readability and attempting to stick closely to the mathematical formulation of the process. However, this does lead to an inefficent condition inside the tighest loop of the model to handle the beamstop occlusion. 

In practice, it is possible to avoid this conditional by changing how the data is stored...

## Optimization

Now that we have our simulated data and our generative model, we can attempt to recover the image. 

The Stan model needs all of the same information the generative model did, except it is supplied with $\tilde{Y}$ instead of the true image $X$, plus a scale parameter for the prior, $\sigma$. Smaller values of $\sigma$ (approaching 0) lead to increasing amounts of blur in the resulting image.

In [None]:
sigma = 1 # prior smoothing
data = {
    "N": N,
    "R": R,
    "d": N,
    "M1": M1,
    "M2": M2,
    "Y_tilde": Y_tilde,
    "r": r,
    "N_p": N_p,
    "sigma": sigma
}

Here we use optimization via the limited-memory quasi-Newton L-BFGS algorithm. This method has a bit more curvature information than what is available to the conjugate gradient approach, but less than the second order trust-region method, used in the paper. This should take a few (1-3) minutes, depending on the machine you are running on.

It is also possible to sample the model using the No-U-Turn Sampler (NUTS), but evaluations of this are out of the scope of this case study.

In [None]:
%time fit = HoloML_model.optimize(data, inits=1, seed=5678)

We can then plot the recovered image alongside the original

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

ax1.imshow(X_src, cmap="gray", vmin=0, vmax=1)
ax1.set_title("True Image")
ax1.set_axis_off()

ax2.imshow(fit.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
ax2.set_title("Recovered Image\n(sigma = 1)")
ax2.set_axis_off()

# blank third image -- used in later comparison
ax3.imshow([[1]], cmap="gray", vmin=0, vmax=1)
ax3.set_axis_off()

plt.tight_layout()

### Prior tuning

The above choice of $\sigma = 1$ has a very slight effect on the output image. It is also interesting to observe the effect of a smaller (i.e. stronger) value such as 0.05. This imposes a greater penalty on adjacent pixels which are significantly different than each other, smoothing out the result.

In [None]:
data_smoothing = data.copy()
data_smoothing['sigma'] = 0.05

%time fit_smooth = HoloML_model.optimize(data_smoothing, inits=1, seed=5678)

In [None]:
ax3.imshow(fit_smooth.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
ax3.set_title("Recovered Image\n(sigma = 0.05)")
fig

## Reproducibility 
The following versions were used to produce this notebook

In [None]:
%load_ext watermark
%watermark -n -u -v -iv -w -p cmdstanpy,cmdstanjupyter

In [None]:
import cmdstanpy
print("CmdStan:", cmdstanpy.utils.cmdstan_version())

## References

- David A. Barmherzig and Ju Sun, "Towards practical holographic coherent diffraction imaging via maximum likelihood estimation," Opt. Express 30, 6886-6906 (2022)
- A. H. Barnett, C. L. Epstein, L. F. Greengard, and J. F. Magland, “Geometry of the phase retrieval problem,” Inverse Probl. 36(9), 094003 (2020)
- E. E. Fenimore and T. M. Cannon, "Coded aperture imaging with uniformly redundant arrays," Appl. Opt. 17, 337-347 (1978)