In [1]:
# Jax and Equinox imports
from functools import partial
from typing import Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray

In [2]:
# Plotting imports and functions
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax

In [None]:
# CryoJAX imports

import cryojax.image.transform as tf
import cryojax.simulator as cxs
from cryojax.data import (
    RelionParticleParameterFile,
    RelionParticleStackDataset,
    simulate_particle_stack,
)
from cryojax.io import read_atoms_from_pdb
from cryojax.rotations import SO3

  from .autonotebook import tqdm as notebook_tqdm


# Simulating Ensembles and doing Ensemble Reweighting

In this tutorial we will generate a heterogeneous dataset by defining a distribution on multiple atomic structures. We will then compute a likelihood matrix
$$ P_{nm} = p(y_n | x_m) $$

where $y_n$ is a data point and $x_m$ is a structure in the ensemble. We will define the likelihood through one of cryoJAX's distributions, although in principle any distribution works.

# Generate a starfile

First, we will just follow the tutorial `simulate-relion-dataset.ipynb` to generate a starfile. No ensemble stuff yet.

In [4]:
@partial(eqx.filter_vmap, in_axes=(0, None))
def make_particle_parameters(
    key: PRNGKeyArray, instrument_config: cxs.InstrumentConfig
):  # -> tuple[RelionParticleParameters, RelionParticleParameters]:
    # Generate random parameters

    # Pose
    # ... instantiate rotations
    key, subkey = jax.random.split(key)  # split the key to use for the next random number
    rotation = SO3.sample_uniform(subkey)

    # ... now in-plane translation
    ny, nx = instrument_config.shape

    key, subkey = jax.random.split(key)  # do this everytime you use a key!!
    offset_in_angstroms = (
        jax.random.uniform(subkey, (2,), minval=-0.2, maxval=0.2)
        * jnp.asarray((nx, ny))
        * instrument_config.pixel_size
    )
    # ... build the pose
    pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)

    # CTF Parameters
    # ... defocus
    key, subkey = jax.random.split(key)
    defocus_in_angstroms = jax.random.uniform(subkey, (), minval=10000, maxval=15000)

    key, subkey = jax.random.split(key)
    astigmatism_in_angstroms = jax.random.uniform(subkey, (), minval=0, maxval=100)

    key, subkey = jax.random.split(key)
    astigmatism_angle = jax.random.uniform(subkey, (), minval=0, maxval=jnp.pi)

    key, subkey = jax.random.split(key)
    phase_shift = jax.random.uniform(subkey, (), minval=0, maxval=0)
    # no more random numbers needed

    # now generate your non-random values
    spherical_aberration_in_mm = 2.7
    amplitude_contrast_ratio = 0.1

    # ... build the CTF
    transfer_theory = cxs.ContrastTransferTheory(
        ctf=cxs.CTF(
            defocus_in_angstroms=defocus_in_angstroms,
            astigmatism_in_angstroms=astigmatism_in_angstroms,
            astigmatism_angle=astigmatism_angle,
            spherical_aberration_in_mm=spherical_aberration_in_mm,
        ),
        amplitude_contrast_ratio=amplitude_contrast_ratio,
        phase_shift=phase_shift,
    )

    particle_parameters = {
        "instrument_config": instrument_config,
        "pose": pose,
        "transfer_theory": transfer_theory,
        "metadata": {},
    }

    return particle_parameters

In [5]:
# Generate instrument config
instrument_config = cxs.InstrumentConfig(
    shape=(128, 128),
    pixel_size=1.5,
    voltage_in_kilovolts=300.0,
    pad_scale=1.0,  # no padding
)

# Generate RNG keys
number_of_images = 100
keys = jax.random.split(jax.random.key(0), number_of_images)

# ... instantiate the RelionParticleDataset
particle_parameters = make_particle_parameters(keys, instrument_config)

In [6]:
# ... generate the starfile
new_parameters_file = RelionParticleParameterFile(
    path_to_starfile="./outputs/heterogeneous_relion_dataset.star",
    mode="w",  # writing mode!
    exists_ok=True,  # in case the file already exists
)
new_parameters_file.append(particle_parameters)
new_parameters_file.save(overwrite=True)

