Automatic differentiation using JAX is a powerful tool for conducting high-dimensional data analysis. This tutorial demonstrates how automatic differentiation may be leveraged for structure refinement.

The numerical experiment performed here can be described in 3 steps:

1. Simulate a synthetic dataset of GroEL in a holoprotein state
2. Define a image simulation model using GroEL in an apoprotein state
3. Optimize the apo GroEL atom positions using gradient descent to recover the holo conformation

For simplicity, we will only generate images from a single chain of GroEL.

In [1]:
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, cmap="gray", **kwargs):
    fig, ax = plt.subplots(figsize=(3, 3))
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

In [4]:
# Generate synthetic dataset
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray

import cryojax.simulator as cxs
from cryojax.io import read_atoms_from_pdb
from cryojax.rotations import SO3


# Load atomic model
atom_positions, atom_types = read_atoms_from_pdb("./data/groel_chainA_holo.pdb")
scattering_params = cxs.PengScatteringFactorParameters(atom_types)
volume = cxs.PengIndependentAtomPotential.from_tabulated_parameters(
    atom_positions, scattering_params
)

# Make dataset pose and CTF parameters
N_PARTICLES = 1000
RNG_KEY = 1234
SNR = 0.1
DEFOCUS_RANGE = (10000, 30000)
SPHERICAL_ABERRATION_IN_MM = 2.7
AMPLITUDE_CONTRAST_RATIO = 0.1
PIXEL_SIZE = 1.0
VOLTAGE_IN_KILOVOLTS = 300.0
IMAGE_SHAPE = (200, 200)


@eqx.filter_vmap(in_axes=(0, None))
def make_parameters(
    rng_key: PRNGKeyArray, image_config: cxs.BasicImageConfig
) -> tuple[cxs.EulerAnglePose, cxs.ContrastTransferTheory]:
    key, subkey = jax.random.split(rng_key)
    rotation = SO3.sample_uniform(subkey)
    ny, nx = image_config.shape
    key, subkey = jax.random.split(key)
    offset_in_angstroms = (
        jax.random.uniform(subkey, (2,), minval=-0.1, maxval=0.1)
        * jnp.asarray((nx, ny))
        / 2
        * image_config.pixel_size
    )
    pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)
    key, subkey = jax.random.split(key)
    defocus_in_angstroms = jax.random.uniform(
        subkey, (), minval=DEFOCUS_RANGE[0], maxval=DEFOCUS_RANGE[1]
    )
    transfer_theory = cxs.ContrastTransferTheory(
        ctf=cxs.AstigmaticCTF(
            defocus_in_angstroms=defocus_in_angstroms,
            spherical_aberration_in_mm=SPHERICAL_ABERRATION_IN_MM,
        ),
        amplitude_contrast_ratio=AMPLITUDE_CONTRAST_RATIO,
    )

    return pose, transfer_theory


# Define function to simulate images
@eqx.filter_vmap(in_axes=(0, None, None))
def simulate_fn(
    rng_key: PRNGKeyArray,
    volume: cxs.PengIndependentAtomPotential,
    image_config: cxs.BasicImageConfig,
) -> tuple[Array, dict]:
    particle_rng_key, noise_rng_key = jax.random.split(rng_key, num=2)
    pose, transfer_theory = make_parameters(particle_rng_key, image_config)
    signal_region = None
    image_model = cxs.make_image_model(
        volume,
        image_config,
        pose,
        transfer_theory,
        normalizes_signal=True,
        signal_region=signal_region,
    )
    noise_model = cxs.UncorrelatedGaussianNoiseModel(
        image_model, variance=1.0, signal_scale_factor=jnp.sqrt(SNR)
    )
    parameters = dict(
        pose=pose, transfer_theory=transfer_theory, image_config=image_config
    )

    return noise_model.sample(noise_rng_key), parameters


rng_key = jax.random.split(jax.random.key(RNG_KEY), num=N_PARTICLES)
image_config = cxs.BasicImageConfig(IMAGE_SHAPE, PIXEL_SIZE, VOLTAGE_IN_KILOVOLTS)
synthetic_image_stack, particle_parameters = simulate_fn(rng_key, volume, image_config)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())