# CTF Simulation
This Jupyter Notebook provides an implementation of the contrast transfer function (CTF) in electron microscopy. It includes an interactive graphical user interface (GUI) that enables users to explore how various factors influence the CTF in both 1-D and 2-D. This approach offers insights into the impact of parameters such as defocus, aberrations, and envelope functions on the resulting CTF.

In [1]:
# import block
import math
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
%matplotlib widget
from IPython.display import display
from collections.abc import Callable
from numpy.typing import ArrayLike, NDArray

## 1. Wavelength of Electron Beam in an Electron Microscope
The wavelength of a moving particle is based on the de Broglie formula:

$$ \lambda = \frac{h}{p} $$

where $h$ is Planck's constant, and $p$ is the momentum of the particle.

The kinetic energy of an accelerated electron is $eV$, where $e$ is the elementary charge of an electron, and $V$ is the accelerating voltage. Converting kinetic energy to momentum and plugging it into the equation above:

$$ \lambda = \frac{h}{\sqrt{2m_eeV}} $$

where $m_e$ is the mass of a stationary electron.

For electrons in a modern electron microscope, the high accelerating voltage causes their speed to approach the speed of light, necessitating the consideration of relativistic effects. By applying the energy-momentum relation ([wiki](https://en.wikipedia.org/wiki/Energy–momentum_relation)):

$$ E^2 = (pc)^2 + (m_ec^2)^2 $$

where $c$ is the speed of light, $m_ec^2$ is known as the rest energy ($E_{rest}$), and $E$ is the relativistic energy.

Also, the relativistic energy is the sum of kinetic and rest energy:

$$ E = E_{kinetic} + E_{rest} = eV + m_ec^2$$

The above two equations allow us to calculate the momentum of the electron:

$$ p=\sqrt{\frac{\left(eV\right)^2}{c^2} + 2m_eeV} $$

Note that if the kinetic energy ($eV$) is negligible compared to $c$, the momentum reduces to the denominator in the second equation above.

Plugging this into the original de Broglie equation:

$$\lambda = \frac{h}{ \sqrt{2m_eeV \left( 1+\frac{eV}{2m_ec^2} \right) }}$$

Thus, the wavelength is shortened by an additional factor of $\sqrt{1+\frac{eV}{2m_ec^2}}$ compared to the second equation.


In [2]:
# Physical Constants
SPEED_OF_LIGHT = 299792458  # speed of light in m/s
ELEMENTARY_CHARGE = 1.60217663e-19  # electron charge in coulomb
ELECTRON_MASS = 9.1093837e-31  # electron mass in kg
PLANCK_CONSTANT = 6.62607015e-34  # Planck's constant in joule/hz

def electron_wavelength(voltage: float) -> float:
    """Calculate electron wavelength from input voltage

    Args:
        voltage (float): voltage of electron microscope, in kilo volts

    Returns:
        float: wavelength of electron beam, in angstrom
    """
    voltage_si = voltage * 1000.0
    return PLANCK_CONSTANT / math.sqrt(2 * ELECTRON_MASS * ELEMENTARY_CHARGE * (voltage_si) * (1 + ELEMENTARY_CHARGE * voltage_si 
                                  / (2 * ELECTRON_MASS * SPEED_OF_LIGHT ** 2))) * 1e10

# print(electron_wavelength(300))

## 2. Contrast Transfer Function (CTF)
The contrast transfer function (CTF) in electron microscopy describes how the microscope’s optical system modulates the contrast of spatial frequencies in an image. In its simplest form, it accounts for the phase shifts introduced by the objective lens due to defocus and spherical aberration. The CTF plays a critical role in single-particle analysis, as it introduces oscillations that enhance or suppress specific frequency components, requiring computational corrections during image processing.

The phase shifts introduced by the objective lens can be represented by a wave aberration function $\gamma$ as a function of spatial frequency $f$.  
For the 1-D case: 

$$ \gamma(f) = -\frac{\pi}{2} C_s \lambda^3 f^4 + \pi d_f \lambda f^2 + p $$

where:
- $d_f$ is the defocus,
- $C_s$ is the spherical aberration coefficient, and
- $p$ is the additional phase shift caused by a phase plate or other factors. 

Note that this notebook uses the convention that underfocus is positive and overfocus is negative.

For the 2-D case, $d_f$ is further parameterized by $d_u$ and $d_v$ in length, and $\phi_a$ in angle, to account for axial astigmatism:

$$ f = (f_x, f_y) $$
$$ \phi = \arctan\left(\frac{f_y}{f_x}\right) $$
$$ d_f = \frac{1}{2} \left( d_u + d_v + (d_u - d_v) \cos\left( 2\left( \phi - \phi_a \right)\right)\right) $$

Based on the weak-phase object approximation ([nice derivations by Fred Sigworth](https://cryoemprinciples.yale.edu/sites/default/files/files/2%20Phase%20contrast.pdf)), while taking into account the amplitude constrast due to inelastic scattering:

$$CTF(f) = -A_c \cos\left(\gamma(f)\right) - \sqrt{1 - A_c^2} \sin\left(\gamma(f)\right) = -\sin\left(\gamma(f) + \arcsin(A_c)\right)$$

where $A_c$ is the relative amplitude contrast.

Thus, it is apparent that the CTF is essentially an oscillating $\sin$ funciton.

In [3]:
def CTF(
        electron_wavelength: float, 
        spherical_aberration: float, 
        amplitude_contrast: float, 
        defocus: float | tuple[float, float, float], 
        phase_shift: float) -> Callable[[ArrayLike], NDArray]:
    """Setup 1D or 2D contrast transfer function from parameters

    Args:
        electron_wavelength (float): wavelength of electron beam, in angstrom
        spherical_aberration (float): spherical aberration constant, in millimeter
        amplitude_contrast (float): percentage of amplitude constrast, unitless
        defocus (float | tuple): defocus of the objective lens; for 1D case, expecting a float number, in angstrom; \
            for 2D case, expecting a tuple of three float numbers, defocus_1 in angstrom, defocus_2 in angstrom, and azimuthal angle in degree
        phase_shift (float): additional phase shift, in degree

    Returns:
        Callable[[ArrayLike], NDArray]: constrast transfer function taking an array of frequencies and return the values 

    Note:
        This function works for both 1D and 2D cases, with different defocus types. The return function takes either 1D frequencies or 2D frequencies.
    """
    spherical_aberration_angstrom = spherical_aberration * 1e7 # convert to angstrom
    cs_part = np.pi / 2 * spherical_aberration_angstrom * electron_wavelength ** 3
    amplitude_contrast_correction = math.asin(amplitude_contrast)
    if isinstance(defocus, tuple):  # 2D CTF
        defocus_u, defocus_v, defocus_a = defocus
        delta_df = defocus_u - defocus_v
        tilt_angle = lambda x, y: np.atan2(y, x) - math.radians(defocus_a)
        df = lambda x, y: 0.5 * (defocus_u + defocus_v + delta_df * np.cos(2 * tilt_angle(x, y)))
        freq = lambda x, y: x ** 2 + y ** 2
        phase = lambda x, y: np.pi * electron_wavelength * df(x, y) * freq(x, y) - cs_part * freq(x, y) ** 2
        return lambda x, y: -np.sin(phase(x, y) + amplitude_contrast_correction + math.radians(phase_shift))
    else:  # 1D CTF
        df_part = np.pi * defocus * electron_wavelength 
        phase = lambda x: df_part * x ** 2 -cs_part * x ** 4
        return lambda x: -np.sin(phase(x) + amplitude_contrast_correction + math.radians(phase_shift))    

## 3. Envelope Functions
The equations above assume completely coherent illumination with monochromatic electrons and perfect detection. In practice, the electron beam has finite divergence and energy spread, and the detector is not perfect. These factors dampen the CTF, especially at higher spatial frequencies and ultimately limit the resolution. The dampened CTF ($dCTF$) can be approximated by an envelope function ($E$):

$$dCTF(f) = CTF(f)\times E(f)$$

The envelope function can be further factored into several components, as shown below.  

---

### 3.1 Chromatic Aberration and Temporal Envelope
Chromatic aberration occurs because electrons with different energies are focused with varying strengths.

The resulting **_focus spread_** ($f_s$) depends on the energy spread of the electron source, the instabilities of the electron beam potential, and the instabilities of the objective lens current: 

$$f_s = C_c \sqrt{\left(\frac{\Delta V}{V}\right)^2 + \left(\frac{2 \Delta I}{I}\right)^2 + \left(\frac{\Delta E}{eV}\right)^2}$$

where:
- $C_c$ is the chromatic aberration coefficient, 
- $\frac{\Delta V}{V}$ represents the beam potential instability, 
- $\frac{\Delta I}{I}$ represents the objective lens current instability, and 
- $\Delta E$ is the energy spread of the electron source.

The temporal envelope function ($E_T$) resulting from chromatic aberration is given by:

$$E_T(f) = \exp\left(-\frac{\pi^2 \lambda^2 f_s^2 f^4}{2}\right)$$

where:
- $\lambda$ is the electron wavelength, 
- $f_s$ is the focus spread, and 
- $f$ is the spatial frequency.

In [4]:
def focus_spread(
        chromatic_aberration: float, 
        voltage_stability: float, 
        obj_lens_stability: float, 
        electron_source_spread: float, 
        voltage: float) -> float:
    """Calculate focus spread resulted from chromatic aberration

    Args:
        chromatic_aberration (float): chromatic aberration constant, in millimeter
        voltage_stability (float): deltaV/V in the equation
        obj_lens_stability (float): deltaI/I in the equation
        electron_source_spread (float): deltaE in the equation, in electron volts
        voltage (float): voltage of electron microscope, in kilo volts

    Returns:
        float: focus spread, in angstrom
    """
    voltage_si = voltage * 1000.0
    chromatic_aberration_angstrom = chromatic_aberration * 1e7
    return chromatic_aberration_angstrom * math.sqrt((voltage_stability) ** 2 + 4 * (obj_lens_stability) ** 2 
                                       + (electron_source_spread / voltage_si) ** 2)

def temporal_envelope_function(electron_wavelength: float, focus_spread: float) -> Callable[[ArrayLike], NDArray]:
    """Setup temporal envelope function from parameters

    Args:
        electron_wavelength (float): wavelength of electron beam, in angstrom
        focus_spread (float): focus spread due to chromatic aberration, in angstrom

    Returns:
        Callable[[ArrayLike], NDArray]: temporal envelope function taking an array of frequencies and return the values
    """
    return lambda x: np.exp(-0.5 * (np.pi * electron_wavelength * focus_spread) ** 2 * x ** 4)

# print(focus_spread(3.4, 0.5e-6, 1e-6, 0.3, 300))

### 3.2 Spatial Envelope
The envelope function $E_S$, arising from partially coherent illumination, depends on the electron source, as well as the spherical aberration and defocus introduced by objective lens. For an electron source with a Gaussian profile, it is given by:

$$E_S(f) = \exp\left(-\left(\frac{\pi e_a}{\lambda}\right)^2 \left(C_s \lambda^3 f^3 + d_f \lambda f\right)^2\right)$$

where:
- $e_a$ is the electron source angle, which is proportional to the size of the source as it appears in the back focal plane,
- $C_s$ is the spherical aberration coefficient, and
- $d_f$ is the defocus. 

Note that in the 2-D case, the spatial envelope may become anisotropic due to objective lens astigmatism. The implementation below returns different functions depending on the specific case. 

In [5]:
def spatial_envelope_function(
        electron_wavelength: float, 
        electron_source_angle: float, 
        spherical_aberration: float, 
        defocus: float | tuple[float, float, float]) -> Callable[[ArrayLike], NDArray]:
    """Setup spatial envelope function from parameters

    Args:
        electron_wavelength (float): wavelength of electron beam, in angstrom
        electron_source_angle (float): electron source angle, in radians
        spherical_aberration (float): spherical aberration constant, in millimeter
        defocus (float | | tuple[float, float, float]): defocus of the objective lens, in angstrom; \
            for 2D case, expecting a tuple of three float numbers, defocus_1 in angstrom, defocus_2 in angstrom, and azimuthal angle in degree

    Returns:
        Callable[[ArrayLike], NDArray]: spatial envelope function taking an array of frequencies and return the values
    
    Note:
        This function works for both 1D and 2D cases, with different defocus types. The return function takes either 1D frequencies or 2D frequencies.
    """
    spherical_aberration_angstrom = spherical_aberration * 1e7 # convert spherical aberration to angstrom
    if isinstance(defocus, tuple):  # 2D
        defocus_u, defocus_v, defocus_a = defocus
        delta_df = defocus_u - defocus_v
        tilt_angle = lambda x, y: np.atan2(y, x) - math.radians(defocus_a)
        df = lambda x, y: 0.5 * (defocus_u + defocus_v + delta_df * np.cos(2 * tilt_angle(x, y)))
        return lambda x, y: np.exp(-(np.pi * electron_source_angle / electron_wavelength) ** 2 
                                * (spherical_aberration_angstrom * electron_wavelength ** 3 
                                * np.sqrt(x ** 2 + y ** 2) ** 3 + df(x, y) * electron_wavelength * np.sqrt(x ** 2 + y ** 2)) ** 2)
    else:
        return lambda x: np.exp(-(np.pi * electron_source_angle / electron_wavelength) ** 2 
                                * (spherical_aberration_angstrom * electron_wavelength ** 3 
                                * x ** 3 + defocus * electron_wavelength * x) ** 2)


### 3.3 Detector Envelope
The detector envelope $E_D$ is estimated using the detective quantum efficiency ($DQE$) curve, which is a function of the fractions of nyquist frequency $N$. One way to approximate the DQE curve is to fit a polynomial that passes through measured detector DQE points. Using a cubic polynomial as an example:

$${DQE} \left(\frac{f}{N} \right) = a_3 \left(\frac{f}{N} \right)^3 + a_2 \left(\frac{f}{N} \right)^2 + a_1 \left(\frac{f}{N} \right) + a_0$$

where:
- $f$ is the actual spatial frequency, and
- $N$ is the Nyquist frequency, given by:

$$N = \frac{1}{2a_{pix}b}$$ 

Here $a_{pix}$ is the pixel size, and $b$ is the detector binning factor. 

Given the DQE envelope, the detector envelope is calculated as the ratio between the value of the DQE envelope at a given frequency to its maximum value over the range $f$:

$$E_D(f) = \frac{DQE \left(\frac{f}{N} \right)}{\max(DQE)}$$

Note that the detector DQE depends on many factors, including the accelerating voltage and electron dose. For simplicity, only a limited set of options is implemented in this notebook.    

In [6]:
def nyquist_frequency(pixel_size: float, binning_factor: float) -> float:
    """Calculate nyquist frequency based on pixel size and detector binning factor

    Args:
        pixel_size (float): pixel size, in angstrom
        binning_factor (float): detector binning factor, unitless

    Returns:
        float: nyquist frequency, in angstrom^-1
    """
    return 1 / (2.0 * pixel_size * binning_factor)

def DQE_function(DQE_X: list[float], DQE_Y: list[float]) -> Callable[[ArrayLike], NDArray]:
    """Estimate the detective quantum efficiency from measured data

    Args:
        DQE_X (list[float]): X coordinates from a measured DQE curve 
        DQE_Y (list[float]): matching Y coordinates from a measured DQE curve

    Returns:
        Callable[[ArrayLike], NDArray]: A function approximating the detective quantum efficiency curve of a detector
    
    Note:
        The returned function may generate unreal values due to the cubic polynomial if estimating a frequncy beyond the nyquist
    """
    if len(DQE_X) <= 3: # for fewer measuring points
        n = 2
    else:
        n = 3 # cubic is usually enough for DQE curve
    DQE_polynomial = np.polynomial.Polynomial(np.polyfit(DQE_X, DQE_Y, n)[::-1])
    return lambda x: np.maximum(DQE_polynomial(x), 0)

def DQE_envelope_function(nyquist: float, DQE_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup detector envelope function from parameters

    Args:
        nyquist (float): nyquist frequency, in angstrom^-1
        DQE_function (Callable[[ArrayLike], NDArray]): DQE function of a detector

    Returns:
        Callable[[ArrayLike], NDArray]: detector envelope function taking an array of frequencies and return the values
    """
    return lambda x: DQE_function(x / nyquist) / np.max(DQE_function(x / nyquist))

# Precalculated parameters
# Values are drawn from published curves online.  
# From Gatan
K3_DQE_X = [0, 0.5, 1]
K3_DQE_Y = [0.95, 0.71, 0.40]

# SO-163
FILM_DQE_X = [0, 0.25, 0.5, 0.75, 1]
FILM_DQE_Y = [0.37, 0.32, 0.33, 0.22, 0.07]

# TVIPS 224
CCD_DQE_X = [0, 0.25, 0.5, 0.75, 1]
CCD_DQE_Y = [0.37, 0.16, 0.13, 0.1, 0.05]

# precalculated DQE functions for three types of detectors
ddd_DQE_function = DQE_function(K3_DQE_X, K3_DQE_Y)
film_DQE_function = DQE_function(FILM_DQE_X, FILM_DQE_Y)
ccd_DQE_function = DQE_function(CCD_DQE_X, CCD_DQE_Y)

### 3.4 Total Envelope Function and Dampened CTF
The total envelope function, $E_{total}$, is given as the product of the temporal, spatial and detector envelopes:
$$E_{total}(f) = E_T(f) \times E_S(f) \times E_D(f)$$

Thus, the dampened CTF is:

$$dCTF(f) = CTF(f) \times E_{total}(f)$$

In [7]:
def total_envelope_function(
        temporal_envelope_function: Callable[[ArrayLike], NDArray], 
        spatial_envelope_function: Callable[[ArrayLike], NDArray], 
        DQE_envelope_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup total envelope function from parameters

    Args:
        temporal_envelope_function (Callable[[ArrayLike], NDArray]): temporal envelope function
        spatial_envelope_function (Callable[[ArrayLike], NDArray]): spatial envelope function
        DQE_envelope_function (Callable[[ArrayLike], NDArray]): DQE envelope function

    Returns:
        Callable[[ArrayLike], NDArray]: total envelope function taking an array of frequencies and return the values 
    """
    return lambda x: temporal_envelope_function(x) * spatial_envelope_function(x) * DQE_envelope_function(x)

def total_envelope_function_2D(
        temporal_envelope_function: Callable[[ArrayLike], NDArray], 
        spatial_envelope_function: Callable[[ArrayLike], NDArray], 
        DQE_envelope_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup total 2D envelope function from parameters

    Args:
        temporal_envelope_function (Callable[[ArrayLike], NDArray]): temporal envelope function
        spatial_envelope_function (Callable[[ArrayLike], NDArray]): spatial envelope function
        DQE_envelope_function (Callable[[ArrayLike], NDArray]): DQE envelope function

    Returns:
        Callable[[ArrayLike], NDArray]: total envelope function taking an array of frequencies and return the values 
    """
    return lambda x, y: temporal_envelope_function(np.sqrt(x ** 2 + y ** 2)) * spatial_envelope_function(x, y) * DQE_envelope_function(np.sqrt(x ** 2 + y ** 2))


def dampened_CTF(CTF: Callable[[ArrayLike], NDArray], envelope_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup dampened 1D CTF function from undampened CTF and envelope function

    Args:
        CTF (Callable[[ArrayLike], NDArray]): the contrast transfer function
        envelope_function (Callable[[ArrayLike], NDArray]): the envelope function

    Returns:
        Callable[[ArrayLike], NDArray]: dampened CTF taking an array of frequencies and return the values
    """
    return lambda x: CTF(x) * envelope_function(x)  

def dampened_CTF_2(CTF: Callable[[ArrayLike], NDArray], envelope_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup dampened 1D CTF function from undampened CTF and envelope function, return the square of CTF values

    Args:
        CTF (Callable[[ArrayLike], NDArray]): the contrast transfer function
        envelope_function (Callable[[ArrayLike], NDArray]): the envelope function

    Returns:
        Callable[[ArrayLike], NDArray]: dampened CTF^2 taking an array of frequencies and return the values
    """
    return lambda x: CTF(x) ** 2 * envelope_function(x)

def dampened_2D_CTF(CTF: Callable[[ArrayLike], NDArray], envelope_function: Callable[[ArrayLike], NDArray]) -> Callable[[ArrayLike], NDArray]:
    """Setup dampened 2D CTF function from undampened CTF and envelope function

    Args:
        CTF (Callable[[ArrayLike], NDArray]): the 2D contrast transfer function
        envelope_function (Callable[[ArrayLike], NDArray]): the 2D envelope function

    Returns:
        Callable[[ArrayLike], NDArray]: dampened 2D CTF taking an array of 2D frequencies and return the values
    """
    return lambda x, y: CTF(x, y) * envelope_function(x, y)


## 4. GUI Implementation
The code below implement a GUI to demonstrate how CTF is affected by various parameters.

In [8]:
class CTFSimGUI:
    """A GUI simulating the contrast transfer function of electron microscopy
    """
    def __init__(self, line_points=10000, image_size=400):
        """constructor

        Args:
            line_points (int, optional): number of sampling points in 1D plot. Defaults to 10000.
            image_size (int, optional): size of 2D plot in pixels. Defaults to 400.
        """
        self._freqs_1d = np.linspace(0.001, 1, line_points)
        freq_x = np.linspace(-0.5, 0.5, image_size)
        freq_y = np.linspace(-0.5, 0.5, image_size)
        self._fx, self._fy = np.meshgrid(freq_x, freq_y, sparse=True)

        self._setup_microscope_widgets(),
        self._setup_detector_widgets(),
        self._setup_imaging_widgets(),
        self._setup_plotting_widgets()

        self._setup_functions()
        self._setup_1D_plot()
        self._setup_2D_plot()
        self._setup_plot_tab()
        self._setup_reset_button()

        self.run_application()

    def generate_layout(self):
        """combine all the widgets defined in setup functions

        Returns:
            widgets.Widget: the GUI
        """
        return widgets.HBox([
            widgets.VBox([
                self._microscope_widgets,
                self._detector_widgets,
                self._imaging_widgets,
                self._plotting_widgets
            ], layout=widgets.Layout(width='50%')),
            widgets.VBox([
                self.plot_tab,
                self.reset_button
            ], layout=widgets.Layout(align_items='center'))
        ])
    
    def run_application(self):
        self._setup_event_handlers()
        display(self.generate_layout())

    # setup functions
    def _setup_microscope_widgets(self):
        self._voltage_slider = widgets.SelectionSlider(
            options=[80, 100, 120, 200, 300, 500, 1000],
            value=300.,
            # description='Voltage (KV): ',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.0f',
        )
        self._voltage_stability_slider = widgets.FloatLogSlider(
            value=3.3333e-8,
            base=10,
            min=-9,
            max=-4,
            step=0.5,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2e',
        )
        self._electron_source_angle_slider = widgets.FloatLogSlider(
            value=1.0e-4,
            base=10,
            min=-5,
            max=-2,
            step=0.2,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1e',
        )
        self._electron_source_spread_slider = widgets.FloatSlider(
            value=0.7,
            min=0,
            max=10,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self._chromatic_aberration_slider = widgets.FloatSlider(
            value=3.4,
            min=0,
            max=10,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self._spherical_aberration_slider = widgets.FloatSlider(
            value=2.7,
            min=0,
            max=10,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self._objective_lens_stability_slider = widgets.FloatLogSlider(
            value=1.6666e-8,
            base=10,
            min=-9,
            max=-4,
            step=0.5,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2e',
        )
        microscope_label_panel = widgets.VBox([
            widgets.Label('Voltage (KV): '),
            widgets.Label('Voltage stability: '),
            widgets.Label('Electron source angle (rad): '),
            widgets.Label('Electron source spread (eV): '),
            widgets.Label('Chromatic aberration (mm): '),
            widgets.Label('Spherical aberration (mm): '),
            widgets.Label('Objective lens stability: '),
        ], layout=widgets.Layout(width='40%'))
        microscope_widget_panel = widgets.VBox([
            self._voltage_slider,
            self._voltage_stability_slider,
            self._electron_source_angle_slider,
            self._electron_source_spread_slider,
            self._chromatic_aberration_slider,
            self._spherical_aberration_slider,
            self._objective_lens_stability_slider
        ])
        self._microscope_widgets = widgets.VBox([
            widgets.HTML('<b> Microscopy Parameters </b>'), 
            widgets.HBox([
                microscope_label_panel, 
                microscope_widget_panel], layout=widgets.Layout(border='solid'))])
        
    def _setup_detector_widgets(self):
        self._detector_dropdown = widgets.Dropdown(
            options=['DDD super resolution counting', 'DDD counting', 'Film', 'CCD'],
            value='DDD counting'
        )
        detector_label_panel = widgets.VBox([widgets.Label(value='Detector: ')], layout=widgets.Layout(width='40%'))
        detector_widget_panel = widgets.VBox([self._detector_dropdown])    
        self._detector_widgets = widgets.VBox([
            widgets.HTML('<b> Detector Parameters </b>'), 
            widgets.HBox([
                detector_label_panel, 
                detector_widget_panel], layout=widgets.Layout(border='solid'))])

    def _setup_imaging_widgets(self):
        self._pixel_size_slider = widgets.FloatSlider(
            value=1.0,
            min=0.2,
            max=5.,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self._amplitude_contrast_slider = widgets.FloatSlider(
            value=0.1,
            min=0,
            max=1,
            step=0.01,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
        )
        self._phase_shift_slider = widgets.FloatSlider(
            value=0.,
            min=0.,
            max=180.,
            step=1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.0f',
        )
        self._defocus_slider = widgets.FloatSlider(
            value=1.,
            min=-5.,
            max=5.,
            step=0.05,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
        )
        self._defocus_difference_slider = widgets.FloatSlider(
            value=0.,
            min=-2.5,
            max=2.5,
            step=0.05,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
        )
        self._defocus_azimuthal_slider = widgets.FloatSlider(
            value=0.,
            min=0.,
            max=180.,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        imaging_label_panel = widgets.VBox([
            widgets.Label('Pixel size (Å/pixel): '), 
            widgets.Label('Amplitude constrast: '),
            widgets.Label('Addtional phase shift (degree): '),
            widgets.Label('Defocus (um, + for underfocus): '),
            widgets.Label('Defocus difference (um, 2D): '),
            widgets.Label('Defocus azimuthal (degree, 2D): ')
        ], layout=widgets.Layout(width='40%'))
        imaging_widget_panel = widgets.VBox([
            self._pixel_size_slider, 
            self._amplitude_contrast_slider,
            self._phase_shift_slider,
            self._defocus_slider,
            self._defocus_difference_slider,
            self._defocus_azimuthal_slider
        ])
        self._imaging_widgets = widgets.VBox([
            widgets.HTML('<b> Imaging Parameters </b>'), 
            widgets.HBox([
                imaging_label_panel, 
                imaging_widget_panel], layout=widgets.Layout(border='solid'))]
        )

    def _setup_plotting_widgets(self):
        self._xlim_slider = widgets.FloatSlider(
            value=0.5,
            min=0.1,
            max=1.1,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self._temporal_env_checkbox = widgets.Checkbox(
            value=True,
            description='Temporal',
            disabled=False,
            indent=False,
        )
        self._spatial_env_checkbox = widgets.Checkbox(
            value=True,
            description='Spatial',
            disabled=False,
            indent=False,
        )
        self._detector_env_checkbox = widgets.Checkbox(
            value=True,
            description='Detector',
            disabled=False,
            indent=False,
        )
        plotting_label_panel = widgets.VBox([
            widgets.Label('X limit (Å^-1, 1D): '),
            widgets.Label('Envelope function: ')
        ], layout=widgets.Layout(width='40%'))
        plotting_widget_panel = widgets.VBox([
            self._xlim_slider,
            widgets.HBox([
                self._temporal_env_checkbox,
                self._spatial_env_checkbox,
                self._detector_env_checkbox 
            ])
        ], layout=widgets.Layout(width='60%'))          
        self._plotting_widgets = widgets.VBox([
            widgets.HTML('<b> Plotting Parameters </b>'), 
            widgets.HBox([
                plotting_label_panel, 
                plotting_widget_panel], layout=widgets.Layout(border='solid'))])  
    
    def _setup_functions(self):
        lambda_e = electron_wavelength(self._voltage_slider.value)
        fs = focus_spread(
            self._chromatic_aberration_slider.value,
            self._voltage_stability_slider.value,
            self._objective_lens_stability_slider.value,
            self._electron_source_spread_slider.value,
            self._voltage_slider.value)
        dqe_function, binning_factor = self._select_detector()
        nyquist = nyquist_frequency(self._pixel_size_slider.value, binning_factor)
        defocus_u = (self._defocus_slider.value + 0.5 * self._defocus_difference_slider.value) * 10000
        defocus_v = (self._defocus_slider.value - 0.5 * self._defocus_difference_slider.value) * 10000
        defocus_a = self._defocus_azimuthal_slider.value

        if self._temporal_env_checkbox.value:
            self.Et = temporal_envelope_function(lambda_e, fs)
        else:
            self.Et = lambda x: np.ones_like(x)
        if self._spatial_env_checkbox.value:
            self.Es = spatial_envelope_function(
                lambda_e, 
                self._electron_source_angle_slider.value, 
                self._spherical_aberration_slider.value,
                self._defocus_slider.value * 10000)
            self.Es_2d = spatial_envelope_function(
                lambda_e, 
                self._electron_source_angle_slider.value, 
                self._spherical_aberration_slider.value,
                (defocus_u, defocus_v, defocus_a))
        else: 
            self.Es = lambda x: np.ones_like(x)
            self.Es_2d = lambda x, y: np.ones_like(x+y)
        if self._detector_env_checkbox.value:
            self.Ed = DQE_envelope_function(nyquist=nyquist, DQE_function=dqe_function)
        else: 
            self.Ed = lambda x: np.ones_like(x)
        self.Etotal = total_envelope_function(self.Et, self.Es, self.Ed)
        self.Etotal_2d = total_envelope_function_2D(self.Et, self.Es_2d, self.Ed)
        self.ctf_1d = CTF(lambda_e, 
                  self._spherical_aberration_slider.value, 
                  self._amplitude_contrast_slider.value,
                  self._defocus_slider.value * 10000,
                  self._phase_shift_slider.value)               
        self.ctf_2d = CTF(lambda_e, 
                  self._spherical_aberration_slider.value, 
                  self._amplitude_contrast_slider.value,
                  (defocus_u, defocus_v, defocus_a),
                  self._phase_shift_slider.value)
        self.dampened_ctf_1d = dampened_CTF(self.ctf_1d, self.Etotal)
        self.dampened_ctf_2d = dampened_2D_CTF(self.ctf_2d, self.Etotal_2d)

    def _setup_1D_plot(self):
        with plt.ioff():
            self._fig0, self._ax0 = plt.subplots() 
        self._fig0.canvas.header_visible = False
        self._fig0.canvas.toolbar_position = 'top'

        self._ax0.set_title("1-D Contrast Transfer Function") 
        self._ax0.set_xlim(0, 0.5)
        self._ax0.set_ylim(-1, 1)
        self._ax0.axhline(y=0, color='grey', linestyle='--', alpha=1, linewidth=0.5)
        self._ax0.set_xlabel("Spacial Frequency (1/Angstrom)")  

        self.line_et = self._ax0.plot(self._freqs_1d, self.Et(self._freqs_1d), label="Temporal Envelope", linestyle="dashed")
        self.line_es = self._ax0.plot(self._freqs_1d, self.Es(self._freqs_1d), label="Spacial Envelope", linestyle="dashed")
        self.line_ed = self._ax0.plot(self._freqs_1d, self.Ed(self._freqs_1d), label="Detector Envelope", linestyle="dashed")
        self.line_te = self._ax0.plot(self._freqs_1d, self.Etotal(self._freqs_1d), label="Total Envelope")
        self.line_dc = self._ax0.plot(self._freqs_1d, self.dampened_ctf_1d(self._freqs_1d), label="Microscope CTF")
        self._ax0.legend()

    def _setup_2D_plot(self):
        with plt.ioff():
            self._fig1, self._ax1 = plt.subplots() 
        self._fig1.canvas.header_visible = False
        self._fig1.canvas.toolbar_position = 'top'

        self._ax1.set_title("2-D Contrast Transfer Function") 

        self.image = self._ax1.imshow(self.dampened_ctf_2d(self._fx, self._fy), cmap='Greys')      

    def _setup_plot_tab(self):
        self.plot_tab = widgets.Tab([
            self._fig0.canvas,
            self._fig1.canvas
        ], layout=widgets.Layout(width='95%'))
        self.plot_tab.set_title(0, '1D-CTF')
        self.plot_tab.set_title(1, '2D-CTF')

    def _setup_reset_button(self):
        self.reset_button = widgets.Button(description='Reset')

    def _setup_event_handlers(self):
        self._voltage_slider.observe(self.update_ctf, names='value')
        self._voltage_stability_slider.observe(self.update_ctf, names='value')
        self._electron_source_angle_slider.observe(self.update_ctf, names='value')
        self._electron_source_spread_slider.observe(self.update_ctf, names='value')
        self._chromatic_aberration_slider.observe(self.update_ctf, names='value')
        self._spherical_aberration_slider.observe(self.update_ctf, names='value')
        self._objective_lens_stability_slider.observe(self.update_ctf, names='value')
        self._detector_dropdown.observe(self.update_ctf, names='value')
        self._pixel_size_slider.observe(self.update_ctf, names='value')
        self._amplitude_contrast_slider.observe(self.update_ctf, names='value')
        self._phase_shift_slider.observe(self.update_ctf, names='value')
        self._defocus_slider.observe(self.update_ctf, names='value')
        self._temporal_env_checkbox.observe(self.update_ctf, names='value')
        self._spatial_env_checkbox.observe(self.update_ctf, names='value')
        self._detector_env_checkbox.observe(self.update_ctf, names='value')
        self._xlim_slider.observe(self.update_ctf, names='value')
        self._defocus_difference_slider.observe(self.update_ctf, names='value')
        self._defocus_azimuthal_slider.observe(self.update_ctf, names='value')
        self.plot_tab.observe(self.update_ctf, names='selected_index')
        self.reset_button.on_click(self.reset_parameters)

    # Callback function
    def update_ctf(self, change):  
        self._setup_functions()
        if self.plot_tab.selected_index == 0:
            self.line_et[0].set_data(self._freqs_1d, self.Et(self._freqs_1d))
            self.line_es[0].set_data(self._freqs_1d, self.Es(self._freqs_1d))
            self.line_ed[0].set_data(self._freqs_1d, self.Ed(self._freqs_1d))
            self.line_te[0].set_data(self._freqs_1d, self.Etotal(self._freqs_1d))
            self.line_dc[0].set_data(self._freqs_1d, self.dampened_ctf_1d(self._freqs_1d))
            self._ax0.set_xlim(0, self._xlim_slider.value)
            self._fig0.canvas.draw_idle()
        else:
            self.image.set_data(self.dampened_ctf_2d(self._fx, self._fy))
            self._fig1.canvas.draw_idle()

    def reset_parameters(self, change):
        self._voltage_slider.value = 300  # voltage in kV
        self._voltage_stability_slider.value = 3.3333e-8  # 3.3333e-8 s^-1 = 2e-6 min^-1
        self._electron_source_angle_slider.value = 1.0e-4  # rad
        self._electron_source_spread_slider.value = 0.7  # eV
        self._chromatic_aberration_slider.value = 3.4  # chromatic aberration (Cc) in mm
        self._spherical_aberration_slider.value = 2.7  # spherical aberration (Cs) in mm
        self._objective_lens_stability_slider.value = 1.6666e-8  # s^-1
        self._detector_dropdown.value = 'DDD counting'
        self._pixel_size_slider.value = 1.  # pixel size in angstroms.
        self._amplitude_contrast_slider.value = 0.1  # typical value for cryoEM
        self._phase_shift_slider.value = 0.  # additional phase shift in degree, e.g., from phase plate
        self._defocus_slider.value = 1.  # average defocus in um 
        self._defocus_difference_slider.value = 0.  # defocus difference in um
        self._defocus_azimuthal_slider.value = 0.  # azimuthal angle in degree
        self._temporal_env_checkbox.value = True
        self._spatial_env_checkbox.value = True
        self._detector_env_checkbox.value = True
        self._xlim_slider.value = 0.5  # A^-1,      
        self.update_ctf(change)  

    # helper functions
    def _select_detector(self) -> tuple[Callable[[ArrayLike], NDArray], float]:
        """A selector for detector parameters

        Returns:
            tuple[Callable[[ArrayLike], NDArray], float]: A tuple containing the DQE function and the detector binning factor  
        """
        if self._detector_dropdown.value == 'DDD super resolution counting':
            return (ddd_DQE_function, 0.5)
        elif self._detector_dropdown.value == 'DDD counting':
            return (ddd_DQE_function, 1.0)
        elif self._detector_dropdown.value == 'Film':
            return (film_DQE_function, 1.0)
        elif self._detector_dropdown.value == 'CCD':
            return (ccd_DQE_function, 1.0)
    


In [9]:
gui = CTFSimGUI()

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b> Microscopy Parameters </b>'), HBox(children=(VBox…

Note:
- DDD stands for direct electron detector.
- "Defocus difference" and "Defocus azimuthal" do not affect 1D plot.
- In the 2D case, Defocus = $(d_u + d_v) / 2$, Defocus difference = $d_u - d_v$.
- "X limit" does not affect 2D plot.
- Please send comments and report bugs to mlzhao@uchicago.edu. 
- Last update: Jan 6, 2025