---
title: iCOM Contrast Transfer
authors: [Julie Marie Bekkevold, Georgios Varnavides]
date: 2024/09/30
---

In [1]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # custom plotting / utils

import ipywidgets
from IPython.display import display

## 4D STEM Simulation of white noise object

In [2]:
# parameters
n = 96
q_max = 2
q_probe = 1
sampling = 1 / q_max / 2
reciprocal_sampling = 1 / n / sampling

scan_step_size = 1
sx = sy = n//scan_step_size
phi0 = 1.0

In [3]:
def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

In [4]:
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

In [5]:
# probe
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

probe_array_fourier = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)

probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
probe_array = np.fft.fft2(probe_array_fourier) / n

In [6]:
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))

intensities = ctf.simulate_data(
    complex_obj,
    probe_array,
    row,
    col,
).reshape((sx,sy,n,n))**2 / n**2

intensities_sum = intensities.sum((-1,-2))

## Define Virtual Detectors

In [7]:
def annular_segmented_detectors(
    gpts,
    sampling,
    n_angular_bins,
    rotation_offset = 0,
    inner_radius = 0,
    outer_radius = np.inf,
):
    """ """
    nx,ny = gpts
    sx,sy = sampling

    k_x = np.fft.fftfreq(nx,sx)
    k_y = np.fft.fftfreq(ny,sy)

    k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
    radial_mask = ((inner_radius <= k) & (k < outer_radius))
    
    theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
    angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
    angular_bins *= radial_mask.astype("int")

    angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
    
    return angular_bins

## Calculate COM and iCOM

In [8]:
def compute_com_using_virtual_detectors(
    corner_centered_intensities,
    center_centered_masks,
    corner_centered_intensities_sum,
    sx,sy,
    kxa,kya,
):
    """ """

    masks = np.fft.ifftshift(np.asarray(center_centered_masks),axes=(-1,-2))
    
    com_x = np.zeros((sx,sy))
    com_y = np.zeros((sx,sy))

    kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
    kxa, kya = np.meshgrid(kx, ky, indexing='ij')
    
    for mask in masks:
        kxa_i,kya_i=np.where(mask)
        patches= corner_centered_intensities[:,:,kxa_i,kya_i].sum(-1) / corner_centered_intensities_sum
        com_x += patches * np.mean(kxa[kxa_i,kya_i])
        com_y += patches * np.mean(kya[kxa_i,kya_i])
        
    return com_x, com_y

def integrate_com(
    com_x,
    com_y,
    kx_op,
    ky_op,
):
    """ """

    icom_fft = np.fft.fft2(com_x)*kx_op + np.fft.fft2(com_y)*ky_op
    return np.real(np.fft.ifft2(icom_fft))

## Compute CTFs

In [9]:
# Compute the inverse error
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')

k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf
kx_op = -1.0j * kxa / k2
ky_op = -1.0j * kya / k2
inverse_error = (k*np.pi/np.sqrt(2))

ctf_analytic = np.real(
    np.fft.ifft2(
        np.abs(
            np.fft.fft2(
                probe_array_fourier
            )
        )**2
    )
)

q_bins_analytic, I_bins_analytic = ctf.radially_average_ctf(ctf_analytic,(sampling,sampling))
q_bins_analytic_snr, I_bins_analytic_snr = ctf.radially_average_ctf(ctf_analytic*inverse_error,(sampling,sampling))
ctf_analytic_array = np.fft.fftshift(ctf_analytic)

In [10]:
virtual_masks_annular = annular_segmented_detectors(
    gpts=(n,n),
    sampling=(sampling,sampling),
    n_angular_bins=4,
    inner_radius=q_probe/2,
    outer_radius=q_probe*1.05,
    rotation_offset=0,
)   

com_x, com_y = compute_com_using_virtual_detectors(
    intensities,
    virtual_masks_annular,
    intensities_sum,
    sx,sy,
    kxa,kya,
)

icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular) 

q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
    ctf_annular,
    (sampling,sampling)
)

q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
    ctf_annular*inverse_error,
    (sampling,sampling)
)

In [85]:
with plt.ioff():
    dpi=72
    fig = plt.figure(figsize=(675/dpi,240/dpi),dpi=dpi)

gs = GridSpec(2, 5,height_ratios=[3,1],figure=fig)

