In [2]:
# Jax and Equinox imports
from functools import partial

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray

In [3]:
# Plotting imports and functions
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax

In [4]:
# CryoJAX imports

import cryojax.simulator as cxs
from cryojax.data import (
    RelionParticleParameterDataset,
    RelionParticleParameters,
    RelionParticleStackDataset,
    write_simulated_image_stack_from_starfile,
    write_starfile_with_particle_parameters,
)
from cryojax.io import read_atoms_from_pdb
from cryojax.rotations import SO3

  from .autonotebook import tqdm as notebook_tqdm


# Simulating Ensembles and doing Ensemble Reweighting

In this tutorial we will generate a heterogeneous dataset by defining a distribution on multiple atomic structures. We will then compute a likelihood matrix
$$ P_{nm} = p(y_n | x_m) $$

where $y_n$ is a data point and $x_m$ is a structure in the ensemble. We will define the likelihood through one of cryoJAX's distributions, although in principle any distribution works.

# Generate a starfile

First, we will just follow the tutorial `simulate-relion-dataset.ipynb` to generate a starfile. No ensemble stuff yet.

In [5]:
@partial(eqx.filter_vmap, in_axes=(0, None), out_axes=eqx.if_array(0))
def make_particle_parameters(
    key: PRNGKeyArray, instrument_config: cxs.InstrumentConfig
) -> RelionParticleParameters:
    # Generate random parameters

    # Pose
    # ... instantiate rotations

    key, subkey = jax.random.split(key)  # split the key to use for the next random number

    rotation = SO3.sample_uniform(subkey)
    key, subkey = jax.random.split(key)  # do this everytime you use a key!!

    # ... now in-plane translation
    ny, nx = instrument_config.shape
    offset_in_angstroms = (
        jax.random.uniform(subkey, (2,), minval=-0.2, maxval=0.2)
        * jnp.asarray((nx, ny))
        * instrument_config.pixel_size
    )
    # ... build the pose
    pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)

    # CTF Parameters
    # ... defocus
    defocus_in_angstroms = jax.random.uniform(subkey, (), minval=1000, maxval=1500)
    key, subkey = jax.random.split(key)

    astigmatism_in_angstroms = jax.random.uniform(subkey, (), minval=0, maxval=100)
    key, subkey = jax.random.split(key)

    astigmatism_angle = jax.random.uniform(subkey, (), minval=0, maxval=jnp.pi)
    key, subkey = jax.random.split(key)

    phase_shift = jax.random.uniform(subkey, (), minval=0, maxval=0)
    # no more random numbers needed

    # now generate your non-random values
    spherical_aberration_in_mm = 2.7
    amplitude_contrast_ratio = 0.1

    # ... build the CTF
    transfer_theory = cxs.ContrastTransferTheory(
        ctf=cxs.AberratedAstigmaticCTF(
            defocus_in_angstroms=defocus_in_angstroms,
            astigmatism_in_angstroms=astigmatism_in_angstroms,
            astigmatism_angle=astigmatism_angle,
            spherical_aberration_in_mm=spherical_aberration_in_mm,
        ),
        amplitude_contrast_ratio=amplitude_contrast_ratio,
        phase_shift=phase_shift,
    )
    relion_particle_parameters = RelionParticleParameters(
        instrument_config=instrument_config,
        pose=pose,
        transfer_theory=transfer_theory,
    )
    return relion_particle_parameters

In [6]:
# Generate instrument config
instrument_config = cxs.InstrumentConfig(
    shape=(128, 128),
    pixel_size=1.5,
    voltage_in_kilovolts=300.0,
    pad_scale=1.0,  # no padding
)

# Generate RNG keys
number_of_images = 100
keys = jax.random.split(jax.random.key(0), number_of_images)

# ... instantiate the RelionParticleDataset
particle_parameters = make_particle_parameters(keys, instrument_config)

In [7]:
# ... generate the starfile
write_starfile_with_particle_parameters(
    particle_parameters,
    "./outputs/heterogeneous_relion_dataset.star",
    mrc_batch_size=50,
    overwrite=True,
)