# Simulating images by choosing a random structure

In [7]:
# First load the starfile

path_to_mrc_files = "./outputs/relion_dataset_particles/heterogeneous"

particle_dataset = RelionParticleStackDataset(
    new_parameters_file,
    path_to_relion_project=path_to_mrc_files,
    mode="w",
    mrcfile_settings={"overwrite": True},  # customize your .mrcs !
)

In [8]:
from cryojax.constants import get_tabulated_scattering_factor_parameters


filenames = ["./data/groel_chainA.pdb", "./data/groel_chainA_holo.pdb"]

box_size = new_parameters_file[0]["instrument_config"].shape[0]

potentials = []
voxel_size = new_parameters_file[0]["instrument_config"].pixel_size
for filename in filenames:
    # Load the atomic structure and transform into a potential
    atom_positions, atom_identities, bfactors = read_atoms_from_pdb(
        filename, center=True, select="not element H", loads_b_factors=True
    )
    scattering_factor_parameters = get_tabulated_scattering_factor_parameters(
        atom_identities
    )
    atomic_potential = cxs.PengAtomicPotential(
        atom_positions,
        scattering_factor_a=scattering_factor_parameters["a"],
        scattering_factor_b=scattering_factor_parameters["b"],
        b_factors=bfactors,
    )
    # Convert to a real voxel grid
    # This step is optional, you could use the atomic potential directly!
    real_voxel_grid = atomic_potential.as_real_voxel_grid(
        shape=(box_size, box_size, box_size), voxel_size=voxel_size
    )
    potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
        real_voxel_grid, voxel_size, pad_scale=2
    )
    potentials.append(potential)

potentials = tuple(potentials)
potential_integrator = cxs.FourierSliceExtraction()

# Use this if using an atomic potential
# potential_integrator = cxs.GaussianMixtureProjection()

!!! info 
See our tutorial on simulating simple datasets for more details for how to generate an dataset with noisy images

In [9]:
from cryojax.inference.distributions import IndependentGaussianPixels


def compute_image(parameters, constant_args, per_particle_args):
    potentials, potential_integrator, mask, snr = constant_args
    noise_key, potential_id = per_particle_args  # jax random stuff

    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        parameters["pose"],
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble, potential_integrator, parameters["transfer_theory"]
    )

    image_model = cxs.ContrastImageModel(
        parameters["instrument_config"], scattering_theory, mask=mask
    )

    distribution = IndependentGaussianPixels(
        image_model,
        variance=1.0,
        signal_scale_factor=jnp.sqrt(snr),
        normalizes_signal=True,
    )

    return distribution.sample(noise_key, applies_mask=False)

## Simulating the images

In [None]:
snr = 0.1  # define whatever snr you want
mask = tf.CircularCosineMask(
    coordinate_grid=instrument_config.coordinate_grid_in_pixels,
    radius=instrument_config.shape[0] // 2,
    rolloff_width=0.0,
)

constant_args = (potentials, potential_integrator, mask, snr)

# Generate RNG keys for per-image noise, and per-image conformations
keys_noise = jax.random.split(jax.random.key(0), number_of_images)
key_structure = jax.random.key(1)

# Generate the per-image conformation assignments
ensemble_weights = jnp.array([0.3, 0.7])  # weights for sampling structures
potential_ids = jnp.ones((number_of_images,), dtype=int)

# Exactly 30 will come from potential with id 0
potential_ids = potential_ids.at[0 : int(ensemble_weights[0] * number_of_images)].set(0)

simulate_particle_stack(
    particle_dataset,
    compute_image_fn=compute_image,
    constant_args=constant_args,
    per_particle_args=(keys_noise, potential_ids),
    batch_size=10,
    images_per_file=50,
    overwrite=True,
)

# Computing a likelihood Matrix

Now we have a heterogeneous dataset. Let's say we have a new ensemble (I'll use the true one for simplicity), we want to generate the likelihood between each member of the ensemble and each image. This will give us a likelihood matrix, which can be used for ensemble reweighting among other things.

