This tutorial demonstrates how to use `cryojax` to build a scattering potential from a PDB entry. Specifically, the tutorial will build a voxel grid, take some steps to validate it, then save the result to be used in subsequent tutorials.

The scattering potential will be built using the tabulation of atomic scattering factors from the work of Lian-Mao Peng, which fits the potential from single atoms to a sum of five gaussians. See the `cryojax` documentation [here](../api/simulator/volume.md#atom-based-volumes) for more information.

*References:*

- Peng, L-M., et al. "Robust parameterization of elastic and absorptive electron atomic scattering factors." Acta Crystallographica Section A: Foundations of Crystallography 52.2 (1996): 257-276.
- Himes, Benjamin, and Nikolaus Grigorieff. "Cryo-TEM simulations of amorphous radiation-sensitive samples using multislice wave propagation." IUCrJ 8.6 (2021): 943-953.

In [1]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax

First, load the atomic positions and identities from a PDB entry. Here, a structure of GroEL (PDB ID 5w0s) is used. This is loaded into a `GaussianMixtureVolume` object.

In [2]:
import cryojax.simulator as cxs
from cryojax.constants import PengScatteringFactorParameters
from cryojax.io import read_atoms_from_pdb


atom_positions, atomic_numbers = read_atoms_from_pdb(
    "./data/5w0s.pdb",
    center=True,
)
parameters = PengScatteringFactorParameters(atomic_numbers)
atom_volume = cxs.GaussianMixtureVolume.from_tabulated_parameters(
    atom_positions,
    parameters,
)
print(atom_volume)

GaussianMixtureVolume(
  positions=f32[54021,3], amplitudes=f32[54021,5], variances=f32[54021,5]
)


We see above that the `GaussianMixtureVolume` includes 

- The `positions` are the positions of the atoms in angstroms.
- The `amplitudes`, and `variances` are the parameters $a_i$ and $b_i$ from Peng et al. (1996) up to numerical constants (see documentation for details). 

Optionally, we can also load PDB B-factors into the `b_factors` field using

```python
from cryojax.io import read_atoms_from_pdb

atom_positions, atom_types, atom_properties = read_atoms_from_pdb(..., loads_properties=True)
b_factors = atom_properties["b_factors"]
```

Next, we can build the voxel grid representation of the potential.

In [3]:
import equinox as eqx
import jax


# Evaluate the potential on a voxel grid
shape = (240, 240, 240)
voxel_size = 1.0


@eqx.filter_jit
def compute_voxels(atom_volume: cxs.GaussianMixtureVolume) -> jax.Array:
    gaussian_render_fn = cxs.GaussianMixtureRenderFn(
        shape, voxel_size, batch_options=dict(batch_size=1)
    )
    return gaussian_render_fn(atom_volume)


real_voxel_grid = compute_voxels(atom_volume)

Now, downsample the voxel array to the desired voxel size. Because the potentials from individual atoms are short-ranged, finite sampling effects can be significant and it is best to first generate a potential at a smaller-than-desired voxel size. 

In [None]:
from cryojax.ndimage import block_reduce_downsample


ds_factor = 3
voxel_size_ds = ds_factor * voxel_size
voxel_grid_ds = block_reduce_downsample(
    real_voxel_grid,
    downsample_factor=ds_factor,
    operation=lambda x, y: (x + y) / ds_factor**3,
)

TypeError: reduce_window got inconsistent dtypes for operands and init_values: got operand dtypes [dtype('float32')] and init_value dtypes [dtype('int32')].

!!! info
    In, `cryojax`, potentials are built in units of *inverse length squared*,
    $[L]^{-2}$. This rescaled potential is defined to be

    $$U(x, y, z) = \frac{m_0 e}{2 \pi \hbar^2} V(x, y, z),$$

    where $V$ is the electrostatic potential energy, $(x, y, z)$ are positional
    coordinates, $m_0$ is the electron rest mass, and $e$ is the electron charge.

    In the following, we will compute projections of the potential, which we will
    define to be

    $$U_z(x, y) = \int_{-\infty}^z dz' \ U(x, y, z'),$$

    where in practice the integration domain is taken to be between $z'$-planes above and below where the potential has sufficiently decayed. In this tutorial, this integral is computed with fourier slice extraction.