# Simulating images by choosing a random structure

In [8]:
# First load the starfile

path_to_mrc_files = "./outputs/relion_dataset_particles/heterogeneous"

parameter_dataset = RelionParticleParameterDataset(
    path_to_starfile="./outputs/heterogeneous_relion_dataset.star",  # starfile we created
    path_to_relion_project=path_to_mrc_files,  # here is where the mrcs will be saved
)

In [9]:
from cryojax.constants import get_tabulated_scattering_factor_parameters


filenames = ["./data/groel_chainA.pdb", "./data/groel_chainA_holo.pdb"]

box_size = parameter_dataset[0].instrument_config.shape[0]

potentials = []
voxel_size = parameter_dataset[0].instrument_config.pixel_size
for filename in filenames:
    # Load the atomic structure and transform into a potential
    atom_positions, atom_identities, bfactors = read_atoms_from_pdb(
        filename, center=True, select="not element H", loads_b_factors=True
    )
    scattering_factor_parameters = get_tabulated_scattering_factor_parameters(
        atom_identities
    )
    atomic_potential = cxs.PengAtomicPotential(
        atom_positions,
        scattering_factor_a=scattering_factor_parameters["a"],
        scattering_factor_b=scattering_factor_parameters["b"],
        b_factors=bfactors,
    )
    # Convert to a real voxel grid
    # This step is optional, you could use the atomic potential directly!
    real_voxel_grid = atomic_potential.as_real_voxel_grid(
        shape=(box_size, box_size, box_size), voxel_size=voxel_size
    )
    potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
        real_voxel_grid, voxel_size, pad_scale=2
    )
    potentials.append(potential)

potentials = tuple(potentials)
potential_integrator = cxs.FourierSliceExtraction()

# Use this if using an atomic potential
# potential_integrator = cxs.GaussianMixtureProjection()

In [10]:
from typing import Any, Tuple

from jaxtyping import Array, Float

from cryojax.inference import distributions as dist


def build_distribution_from_relion_particle_parameters(
    potential_id: Float[Array, ""],
    relion_particle_parameters: RelionParticleParameters,
    args: Tuple[
        cxs.AbstractPotentialRepresentation,
        cxs.AbstractPotentialIntegrator,
        Float,
    ],
) -> dist.IndependentGaussianPixels:
    potentials, potential_integrator, variance = args

    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        relion_particle_parameters.pose,
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        relion_particle_parameters.transfer_theory,
    )
    image_model = cxs.ContrastImageModel(
        relion_particle_parameters.instrument_config, scattering_theory
    )
    distribution = dist.IndependentGaussianPixels(
        image_model,
        variance=variance,
    )
    return distribution

## Figuring out the variance of the noise

Before we generate images, we need to define the variance of the noise based on a given SNR. To do this you can simulate a set of noiseless images and estimate the mean of the variance inside a mask where you know you have signal. We can do this using cryoJAX's circular mask.

In [11]:
from cryojax.image.operators import CircularCosineMask


@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(0, eqx.if_array(0), None))
def simulate_noiseless_images(potential_id, particle_parameters, args):
    distribution = build_distribution_from_relion_particle_parameters(
        potential_id, particle_parameters, args
    )
    return distribution.compute_signal()


@eqx.filter_jit
def estimate_signal_variance(
    key, n_images_for_estimation, mask_radius, instrument_config, args, *, batch_size=None
):
    potentials, potential_integrator, ensemble_weights, variance = args

    key, *subkeys = jax.random.split(key, n_images_for_estimation + 1)
    subkeys = jnp.array(subkeys)

    particle_parameters = make_particle_parameters(subkeys, instrument_config)

    # set offset at 0 for simplicity
    particle_parameters = eqx.tree_at(
        lambda d: (d.pose.offset_x_in_angstroms, d.pose.offset_y_in_angstroms),
        particle_parameters,
        replace_fn=lambda x: 0.0 * x,
    )

    key, subkey = jax.random.split(key)
    potential_ids = jax.random.choice(
        subkey, ensemble_weights.shape[0], (n_images_for_estimation,), p=ensemble_weights
    )
    noiseless_images = simulate_noiseless_images(
        potential_ids, particle_parameters, (potentials, potential_integrator, variance)
    )

    # define noise mask
    mask = CircularCosineMask(
        particle_parameters.instrument_config.coordinate_grid_in_pixels,
        radius_in_angstroms_or_pixels=mask_radius,
        rolloff_width_in_angstroms_or_pixels=1.0,
    )

    signal_variance = jnp.var(
        noiseless_images, axis=(1, 2), where=jnp.where(mask.array == 1.0, True, False)
    ).mean()

    return signal_variance

