In [None]:
# Jax imports
import jax.numpy as jnp
import numpy as np
from jax import config

config.update("jax_enable_x64", False)

In [None]:
# Plotting imports
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def plot_image(image, fig, ax, cmap="gray", **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    return fig, ax

def plot_images(images, cmap="gray", labels=None, **kwargs):
    ncols = len(images)
    fig, axes = plt.subplots(ncols=ncols, figsize=(ncols*5, 5))
    for idx in range(ncols):
        ax = axes[idx]
        image = images[idx]
        im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
        label = None if labels is None else labels[idx]
        ax.set(title=label)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(im, cax=cax)
    return fig, axes

def plot_profiles(bins, profiles, labels=None, **kwargs):
    nprofiles = len(profiles)
    fig, ax = plt.subplots()
    ax.set(**kwargs)
    for idx in range(nprofiles):
        profile = profiles[idx]
        label = None if labels is None else labels[idx]
        ax.plot(bins, profile, label=label)
    ax.legend(fontsize=12)
    return fig, ax

In [None]:
# cisTEM imports
from pycistem.core import CTF

In [None]:
# cryojax imports
from cryojax.simulator import CTFOptics
from cryojax.utils import make_frequencies, cartesian_to_polar, powerspectrum

In [None]:
# Parameters for the CTF
defocus1, defocus2, asti_angle, kV, cs, ac = 24000, 12000, 30.0, 300.0, 2.7, 0.07

In [None]:
# Frequency coordinates
shape = (512, 512)
pixel_size = 0.9
freqs = make_frequencies(shape, pixel_size=pixel_size)
k_sqr, theta = cartesian_to_polar(freqs, square=True)

In [None]:
# cryojax CTF and power spectrum
optics = CTFOptics(defocus_u=defocus1, defocus_v=defocus2,
                   defocus_angle=asti_angle, voltage=kV,
                   spherical_aberration=cs, amplitude_contrast=ac, envelope=None)
ctf = np.array(optics(freqs))

In [None]:
# cisTEM CTF
cisTEM_optics = CTF(kV=kV, cs=cs, ac=ac, defocus1=defocus1, defocus2=defocus2, astig_angle=asti_angle, pixel_size=pixel_size)
cisTEM_ctf = np.vectorize(lambda k_sqr, theta: cisTEM_optics.Evaluate(k_sqr, theta))(k_sqr.ravel()*pixel_size**2, theta.ravel()).reshape(shape)

In [None]:
# Plot CTFs
ctfs = [ctf, cisTEM_ctf]
labels = ["cryojax", "cisTEM"]
fig, axes = plot_images(ctfs, labels=labels)
plt.tight_layout()

In [None]:
# Compute cryojax and cisTEM power spectrum
spectrum1D, k_bins = powerspectrum(ctf, freqs, pixel_size=pixel_size)
cisTEM_spectrum1D, k_bins = powerspectrum(cisTEM_ctf, freqs, pixel_size=pixel_size)

In [None]:
# Plot profiles
profiles = [spectrum1D, cisTEM_spectrum1D]
fig, axes = plot_profiles(k_bins, profiles, labels=labels)

In [None]:
# Assert CTFs are the same
assert np.allclose(ctf, cisTEM_ctf, atol=5e-2)
assert np.allclose(spectrum1D, cisTEM_spectrum1D, atol=5e-3)