In [None]:
# Plotting imports and function definitions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax

In [None]:
import cryojax.simulator as cxs

In [None]:
from cryojax.constants import (
    get_tabulated_scattering_factor_parameters,
    peng_element_scattering_factor_parameter_table,
)
from cryojax.data import read_atoms_with_b_factors_from_pdb


atom_positions, atom_identities, atom_b_factors = read_atoms_with_b_factors_from_pdb(
    "./data/5w0s.pdb"
)
scattering_factor_a, scattering_factor_b = get_tabulated_scattering_factor_parameters(
    atom_identities - 1, peng_element_scattering_factor_parameter_table
)
atom_potential = cxs.PengTabulatedAtomicPotential(
    atom_positions,
    scattering_factor_a,
    scattering_factor_b,
    atom_b_factors=atom_b_factors,
)

In [None]:
from cryojax.coordinates import make_coordinate_grid
from cryojax.image import downsample_with_fourier_cropping


shape = (40, 40, 40)
voxel_size = 5.0
coordinate_grid_in_angstroms = make_coordinate_grid(shape, voxel_size)
real_voxel_grid = atom_potential.as_real_voxel_grid(
    coordinate_grid_in_angstroms, batch_size=100, progress_bar=True
)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 3))
labels = ["z projection", "y projection", "x projection"]
[
    plot_image(
        downsample_with_fourier_cropping(real_voxel_grid, 1.0).sum(axis=idx),
        fig,
        ax,
        label=labels[idx],
    )
    for idx, ax in enumerate(axes)
]
plt.tight_layout()