---
title: iCOM Contrast Transfer
authors: [Julie Marie Bekkevold, Georgios Varnavides]
date: 2024/09/30
---

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import py4DSTEM
import ipywidgets
from IPython.display import display
from emdfile import tqdmnd

## 4D STEM Simulation of white noise object

In [None]:
# Define parameters
n = 128
energy = 300e3
k_max = 2
sampling = 1 / k_max / 2
convergence_angle = py4DSTEM.process.utils.electron_wavelength_angstrom(energy)*1e3 #~20 mrad
defocus = 0

scan_step_size = 1 # sim
phi0 = 1.0

In [None]:
def white_noise_object_2D(n, phi0, xp=np):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = xp.arange(1,(n if evenQ else n+1)//2)
    neg_ind = xp.flip(xp.arange(n//2+1,n))

    # random phase
    arr = xp.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = xp.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = xp.fft.ifft2(arr).real
    
    return arr

In [None]:
sx = sy = n//scan_step_size

amplitudes = np.load('./data/white_noise_amplitudes.npy')
intensities = amplitudes**2

## Define Virtual Detectors

In [None]:
def annular_segmented_detectors(
    energy,
    gpts,
    sampling,
    n_angular_bins,
    rotation_offset = 0,
    inner_radius = 0,
    outer_radius = np.inf,
):
    """ """
    nx,ny = gpts
    sx,sy = sampling
    wavelength = py4DSTEM.process.utils.electron_wavelength_angstrom(energy)

    alpha_x = np.fft.fftfreq(nx,sx)*wavelength
    alpha_y = np.fft.fftfreq(ny,sy)*wavelength

    alpha = np.sqrt(alpha_x[:,None]**2 + alpha_y[None,:]**2)
    radial_mask = ((inner_radius*1e-3 <= alpha) & (alpha < outer_radius*1e-3))
    
    theta = (np.arctan2(alpha_y[None,:], alpha_x[:,None]) + rotation_offset) % (2 * np.pi)
    angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
    angular_bins *= radial_mask.astype("int")

    angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
    
    return angular_bins

In [None]:
style = {
    'description_width': 'initial',
}

layout = ipywidgets.Layout(width="250px",height="30px")

inner_collection_angle_slider = ipywidgets.FloatSlider(
    value = convergence_angle/2, min = 0, max = convergence_angle, 
    step = 0.1,
    description = "Inner collection angle",
    style = style,
    layout = layout,
    continuous_update=False,
)

outer_collection_angle_slider = ipywidgets.FloatSlider(
    value = convergence_angle, 
    min = 1, 
    max = convergence_angle*4, 
    step = 0.1,
    description = "Outer collection angle",
    style = style,
    layout = layout,
    continuous_update=False,
)

number_of_segments_slider = ipywidgets.IntSlider(
    value = 4, min = 3, max = 15, step = 1,
    description = "Number of segments",
    style = style,
    layout = layout,
    continuous_update=False,
)

def update_oca(*args):
    outer_collection_angle_slider.min = inner_collection_angle_slider.value*1.05
inner_collection_angle_slider.observe(update_oca, 'value')  

## Calculate COM and iCOM

In [None]:
def compute_com_using_virtual_detectors(
    corner_centered_intensities,
    center_centered_masks, 
    xp=np
):
    """ """

    intensities = xp.asarray(corner_centered_intensities)
    masks = xp.fft.ifftshift(xp.asarray(center_centered_masks),axes=(-1,-2))
    
    com_x = xp.zeros((sx,sy))
    com_y = xp.zeros((sx,sy))
    intensities_sum = intensities.sum((-1,-2))

    kx = ky = xp.fft.fftfreq(n,sampling).astype(xp.float32)
    kxa, kya = xp.meshgrid(kx, ky, indexing='ij')
    
    for mask in masks:
        kxa_i,kya_i=xp.where(mask)
        patches= intensities[:,:,kxa_i,kya_i].sum(-1) / intensities_sum
        com_x += patches * xp.mean(kxa[kxa_i,kya_i])
        com_y += patches * xp.mean(kya[kxa_i,kya_i])
        
    return com_x, com_y, kxa, kya

def integrate_com_using_virtual_detectors(
    corner_centered_intensities,
    center_centered_masks, 
    xp=np
):
    """ """
    com_x, com_y, kxa, kya = compute_com_using_virtual_detectors(
        corner_centered_intensities,
        center_centered_masks,
        xp=xp
    )
    
    k_den = kxa**2 + kya**2
    k_den[0, 0] = np.inf
    k_den = 1 / k_den

    kx_op = -1.0j * kxa * k_den
    ky_op = -1.0j * kya * k_den
    

    icom_fft = xp.fft.fft2(com_x)*kx_op + xp.fft.fft2(com_y)*ky_op
    return xp.real(xp.fft.ifft2(icom_fft))

## Compute CTFs

In [None]:
def compute_ctf(icom,xp=np):
    ctf = xp.abs(xp.fft.fft2(icom))
    # crude DC estimation
    ctf[0,0] = ctf[[-1,-1,-1,0,0,1,1,1],[-1,0,1,-1,1,-1,0,1]].mean()
    return ctf

def radially_average_ctf(ctf):
    q_bins, I_bins, _ , = py4DSTEM.process.phase.utils.return_1D_profile(ctf,pixel_size=(sampling,sampling),device='cpu')
    return q_bins, I_bins

In [None]:
# Compute the inverse error
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
k = np.sqrt(kxa**2 + kya**2)
inverse_error = (k*np.pi/np.sqrt(2))

In [None]:
def update_figure(
    inner_collection_angle,
    outer_collection_angle,
    number_of_segments,
    ):

    virtual_masks_annular = annular_segmented_detectors(
        energy=energy,
        gpts=(n,n),
        sampling=(sampling,sampling),
        n_angular_bins=number_of_segments,
        inner_radius=inner_collection_angle,
        outer_radius=outer_collection_angle,
        rotation_offset=0,
    )   
    

    icom_annular = integrate_com_using_virtual_detectors(
        intensities,
        virtual_masks_annular,
        )
    ctf_annular = compute_ctf(icom_annular) 

    q_bins_annular, I_bins_annular = radially_average_ctf(
        ctf_annular*inverse_error
    )
    
    fig, axs = plt.subplots(1,3,figsize=(9,4))

    py4DSTEM.show(
        virtual_masks_annular,
        combine_images=True,
        figax=(fig, axs[0]),
        title='Detector geometry',
        ticks=False,
        scalebar=False,
    )
    py4DSTEM.show(
        np.fft.fftshift(ctf_annular*inverse_error),
        figax=(fig,axs[1]),
        title='2D CTF',
        ticks=False,
        cmap='turbo',
        scalebar=False,
        vmin=0.00,vmax=0.99,
    )
    axs[2].set_title('Radially averaged CTF')
    axs[2].plot(q_bins_annular,I_bins_annular)
    axs[2].vlines(
        [
            inner_collection_angle/convergence_angle,
            outer_collection_angle/convergence_angle,
            2
        ],
        0,1,colors='k',linestyles='--',linewidth=1,
    )
    axs[2].set_xlim((q_bins_annular[0],2.1))
    # axs[2].set_ylim((0,1))
    axs[2].set_aspect("equal")
    axs[2].set_xlabel('Spatial frequency')
    return None


In [None]:
#| label: app:annular_segmented_detectors
# Annular Segmented Detectors
w = ipywidgets.interactive(
    update_figure,
    inner_collection_angle=inner_collection_angle_slider,
    outer_collection_angle=outer_collection_angle_slider,
    number_of_segments=number_of_segments_slider
)

display(w)