ax_detector = fig.add_subplot(gs[0, 0])
im_detector = ax_detector.imshow(ctf.combined_images_rgb(virtual_masks_annular))
ctf.add_scalebar(ax_detector,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')

ax_ctf_analytic = fig.add_subplot(gs[0, 1])
ax_ctf_analytic.imshow(ctf.histogram_scaling(np.flip(ctf_analytic[:n//2,:n//2],axis=None),normalize=True),cmap='turbo')

ax_ctf_annular = fig.add_subplot(gs[0, 2])
im_ctf = ax_ctf_annular.imshow(ctf.histogram_scaling(np.flip(ctf_annular[:n//2,:n//2],axis=0),normalize=True),cmap='turbo')

ax_snr_analytic = fig.add_subplot(gs[0, 3])
ax_snr_analytic.imshow(ctf.histogram_scaling(np.flip((ctf_analytic*inverse_error)[:n//2,:n//2],axis=None),normalize=True),cmap='turbo')

ax_snr_annular = fig.add_subplot(gs[0, 4])
im_snr = ax_snr_annular.imshow(ctf.histogram_scaling(np.flip((ctf_annular*inverse_error)[:n//2,:n//2],axis=0),normalize=True),cmap='turbo')

ax_ctf_analytic_rad = fig.add_subplot(gs[1, 1])
ax_ctf_analytic_rad.plot(-q_bins_analytic,I_bins_analytic,color='red')
ax_ctf_analytic_rad.set_xlim([-q_max,0])
ax_ctf_analytic_rad.set_ylim([0,1])

ax_ctf_annular_rad = fig.add_subplot(gs[1, 2])
plot_ctf = ax_ctf_annular_rad.plot(q_bins_annular,I_bins_annular,color='red')[0]
ax_ctf_annular_rad.set_xlim([0,q_max])
ax_ctf_annular_rad.set_ylim([0,1])

ax_snr_analytic_rad = fig.add_subplot(gs[1, 3])
ax_snr_analytic_rad.plot(-q_bins_analytic_snr,I_bins_analytic_snr,color='red')
ax_snr_analytic_rad.set_xlim([-q_max,0])
ax_snr_analytic_rad.set_ylim([0,1])

ax_snr_annular_rad = fig.add_subplot(gs[1, 4])
plot_snr = ax_snr_annular_rad.plot(q_bins_annular_snr,I_bins_annular_snr,color='red')[0]
ax_snr_annular_rad.set_xlim([0,q_max])
ax_snr_annular_rad.set_ylim([0,1])

for ax, title in zip(fig.axes[:5],["detector geometry","analytical CTF","annular CTF","analytical SNR","annular SNR"]):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)

for index, ax in enumerate(fig.axes[5:]):
    ax.set_yticks([])
    if index % 2:
        ax.set_xticks([0,q_probe,q_max])
        ax.set_xticklabels([0,1,2])
        ax.vlines([q_probe/2,q_probe*1.05,],0,2,colors='k',linestyles='--',linewidth=1,)
    else:
        ax.set_xticks([-q_max,-q_probe,0])
        ax.set_xticklabels([2,1,0])
        ax.vlines([-q_probe/2,-q_probe*1.05,],0,2,colors='k',linestyles='--',linewidth=1,)
    ax.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")


fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '280px'
gs.tight_layout(fig)

In [86]:
def update_figure(
    inner_collection_angle,
    outer_collection_angle,
    number_of_segments,
    rotation_offset,
):
    """ """

    virtual_masks_annular = annular_segmented_detectors(
        gpts=(n,n),
        sampling=(sampling,sampling),
        n_angular_bins=number_of_segments,
        inner_radius=inner_collection_angle,
        outer_radius=outer_collection_angle,
        rotation_offset=np.deg2rad(rotation_offset),
    )   
    
    com_x, com_y = compute_com_using_virtual_detectors(
        intensities,
        virtual_masks_annular,
        intensities_sum,
        sx,sy,
        kxa,kya,
    )

    icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
    ctf_annular = ctf.compute_ctf(icom_annular) 

    q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
        ctf_annular,
        (sampling,sampling)
    )

    q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
        ctf_annular*inverse_error,
        (sampling,sampling)
    )
    
    # update data

    # 2D arrays
    im_detector.set_data(ctf.combined_images_rgb(virtual_masks_annular))
    im_ctf.set_data(ctf.histogram_scaling(np.flip(ctf_annular[:n//2,:n//2],axis=0),normalize=True))
    im_snr.set_data(ctf.histogram_scaling(np.flip((ctf_annular*inverse_error)[:n//2,:n//2],axis=0),normalize=True))

    # 1D lines
    plot_ctf.set_ydata(I_bins_annular)
    plot_snr.set_ydata(I_bins_annular_snr)

    # collections (vlines)
    for index, ax in enumerate(fig.axes[5:]):
        # remove old vlines
        ax.collections[0].remove()
        # add lines
        if index % 2:
            ax.vlines([inner_collection_angle,outer_collection_angle],0,2,colors='k',linestyles='--',linewidth=1,)
        else:
            ax.vlines([-inner_collection_angle,-outer_collection_angle,],0,2,colors='k',linestyles='--',linewidth=1,)
    fig.canvas.draw_idle()
    return None


In [90]:
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="340px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}

inner_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe/2,
    min = 0,
    max = q_probe, 
    step = q_probe/20,
    description = r"Inner collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

outer_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe*1.05, 
    min = q_probe/10, 
    max = q_max, 
    step = q_probe/20,
    description = r"Outer collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

number_of_segments_slider = ipywidgets.IntSlider(
    value = 4, min = 3, max = 16, step = 1,
    description = "Number of segments",
    **kwargs
)

rotation_offset_slider = ipywidgets.IntSlider(
    value = 0, min = 0, max = 180/4, step = 1,
    description = "Rotation offset [°]",
    **kwargs
)

def update_outer_collection_angle(change):
    value = change['new']
    outer_collection_angle_slider.min = value*1.05

inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')  

def update_rotation_offset_range(change):
    value = change['new']
    rotation_offset_slider.max = 180/value

number_of_segments_slider.observe(update_rotation_offset_range, names='value')  

In [91]:
#| label: app:annular_segmented_detectors
# Annular Segmented Detectors
ipywidgets.interactive_output(
    update_figure,
    {
    'inner_collection_angle': inner_collection_angle_slider,
    'outer_collection_angle': outer_collection_angle_slider,
    'number_of_segments': number_of_segments_slider,
    'rotation_offset': rotation_offset_slider,
}
)

display(
    ipywidgets.VBox(
        [
            ipywidgets.VBox(
                [
                    ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
                    ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
                ]
            ),
            fig.canvas
        ]
    )
)

VBox(children=(VBox(children=(HBox(children=(FloatSlider(value=0.5, continuous_update=False, description='Inne…