## Let's demonstrate the imaging pipeline for a helical specimen.

In [None]:
# Jax imports
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Image simulator imports
import cryojax.simulator as cs
from cryojax.image import fftn, irfftn

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

def plot_images(images, labels=None, **kwargs):
    nimages = len(images)
    fig, axes = plt.subplots(ncols=nimages, figsize=(4*nimages, 6))
    if nimages == 1:
        axes = [axes]
    for idx, ax in enumerate(axes):
        image = images[idx]
        plot_image(image, fig, ax, **kwargs)
        if labels is not None:
            ax.set(title=labels[idx])
    return fig, axes

def plot_net(theta, z, **kwargs):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set(xlabel=r"azimuthal angle, $\theta$", ylabel=r"axial rise, $z$")
    ax.scatter(theta, z, **kwargs)
    return fig, ax

In [None]:
# Volume filename and metadata
subunit_filename = "../tests/data/3j9g_subunit_bfm1_ps4_4.mrc"
assembly_filename = "../tests/data/3j9g_bfm1_ps4_4.mrc"

In [None]:
# Helical parameters
rise = 21.94532431  # Angstroms
twist = 29.571584705551697  # Degrees
n_start = 6 # Start number
r_0 = jnp.asarray([-88.70895129, 9.75357114, 0.0], dtype=float)  # Displacement of initial subunit
n_subunits_per_start = 2  # Number of subunits per sub-helix

In [None]:
# Initialize density distributions and center of mass pose
subunit_density = cs.FourierVoxelGrid.from_file(subunit_filename, pad_scale=1.2)
assembly_density = cs.FourierVoxelGrid.from_file(assembly_filename, pad_scale=1.2)
pose = cs.EulerPose(offset_x=0.0, offset_y=0.0, view_phi=0.0, view_theta=0.0, view_psi=0.0)

# ... initialize the Ensembles
initial_subunit = cs.Ensemble(density=subunit_density, pose=cs.EulerPose(*r_0))
true_assembly = cs.Ensemble(density=assembly_density, pose=pose)

# ... initialize the Helix
helix = cs.Helix(subunit=initial_subunit, pose=pose, rise=rise, twist=twist, n_start=n_start, n_subunits_per_start=n_subunits_per_start)

In [None]:
# View the helical net
lattice = helix.positions
theta, z = jnp.arctan2(lattice[:, 1], lattice[:, 0]), lattice[:, 2]
plot_net(theta, z)

In [None]:
# Configure the image settings and projection method
shape = (80, 80)
pixel_size = assembly_density.voxel_size
manager = cs.ImageManager(shape=shape, pad_scale=1.4)
scattering = cs.FourierSliceExtract(manager, pixel_size=pixel_size)

# ... instantiate the image formation models
simulated = cs.SuperpositionPipeline(scattering=scattering, ensemble=helix.subunits)
truth = cs.ImagePipeline(scattering=scattering, ensemble=true_assembly)

In [None]:
# Plot models
im1, im2 = simulated.render(), truth.render()
fig, axes = plot_images([im1, im2, im1-im2], labels=["Cryojax Assembly", "Ground truth", "Difference map"])
plt.tight_layout()