---
title: Iterative Ptychography
authors: [gvarnavides]
date: 2025-02-01
---

In [1]:
# imports
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from IPython.display import display
import ipywidgets
import py4DSTEM

In [2]:
corner_shifted_dps = np.fromfile("data/FCC-slab-dps-25x25x96x96-float32.npy",dtype=np.float32).reshape((25,25,96,96))
diffraction_amplitudes = np.sqrt(corner_shifted_dps)

shape = np.array(diffraction_amplitudes.shape)
scan_gpts, gpts = np.split(shape,2)
unshuffled_order = np.arange(scan_gpts.prod())
shuffled_order = unshuffled_order.copy()
np.random.shuffle(shuffled_order)

In [3]:
def overlap_projection(potential, probe):
    """ """
    cmplx_potential = np.exp(1j*potential)
    exit_wave = probe * cmplx_potential
    return exit_wave, cmplx_potential

def fourier_projection(amplitudes,exit_wave):
    """ """
    exit_wave_fourier = np.fft.fft2(exit_wave)
    modified_exit_wave_fourier = amplitudes * np.exp(1j*np.angle(exit_wave_fourier))
    gradient_fourier = modified_exit_wave_fourier - exit_wave_fourier
    gradient = np.fft.ifft2(gradient_fourier)
    return gradient

def forward_operator(potential, probe, amplitudes):
    """ """
    exit_wave, cmplx_potential = overlap_projection(potential,probe)
    gradient = fourier_projection(amplitudes,exit_wave)
    return exit_wave, cmplx_potential, gradient

def gradient_descent(potential, probe, amp, step_size):
    """ """
    exit_wave, cmplx_potential, gradient = forward_operator(potential, probe, amp)
    numerator = -1j*gradient * cmplx_potential.conj() * probe.conj() / real_space_normalization
    scaled_gradient = np.real(numerator)*step_size
    return scaled_gradient

def positivity_constraint(potential):
    """ """
    return potential.clip(0)