In [12]:
var_est_seed = 0
key_var_est = jax.random.key(var_est_seed)
uniform_weights = jnp.array([0.5, 0.5])  # weights for sampling structuures

signal_variance = estimate_signal_variance(
    key_var_est,
    n_images_for_estimation=10,
    mask_radius=box_size // 3,
    instrument_config=instrument_config,
    args=(
        potentials,
        potential_integrator,
        uniform_weights,
        1.0,
    ),  # the last argument is the variance, not needed for this
)

## Simulating the images

In [13]:
def compute_image_with_noise(
    particle_parameters: RelionParticleParameters,
    constant_args,
    per_particle_args,
):
    key_noise, potential_id = per_particle_args
    distribution = build_distribution_from_relion_particle_parameters(
        potential_id, particle_parameters, constant_args
    )
    return distribution.sample(key_noise)

In [None]:
snr = 0.1  # define whatever snr you want
noise_variance = signal_variance / snr

constant_args = (potentials, potential_integrator, noise_variance)

# Generate RNG keys for per-image noise, and per-image conformations
keys_noise = jax.random.split(jax.random.key(0), number_of_images)
key_structure = jax.random.key(1)

# Generate the per-image conformation assignments
ensemble_weights = jnp.array([0.3, 0.7])  # weights for sampling structures
potential_ids = jax.random.choice(
    key_structure, ensemble_weights.shape[0], (number_of_images,), p=ensemble_weights
)

write_simulated_image_stack_from_starfile(
    param_dataset=parameter_dataset,
    compute_image_fn=compute_image_with_noise,
    constant_args=constant_args,
    per_particle_args=(keys_noise, potential_ids),
    is_jittable=True,
    batch_size_per_mrc=10,
    overwrite=True,
    compression=None,
)

# Computing a likelihood Matrix

Now we have a heterogeneous dataset. Let's say we have a new ensemble (I'll use the true one for simplicity), we want to generate the likelihood between each member of the ensemble and each image. This will give us a likelihood matrix, which can be used for ensemble reweighting among other things.

In [32]:
# First load the data
particle_reader = RelionParticleStackDataset(parameter_dataset)

In [33]:
type(particle_reader[0])

cryojax.data._particle_data.ParticleStack

## Setting up a dataloader

Normally, you'll have thousands of images, so loading them all into memory at once is not a good idea. CryoJAX is very flexible, and allows us to use external dataloaders. Here I will use the dataloader implemented in: https://github.com/BirkhoffG/jax-dataloader

*You will need to install the `jax_dataloader` library above to continue with the rest of the tutorial.*

In [34]:
import jax_dataloader as jdl

from cryojax.data import ParticleStack


class CustomJaxDataset(jdl.Dataset):
    def __init__(self, cryojax_dataset: RelionParticleStackDataset):
        self.cryojax_dataset = cryojax_dataset

    def __getitem__(self, index) -> ParticleStack:
        return self.cryojax_dataset[index]

    def __len__(self) -> int:
        return len(self.cryojax_dataset)

In [35]:
dataloader = jdl.DataLoader(
    CustomJaxDataset(
        particle_reader
    ),  # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend="jax",  # Use 'jax' backend for loading data
    batch_size=5,  # Batch size
    shuffle=False,  # Shuffle the dataloader every iteration or not
    drop_last=False,  # Drop the last batch or not
)

# Computing the likelihood