## Setting up a dataloader

Normally, you'll have thousands of images, so loading them all into memory at once is not a good idea. CryoJAX is very flexible, and allows us to use external dataloaders. Here I will use the dataloader implemented in: https://github.com/BirkhoffG/jax-dataloader

In [None]:
# !pip install jax_dataloader (run this!)

In [11]:
import jax_dataloader as jdl


class CustomJaxDataset(jdl.Dataset):
    def __init__(self, cryojax_dataset: RelionParticleStackDataset):
        self.cryojax_dataset = cryojax_dataset

    def __getitem__(self, index):
        return self.cryojax_dataset[index]

    def __len__(self) -> int:
        return len(self.cryojax_dataset)

In [12]:
dataloader = jdl.DataLoader(
    CustomJaxDataset(
        particle_dataset
    ),  # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend="jax",  # Use 'jax' backend for loading data
    batch_size=20,  # Batch size
    shuffle=False,  # Shuffle the dataloader every iteration or not
    drop_last=False,  # Drop the last batch or not
)

# Computing the likelihood

Here we show several ways to compute the likelihood. I will show how to compute it using vmapping, but also jax.lax.map, which is usually more memory friendly. I will also show how to compute the likelihood from a stack of atom_positions, which will be useful for computing gradients for atomic structures.

In all cases we will vmap first over images and then over structures/potentials. This is because computing quantities this way is faster. Think about it this way, it is much more easier to grab one potential and compute all the images required, than to compute a potential for every image.

In [13]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(None, eqx.if_array(0), None))
def compute_likelihood(
    potential_id: int,
    particle_stack,
    args: Tuple[
        Tuple[cxs.AbstractPotentialRepresentation], cxs.AbstractPotentialIntegrator
    ],
) -> Float:
    potentials, potential_integrator = args
    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        particle_stack["parameters"]["pose"],
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        particle_stack["parameters"]["transfer_theory"],
    )
    image_model = cxs.ContrastImageModel(
        particle_stack["parameters"]["instrument_config"], scattering_theory
    )

    simulated_image = image_model.render()
    observed_image = particle_stack["images"]

    # This is to estimate the snr
    cc = jnp.mean(simulated_image**2)
    co = jnp.mean(observed_image * simulated_image)
    c = jnp.mean(simulated_image)
    o = jnp.mean(observed_image)

    scale = (co - c * o) / (cc - c**2)
    bias = o - scale * c

    # remember the noise variance is 1!!
    return -jnp.sum((observed_image - scale * simulated_image - bias) ** 2) / 2.0

## Computing with equinox.filter_vmap

This is the simplest way to compute the likelihood matrix. Simply set a batch_size in the dataloader such that you don't get memory errors.

In [14]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(0, None, None))
def compute_likelihood_batch(
    potential_id: int,
    relion_particle_stack,
    args: Tuple[
        Tuple[cxs.AbstractPotentialRepresentation], cxs.AbstractPotentialIntegrator
    ],
):
    return compute_likelihood(potential_id, relion_particle_stack, args)


def compute_likelihood_matrix(
    dataloader: jdl.DataLoader,
    args: Tuple[cxs.AbstractPotentialRepresentation, cxs.AbstractPotentialIntegrator],
) -> Float[Array, " n_images n_potentials"]:
    n_potentials = len(args[0])
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = compute_likelihood_batch(
            jnp.arange(n_potentials), batch, args
        ).T
        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

In [15]:
likelihood_matrix = compute_likelihood_matrix(
    dataloader, args=(potentials, potential_integrator)
)

In [16]:
# Let's compute the populations by checking which structure
# obtains the highest likelihood for each image
# They should be around 0.3 and 0.7 (this will not be true at low SNR)

print(f"Population for id 0: {jnp.sum(jnp.argmax(likelihood_matrix, axis=1) == 0)}")
print(f"Population for id 1: {jnp.sum(jnp.argmax(likelihood_matrix, axis=1) == 1)}")