Now, let's validate that what we see is reasonable. The first validation step is to compute a few different projections of the potential.

In [None]:
import equinox as eqx


@eqx.filter_jit
def compute_projection(potential, config):
    """Compute a projection of a voxel-based potential."""
    # ... initialize the integration method for the potential
    integrator = cxs.FourierSliceExtraction()
    # ... compute the integrated potential
    integrated_potential = integrator.integrate(
        potential, config, outputs_real_space=True
    )
    return integrated_potential


# Load the voxel grid into a voxel-based potential representation
voxel_potential = cxs.FourierVoxelGridVolume.from_real_voxel_grid(
    voxel_grid_ds,
)
# ... and the configuration of the imaging instrument
config = cxs.BasicImageConfig(
    shape=voxel_potential.shape[0:2],
    pixel_size=voxel_grid_ds,
    voltage_in_kilovolts=300.0,
)
# Now, compute the projection integral
integrated_potential = compute_projection(voxel_potential, config)
# ... and plot
fig, ax = plt.subplots(figsize=(3, 3))
plot_image(integrated_potential, fig, ax, label="Integrated potential, $U_z(x, y)$")

We can also inspect a different viewing angle by rotating the `voxel_potential` to a different pose. This involves instantiating a `cryojax` representation of a pose, which here is the `cryojax.simulator.EulerAnglePose`. The three euler angles in this object are:

- The first euler angle $\phi$, denoted `phi_angle`
- The second euler angle $\theta$, denoted `theta_angle`
- The third euler angle $\psi$, denoted `psi_angle`

The euler angle convention in `cryojax` is a zyz extrinsic rotation, which follows other standard cryo-EM software, such as RELION and cisTEM.

In [None]:
# Instantiate the pose and rotate the potential
# ... angles are in degrees
pose = cxs.EulerAnglePose(phi_angle=0.0, theta_angle=90.0, psi_angle=90.0)
rotated_voxel_potential = voxel_potential.rotate_to_pose(pose)
# ... again compute the projection integral
integrated_potential = compute_projection(rotated_voxel_potential, config)
# ... and again, plot
fig, ax = plt.subplots(figsize=(3, 3))
plot_image(integrated_potential, fig, ax, label="Integrated potential, $U_z(x, y)$")

Another good sanity check is to check that the potential is relatively weak compared to a typical incident electron beam energy in cryo-EM. For an electron beam with incident wavenumber $k$, this can be checked with the condition

$$4 \pi U / k^2 << 1,$$

where again $U = m_0 e V / 2 \pi \hbar^2$ is the rescaled potential. Below, we consider an incident energy of $300 \ \textrm{keV}$.

In [None]:
import numpy as np
from cryojax.constants import wavelength_from_kilovolts


# First compute the wavenumber
voltage_in_kilovolts = 300.0
wavelength_in_angstroms = wavelength_from_kilovolts(voltage_in_kilovolts)
wavenumber = 2 * np.pi / wavelength_in_angstroms
# ... now get the maximum value of the potential
potential_maximum = voxel_grid_ds.max()
# ... and compare
print(4 * np.pi * potential_maximum / wavenumber**2)

Looks reasonable! Finally, we can write the voxel grid to disk for later processing.

In [None]:
# from cryojax.io import write_volume_to_mrc


# write_volume_to_mrc(
#     downsampled_voxel_grid,
#     downsampled_voxel_size,
#     "./data/groel_5w0s_scattering_potential.mrc",
#     overwrite=True,
# )