---
title: Pixelated Detector iCOM and Ptychography Contrast Transfer
authors: [Julie Marie Bekkevold, Georgios Varnavides]
date: 2024-09-30
---

In [1]:
# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # import custom plotting / utils
import cmasher as cmr

from tqdm.notebook import tqdm

import ipywidgets
from IPython.display import display

## 4D STEM Simulation

In [2]:
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms

scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0

cmap = cmr.eclipse
icom_line_color = 'cornflowerblue'
iter_ptycho_line_color = 'mediumvioletred'

### White Noise Potential

In [3]:
def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

### Probe

In [4]:
# we build probe in Fourier space, using a soft aperture

qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

aperture_fourier = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)

In [5]:
def simulate_intensities(defocus, batch_size=n**2, pbar=None):

    m = n**2
    n_batch = int(m // batch_size)
    order = np.arange(m).reshape((n_batch,batch_size))
    amplitudes = np.zeros((m,n,n))

    if pbar is not None:
        pbar.reset(n_batch)
        pbar.colour = None
        pbar.refresh()

    probe_array_fourier = aperture_fourier * np.exp(-1j*np.pi*wavelength*defocus*q**2)
    probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
    probe_array = np.fft.ifft2(probe_array_fourier) * n
    
    for batch_index in range(n_batch):
        batch_order = order[batch_index]
        amplitudes[batch_order] = ctf.simulate_data(
            complex_obj,
            probe_array,
            row[batch_order],
            col[batch_order],
        )
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'
        
    return [amplitudes, probe_array]

In [6]:
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))

amplitudes_probe = simulate_intensities(defocus=0, batch_size=1024, pbar=None)

intensities = [amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2,None]
intensities[1] = intensities[0].sum((-1,-2))

## CoM and ptycho calculation 

In [7]:
def compute_pixelated_com(
    corner_centered_intensities,
    corner_centered_intensities_sum,
    sx,sy,
    kxa,kya,
):
    """ """
    intensities = np.asarray(corner_centered_intensities)
    
    com_x_pixelated = (
        np.sum(intensities * kxa[None, None], axis=(-2, -1))
        / corner_centered_intensities_sum
    )
    com_y_pixelated = (
        np.sum(intensities * kya[None, None], axis=(-2, -1))
        / corner_centered_intensities_sum
    )

    return com_x_pixelated, com_y_pixelated

def integrate_com(
    com_x,
    com_y,
    kx_op,
    ky_op,
):
    """ """

    icom_fft = np.fft.fft2(com_x)*kx_op + np.fft.fft2(com_y)*ky_op
    return np.real(np.fft.ifft2(icom_fft))

