In [1]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets

plt.rcParams['text.color']='white'
plt.rcParams['xtick.labelcolor']='white'
plt.rcParams['xtick.color']='white'
plt.rcParams['ytick.labelcolor']='white'
plt.rcParams['ytick.color']='white'
plt.rcParams['axes.labelcolor']='white'
plt.rcParams['axes.edgecolor']='white'

In [2]:
import numpy as np
from numpy.typing import NDArray

def electron_wavelength(energy: float) -> float:
    """ """
    m = 9.109383e-31  # mass in SI (Kg)
    e = 1.602177e-19  # elementary charge in SI (C)
    c = 299792458  # speed of light in SI (m/s)
    h = 6.62607e-34  # planch constant in SI (Kg m^2 / s)

    lam = h / np.sqrt(2 * m * e * energy) / np.sqrt(1 + e * energy / 2 / m / c**2)
    return lam * 1e10  # convert from m to Angstroms


def spatial_frequencies(
    gpts: tuple[int, int],
    sampling: tuple[float, float],
    rotation_angle: float | None = None,
) -> tuple[NDArray, NDArray]:
    """ """
    ny, nx = gpts
    sy, sx = sampling

    kx = np.fft.fftfreq(ny, sy)
    ky = np.fft.fftfreq(nx, sx)
    kxa, kya = np.meshgrid(kx, ky, indexing="ij")

    if rotation_angle is not None:
        cos_theta = np.cos(rotation_angle)
        sin_theta = np.sin(rotation_angle)

        kxa_new = cos_theta * kxa - sin_theta * kya
        kya_new = sin_theta * kxa + cos_theta * kya

        kxa, kya = kxa_new, kya_new

    return kxa, kya


def polar_coordinates(kx: NDArray, ky: NDArray) -> tuple[NDArray, NDArray]:
    """ """
    k = np.sqrt(kx**2 + ky**2)
    phi = np.arctan2(ky, kx)
    return k, phi


def quadratic_aberration_surface(
    alpha: NDArray, phi: NDArray, wavelength: float, aberration_coefs: dict[str, float]
) -> NDArray:
    """ """
    C10 = aberration_coefs.get("C10", 0.0)
    C12 = aberration_coefs.get("C12", 0.0)
    phi12 = aberration_coefs.get("phi12", 0.0)

    prefactor = np.pi / wavelength

    aberration_surface = (
        prefactor * alpha**2 * (C10 + C12 * np.cos(2.0 * (phi - phi12)))
    )

    return aberration_surface


def quadratic_aberration_cartesian_gradients(
    alpha: NDArray, phi: NDArray, aberration_coefs: dict[str, float]
) -> tuple[NDArray, NDArray]:
    """ """
    C10 = aberration_coefs.get("C10", 0.0)
    C12 = aberration_coefs.get("C12", 0.0)
    phi12 = aberration_coefs.get("phi12", 0.0)

    cos2 = np.cos(2.0 * (phi - phi12))
    sin2 = np.sin(2.0 * (phi - phi12))

    # dχ/dα and dχ/dφ
    scale = 2 * np.pi
    dchi_dk = scale * alpha * (C10 + C12 * cos2)
    dchi_dphi = -scale * alpha * (C12 * sin2)

    cos_phi = np.cos(phi)
    sin_phi = np.sin(phi)

    dchi_dx = cos_phi * dchi_dk - sin_phi * dchi_dphi
    dchi_dy = sin_phi * dchi_dk + cos_phi * dchi_dphi

    return dchi_dx, dchi_dy


