In [None]:
# Jax imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

In [None]:
# Image simulator
import cryojax.simulator as cs
from cryojax.utils import fft, irfft, fftfreqs

In [None]:
# Volume filename and metadata
resolution = 5.28  # Angstroms
filename = "../tests/data/3jar_13pf_bfm1_ps5_28.mrc"

In [None]:
# Configure image data and read template
scattering = cs.NufftScattering(shape=(80, 80), resolution=resolution, pad_scale=1.2, eps=1e-5)
cloud = cs.ElectronCloud.from_file(filename, atol=1e-8)
scattering = cs.FourierSliceScattering(shape=(80, 80), resolution=resolution, pad_scale=1.5)
cloud = cs.ElectronGrid.from_file(filename, pad_scale=1.5)

In [None]:
# Initialize model, parameters, and compute image
pose = cs.EulerPose(offset_x=0.0, offset_y=0.0, view_phi=-np.pi / 4, view_theta=np.pi / 2+np.pi/10, view_psi=np.pi / 8)
optics = cs.CTFOptics(defocus_u=10000, defocus_v=10000, amplitude_contrast=.07)
ice = cs.ExponentialNoiseIce(key=jax.random.PRNGKey(seed=0), kappa=0.0, xi=1e-3, gamma=0.0)
exposure = cs.UniformExposure(N=1, mu=0)
detector = cs.WhiteNoiseDetector(pixel_size=5.29, key=jax.random.PRNGKey(seed=1234), alpha=1.0)
state = cs.PipelineState(pose=pose, ice=ice, optics=optics, exposure=exposure, detector=detector)

In [None]:
scattering_model = cs.ScatteringImage(scattering=scattering, specimen=cloud, state=state)#, filters=[])
optics_model = cs.OpticsImage(scattering=scattering, specimen=cloud, state=state)#, filters=[])
detector_model = cs.DetectorImage(scattering=scattering, specimen=cloud, state=state)#, filters=[])

In [None]:
# Plot scattering in fourier space
fig, ax = plt.subplots(figsize=(3.25, 3.25))
im = plot_image(jnp.log(jnp.abs(fft(scattering_model.render()))), fig, ax)
plt.tight_layout()

In [None]:
# Plot models
fig, axes = plt.subplots(ncols=3, figsize=(12, 6))
ax1, ax2, ax3 = axes
im1 = plot_image(scattering_model(), fig, ax1)
im2 = plot_image(optics_model(), fig, ax2)
im3 = plot_image(detector_model(), fig, ax3)
plt.tight_layout()

In [None]:
# Instantiate image filters
fig, ax = plt.subplots(figsize=(4, 4))
micrograph_freqs = fftfreqs((800, 600), pixel_size=detector.pixel_size)
micrograph = ice.sample(micrograph_freqs) * optics(micrograph_freqs) + detector.sample(micrograph_freqs)
whiten = cs.WhiteningFilter(scattering.freqs, detector.pixel_size * micrograph_freqs, micrograph)
plot_image(irfft(whiten(fft(detector_model()))), fig, ax)

In [None]:
plt.imshow(irfft(micrograph).T, origin="lower", cmap="gray")
plt.colorbar()

We see that the whitening filter looks right, empirically! Now generate an image at a colored noise model with a given whitening filter

In [None]:
# Show forward model for a whitened image
fig, ax = plt.subplots(figsize=(4, 4))
filters = [cs.LowpassFilter(scattering.padded_freqs), cs.WhiteningFilter(scattering.padded_freqs, detector.pixel_size * micrograph_freqs, micrograph)]
filtered_model = cs.GaussianImage(scattering=scattering, specimen=cloud, state=state, filters=filters)
plot_image(filtered_model.render(), fig, ax)

In [None]:
# Visualize filters
fig, axes = plt.subplots(ncols=2, figsize=(8, 6))
ax1, ax2 = axes
antialias, whiten = filters
im1 = plot_image(whiten.filter, fig, ax1, cmap="viridis")
im2 = plot_image(antialias.filter, fig, ax2, cmap="gray")
plt.tight_layout()

Computing an image is straight-forward, but really we want to define a function that can be arbitrarily transformed by JAX and evaulated at subsets of the parameters.

In [None]:
# Define subset of parameters over which to evaluate model, and jitted model
params = dict(view_psi=np.pi, defocus_u=9000.0, alpha=1.0, N=1.0, mu=0.0, pixel_size=5.0)
jitted_model = jax.jit(lambda params: detector_model(params))

In [None]:
# Plot model with updated parameters.
fig, axes = plt.subplots(ncols=2, figsize=(8, 6))
ax1, ax2 = axes
plot_image(optics_model(), fig, ax1)
plot_image(optics_model(params), fig, ax2)

In [None]:
# Benchmark jitted pipeline
# jitted_image = jitted_model(params)

In [None]:
# Benchmark non-jitted pipeline
# image = detector_model(params)

Now, we can also use the model to compute the likelihood. Let's evaulate the likelihood at the simulated data, and visualize the residuals.

In [None]:
# Initialize the model and plot residuals
fig, axes = plt.subplots(ncols=3, figsize=(12, 6))
ax1, ax2, ax3 = axes
masks = [cs.CircularMask(scattering.coords)]
observation_model = cs.GaussianImage(scattering=scattering, specimen=cloud, state=state, masks=[], filters=[])
observed = observation_model()
model = cs.GaussianImage(scattering=scattering, specimen=cloud, state=state, masks=masks, observed=observed)
simulated, observed, residuals = model.render(), model.observed, model.residuals()
plot_image(simulated, fig, ax1)
plot_image(observed, fig, ax2)
plot_image(residuals, fig, ax3)
plt.tight_layout()

In [None]:
# Loss and gradient pipelines
loss = lambda params: model(params)
grad_loss = jax.grad(lambda params: model(params))

In [None]:
# Benchmark loss pipeline
%timeit likelihood = loss(params)

In [None]:
# Benchmark gradient pipeline
%timeit grad = grad_loss(params)

In [None]:
# Jitted loss and gradient pipeline
grad_loss = jax.jit(jax.value_and_grad(lambda params: model(params)))

In [None]:
# Benchmark gradient pipeline
%timeit grad = grad_loss(params)