def ptycho_reconstruction(
    amplitudes,
    row,
    col,
    positions,
    recon_array,
    probe_array,
    pbars,
    batch_size = n**2,
    iterations=64,
    step_size=1.0,
):
    """ """
    m = amplitudes.shape[0]
    nx, ny = probe_array.shape
    n = int(m // batch_size)
    
    order = np.arange(m)
    np.random.shuffle(order)
    
    # normalization
    probe_normalization = np.mean(np.sum(amplitudes**2,0)) / nx / ny
    shifted_probes = probe_array * np.sqrt(probe_normalization)

    outer_pbar,inner_pbar = pbars
    
    outer_pbar.reset(iterations)
    outer_pbar.colour= None
    outer_pbar.refresh()

    
    for iter_index in range(iterations):

        inner_pbar.reset(n)
        inner_pbar.colour= None
        inner_pbar.refresh()
        
        for batch_index in range(n):

            batch_order = order.reshape((n,batch_size))[batch_index]
        
            batch_amplitudes = amplitudes[batch_order]
            batch_pos = positions[batch_order]
            batch_row = row[batch_order]
            batch_col = col[batch_order]
            
            # recon
            obj_patches = recon_array[batch_row,batch_col]

            overlap = shifted_probes * obj_patches
            fourier_overlap = np.fft.fft2(overlap)
            fourier_intensities = np.abs(fourier_overlap)**2
            
            # preprocess fourier overlap
            modified_fourier_overlap = batch_amplitudes*np.exp(1j*np.angle(fourier_overlap))
            grad = np.fft.ifft2(modified_fourier_overlap-fourier_overlap)

            update = ctf.sum_patches(
                grad*np.conj(shifted_probes),
                batch_pos,
                (nx,ny),
                (nx,ny),
            ) / probe_normalization
            
            recon_array += (step_size*update)

            amp = np.abs(recon_array).clip(0.0,1.0)
            recon_array = amp * np.exp(1j*np.angle(recon_array))
            inner_pbar.update(1)

        np.random.shuffle(order)
        update_ptycho_panel(recon_array)
        outer_pbar.update(1)

    inner_pbar.colour='green'
    outer_pbar.colour='green'
    return recon_array

In [8]:
# Spatial frequencies
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')

k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf

# iCoM operators
kx_op = -1.0j * kxa / k2
ky_op = -1.0j * kya / k2

In [9]:
# com
com_x, com_y = compute_pixelated_com(
    intensities[0],
    intensities[1],
    sx,sy,
    kxa,kya,
)

icom_pixelated = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_pixelated = ctf.compute_ctf(icom_pixelated) 

q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(
    ctf_pixelated,
    (sampling,sampling)
)


In [10]:
ptycho_recon = [np.ones((n,n),dtype=np.complex128)]

In [17]:
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(1,3,figsize=(640/dpi,260/dpi),dpi=dpi)

# pixelated CTF
ax_ctf_pixelated_dpc = axs[0]
im_ctf_dpc = ax_ctf_pixelated_dpc.imshow(ctf.histogram_scaling(np.fft.fftshift(ctf_pixelated),normalize=False),cmap=cmap)

# ptycho CTF
ax_ctf_pixelated_ptycho = axs[1]
im_ctf_ptycho = ax_ctf_pixelated_ptycho.imshow(np.zeros((n,n)),cmap=cmap,vmin=0,vmax=1)

# analytic CTF radially-averaged
ax_ctf_rad = axs[2]
plot_ctf_dpc = ax_ctf_rad.plot(q_bins_pixelated, I_bins_pixelated, color=icom_line_color,label='pixelated iCOM')[0]
plot_ctf_ptycho = ax_ctf_rad.plot(q_bins_pixelated, np.zeros_like(I_bins_pixelated), color=iter_ptycho_line_color,label='pixelated ptycho')[0]
ax_ctf_rad.hlines(1,0,2,ls='--',alpha=.7,color='k')
ax_ctf_rad.legend()

# remove ticks, add titles to 2D-plots
for ax, title in zip(
    axs,
    [
        # "detector geometry",
        "pixelated iCOM CTF",
        "pixelated ptycho CTF",
        "radially averaged CTFs",
    ]
):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)

for ax in axs[:2].flatten():
    ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')