def suppress_nyquist_frequency(array: NDArray):
    """ """
    fourier_array = np.fft.fft2(array)
    Nx, Ny = fourier_array.shape

    if Nx % 2 == 0:
        fourier_array[Nx // 2, :] = 0.0
    if Ny % 2 == 0:
        fourier_array[:, Ny // 2] = 0.0

    return np.fft.ifft2(fourier_array).real


def prepare_grouped_phase_flipping_kernel(kernel, shifts_m_upsampled, upsampled_gpts):
    """ """
    Ny, Nx = upsampled_gpts
    h, w = kernel.shape
    M = shifts_m_upsampled.shape[0]
    L0 = h * w

    # kernel grid
    dy = np.arange(h)
    dx = np.arange(w)
    dy_grid = np.repeat(dy, w)
    dx_grid = np.tile(dx, h)

    # repeat for M BF pixels
    dy_rep = np.tile(dy_grid, M)
    dx_rep = np.tile(dx_grid, M)

    # shifts repeated
    s_my = np.repeat(shifts_m_upsampled[:, 0], L0)
    s_mx = np.repeat(shifts_m_upsampled[:, 1], L0)

    # compute flattened offsets (wrapped properly)
    offsets = ((dy_rep + s_my) % Ny) * Nx + ((dx_rep + s_mx) % Nx)

    # find unique offsets and inverse indices
    unique_offsets, inv = np.unique(offsets, return_inverse=True)
    U = unique_offsets.size

    # build grouped kernel
    H_flat = kernel.ravel()
    H_all = np.tile(H_flat, M)
    m_idx = np.repeat(np.arange(M), L0)

    grouped_kernel = np.zeros((U, M), dtype=kernel.dtype)
    np.add.at(grouped_kernel, (inv, m_idx), H_all)  # accumulate values

    return unique_offsets.astype(np.int64), grouped_kernel

In [3]:
from dataclasses import dataclass

@dataclass(frozen=True)
class PreprocessedGeometry:
    bf_flat_inds: np.ndarray
    shifts: np.ndarray
    wavelength: float
    gpts: tuple[int, int]
    reciprocal_sampling: tuple[float, float]
    sampling: tuple[float, float]
    upsampled_scan_gpts: tuple[int, int]
    upsampled_scan_sampling: tuple[float, float]
    upsampling_factor: int
    aberration_coefs: dict[str, float]

def preprocess_geometry(
    shape: tuple[int, int, int, int],
    scan_sampling: tuple[float, float],
    energy: float,
    semiangle_cutoff: float,
    reciprocal_sampling: tuple[float, float] | None = None,
    angular_sampling: tuple[float, float] | None = None,
    aberration_coefs: dict[str, float] | None = None,
    rotation_angle: float | None = None,
    upsampling_factor: int = 1,
    detector_flip_cols: bool = False,
):
    """ """

    # ---- Sampling ----
    wavelength = electron_wavelength(energy)

    if reciprocal_sampling is not None and angular_sampling is not None:
        raise ValueError(
            "Specify only one of `reciprocal_sampling` or `angular_sampling`, not both."
        )

    if reciprocal_sampling is None and angular_sampling is None:
        raise ValueError(
            "One of `reciprocal_sampling` or `angular_sampling` must be specified."
        )

    # Canonicalize to reciprocal sampling
    if reciprocal_sampling is None:
        assert angular_sampling is not None
        reciprocal_sampling = (
            angular_sampling[0] / wavelength / 1e3,
            angular_sampling[1] / wavelength / 1e3,
        )

    if len(shape) != 4:
        raise ValueError(f"`shape` must have length 4, not {len(shape)}.")

    scan_gpts = (shape[0], shape[1])
    gpts = (shape[-2], shape[-1])

    sampling = (
        1.0 / reciprocal_sampling[0] / gpts[0],
        1.0 / reciprocal_sampling[1] / gpts[1],
    )

    upsampled_scan_gpts = (
        scan_gpts[0] * upsampling_factor,
        scan_gpts[1] * upsampling_factor,
    )

    upsampled_scan_sampling = (
        scan_sampling[0] / upsampling_factor,
        scan_sampling[1] / upsampling_factor,
    )

    # ---- Parallax shifts ----
    if aberration_coefs is None:
        aberration_coefs = {}

    kxa, kya = spatial_frequencies(
        gpts,
        sampling,
        rotation_angle=rotation_angle,
    )

    k, phi = polar_coordinates(kxa, kya)

    # ---- BF indices ----
    bf_mask = k * wavelength * 1e3 <= semiangle_cutoff
    inds_i, inds_j = np.where(bf_mask)

    inds_i_fft = (inds_i - gpts[0] // 2) % gpts[0]
    inds_j_fft = (inds_j - gpts[1] // 2) % gpts[1]

    if rotation_angle is not None:
        n_rot = int(np.round(rotation_angle / (np.pi / 2))) % 4

        if n_rot == 1:  # +π/2
            if gpts[1] % 2 == 0:
                inds_j_fft = (inds_j_fft - 1) % gpts[1]
        elif n_rot == 2:  # π
            if gpts[0] % 2 == 0:
                inds_i_fft = (inds_i_fft - 1) % gpts[0]
            if gpts[1] % 2 == 0:
                inds_j_fft = (inds_j_fft - 1) % gpts[1]
        elif n_rot == 3:  # -π/2
            if gpts[0] % 2 == 0:
                inds_i_fft = (inds_i_fft - 1) % gpts[0]

    bf_flat_inds = (inds_i_fft * gpts[1] + inds_j_fft).astype(np.int32)

    dx, dy = quadratic_aberration_cartesian_gradients(
        k * wavelength,
        phi,
        aberration_coefs,
    )

    grad_k = np.stack(
        (dx[inds_i, inds_j], dy[inds_i, inds_j]),
        axis=-1,
    )

    shifts = np.round(grad_k / (2 * np.pi) / upsampled_scan_sampling).astype(
        np.int32
    )

    return PreprocessedGeometry(
        bf_flat_inds=bf_flat_inds,
        shifts=shifts,
        wavelength=wavelength,
        gpts=gpts,
        reciprocal_sampling=reciprocal_sampling,
        sampling=sampling,
        upsampled_scan_gpts=upsampled_scan_gpts,
        upsampled_scan_sampling=upsampled_scan_sampling,
        upsampling_factor=upsampling_factor,
        aberration_coefs=aberration_coefs,
    )

In [4]:
import numba

@numba.njit(fastmath=True, nogil=True, cache=True)
def parallax_accumulate_cpu(
    frames,  # (T, sy, sx) float32/64
    bf_flat_inds,  # (M,) int32
    shifts,  # (M, 2) int32
    coords,  # (T, 2) int64
    out,  # (Ny, Nx) float64
):
    """ """
    # shapes
    T = frames.shape[0]
    M = shifts.shape[0]
    Ny, Nx = out.shape
    sx = frames.shape[2]

    # loop over frames
    for t in range(T):
        frame = frames[t]
        yt, xt = coords[t]

        # Compute mean over BF pixels
        mean = 0.0
        for m in range(M):
            flat_idx = bf_flat_inds[m]
            iy = flat_idx // sx
            ix = flat_idx % sx
            mean += frame[iy, ix]
        mean /= M

        # Accumulate shifted, mean-subtracted values
        for m in range(M):
            flat_idx = bf_flat_inds[m]
            iy = flat_idx // sx
            ix = flat_idx % sx

            val = frame[iy, ix] - mean

            dy, dx = shifts[m]
            oy = (yt + dy) % Ny
            ox = (xt + dx) % Nx

            out[oy, ox] += val

In [5]:
@numba.njit(fastmath=True, nogil=True, cache=True)
def parallax_phase_flip_accumulate_cpu(
    frames, bf_rows, bf_cols, coords, unique_offsets, grouped_kernel, out
):
    """ """

    T = frames.shape[0]
    M = len(bf_rows)
    U = len(unique_offsets)
    Ny, Nx = out.shape

    for t in range(T):
        # Extract BF pixels and subtract per-frame mean
        I_bf = np.empty(M, dtype=np.float64)
        s = 0.0
        for m in range(M):
            s += frames[t, bf_rows[m], bf_cols[m]]
        mean = s / M
        for m in range(M):
            I_bf[m] = frames[t, bf_rows[m], bf_cols[m]] - mean

        # Compute contributions
        vals = np.empty(U, dtype=np.float64)
        for u in range(U):
            acc = 0.0
            for m in range(M):
                acc += grouped_kernel[u, m] * I_bf[m]
            vals[u] = acc

        # Scatter-add to accumulator
        yt, xt = coords[t]
        r_off = yt * Nx + xt
        for u in range(U):
            idx = (r_off + unique_offsets[u]) % (Ny * Nx)
            out.flat[idx] += vals[u]

def preprocess_flipped_geometry(
    shape: tuple[int, int, int, int],
    scan_sampling: tuple[float, float],
    energy: float,
    semiangle_cutoff: float,
    reciprocal_sampling: tuple[float, float] | None = None,
    angular_sampling: tuple[float, float] | None = None,
    aberration_coefs: dict[str, float] | None = None,
    rotation_angle: float | None = None,
    upsampling_factor: int = 1,
    detector_flip_cols: bool = False,
):

    pre = preprocess_geometry(
        shape,
        scan_sampling,
        energy,
        semiangle_cutoff,
        reciprocal_sampling,
        angular_sampling,
        aberration_coefs,
        rotation_angle,
        upsampling_factor,
        detector_flip_cols,
    )

    shifts = pre.shifts
    wavelength = pre.wavelength
    upsampled_scan_gpts = pre.upsampled_scan_gpts
    upsampled_scan_sampling = pre.upsampled_scan_sampling
    aberration_coefs = pre.aberration_coefs
    
    # Phase-flip kernel
    qxa, qya = spatial_frequencies(upsampled_scan_gpts, upsampled_scan_sampling)
    q, theta = polar_coordinates(qxa, qya)
    aberration_surface = quadratic_aberration_surface(
        q * wavelength,
        theta,
        wavelength,
        aberration_coefs=aberration_coefs,
    )
    fourier_kernel = np.sign(np.sin(aberration_surface))

    Nx, Ny = fourier_kernel.shape
    fourier_kernel[Nx // 2, :] = 0.0
    fourier_kernel[:, Ny // 2] = 0.0

    realspace_kernel = np.fft.ifft2(fourier_kernel).real

    unique_offsets, grouped_kernel = prepare_grouped_phase_flipping_kernel(
        realspace_kernel, shifts, upsampled_scan_gpts
    )

    return pre, unique_offsets, grouped_kernel

In [6]:
data = np.load("data/apoF_4mrad_1.5um-df_3A-step_30eA2_binary_uint8.npy")
frames = data.reshape((-1,)+data.shape[-2:])
num_pos = frames.shape[0]

In [7]:
style = {'description_width': 'initial'}
layout_half = ipywidgets.Layout(width="310px",height="30px")
layout_third = ipywidgets.Layout(width="205px",height="30px")

# Sliders
defocus_slider = ipywidgets.FloatSlider(
    value=-1000, min=-2000, max=2000,step=100,
    description='defocus [nm]', continuous_update=False,
    style=style, layout=layout_half,
)

rotation_slider = ipywidgets.FloatSlider(
    value=0.0, min=-30, max=30,step=0.5,
    description='rotation [deg]', continuous_update=False,
    style=style, layout=layout_half,
)

frame_slider = ipywidgets.IntSlider(
    value=0, min=0, max=num_pos-1, step=72,
    description='frame', continuous_update=False,
    style=style, layout=layout_half,
)

play_button = ipywidgets.Play(
    interval=10, value=0, min=0, max=num_pos-1, step=72,
    repeat=True, show_repeat=False,style=style,layout=layout_third,
)

reset_button = ipywidgets.Button(
    description = "reset reconstruction",
    style=style, layout=layout_third,
)

ground_truth_button = ipywidgets.Button(
    description = "ground truth aberrations",
    style=style, layout=layout_third,
)

def update_display(
    change
):
    idx = change["new"]
    I_j = frames[idx:idx+72]
    coords = inputs[2][idx:idx+72]

    bf_flat_inds = np.asarray(inputs[3].bf_flat_inds)
    bf_rows, bf_cols = np.unravel_index(bf_flat_inds, inputs[3].gpts)

    parallax_accumulate_cpu(
        frames[idx:idx+72],
        inputs[3].bf_flat_inds,
        inputs[3].shifts,
        coords,
        inputs[0],
    )
    
    parallax_phase_flip_accumulate_cpu(
        frames[idx:idx+72],
        bf_rows,
        bf_cols,
        coords,
        inputs[4],
        inputs[5],
        inputs[1]
    )
    
    vmin, vmax = np.quantile(inputs[0],(0.02,0.98))
    frame = (inputs[0].clip(vmin,vmax) - vmin) / (vmax-vmin)
    im.set_array(frame)

    vmin, vmax = np.quantile(inputs[1],(0.02,0.98))
    frame_flipped = (inputs[1].clip(vmin,vmax) - vmin) / (vmax-vmin)
    im_flipped.set_array(frame_flipped)
    return None

def reset_recon(*args):
    inputs[0][:] = 0.0
    inputs[1][:] = 0.0
    ny,nx = inputs[0].shape
    im.set_array(np.full((ny,nx),0.5))
    im_flipped.set_array(np.full((ny,nx),0.5))
    return None

def update_aberrations_and_rotation(
    change
):
    defocus = defocus_slider.value * 10
    rotation = np.deg2rad(rotation_slider.value)
    
    pre, unique_offsets, grouped_kernel = preprocess_flipped_geometry(
        shape=data.shape,
        scan_sampling=(256/72,256/72),
        energy=300e3,
        semiangle_cutoff=4.0,
        angular_sampling = (0.307617,0.307617),
        aberration_coefs = {"C10":-defocus},
        rotation_angle = rotation,
    )

    inputs[3] = pre
    inputs[4] = unique_offsets
    inputs[5] = grouped_kernel
    
    reset_recon()
    return None

def set_ground_truth(
    *args
):
    defocus_slider.value = 1.5e3
    rotation_slider.value = -15
    return None

ipywidgets.jslink((play_button, 'value'), (frame_slider, 'value'))
frame_slider.observe(update_display, names='value')
defocus_slider.observe(update_aberrations_and_rotation,names='value')
rotation_slider.observe(update_aberrations_and_rotation,names='value')
reset_button.on_click(reset_recon)
ground_truth_button.on_click(set_ground_truth)

In [8]:
pre, unique_offsets, grouped_kernel = preprocess_flipped_geometry(
    shape=data.shape,
    scan_sampling=(256/72,256/72),
    energy=300e3,
    semiangle_cutoff=4.0,
    angular_sampling = (0.307617,0.307617),
    aberration_coefs = {"C10":-defocus_slider.value*10},
    rotation_angle = np.deg2rad(rotation_slider.value),
)

coords = np.stack(
    np.meshgrid(
        np.arange(data.shape[0]),
        np.arange(data.shape[1]),
        indexing='ij'
    ),
    -1
).reshape(-1,2) * pre.upsampling_factor


out_prlx = np.zeros(
    pre.upsampled_scan_gpts,
    dtype=np.float64
)

out_prlx_flipped = np.zeros(
    pre.upsampled_scan_gpts,
    dtype=np.float64
)

inputs = [
    out_prlx,
    out_prlx_flipped,
    coords,
    pre,
    unique_offsets,
    grouped_kernel
]

In [12]:
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(1,2,figsize=(640/dpi,365/dpi),dpi=dpi)
    
im = axs[0].imshow(
    np.full(inputs[3].upsampled_scan_gpts,0.5),
    cmap="gray",
    vmin=0,
    vmax=1,
)
axs[0].set(xticks=[],yticks=[],title="streamed parallax")
im_flipped = axs[1].imshow(
    np.full(inputs[3].upsampled_scan_gpts,0.5),
    cmap="gray",
    vmin=0,
    vmax=1,
)
axs[1].set(xticks=[],yticks=[],title="streamed phase-flipped parallax")

fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "360px"
fig.canvas.layout.width = '640px'
fig.patch.set_alpha(0)
None

In [13]:
#| label: app:streaming-parallax
ipywidgets.VBox(
    [
        ipywidgets.HBox([defocus_slider,rotation_slider]),
        ipywidgets.HBox([play_button,reset_button,ground_truth_button]),
        fig.canvas
    ]
)

VBox(children=(HBox(children=(FloatSlider(value=-1000.0, continuous_update=False, description='defocus [nm]', …