In [4]:
def mutate_arrays(
    calculate_gradient=True,
    update_potential=True,
    object_positivity=True,
):
    """ """
    # shift probe
    mutable_arrays[3] = py4DSTEM.process.phase.utils.fft_shift(
        real_space_probe,
        mutable_arrays[0], # probe_xy
    )
    
    # extract amplitude
    i,j = (mutable_arrays[0] // 4).astype("int")
    mutable_arrays[4] = diffraction_amplitudes[i,j]
    
    if calculate_gradient:
        mutable_arrays[5] = gradient_descent(
            mutable_arrays[2], # potential
            mutable_arrays[3], # probe
            mutable_arrays[4], # amp
            step_size
        )
    
        # update potential
        if update_potential:
            mutable_arrays[2] += mutable_arrays[5]

    if object_positivity:
        mutable_arrays[2] = positivity_constraint(mutable_arrays[2])

    return None

In [5]:
sampling = (0.255,0.255)
q_sampling = 1/gpts[1]/sampling[1]
energy = 80e3
semiangle = 25
defocus = 150
step_size = 1.0
dpi = 72

real_space_probe = py4DSTEM.process.phase.utils.ComplexProbe(
    energy=energy,
    gpts=gpts,
    sampling=sampling,
    semiangle_cutoff=semiangle,
    defocus=defocus,
).build()._array

scaled_probe = py4DSTEM.visualize.Complex2RGB(
    real_space_probe,
    vmin=0,
    vmax=1
)
real_space_normalization = np.amax(np.abs(real_space_probe)**2)

In [6]:
mutable_arrays = [
    gpts//2, # probe_xy
    shuffled_order, # order
    np.zeros(gpts), # potential
    None, # dummy probe
    None, # dummy amp
    None, # dummy gradient
]

mutate_arrays(
    calculate_gradient=True,
    update_potential=False,
)

In [7]:
scaled_probe = py4DSTEM.visualize.Complex2RGB(
    mutable_arrays[3],
    vmin=0,
    vmax=1
)

scaled_dp = py4DSTEM.visualize.return_scaled_histogram_ordering(
    np.fft.fftshift(mutable_arrays[4]**2),
    normalize=True,
    vmin=0,
    vmax=1,
)[0]

In [8]:
def update_panels(
    plot_probe=True,
    plot_pot=True,
    plot_dp=True,
    plot_grad=True,
):
    """ """
    if plot_probe:
        scaled_probe = py4DSTEM.visualize.Complex2RGB(
            mutable_arrays[3],
            vmin=0,
            vmax=1
        )
        im_probe.set_data(scaled_probe)

    if plot_pot:
        if np.abs(mutable_arrays[2]).max() > 0:
            scaled_pot = py4DSTEM.visualize.return_scaled_histogram_ordering(
                mutable_arrays[2],
                normalize=True,
                vmin=0.005,
                vmax=0.995,
            )[0]
            im_pot.set_data(scaled_pot)
        else:
            im_pot.set_data(mutable_arrays[2])

    if plot_dp:
        scaled_dp = py4DSTEM.visualize.return_scaled_histogram_ordering(
            np.fft.fftshift(mutable_arrays[4]**2),
            normalize=True,
            vmin=0,
            vmax=1,
        )[0]
        im_dp.set_data(scaled_dp)

    if plot_grad:
        im_grad.set_data(mutable_arrays[5])

    fig.canvas.draw_idle()
    return None

def add_scalebar(ax, length, sampling, units, color="white", size_vertical=1, pad=0.5):
    """ """
    bar = AnchoredSizeBar(
        ax.transData,
        length,
        f"{sampling*length:.2f} {units}",
        "lower right",
        pad=pad,
        color=color,
        frameon=False,
        label_top=True,
        size_vertical=size_vertical,
    )
    ax.add_artist(bar)
    return ax, bar

In [9]:
# visualization
with plt.ioff():
    fig,axs = plt.subplots(1,4, figsize=(680/dpi,200/dpi),dpi=dpi)

im_probe = axs[0].imshow(scaled_probe)
im_pot = axs[1].imshow(mutable_arrays[2],cmap='magma',vmin=0,vmax=1)
im_dp = axs[2].imshow(scaled_dp,cmap='magma')
im_grad = axs[3].imshow(
    mutable_arrays[5],
    cmap='PuOr',
    vmin=-real_space_normalization*2,
    vmax=real_space_normalization*2
)

# turn off pot and grad
im_pot.set_visible(False)
im_grad.set_visible(False)
axs[1].set_visible(False)
axs[3].set_visible(False)

titles = [
    "converged electron probe",
    "projected sample potential",
    "diffracted probe intensity",
    "ptychographic gradient",
]

scalebars = [
    {'sampling':sampling[1],'length':5/sampling[1],'units':'Å'},
    {'sampling':sampling[1],'length':5/sampling[1],'units':'Å'},
    {'sampling':q_sampling,'length':1/q_sampling,'units':r'Å$^{-1}$'},
    {'sampling':sampling[1],'length':5/sampling[1],'units':'Å','color':'black'},
]

for ax, title, bar in zip(
    axs,
    titles,
    scalebars
):
    add_scalebar(ax,**bar)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '680px'
fig.canvas.layout.height = "225px"
fig.canvas.toolbar_position = 'bottom'

def onmove(event):
    """ """
    pos = np.array([event.ydata,event.xdata])
    
    if pos[0] is not None:
        mutable_arrays[0] = pos
        mutate_arrays(
            calculate_gradient = gradient_checkbox.value,
            update_potential = potential_checkbox.value,
            # object_positivity = positivity_checkbox.value,
        )
        update_panels(
            plot_pot = potential_checkbox.value,
            plot_grad = gradient_checkbox.value
        )

cid = fig.canvas.mpl_connect('motion_notify_event',onmove)

def reset_potential(*args):
    """ """
    mutable_arrays[2] = np.zeros(gpts)
    mutate_arrays(
        calculate_gradient = gradient_checkbox.value,
        update_potential = False,
        # object_positivity = positivity_checkbox.value,
    )
    update_panels(
        plot_pot = potential_checkbox.value,
        plot_grad = gradient_checkbox.value
    )

def reset_probe_position(*args):
    """ """
    mutable_arrays[0] = gpts/2
    mutate_arrays(
        calculate_gradient = gradient_checkbox.value,
        update_potential = False,
        # object_positivity = positivity_checkbox.value,
    )
    update_panels(
        plot_pot = potential_checkbox.value,
        plot_grad = gradient_checkbox.value
    )

style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="170px",height="30px")

potential_checkbox = ipywidgets.Checkbox(
    value=False,
    description = "show potential",
    style=style,
    layout=layout,
    indent=False,
)

gradient_checkbox = ipywidgets.Checkbox(
    value=False,
    description = "show gradient",
    style=style,
    layout=layout,
    indent=False,
)

reset_potential_button = ipywidgets.Button(
    description = "clear potential",
    style=style,
    layout=layout,
)
reset_potential_button.on_click(reset_potential)

reset_probe_button = ipywidgets.Button(
    description = "reset probe position",
    style=style,
    layout=layout,
)
reset_probe_button.on_click(reset_probe_position)

# positivity_checkbox = ipywidgets.Checkbox(
#     value=True,
#     description = "object positivity",
#     style=style,
#     layout=layout,
#     indent=False,
# )

def toggle_potential(change):
    new = change['new']
    im_pot.set_visible(new)
    axs[1].set_visible(new)
    fig.canvas.draw_idle()
    return None
    
potential_checkbox.observe(toggle_potential,names='value')

def toggle_gradient(change):
    new = change['new']
    im_grad.set_visible(new)
    axs[3].set_visible(new)
    fig.canvas.draw_idle()
    return None
    
gradient_checkbox.observe(toggle_gradient,names='value')

# def toggle_positivity(change):
#     new = change['new']
#     im_pot.set_cmap("magma" if new else "PiYG")
#     fig.canvas.draw_idle()
#     return None
    
# positivity_checkbox.observe(toggle_positivity,names='value')

In [10]:
#| label: app:iterative_ptychography
ipywidgets.VBox(
    [
        fig.canvas,
        ipywidgets.HBox(
            [
                reset_potential_button,
                reset_probe_button,
                potential_checkbox,
                gradient_checkbox,
            ]
        )
    ]
)

VBox(children=(Canvas(footer_visible=False, header_visible=False, layout=Layout(height='225px', width='680px')…