Population for id 0: 30
Population for id 1: 70


## Computing with jax.lax.map

Here we need to use equinox partition, as jax.lax.map does not have utilities such as eqx.if_array (see how we vmapped in the previous example). The filtering is very simple, we just need to get rid of all leaves that are not arrays.

In [17]:
@eqx.filter_jit
def compute_single_likelihood(
    potential_id: int,
    particle_stack,
    args: Tuple[
        Tuple[cxs.AbstractPotentialRepresentation], cxs.AbstractPotentialIntegrator
    ],
) -> Float:
    potentials, potential_integrator = args
    structural_ensemble = cxs.DiscreteStructuralEnsemble(
        potentials,
        particle_stack["parameters"]["pose"],
        cxs.DiscreteConformationalVariable(potential_id),
    )

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        particle_stack["parameters"]["transfer_theory"],
    )
    image_model = cxs.ContrastImageModel(
        particle_stack["parameters"]["instrument_config"], scattering_theory
    )

    simulated_image = image_model.render()
    observed_image = particle_stack["images"]

    # This is to estimate the snr
    cc = jnp.mean(simulated_image**2)
    co = jnp.mean(observed_image * simulated_image)
    c = jnp.mean(simulated_image)
    o = jnp.mean(observed_image)

    scale = (co - c * o) / (cc - c**2)
    bias = o - scale * c

    # remember the noise variance is 1!!
    return -jnp.sum((observed_image - scale * simulated_image - bias) ** 2) / 2.0


@eqx.filter_jit
def compute_likelihood_with_map(
    potential_id: int,
    particle_stack,
    args: Tuple[
        Tuple[cxs.AbstractPotentialRepresentation], cxs.AbstractPotentialIntegrator
    ],
    *,
    batch_size_images: int,
) -> Float[Array, " n_structures"]:
    """
    Computes one row of the likelihood matrix (all structures, one image)
    """

    stack_map, stack_nomap = eqx.partition(particle_stack, eqx.is_array)

    likelihood_batch = jax.lax.map(
        lambda x: compute_single_likelihood(
            potential_id, eqx.combine(x, stack_nomap), args
        ),
        xs=stack_map,
        batch_size=batch_size_images,  # compute for this many images in parallel
    )
    return likelihood_batch


def compute_likelihood_matrix_with_lax_map(
    dataloader: jdl.DataLoader,
    args: Tuple[
        Tuple[cxs.AbstractPotentialRepresentation], cxs.AbstractPotentialIntegrator
    ],
    *,
    batch_size_potentials: int = None,
    batch_size_images: int = None,
) -> Float[Array, " n_images n_structures"]:
    n_potentials = len(args[0])
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = jax.lax.map(
            lambda x: compute_likelihood_with_map(
                x, batch, args, batch_size_images=batch_size_images
            ),
            xs=jnp.arange(n_potentials),
            batch_size=batch_size_potentials,  # potentials to compute in parallel
        ).T
        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

This will take longer, but uses less memory. Play around with the batch sizes. `batch_size_potentials` controls how many potentials are used in a single vmap operation. `batch_size_images` controls how many images are used in a single vmap operation. Atomic potentials are cheap when it comes to memory, but they are slower when comparing against many images. Voxel potentials are more memory expensive, but it's vary fast to compare them against many images. This might not be true if you need to compute gradients. Always profile your code!

In [18]:
likelihood_matrix_lax_map = compute_likelihood_matrix_with_lax_map(
    dataloader,
    args=(potentials, potential_integrator),
    batch_size_potentials=3,
    batch_size_images=20,
)

In [19]:
# We get the same result as before
jnp.allclose(likelihood_matrix_lax_map, likelihood_matrix)

Array(True, dtype=bool)

## Computing likelihood matrix from multiple atomic positions

Here we will not convert the atomic potential to a voxel potential, as the objective of this tutorial is to be able to create a loss function that allows for the computation of gradients for the atomic positions