Here we show several ways to compute the likelihood. I will show how to compute it using vmapping, but also jax.lax.map, which is usually more memory friendly. I will also show how to compute the likelihood from a stack of atom_positions, which will be useful for computing gradients for atomic structures.

In all cases we will vmap first over images and then over structures/potentials. This is because computing quantities this way is faster. Think about it this way, it is much more easier to grab one potential and compute all the images required, than to compute a potential for every image.

In [36]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(None, eqx.if_array(0), None))
def compute_likelihood(
    potential_id,
    particle_dataset: ParticleStack,
    args: Any,
) -> Float:
    potentials, potential_integrator, variance = args
    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        particle_dataset.parameters.pose,
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        particle_dataset.parameters.transfer_theory,
    )
    imaging_pipeline = cxs.ContrastImageModel(
        particle_dataset.parameters.instrument_config, scattering_theory
    )
    distribution = dist.IndependentGaussianPixels(
        imaging_pipeline,
        variance=variance,
    )
    return distribution.log_likelihood(particle_dataset.images)

## Computing with equinox.filter_vmap

This is the simplest way to compute the likelihood matrix. Simply set a batch_size in the dataloader such that you don't get memory errors.

In [37]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(0, None, None))
def compute_likelihood_batch(potential_id, relion_particle_stack, args):
    return compute_likelihood(potential_id, relion_particle_stack, args)


def compute_likelihood_matrix(dataloader, args):
    n_potentials = len(args[0])
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = compute_likelihood_batch(
            jnp.arange(n_potentials), batch, args
        ).T
        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

In [38]:
likelihood_matrix = compute_likelihood_matrix(
    dataloader, args=(potentials, potential_integrator, noise_variance)
)

In [39]:
# Let's compute the populations
# They should be around 0.7 and 0.3 (might not be true at low snr)

(
    jnp.sum(jnp.argmin(likelihood_matrix, axis=1) == 0),
    jnp.sum(jnp.argmin(likelihood_matrix, axis=1) == 1),
)

(Array(65, dtype=int32), Array(35, dtype=int32))

## Computing with jax.lax.map

Here we need to use equinox partition, as jax.lax.map does not have utilities such as eqx.if_array (see how we vmapped in the previous example). The filtering is very simple, we just need to get rid of all leaves that are not arrays.

In [40]:
@eqx.filter_jit
def compute_single_likelihood(
    potential_id,
    particle_stack_map: ParticleStack,
    particle_stack_nomap: ParticleStack,
    args: Any,
) -> Float:
    particle_stack = eqx.combine(particle_stack_map, particle_stack_nomap)
    potentials, potential_integrator, variance = args
    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        particle_stack.parameters.pose,
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        particle_stack.parameters.transfer_theory,
    )
    imaging_pipeline = cxs.ContrastImageModel(
        particle_stack.parameters.instrument_config, scattering_theory
    )
    distribution = dist.IndependentGaussianPixels(
        imaging_pipeline,
        variance=variance,
    )
    return distribution.log_likelihood(particle_stack.images)


@eqx.filter_jit
def compute_likelihood_with_map(potential_id, particle_stack, args, *, batch_size_images):
    """
    Computes one row of the likelihood matrix (all structures, one image)
    """

    stack_map, stack_nomap = eqx.partition(particle_stack, eqx.is_array)

    likelihood_batch = jax.lax.map(
        lambda x: compute_single_likelihood(potential_id, x, stack_nomap, args),
        xs=stack_map,
        batch_size=batch_size_images,  # compute for this many images in parallel
    )
    return likelihood_batch


def compute_likelihood_matrix_with_lax_map(
    dataloader, args, *, batch_size_potentials=None, batch_size_images=None
):
    n_potentials = len(args[0])
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = jax.lax.map(
            lambda x: compute_likelihood_with_map(
                x, batch, args, batch_size_images=batch_size_images
            ),
            xs=jnp.arange(n_potentials),
            batch_size=batch_size_potentials,  # potentials to compute in parallel
        ).T
        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

