# Interactive Complex Probes

## Source Code

In [7]:
import sys
if "pyodide" in sys.modules:
  import micropip
  await micropip.install('ipywidgets')
  await micropip.install('ipympl')

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from colorspacious import cspace_convert
from ipywidgets import widgets, IntSlider, VBox, Layout

In [2]:
# Complex Probes Utilities
def energy2wavelength(energy):
    """ """
    hplanck = 6.62607e-34
    c = 299792458.0
    me = 9.1093856e-31
    e = 1.6021766208e-19

    return hplanck * c / np.sqrt(energy * (2 * me * c**2 / e + energy)) / e * 1.0e10


class ComplexProbe:
    """ """

    # fmt: off
    _polar_symbols = (
        "C10", "C12", "phi12",
        "C21", "phi21", "C23", "phi23",
        "C30", "C32", "phi32", "C34", "phi34",
        "C41", "phi41", "C43", "phi43", "C45", "phi45",
        "C50", "C52", "phi52", "C54", "phi54", "C56", "phi56",
    )

    _polar_aliases = {
        "defocus": "C10", "astigmatism": "C12", "astigmatism_angle": "phi12",
        "coma": "C21", "coma_angle": "phi21",
        "Cs": "C30",
        "C5": "C50",
    }
    # fmt: on

    def __init__(
        self,
        energy,
        gpts,
        sampling,
        semiangle_cutoff,
        soft_aperture=True,
        parameters={},
        **kwargs,
    ):
        self._energy = energy
        self._gpts = gpts
        self._sampling = sampling
        self._semiangle_cutoff = semiangle_cutoff
        self._soft_aperture = soft_aperture

        self._parameters = dict(
            zip(self._polar_symbols, [0.0] * len(self._polar_symbols))
        )
        parameters.update(kwargs)
        self.set_parameters(parameters)
        self._wavelength = energy2wavelength(self._energy)

    def set_parameters(self, parameters):
        """ """
        for symbol, value in parameters.items():
            if symbol in self._parameters.keys():
                self._parameters[symbol] = value

            elif symbol == "defocus":
                self._parameters[self._polar_aliases[symbol]] = -value

            elif symbol in self._polar_aliases.keys():
                self._parameters[self._polar_aliases[symbol]] = value

            else:
                raise ValueError("{} not a recognized parameter".format(symbol))

        return parameters

    def get_spatial_frequencies(self):
        return tuple(np.fft.fftfreq(n, d) for n, d in zip(self._gpts, self._sampling))

    def get_scattering_angles(self):
        kx, ky = self.get_spatial_frequencies()
        kx, ky = kx * self._wavelength, ky * self._wavelength
        alpha = np.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2)
        phi = np.arctan2(ky[None, :], kx[:, None])
        return alpha, phi

    def hard_aperture(self, alpha, semiangle_cutoff):
        return alpha <= semiangle_cutoff

    def soft_aperture(self, alpha, semiangle_cutoff, angular_sampling):
        denominator = (
            np.sqrt(angular_sampling[0] ** 2 + angular_sampling[1] ** 2) * 1e-3
        )
        return np.clip((semiangle_cutoff - alpha) / denominator + 0.5, 0, 1)

    def evaluate_aperture(self, alpha, phi):
        if self._soft_aperture:
            return self.soft_aperture(
                alpha, self._semiangle_cutoff * 1e-3, self.angular_sampling
            )
        else:
            return self.hard_aperture(alpha, self._semiangle_cutoff * 1e-3)

    def evaluate_chi(self, alpha, phi):
        p = self._parameters

        alpha2 = alpha**2

        array = np.zeros_like(alpha)
        if any([p[symbol] != 0.0 for symbol in ("C10", "C12", "phi12")]):
            array += (
                1 / 2 * alpha2 * (p["C10"] + p["C12"] * np.cos(2 * (phi - p["phi12"])))
            )

        if any([p[symbol] != 0.0 for symbol in ("C21", "phi21", "C23", "phi23")]):
            array += (
                1
                / 3
                * alpha2
                * alpha
                * (
                    p["C21"] * np.cos(phi - p["phi21"])
                    + p["C23"] * np.cos(3 * (phi - p["phi23"]))
                )
            )

        if any(
            [p[symbol] != 0.0 for symbol in ("C30", "C32", "phi32", "C34", "phi34")]
        ):
            array += (
                1
                / 4
                * alpha2**2
                * (
                    p["C30"]
                    + p["C32"] * np.cos(2 * (phi - p["phi32"]))
                    + p["C34"] * np.cos(4 * (phi - p["phi34"]))
                )
            )

        if any(
            [
                p[symbol] != 0.0
                for symbol in ("C41", "phi41", "C43", "phi43", "C45", "phi41")
            ]
        ):
            array += (
                1
                / 5
                * alpha2**2
                * alpha
                * (
                    p["C41"] * np.cos((phi - p["phi41"]))
                    + p["C43"] * np.cos(3 * (phi - p["phi43"]))
                    + p["C45"] * np.cos(5 * (phi - p["phi45"]))
                )
            )

        if any(
            [
                p[symbol] != 0.0
                for symbol in ("C50", "C52", "phi52", "C54", "phi54", "C56", "phi56")
            ]
        ):
            array += (
                1
                / 6
                * alpha2**3
                * (
                    p["C50"]
                    + p["C52"] * np.cos(2 * (phi - p["phi52"]))
                    + p["C54"] * np.cos(4 * (phi - p["phi54"]))
                    + p["C56"] * np.cos(6 * (phi - p["phi56"]))
                )
            )

        array = 2 * np.pi / self._wavelength * array
        return array

    def evaluate_aberrations(self, alpha, phi):
        return np.exp(-1.0j * self.evaluate_chi(alpha, phi))

    def evaluate_ctf(self):
        alpha, phi = self.get_scattering_angles()
        array = self.evaluate_aberrations(alpha, phi)
        array *= self.evaluate_aperture(alpha, phi)
        return array

    def build(self):
        self._array_fourier = self.evaluate_ctf()
        array = np.fft.ifft2(self._array_fourier)
        array /= np.sqrt(np.sum(np.abs(array) ** 2))
        self._array = array
        return self

    @property
    def reciprocal_space_sampling(self):
        return tuple(1 / (n * s) for n, s in zip(self._gpts, self._sampling))

    @property
    def angular_sampling(self):
        return tuple(
            dk * self._wavelength * 1e3 for dk in self.reciprocal_space_sampling
        )

