## Let's demonstrate the imaging pipeline for a helical specimen.

In [None]:
# Jax imports
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Image simulator imports
import cryojax.simulator as cs
from cryojax.utils import fftn, irfftn

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

def plot_images(images, **kwargs):
    nimages = len(images)
    fig, axes = plt.subplots(ncols=nimages, figsize=(4*nimages, 6))
    if nimages == 1:
        axes = [axes]
    for idx, ax in enumerate(axes):
        image = images[idx]
        plot_image(image, fig, ax, **kwargs)
    return fig, axes

def plot_net(theta, z, **kwargs):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set(xlabel=r"azimuthal angle, $\theta$", ylabel=r"axial rise, $z$")
    ax.scatter(theta, z, **kwargs)
    return fig, ax

In [None]:
# Volume filename and metadata
filename = "../tests/data/3jar_monomer_bfm1_ps5_28.mrc"

In [None]:
# Helical parameters
rise = 10.8  # Angstroms
twist = -23.8  # Degrees
n_start = 2   
radius = 250.0/2 # Angstroms
n_subunits = 30

In [None]:
# Initialize the Specimen
resolution = 5.28  # Angstroms
density = cs.ElectronGrid.from_file(filename, config=dict(pad_scale=1.5))
monomer = cs.Specimen(density=density, resolution=resolution)
helix = cs.Helix(subunit=monomer, rise=rise, twist=twist, radius=radius, n_start=n_start, n_subunits=n_subunits)

In [None]:
# View the helical net
lattice = helix.lattice
theta, z = jnp.arctan2(lattice[:, 1], lattice[:, 0]), lattice[:, 2]
plot_net(theta, z)

In [None]:
# Configure the image formation process
shape = (81, 82)
pad_scale = 1.5
scattering = cs.FourierSliceScattering(shape=shape, pad_scale=pad_scale)

In [None]:
# Initialize the image formation pipeline
pose = cs.EulerPose(offset_x=0.0, offset_y=0.0, view_phi=0.0, view_theta=90.0, view_psi=0.0)
optics = cs.CTFOptics(defocus_u=10000, defocus_v=10000, amplitude_contrast=.07)
exposure = cs.UniformExposure()
state = cs.PipelineState(pose=pose, optics=optics, exposure=exposure)

In [None]:
# Image formation models
scattering_model = cs.ScatteringImage(scattering=scattering, specimen=helix, state=state)
optics_model = cs.OpticsImage(scattering=scattering, specimen=helix, state=state)

In [None]:
# Jitted models
jitted_scattering = eqx.filter_jit(scattering_model)
jitted_optics = eqx.filter_jit(optics_model)

In [None]:
# Plot models
fig, axes = plot_images([jitted_scattering(), jitted_optics()])
plt.tight_layout()

In [None]:
%timeit jitted_optics()