This will take longer, but uses less memory. Play around with the batch sizes. Batch size potentials controls how many potentials are used in a single vmap operation. Batch size images controls how many images are used in a single vmap operation. Atomic potentials are cheap when it comes to memory, but they are slower when comparing against many images. Voxel potentials are more memory expensive, but it's vary fast to compare them against many images.

In [None]:
likelihood_matrix_lax_map = compute_likelihood_matrix_with_lax_map(
    dataloader,
    args=(potentials, potential_integrator, noise_variance),
    batch_size_potentials=2,
    batch_size_images=5,
)

In [None]:
jnp.allclose(likelihood_matrix_lax_map, likelihood_matrix)

Array(True, dtype=bool)

## Computing likelihood matrix from multiple atomic positions

For this approach we need a little trick to be able to jit the generation of the peng atomic potential. Here we will not convert the atomic potential to a voxel potential, as the objective of this tutorial is to be able to create a loss function that allows for the computation of gradients for the atomic positions

In [None]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(None, eqx.if_array(0), None))
def compute_likelihood_atomic(
    potential, particle_stack: ParticleStack, variance
) -> Float:
    structural_ensemble = cxs.SingleStructureEnsemble(
        potential,
        particle_stack.parameters.pose,
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        cxs.GaussianMixtureProjection(use_error_functions=True),
        particle_stack.parameters.transfer_theory,
    )
    imaging_pipeline = cxs.ContrastImageModel(
        particle_stack.parameters.instrument_config, scattering_theory
    )
    distribution = dist.IndependentGaussianPixels(
        imaging_pipeline,
        variance=variance,
    )
    return distribution.log_likelihood(particle_stack.images)


@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(0, None, None))
def compute_likelihood_build_potential(atom_positions, particle_stack, args):
    b_factors, parameter_table, variance = args
    atom_potential = cxs.PengAtomicPotential(
        atom_positions,
        scattering_factor_a=parameter_table["a"],
        scattering_factor_b=parameter_table["b"],
        b_factors=b_factors,
    )
    return compute_likelihood_atomic(atom_potential, particle_stack, variance)


def compute_likelihood_matrix_from_atoms(atom_positions, dataloader, args):
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = compute_likelihood_build_potential(
            atom_positions, batch, args
        ).T  # we want something with shape (n_images, n_atom_positions)

        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

### WARNING
Here I am assuming that all atomic structures have the same set of atoms. Generalizing is not difficult, you just need to be careful about how you handle the atom_identities and the b_factors.

In [None]:
import numpy as np


filenames = ["./data/groel_chainA.pdb", "./data/groel_chainA_holo.pdb"]

box_size = parameter_dataset[0].instrument_config.shape[0]
voxel_size = parameter_dataset[0].instrument_config.pixel_size

single_atom_positions, atom_identities, b_factors = read_atoms_from_pdb(
    filenames[0], center=True, select="not element H", loads_b_factors=True
)

atom_positions = np.zeros((len(filenames), *single_atom_positions.shape))
atom_positions[0] = single_atom_positions

for i, filename in enumerate(filenames[1:]):
    # Load the atomic structure and transform into a potential
    atom_positions[i + 1] = read_atoms_from_pdb(
        filename, center=True, select="not element H", loads_b_factors=False
    )[0]  # we are only interested in the positions, the resto does not change

atom_positions = jnp.array(atom_positions)

As with earlier, we need to load the parameter table for the Peng Atomic Potential, otherwise generating it is not jittable (this pre-loads the atomic scattering factors)

In [None]:
parameter_table = get_tabulated_scattering_factor_parameters(atom_identities)

In [None]:
args = (b_factors, parameter_table, noise_variance)

likelihood_matrix_atoms = compute_likelihood_matrix_from_atoms(
    atom_positions, dataloader, args
)

In [None]:
(
    jnp.sum(jnp.argmin(likelihood_matrix_atoms, axis=1) == 0),
    jnp.sum(jnp.argmin(likelihood_matrix_atoms, axis=1) == 1),
)

(Array(65, dtype=int32), Array(35, dtype=int32))