In [None]:
# Jax imports
import jax
import numpy as np
from jax import config
config.update("jax_enable_x64", False)

In [None]:
from matplotlib import pyplot as plt

In [None]:
# IO utils
from jax_2dtm.io import load_grid_as_cloud

In [None]:
# Image simulator
from jax_2dtm.simulator import compute_anti_aliasing_filter
from jax_2dtm.simulator import Cloud, ScatteringConfig
from jax_2dtm.simulator import ScatteringImage, OpticsImage
from jax_2dtm.simulator import EulerPose, CTFOptics, ParameterState
from jax_2dtm.utils import fft, ifft, fftfreqs

In [None]:
# Volume filename and metadata
pixel_size = 5.28 # Angstroms
filename = "../tests/data/3jar_13pf_bfm1_ps5_28.mrc"

In [None]:
# Configure image data and read template as point cloud
config = ScatteringConfig((60, 60), pixel_size, eps=1e-4)
cloud = load_grid_as_cloud(filename, config, threshold=1e-4)

In [None]:
# Compute scattering image
freqs = fftfreqs(config.shape, config.pixel_size)
pose = EulerPose(50.0, 50.0, np.pi/8, np.pi/10, np.pi/4)
transformed_cloud = cloud.view(pose)
scattering_image = compute_anti_aliasing_filter(freqs, config.pixel_size) * transformed_cloud.project(config)

In [None]:
# Apply optics model
optics = CTFOptics()
ctf = optics(freqs)
optics_image = scattering_image * ctf

In [None]:
# Plot scattering
fig, axes = plt.subplots(ncols=2)
ax1, ax2 = axes
ax1.imshow(ifft(scattering_image), origin="lower", cmap="gray")
ax2.imshow(ifft(optics_image), origin="lower", cmap="gray")

In [None]:
# Sanity check that ffts are correct
plt.imshow(ifft(fft(ifft(scattering_image))) - ifft(scattering_image), origin="lower", cmap="gray")
plt.colorbar()

Now that we have confirmed the pipeline works step-by-step, let's demonstrate how to use the API.

In [None]:
# Initialize model, parameters, and compute image
state = ParameterState(pose=pose, optics=optics)
model = OpticsImage(config=config, cloud=cloud, state=state)
optics_image = model(state)
plt.imshow(ifft(optics_image), origin="lower", cmap="gray")

Computing an image is straight-forward, but really we want to define a function that can be arbitrarily transformed by JAX and evaulated at subsets of the parameters.

In [None]:
# Define subset of parameters over which to evaluate model
params = dict(view_phi=np.pi, defocus_u=9000)

In [None]:
@jax.jit
def compute_image(params):
    scattering_image = model(params)
    return scattering_image

In [None]:
# Benchmark jitted pipeline
%timeit optics_image = compute_image(params).block_until_ready()

In [None]:
# Benchmark non-jitted pipeline
%timeit optics_image = model(params)