## Let's demonstrate the imaging pipeline for a helical specimen.

In [None]:
# Jax imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Image simulator imports
import cryojax.simulator as cs
from cryojax.utils import fft, irfft

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

In [None]:
# Volume filename and metadata
filename = "../tests/data/3jar_monomer_bfm1_ps5_28.mrc"

In [None]:
# Helical parameters
rise = 9.42  # Angstroms
twist = 27.71  # Degrees
repeat = 400.0  # Angstroms

In [None]:
# Read template into an ElectronGrid
resolution = 5.28  # Angstroms
density = cs.ElectronGrid.from_file(filename, config=dict(pad_scale=1.5))
monomer = cs.Specimen(density=density, resolution=resolution)
helix = cs.Helix(subunit=monomer, rise=rise, twist=twist, repeat=repeat)

In [None]:
# Configure the image formation process
shape = (81, 82)
pad_scale = 1.5
scattering = cs.FourierSliceScattering(shape=shape, pad_scale=pad_scale)

In [None]:
# Initialize the image formation pipeline
pose = cs.EulerPose(offset_x=0.0, offset_y=0.0, view_phi=0.0, view_theta=0.0, view_psi=0.0)
optics = cs.CTFOptics(defocus_u=10000, defocus_v=10000, amplitude_contrast=.07)
state = cs.PipelineState(pose=pose, optics=optics)

In [None]:
# Image formation models
scattering_model = cs.ScatteringImage(scattering=scattering, specimen=helix, state=state)
optics_model = cs.OpticsImage(scattering=scattering, specimen=helix, state=state)

In [None]:
# Plot models
fig, axes = plt.subplots(ncols=2, figsize=(8, 6))
ax1, ax2 = axes
im1 = plot_image(scattering_model(), fig, ax1)
im2 = plot_image(optics_model(), fig, ax2)
plt.tight_layout()

In [None]:
# Parameters of tobacco mosaic virus
#twist = np.deg2rad(22.03)
#rise = 1.408  # Angstroms
#repeat = 69  # Angstroms

Let's plot the helical net given by these parameters

In [None]:
# More helical parameters
pitch = 2*np.pi * rise / np.deg2rad(twist)  # Helical pitch (distance between turns of a full helix)
turns_per_repeat = repeat / pitch  # Number of turns
subunits_per_repeat = repeat / rise  # Number of points per turn

In [None]:
subunits_per_repeat

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Generate points for the helix
t = np.linspace(0, 2 * np.pi * turns_per_repeat, int(turns_per_repeat * subunits_per_repeat))
x = np.cos(t)
y = np.sin(t)
z = pitch * t / (2 * np.pi)

azimuth = np.arctan2(y, x)

# Create a 2D projection
fig, ax = plt.subplots()
ax.scatter(azimuth, z)

# Add labels
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Show the helix
plt.grid(True)
plt.show()