In [20]:
@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(None, eqx.if_array(0)))
def compute_likelihood_atomic(
    potential: cxs.AbstractAtomicPotential, particle_stack
) -> Float:
    structural_ensemble = cxs.SingleStructureEnsemble(
        potential,
        particle_stack["parameters"]["pose"],
    )
    potential_integrator = cxs.GaussianMixtureProjection()

    scattering_theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble,
        potential_integrator,
        particle_stack["parameters"]["transfer_theory"],
    )
    image_model = cxs.ContrastImageModel(
        particle_stack["parameters"]["instrument_config"], scattering_theory
    )

    simulated_image = image_model.render()
    observed_image = particle_stack["images"]

    # This is to estimate the snr
    cc = jnp.mean(simulated_image**2)
    co = jnp.mean(observed_image * simulated_image)
    c = jnp.mean(simulated_image)
    o = jnp.mean(observed_image)

    scale = (co - c * o) / (cc - c**2)
    bias = o - scale * c

    # remember the noise variance is 1!!
    return -jnp.sum((observed_image - scale * simulated_image - bias) ** 2) / 2.0


@eqx.filter_jit
@partial(eqx.filter_vmap, in_axes=(0, None, None))
def compute_likelihood_build_potential(
    atom_positions: Float[Array, " n_atoms 3"],
    particle_stack,
    args: Tuple[
        Float[Array, " n_atoms"],
        dict[str, Float[Array, " n_atoms n_gaussians"]],
    ],
):
    b_factors, parameter_table = args
    atom_potential = cxs.PengAtomicPotential(
        atom_positions,
        scattering_factor_a=parameter_table["a"],
        scattering_factor_b=parameter_table["b"],
        b_factors=b_factors,
    )
    return compute_likelihood_atomic(atom_potential, particle_stack)


def compute_likelihood_matrix_from_atoms(
    batch_atom_positions: Float[Array, " n_structures n_atoms 3"],
    dataloader: jdl.DataLoader,
    args: Tuple[
        Float[Array, " n_atoms"],
        dict[str, Float[Array, " n_atoms n_gaussians"]],
    ],
) -> Float[Array, " n_images n_structures"]:
    likelihood_matrix = []
    for batch in dataloader:
        batch_likelihood = compute_likelihood_build_potential(
            batch_atom_positions, batch, args
        ).T  # we want something with shape (n_images, n_atom_positions)

        likelihood_matrix.append(batch_likelihood)
    likelihood_matrix = jnp.concatenate(likelihood_matrix, axis=0)
    return likelihood_matrix

### WARNING
Here I am assuming that all atomic structures have the same set of atoms. Generalizing is not difficult, you just need to be careful about how you handle the atom_identities and the b_factors.

In [21]:
import numpy as np


filenames = ["./data/groel_chainA.pdb", "./data/groel_chainA_holo.pdb"]

box_size = instrument_config.shape[0]
voxel_size = instrument_config.pixel_size

single_atom_positions, atom_identities, b_factors = read_atoms_from_pdb(
    filenames[0], center=True, select="not element H", loads_b_factors=True
)

# This is needed to define the Peng Atomic Potential
parameter_table = get_tabulated_scattering_factor_parameters(atom_identities)

batch_atom_positions = np.zeros((len(filenames), *single_atom_positions.shape))
batch_atom_positions[0] = single_atom_positions

for i, filename in enumerate(filenames[1:]):
    # Load the atomic structure and transform into a potential
    batch_atom_positions[i + 1] = read_atoms_from_pdb(
        filename, center=True, select="not element H", loads_b_factors=False
    )[0]  # we are only interested in the positions, the resto does not change

batch_atom_positions = jnp.array(batch_atom_positions)

In [22]:
args = (b_factors, parameter_table)

likelihood_matrix_atoms = compute_likelihood_matrix_from_atoms(
    batch_atom_positions, dataloader, args
)

In [23]:
print(f"Population for id 0: {jnp.sum(jnp.argmax(likelihood_matrix_atoms, axis=1) == 0)}")
print(f"Population for id 1: {jnp.sum(jnp.argmax(likelihood_matrix_atoms, axis=1) == 1)}")

Population for id 0: 30
Population for id 1: 70
