## Let's demonstrate the imaging pipeline for a helical specimen.

In [None]:
# Jax imports
import jax
import jax.numpy as jnp
from jax import config


config.update("jax_enable_x64", False)

In [None]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Image simulator imports
import cryojax.simulator as cs
from cryojax.io import read_array_with_spacing_from_mrc

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax


def plot_images(images, labels=None, **kwargs):
    nimages = len(images)
    fig, axes = plt.subplots(ncols=nimages, figsize=(4 * nimages, 6))
    if nimages == 1:
        axes = [axes]
    for idx, ax in enumerate(axes):
        image = images[idx]
        plot_image(image, fig, ax, **kwargs)
        if labels is not None:
            ax.set(title=labels[idx])
    return fig, axes


def plot_net(theta, z, **kwargs):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set(xlabel=r"azimuthal angle, $\theta$", ylabel=r"axial rise, $z$")
    ax.scatter(theta, z, **kwargs)
    return fig, ax

In [None]:
# Volume filename and metadata
subunit_filename = "../../tests/data/3j9g_subunit_potential_ps4_4.mrc"
assembly_filename = "../../tests/data/3j9g_potential_ps4_4.mrc"

In [None]:
# Helical parameters
rise = 21.94532431  # Angstroms
twist = 29.571584705551697  # Degrees
n_start = 6  # Start number
r_0 = jnp.asarray(
    [-88.70895129, 9.75357114, 0.0], dtype=float
)  # Displacement of initial subunit
n_subunits_per_start = 2  # Number of subunits per sub-helix

In [None]:
# Initialize density distributions and center of mass pose
# ... the voxel grid of the subunit
subunit_real_voxel_grid, subunit_voxel_size = read_array_with_spacing_from_mrc(
    subunit_filename
)
# ... make sure subunit sits in box at +z direction
subunit_potential = cs.FourierVoxelGridPotential.from_real_voxel_grid(
    subunit_real_voxel_grid,
    subunit_voxel_size,
    pad_scale=1.5,
)
# ... and of the whole assembly
assembly_real_voxel_grid, assembly_voxel_size = read_array_with_spacing_from_mrc(
    assembly_filename
)
assembly_potential = cs.FourierVoxelGridPotential.from_real_voxel_grid(
    assembly_real_voxel_grid,
    assembly_voxel_size,
    pad_scale=1.2,
)
pose = cs.EulerAnglePose(
    offset_x_in_angstroms=0.0,
    offset_y_in_angstroms=0.0,
    view_phi=0.0,
    view_theta=0.0,
    view_psi=0.0,
)

# Initialize the Specimen
integrator = cs.FourierSliceExtract()
initial_subunit = cs.Specimen(subunit_potential, integrator, cs.EulerAnglePose(*r_0))
true_assembly = cs.Specimen(assembly_potential, integrator, pose)

# Initialize the Helix
helix = cs.Helix(
    subunit=initial_subunit,
    pose=pose,
    rise=rise,
    twist=twist,
    n_start=n_start,
    n_subunits=n_subunits_per_start * n_start,
)

In [None]:
# View the helical net
lattice = helix.offsets_in_angstroms
theta, z = jnp.arctan2(lattice[:, 1], lattice[:, 0]), lattice[:, 2]
plot_net(theta, z)

In [None]:
# Configure the image settings and projection method
shape = (80, 80)
pixel_size = assembly_potential.voxel_size
config = cs.ImageConfig(shape, pixel_size, pad_scale=1.4)

# ... instantiate the image formation models
simulated = cs.AssemblyPipeline(config, helix)
truth = cs.ImagePipeline(config, true_assembly)

In [None]:
# Plot models
im1, im2 = simulated.render(), truth.render()
fig, axes = plot_images(
    [im1, im2, im1 - im2],
    labels=["Cryojax Assembly", "Ground truth", "Difference map"],
)
plt.tight_layout()

In [None]:
# Now, create an instrument
ctf = cs.CTF(
    defocus_u_in_angstroms=10000.0,
    defocus_v_in_angstroms=10000.0,
    amplitude_contrast_ratio=0.07,
)
optics = cs.WeakPhaseOptics(ctf)
dose = cs.ElectronDose(electrons_per_angstrom_squared=1000.0)
detector = cs.GaussianDetector(dqe=cs.IdealDQE(fraction_detected_electrons=1.0))
instrument = cs.Instrument(optics, dose, detector)
# ... and their respective pipelines
simulated = cs.AssemblyPipeline(config, helix, instrument)
truth = cs.ImagePipeline(config, true_assembly, instrument)

In [None]:
# Sample from the instrument models
key = jax.random.PRNGKey(1234)
im1, im2 = simulated.sample(key), truth.sample(key)
fig, axes = plot_images(
    [im1, im2],
    labels=["Cryojax Assembly", "Ground truth"],
)
plt.tight_layout()