In [None]:
# Jax imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
from matplotlib import pyplot as plt

In [None]:
# IO utils
from jax_2dtm.io import load_grid_as_cloud

In [None]:
# Image simulator
from jax_2dtm.simulator import compute_anti_aliasing_filter
from jax_2dtm.simulator import Cloud, ScatteringConfig
from jax_2dtm.simulator import ScatteringImage, OpticsImage, GaussianImage
from jax_2dtm.simulator import AntiAliasingFilter, WhiteningFilter
from jax_2dtm.simulator import EulerPose, CTFOptics, Intensity, LorenzianNoise, WhiteNoise, ParameterState
from jax_2dtm.utils import fft, ifft, fftfreqs

In [None]:
# Volume filename and metadata
pixel_size = 5.28  # Angstroms
filename = "../tests/data/3jar_13pf_bfm1_ps5_28.mrc"

In [None]:
# Configure image data and read template as point cloud
config = ScatteringConfig((81, 81), pixel_size, eps=1e-4)
cloud = load_grid_as_cloud(filename, config, threshold=1e-4)

In [None]:
# Compute scattering image
freqs = fftfreqs(config.shape, config.pixel_size)
pose = EulerPose(-50.0, -50.0, np.pi / 8, np.pi / 10, np.pi / 4)
transformed_cloud = cloud.view(pose)
scattering_image = transformed_cloud.project(config) * compute_anti_aliasing_filter(freqs, config.pixel_size, cutoff=1.0)

In [None]:
# Apply optics model and normalize
N1, N2 = config.shape
optics = CTFOptics()
ctf = optics(freqs)
optics_image = scattering_image * ctf
optics_image = optics_image.at[N1//2, N2//2].set(0.0)
optics_image = optics_image / (jnp.sqrt(jnp.sum((optics_image * jnp.conjugate(optics_image)))) / (N1*N2))

In [None]:
# Add gaussian noise
noise = WhiteNoise(sigma=1.0)
noisy_image = optics_image + noise.sample(freqs, config)

In [None]:
# Plot scattering
fig, axes = plt.subplots(ncols=3, figsize=(12, 7))
ax1, ax2, ax3 = axes
ax1.imshow(ifft(scattering_image), origin="lower", cmap="gray")
ax2.imshow(ifft(optics_image), origin="lower", cmap="gray")
ax3.imshow(ifft(noisy_image), origin="lower", cmap="gray")

Now that we have confirmed the pipeline works step-by-step, let's demonstrate how to use the API.

In [None]:
# Initialize model, parameters, and compute image
pose = EulerPose(-50.0, -50.0, np.pi / 8, np.pi / 10, np.pi / 4)
optics = CTFOptics()
noise = WhiteNoise()
intensity = Intensity()
state = ParameterState(pose=pose, optics=optics, noise=noise, intensity=intensity)


In [None]:
scattering_model = ScatteringImage(config=config, cloud=cloud, state=state)
optics_model = OpticsImage(config=config, cloud=cloud, state=state)
noisy_model = GaussianImage(config=config, cloud=cloud, state=state)

In [None]:
# Plot models
fig, axes = plt.subplots(ncols=3, figsize=(12, 6))
ax1, ax2, ax3 = axes
ax1.imshow(ifft(scattering_model()), origin="lower", cmap="gray")
ax2.imshow(ifft(optics_model()), origin="lower", cmap="gray")
ax3.imshow(ifft(noisy_model()), origin="lower", cmap="gray")

Now, let's test altering the image filters.

In [None]:
# Instantiate image filters
micrograph = noisy_model()
filters = [AntiAliasingFilter(config, freqs), WhiteningFilter(config, freqs, micrograph)]
filtered_model = GaussianImage(config=config, cloud=cloud, state=state, filters=filters)
plt.imshow(ifft(filtered_model()), origin="lower", cmap="gray")

In [None]:
whitening = filters[1]
plt.imshow(np.log(whitening.filter), cmap="gray")
plt.colorbar()

Computing an image is straight-forward, but really we want to define a function that can be arbitrarily transformed by JAX and evaulated at subsets of the parameters.

In [None]:
# Define subset of parameters over which to evaluate model, and jitted model
params = dict(view_phi=np.pi, defocus_u=9000.0, sigma=1.0, N=1.0, mu=10.0)
jitted_noisy_model = jax.jit(lambda params: noisy_model(params))

In [None]:
# Benchmark jitted pipeline
jitted_noisy_image = jitted_noisy_model(params)

In [None]:
# Benchmark non-jitted pipeline
noisy_image = noisy_model(params)

Now, we can also use the model to compute the likelihood. Let's evaulate the likelihood at the simulated data, and visualize the residuals.

In [None]:
# Initialize the model and plot residuals
fig, axes = plt.subplots(ncols=3, figsize=(12, 7))
ax1, ax2, ax3 = axes
model = GaussianImage(config=config, cloud=cloud, state=state, observed=ifft(noisy_model()))
simulated, observed, residuals = model.render(), model.observed, model.residuals()
ax1.imshow(ifft(simulated), origin="lower", cmap="gray")
ax2.imshow(ifft(observed), origin="lower", cmap="gray")
ax3.imshow(ifft(residuals), origin="lower", cmap="gray")


In [None]:
# Compute likelihood
loss = jax.jit(lambda params: model(params))
grad_loss = jax.jit(jax.grad(lambda params: model(params)))

In [None]:
# Benchmark jitted pipeline
likelihood = loss(params)
likelihood

In [None]:
# Benchmark gradient
grad = grad_loss(params)
grad