In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Jax imports
import jax
import jax.numpy as jnp

In [None]:
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax

In [None]:
# CryoJAX imports
from jaxtyping import install_import_hook


with install_import_hook("cryojax", "typeguard.typechecked"):
    import cryojax as cx
    import cryojax.simulator as cxs
    from cryojax.data import generate_starfile, RelionParticleStack
    from cryojax.io import read_atoms_from_pdb
    from cryojax.rotations import SO3

# Generating a starfile

To generate a starfile we will first create a cryojax RelionParticleStack

In [None]:
from functools import partial

import equinox as eqx
import equinox.internal as eqxi
from jaxtyping import PRNGKeyArray

from cryojax.image import operators as op


@partial(eqx.filter_vmap, in_axes=(0, None), out_axes=eqxi.if_mapped(axis=0))
def make_relion_dataset(
    key: PRNGKeyArray, instrument_config: cxs.InstrumentConfig
) -> RelionParticleStack:
    # Generate random values

    # Pose
    # ... instantiate rotations
    rotation = SO3.sample_uniform(key)
    key, subkey = jax.random.split(key)

    # ... now in-plane translation
    ny, nx = instrument_config.shape
    in_plane_offset_in_angstroms = (
        jax.random.uniform(subkey, (2,), minval=-0.45, maxval=0.45)
        * jnp.asarray((nx, ny))
        * instrument_config.pixel_size
    )
    key, subkey = jax.random.split(key)
    # ... convert 2D in-plane translation to 3D, setting the out-of-plane translation to
    # zero
    offset_in_angstroms = jnp.pad(in_plane_offset_in_angstroms, ((0, 1),))
    # ... build the pose
    pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)

    # CTF Parameters
    # ... defocus
    defocus_in_angstroms = jax.random.uniform(subkey, (), minval=1000, maxval=1500)
    key, subkey = jax.random.split(key)

    astigmatism_in_angstroms = jax.random.uniform(subkey, (), minval=1000, maxval=1500)
    key, subkey = jax.random.split(key)

    astigmatism_angle = jax.random.uniform(subkey, (), minval=0, maxval=jnp.pi)
    key, subkey = jax.random.split(key)

    # now generate your non random values
    spherical_aberration_in_mm = 2.7
    amplitude_contrast_ratio = 0.1
    phase_shift = 0.0
    b_factor = 170.0
    ctf_scale_factor = 1.0

    # ... build the CTF
    transfer_theory = cxs.ContrastTransferTheory(
        ctf=cxs.ContrastTransferFunction(
            defocus_in_angstroms=defocus_in_angstroms,
            astigmatism_in_angstroms=astigmatism_in_angstroms,
            astigmatism_angle=astigmatism_angle,
            voltage_in_kilovolts=instrument_config.voltage_in_kilovolts,
            spherical_aberration_in_mm=spherical_aberration_in_mm,
            amplitude_contrast_ratio=amplitude_contrast_ratio,
            phase_shift=phase_shift,
        ),
        envelope=op.FourierGaussian(b_factor=b_factor, amplitude=ctf_scale_factor),
    )

    relion_particle_stack = RelionParticleStack(
        instrument_config=instrument_config,
        pose=pose,
        transfer_theory=transfer_theory,
    )
    return relion_particle_stack

In [None]:
# Generate instrument config
instrument_config = cxs.InstrumentConfig(
    shape=(128, 128),
    pixel_size=1.5,
    voltage_in_kilovolts=300.0,
    pad_scale=1.0,  # no padding
)

# Generate RNG keys
number_of_images = 10
keys = jax.random.split(jax.random.PRNGKey(0), number_of_images)

# ... instantiate the RelionParticleStack
relion_particle_stack = make_relion_dataset(keys, instrument_config)

In [None]:
relion_particle_stack.transfer_theory.envelope

In [None]:
# ... generate the starfile
generate_starfile(relion_particle_stack, "relion_dataset.star", mrc_batch_size=10)

# Generating particles from starfile

In the previous step we generated a starfile. To allow for more flexibility in the simulation of images, we have split the generation of starfiles and mrcfiles into two steps. In this step we will define an imaging pipeline template, which will be used to generate the mrcfiles.

In [None]:
# First, load the scattering potential and projection method
# Instantiate the scattering potential
filename = "./data/groel_chainA.pdb"
atom_positions, atom_identities, b_factors = read_atoms_from_pdb(
    filename, assemble=False, get_b_factors=True
)
atomic_potential = cxs.PengAtomicPotential(atom_positions, atom_identities, b_factors)
box_size = instrument_config.shape[0]
voxel_size = instrument_config.pixel_size

real_voxel_grid = atomic_potential.as_real_voxel_grid(
    shape=(box_size, box_size, box_size), voxel_size=voxel_size
)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
    real_voxel_grid, voxel_size, pad_scale=2
)

potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)

Now we will create a template for our imaging pipeline. The pose and transfer theory will be generated as default, as the values will be updated from the starfile we generated in the previous step

In [None]:
pose = cxs.EulerAnglePose()

transfer_theory = cxs.ContrastTransferTheory(ctf=cxs.ContrastTransferFunction())

structural_ensemble = cxs.SingleStructureEnsemble(potential, pose)

scattering_theory = cxs.WeakPhaseScatteringTheory(
    structural_ensemble, potential_integrator, transfer_theory
)