ax_ctf_rad.set_ylim([0,1.1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")

# fix ipympl canvas from resizing
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '300px'
fig.tight_layout()
# fig

In [12]:
def update_ptycho_panel(ptycho_recon):
    """ """
    ctf_ptycho = ctf.compute_ctf(np.angle(ptycho_recon))
    im_ctf_ptycho.set_data(
        ctf.histogram_scaling(
            np.fft.fftshift(ctf_ptycho),
            normalize=False
        )
    )

    _, I_bins_ptycho = ctf.radially_average_ctf(
        ctf_ptycho,
        (sampling,sampling)
    )
    plot_ctf_ptycho.set_ydata(I_bins_ptycho)

    im_ctf_ptycho.set_alpha(1)
    plot_ctf_ptycho.set_alpha(1)
    
    fig.canvas.draw()
    return None

In [13]:
def compute_ptycho_updates(
    batch_size,
    iterations,
    pbars,
):
    """ """

    ptycho_recon[0] = ptycho_reconstruction(
        amplitudes_probe[0],
        row,
        col,
        positions,
        ptycho_recon[0],
        amplitudes_probe[1],
        pbars,
        batch_size=batch_size,
        iterations = iterations,
    )
    
    return None

In [14]:
def update_icom(
    *args,
):
    """ """
    # calculate pixelated com and icom signal
    com_x, com_y = compute_pixelated_com(
        intensities[0],
        intensities[1],
        sx,sy,
        kxa,kya,
    )

    icom = integrate_com(com_x,com_y,kx_op,ky_op)

    # compute ctf
    ctf_dpc = ctf.compute_ctf(icom)
    _, I_bins_dpc = ctf.radially_average_ctf(
        ctf_dpc,
        (sampling,sampling)
    )
    
    im_ctf_dpc.set_data(
        ctf.histogram_scaling(
            np.fft.fftshift(ctf_dpc),
            normalize=False
        )
    )
    
    plot_ctf_dpc.set_ydata(I_bins_dpc)

    # make the plots opaque again
    im_ctf_dpc.set_alpha(1)
    plot_ctf_dpc.set_alpha(1)

    # re-draw figure
    fig.canvas.draw()
    
    return None

In [15]:
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
smaller_layout = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}

m = n**2
batch_sizes = m/(np.arange(m)+1)
batch_sizes = np.where(np.mod(batch_sizes, 1, out=batch_sizes)==0)[0]+1
batch_size_slider = ipywidgets.SelectionSlider(
    options=batch_sizes,
    value=batch_sizes[-7],
    description= "batch size",
    **kwargs
)

iterations_slider = ipywidgets.IntSlider(
    value = 8, min = 1, max = 64, step = 1,
    description = "(outer loop) iterations",
    **kwargs
)

iterate_button = ipywidgets.Button(
    description="reconstruct (expensive)",
    layout=smaller_layout,
)

reset_button = ipywidgets.Button(
    description="reset object",
    layout=smaller_layout,
)

defocus_slider = ipywidgets.IntSlider(
    value = 0, min = -n, max = n, step = 1,
    description = "negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

simulate_button = ipywidgets.Button(
    description='simulate (expensive)',
    layout=ipywidgets.Layout(width="160px",height="30px")
)

simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(simulation_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

def defocus_wrapper(*args):
    reset_wrapper()
    simulate_button.button_style = 'warning'
    im_ctf_dpc.set_alpha(0.25)
    im_ctf_ptycho.set_alpha(0.25)
    plot_ctf_dpc.set_alpha(0.25)
    plot_ctf_ptycho.set_alpha(0.25)
    simulation_pbar.reset()
defocus_slider.observe(defocus_wrapper,names='value')

def simulate_wrapper(*args):
    disable_all(True)
    amplitudes_probe[:] = simulate_intensities(
        defocus=defocus_slider.value,
        batch_size=1024,
        pbar=simulation_pbar
    )

    intensities[0] = amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2
    intensities[1] = intensities[0].sum((-1,-2))

    update_icom()
    
    disable_all(False)
    iterate_button.button_style = 'warning'
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()
    simulate_button.button_style = ''
simulate_button.on_click(simulate_wrapper)

def reset_wrapper(*args):
    """ """
    ptycho_recon[0] = np.ones((n,n),dtype=np.complex128)
    update_ptycho_panel(ptycho_recon[0])
    iterate_button.button_style = 'warning'
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()

reset_button.on_click(reset_wrapper)

def disable_all(boolean):
    batch_size_slider.disabled = boolean
    iterations_slider.disabled = boolean
    reset_button.disabled = boolean
    iterate_button.disabled = boolean
    defocus_slider.disabled = boolean
    simulate_button.disabled = boolean
    simulation_pbar_wrapper.disabled = boolean

def click_wrapper(*args):
    """ """
    disable_all(True)
    compute_ptycho_updates(
        batch_size=batch_size_slider.value,
        iterations=iterations_slider.value,
        pbars=(outer_reconstruct_pbar,inner_reconstruct_pbar),
    )
    disable_all(False)
    iterate_button.button_style = ''

iterate_button.on_click(click_wrapper)
outer_reconstruct_pbar = tqdm(total=4,display=False)
outer_reconstruct_pbar_wrapper = ipywidgets.HBox(outer_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

inner_reconstruct_pbar = tqdm(total=9,display=False)
inner_reconstruct_pbar_wrapper = ipywidgets.HBox(inner_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

In [18]:
#| label: app:pixelated_icom_and_ptycho
# iCOM and Iteratice Ptycho with Pixelated Detectors
display(
    ipywidgets.VBox(
        [
            ipywidgets.VBox(
                [
                    ipywidgets.HBox([defocus_slider,simulate_button,simulation_pbar_wrapper]),
                    ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
                    ipywidgets.HBox([batch_size_slider,iterations_slider]),
                    ipywidgets.HBox([reset_button,iterate_button,outer_reconstruct_pbar_wrapper,inner_reconstruct_pbar_wrapper]),
                ]
            ),
            fig.canvas
        ]
    )
)

VBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='defocus …