In [3]:
# Complex Plotting Utilities
def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1):
    """ """
    amp = np.abs(complex_data)
    phase = np.angle(complex_data)

    if power is not None:
        amp = amp**power

    if np.isclose(np.max(amp), np.min(amp)):
        if vmin is None:
            vmin = 0
        if vmax is None:
            vmax = np.max(amp)
    else:
        if vmin is None:
            vmin = 0.02
        if vmax is None:
            vmax = 0.98
        vals = np.sort(amp[~np.isnan(amp)])
        ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
        ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
        ind_vmin = np.max([0, ind_vmin])
        ind_vmax = np.min([len(vals) - 1, ind_vmax])
        vmin = vals[ind_vmin]
        vmax = vals[ind_vmax]

    amp = np.where(amp < vmin, vmin, amp)
    amp = np.where(amp > vmax, vmax, amp)
    amp = ((amp - vmin) / vmax).clip(1e-16, 1)

    J = amp * 61.5  # Note we restrict luminance to the monotonic chroma cutoff
    C = np.minimum(chroma_boost * 98 * J / 123, 110)
    h = np.rad2deg(phase) + 180

    JCh = np.stack((J, C, h), axis=-1)
    rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1)

    return rgb


def show_complex(
    complex_data, figax=None, vmin=None, vmax=None, power=None, ticks=True, chroma_boost=1, **kwargs
):
    """ """
    rgb = Complex2RGB(complex_data, vmin, vmax, power=power, chroma_boost=chroma_boost)

    figsize = kwargs.pop("figsize", (6, 6))
    if figax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        figa,ax = figax

    im = ax.imshow(rgb, **kwargs)
    if ticks is False:
        ax.set_xticks([])
        ax.set_yticks([])
    return ax, im 

def build_probes(semiangle_cutoff,defocus,astigmatism,astigmatism_angle):
    
    probe=ComplexProbe(
        energy=300e3,
        gpts=(128,128),
        sampling=(0.1,0.1),
        semiangle_cutoff=semiangle_cutoff,
        defocus=defocus,
        astigmatism=astigmatism,
        astigmatism_angle=np.deg2rad(astigmatism_angle),
    ).build()
    return np.fft.fftshift(probe._array), np.fft.fftshift(probe._array_fourier)

plt.close('all')
plt.ioff()
fig,(ax_real,ax_fourier) = plt.subplots(1,2, figsize=(7,3.5))

probe_real, probe_fourier = build_probes(20,100,0,45)
ax_real, im_real = show_complex(
    probe_real,
    figax=(fig,ax_real),
    ticks=False,
    figsize=(3.5,3.5)
)

ax_fourier, im_fourier = show_complex(
    probe_fourier,
    figax=(fig,ax_fourier),
    ticks=False,
    figsize=(3.5,3.5)
)
ax_real.set_title("Real-space complex probe")
ax_fourier.set_title("Reciprocal-space complex probe")
fig.tight_layout()
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.resizable = False

In [4]:
def update_probes(semiangle_cutoff,defocus,astigmatism,astigmatism_angle):
    
    probe_real, probe_fourier = build_probes(
        semiangle_cutoff,
        defocus,
        astigmatism,
        astigmatism_angle,
    )
    rgb_real = Complex2RGB(probe_real)
    rgb_fourier = Complex2RGB(probe_fourier)
    im_real.set_data(rgb_real)
    im_fourier.set_data(rgb_fourier)
    fig.canvas.draw_idle()

    return None
    
semiangle_cutoff=IntSlider(20,15,40,description='CSA [mrad]')
defocus=IntSlider(100,-200,200,description=r'$C_1$ [Å]')
astigmatism=IntSlider(0,0,200, description=r'$C_{12}$ [Å]')
astigmatism_angle=IntSlider(45,-180,180, description=r'$\phi_{12}$ [°]')

controls = widgets.interactive(
    update_probes, 
    **{
        'semiangle_cutoff':semiangle_cutoff,
        'defocus':defocus,
        'astigmatism':astigmatism,
        'astigmatism_angle':astigmatism_angle,
    },
)

## Visualization

In [8]:
#| label: app:complex_probes
VBox(
    [
        controls,
        fig.canvas
    ],
    layout= Layout(display='flex',align_items='center')
)

VBox(children=(interactive(children=(IntSlider(value=20, description='CSA [mrad]', max=40, min=15), IntSlider(…