strength = 0.05 * real_voxel_grid.sum(axis=0).max() * voxel_size
solvent = cxs.GaussianIce(
    variance_function=op.Lorenzian(
        amplitude=strength**2, length_scale=2.0 * potential.voxel_size
    )
)

imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)

Checking how the image from the imaging pipeline looks

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

im1 = plot_image(
    imaging_pipeline.render(),
    fig,
    ax,
)

Now let's define a noise distribution

In [None]:
from cryojax.inference import distributions as dist


distribution = dist.IndependentGaussianFourierModes(
    imaging_pipeline,
    signal_scale_factor=jnp.sqrt(instrument_config.n_pixels),
    variance_function=op.Constant(1.0),
)

key = jax.random.PRNGKey(seed=0)

fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
im1 = plot_image(
    distribution.sample(key),
    fig,
    ax1,
    label="image with noise",
)

Lastly, let's define our vmapping filter

In [None]:
def batch_filter(distribution: dist.AbstractDistribution):
    """
    These are the only leaves that have a batch size in the starfile
    """
    output = (
        distribution.imaging_pipeline.scattering_theory.structural_ensemble.pose.offset_x_in_angstroms,
        distribution.imaging_pipeline.scattering_theory.structural_ensemble.pose.offset_y_in_angstroms,
        distribution.imaging_pipeline.scattering_theory.structural_ensemble.pose.view_phi,
        distribution.imaging_pipeline.scattering_theory.structural_ensemble.pose.view_theta,
        distribution.imaging_pipeline.scattering_theory.structural_ensemble.pose.view_psi,
        distribution.imaging_pipeline.scattering_theory.transfer_theory.ctf.defocus_in_angstroms,
        distribution.imaging_pipeline.scattering_theory.transfer_theory.ctf.astigmatism_in_angstroms,
        distribution.imaging_pipeline.scattering_theory.transfer_theory.ctf.astigmatism_angle,
        distribution.imaging_pipeline.scattering_theory.transfer_theory.ctf.phase_shift,
        distribution.imaging_pipeline.scattering_theory.transfer_theory.envelope,
    )
    return output

In [None]:
import os
import pathlib
from typing import Optional

from cryojax.data import RelionDataset
from cryojax.io import write_image_stack_to_mrc


@eqx.filter_jit
@eqx.filter_vmap(in_axes=(0, None, 0), out_axes=0)
def compute_noisy_image_stack(
    dist_vmap: dist.AbstractDistribution,
    dist_novmap: dist.AbstractDistribution,
    key: PRNGKeyArray,
):
    """Simulate an image with noise from a `imaging_pipeline`."""
    distribution = eqx.combine(dist_vmap, dist_novmap)

    return distribution.sample(key)


def generate_particles_from_starfile(
    path_to_starfile: str | pathlib.Path,
    path_to_relion_project: str | pathlib.Path,
    distribution: dist.AbstractDistribution,
    key: PRNGKeyArray,
    overwrite: bool = False,
    compression: Optional[str] = None,
) -> None:
    dataset = RelionDataset(
        path_to_starfile=path_to_starfile,
        path_to_relion_project=path_to_relion_project,
        get_image_stack=False,
    )

    particles_fnames = dataset.data_blocks["particles"]["rlnImageName"].str.split(
        "@", expand=True
    )
    mrc_fnames = particles_fnames[1].unique()

    if not os.path.exists(path_to_relion_project):
        os.makedirs(path_to_relion_project)

    key, subkey = jax.random.split(key)
    filter_spec = cx.get_filter_spec(distribution, batch_filter)

    for mrc_fname in mrc_fnames:
        indices = particles_fnames[particles_fnames[1] == mrc_fname].index.to_numpy()
        relion_particle_stack = dataset[indices]

        new_distribution = eqx.tree_at(
            lambda d: d.imaging_pipeline.instrument_config,
            distribution,
            relion_particle_stack.instrument_config,
        )

        new_distribution = eqx.tree_at(
            lambda d: d.imaging_pipeline.scattering_theory.structural_ensemble.pose,
            new_distribution,
            relion_particle_stack.pose,
        )

        new_distribution = eqx.tree_at(
            lambda d: d.imaging_pipeline.scattering_theory.transfer_theory,
            new_distribution,
            relion_particle_stack.transfer_theory,
        )

        keys = jax.random.split(subkey, len(indices) + 1)
        subkey = keys[-1]

        image_stack = compute_noisy_image_stack(
            *eqx.partition(new_distribution, filter_spec), keys[:-1]
        )

        filename = os.path.join(path_to_relion_project, mrc_fname)
        write_image_stack_to_mrc(
            image_stack,
            pixel_size=relion_particle_stack.instrument_config.pixel_size,
            filename=filename,
            overwrite=overwrite,
            compression=compression,
        )

    return

In [None]:
key = jax.random.PRNGKey(seed=0)
generate_particles_from_starfile(
    "relion_dataset.star", "relion_project", distribution, key, overwrite=True
)

In [None]:
dataset = RelionDataset(
    path_to_starfile="relion_dataset.star",
    path_to_relion_project="relion_project/",
    get_image_stack=True,
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

im1 = plot_image(
    dataset[0].image_stack,
    fig,
    ax,
    label="image with noise",
)