In [None]:
# Jinseok Ryu, Ph.D.
# jinseuk56@gmail.com
# 20230927

from abtem import GridScan, PixelatedDetector, Potential, Probe, show_atoms, SMatrix, AnnularDetector
from abtem.detect import PixelatedDetector
from abtem.reconstruct import MixedStatePtychographicOperator, MultislicePtychographicOperator, RegularizedPtychographicOperator
from abtem.measure import Measurement, Calibration, bandlimit, center_of_mass
from abtem.utils import energy2wavelength
from ase.io import read
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.patches import Rectangle
from matplotlib.widgets import RectangleSelector
import matplotlib.patches as pch
import tifffile
import tkinter.filedialog as tkf

def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    return stack

def spike_remove(data, percent_thresh, mode):

    pacbed = np.mean(data, axis=(0, 1))
    intensity_integration_map = np.sum(data, axis=(2, 3))

    threshold = np.percentile(intensity_integration_map, percent_thresh)
    if mode == "upper":
        spike_ind = np.where(intensity_integration_map > threshold)
    elif mode == "lower":
        spike_ind = np.where(intensity_integration_map < threshold)
    else:
        print("Wrong mode!")
        return

    print("threshold value = %f"%threshold)
    print("number of abnormal pixels = %d"%len(spike_ind[0]))

    data[spike_ind] = pacbed.copy()

    return data

# gpu option - CUDA
device = "cpu"
#device = "gpu"

In [None]:
raw_adr = tkf.askopenfilename()
print(raw_adr)

In [None]:
datatype = "float32"
f_shape = [256, 256, 128, 128] # the shape of the 4D-STEM data [scanning_y, scanning_x, DP_y, DP_x]
o_shape = [f_shape[0], f_shape[1], f_shape[2]+2, f_shape[3]]

if raw_adr[-3:] == "raw":
    f_stack = load_binary_4D_stack(raw_adr, datatype, o_shape, f_shape, log_scale=False)
    f_stack = np.flip(f_stack, axis=2)
    f_stack = np.nan_to_num(f_stack)
    
elif raw_adr[-3:] == "tif" or raw_adr[:-4] == "tiff":
    f_stack = tifffile.imread(raw_adr)
    f_stack  = np.nan_to_num(f_stack )
    
else:
    print("The format of the file is not supported here")
    
print(f_stack.shape)
print(f_stack.min(), f_stack.max())
print(f_stack.mean())

# remove spike pixels (replace the spike pixels with the pacbed) -> optional stopgap
#f_stack = spike_remove(f_stack, percent_thresh=0.01, mode="lower")

f_stack = f_stack.clip(min=0.0)

In [None]:
rotation_angle                    = 76*np.pi/180 # degree
energy                            = 200E3 # acceleration voltage [V]
semiangle                         = 19.8 # mrad
step_size_real_space              = (0.408, 0.408) # angstrom
reciprocal_space_sampling_mrad    = (0.619, 0.619) # mrad
center_x = 63.90314659829032 # diffaction center x, [pixel]
center_y = 64.06042187282752 # diffraction center y, [pixel]

In [None]:
x_cb_object = Calibration(offset=0, sampling=step_size_real_space[0], units="Ã", name="x")
y_cb_object = Calibration(offset=0, sampling=step_size_real_space[1], units="Ã", name="y")
dx_cb_object = Calibration(offset=-center_x*reciprocal_space_sampling_mrad[0], sampling=reciprocal_space_sampling_mrad[0], units="mrad", name="alpha_x")
dy_cb_object = Calibration(offset=-center_y*reciprocal_space_sampling_mrad[1], sampling=reciprocal_space_sampling_mrad[1], units="mrad", name="alpha_y")


experimental_measurement = Measurement(f_stack, calibrations=[x_cb_object, y_cb_object, dx_cb_object, dy_cb_object])
adf = AnnularDetector(inner=semiangle,outer=semiangle*2).integrate(experimental_measurement)
print(semiangle, reciprocal_space_sampling_mrad[0]*f_stack.shape[2]/2)
pacbed = experimental_measurement.mean(axis=(0, 1))

fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(8, 6))

adf.show(ax=ax1,title='ADF image')
pacbed.show(power=0.25, ax=ax2, cmap='inferno',title='PACBED pattern')

fig.tight_layout()
plt.show()

In [None]:
# iDPC
icom = center_of_mass(experimental_measurement, return_icom=True)

icom.show()
plt.show()

In [None]:
print(experimental_measurement.shape)
print(*experimental_measurement.calibration_limits, sep="\n")
for i in range(experimental_measurement.dimensions):
    print(experimental_measurement.calibrations[i].name, experimental_measurement.calibrations[i].units, experimental_measurement.calibrations[i].sampling)

In [None]:
# reconstruction sampling scale / extent

experimental_measurement_sampling = tuple(energy2wavelength(energy)*1000/(cal.sampling * pixels) 
                                                      for cal,pixels 
                                                      in zip(experimental_measurement.calibrations[-2:], 
                                                             experimental_measurement.shape[-2:]))

print(f'pixelated_measurement sampling: {experimental_measurement_sampling}')

experimental_measurement_extent = tuple(energy2wavelength(energy)*1000/(cal.sampling) 
                                                      for cal
                                                      in experimental_measurement.calibrations[-2:])

print(f'pixelated_measurement extent: {experimental_measurement_extent}')

recon_extent_x_ratio = experimental_measurement.calibration_limits[0][1] / experimental_measurement_extent[0]
recon_extent_y_ratio = experimental_measurement.calibration_limits[1][1] / experimental_measurement_extent[1]
print(recon_extent_x_ratio, recon_extent_y_ratio)

recon_shape = list(experimental_measurement.shape[2:])
if recon_extent_x_ratio > 1.0:
    recon_shape[0] = recon_extent_x_ratio*experimental_measurement.shape[3]

if recon_extent_y_ratio > 1.0:
    recon_shape[1] = recon_extent_y_ratio*experimental_measurement.shape[2]

recon_shape = np.asarray(recon_shape).astype(np.int16)
print(recon_shape)

In [None]:
n_iter = 5
n_probe = 7 # for mixed-state ptychography or mixed-state multislice ptychography
n_slice = 15 # for multislice ptychography or mixed-state multislice ptychography
slice_thickness = 20 # for multislice ptychography

# abTEM recontruction parameters
alpha = 0.5 # also used in PtyREX
beta = 0.2 # also used in PtyREX
probe_position_correction = False
pre_position_correction_update_steps = 1
position_step_size = 0.5
step_size_damping_rate = 0.995

In [None]:
# rPIE
RPIE_operator = RegularizedPtychographicOperator(experimental_measurement,
                                                               semiangle_cutoff=semiangle,
                                                               energy=energy,
                                                               device=device,
                                                               parameters={'object_px_padding':(0, 0)}).preprocess()

reconstruction_parameters = {}
reconstruction_parameters['alpha'] = alpha
reconstruction_parameters['beta'] = beta
if probe_position_correction:
    reconstruction_parameters['pre_position_correction_update_steps'] = RPIE_operator._num_diffraction_patterns * int(pre_position_correction_update_steps)
    reconstruction_parameters['position_step_size'] = position_step_size
    reconstruction_parameters['step_size_damping_rate'] = step_size_damping_rate

exp_objects, exp_probes, exp_positions, exp_sse  = RPIE_operator.reconstruct(
    max_iterations = n_iter,
    random_seed=1,
    return_iterations=True,
    parameters=reconstruction_parameters)

In [None]:
print(exp_objects[-1].shape)
for i in range(exp_objects[-1].dimensions):
    print(exp_objects[-1].calibrations[i].name, exp_objects[-1].calibrations[i].units,
          exp_objects[-1].calibrations[i].sampling)

plot_every = int(n_iter/5)

fig, axes = plt.subplots(2, int(np.ceil(len(exp_objects) / plot_every))+1, figsize=(20, 8))

for i, j in enumerate(range(0,len(exp_objects), plot_every)):
    axes[0,i].imshow(np.angle(exp_objects[j].array).T, origin='lower', cmap='gray')
    axes[0,i].set_title('iteration: %d, SSE: %.2e'%(j+1, exp_sse[j]))
    axes[1,i].imshow(np.abs(exp_probes[j].array).T ** 2, origin='lower', cmap='gray')
    axes[0,i].axis("off")
    axes[1,i].axis("off")

axes[0,-1].imshow(np.angle(exp_objects[-1].array).T, origin='lower', cmap='gray')
axes[0,-1].set_title('iteration: %d, SSE: %.2e'%(n_iter, exp_sse[-1]))
axes[1,-1].imshow(np.abs(exp_probes[-1].array).T ** 2, origin='lower', cmap='gray')
axes[0,-1].axis("off")
axes[1,-1].axis("off")

fig.tight_layout()
plt.show()

errs = []
for i in range(n_iter):
    errs.append(exp_sse[i].get())

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(np.arange(2,n_iter)+1, errs[2:], 'k-')
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.angle(exp_objects[-1].array).T[:int(recon_shape[1]*recon_extent_y_ratio), :int(recon_shape[0]*recon_extent_x_ratio)], 
          cmap='inferno', origin="lower",
          extent=[0, experimental_measurement.calibration_limits[0][1], 0, experimental_measurement.calibration_limits[1][1]])
fig.tight_layout()
plt.show()

In [None]:
objects = []
probes = []

for i in range(n_iter):
    objects.append(np.angle(exp_objects[i].array).T[:int(recon_shape[1]*recon_extent_y_ratio), 
                                                   :int(recon_shape[0]*recon_extent_x_ratio)])
    probes.append(exp_probes[i].array)


tifffile.imwrite("rpie_object.tif", np.asarray(objects))
tifffile.imwrite("rpie_probe.tif", np.asarray(probes).astype(np.complex64))

In [None]:
# mixed state ptychography

Mixed_operator = MixedStatePtychographicOperator(experimental_measurement,
                                                               num_probes=n_probe,
                                                               semiangle_cutoff=semiangle,
                                                               energy=energy,
                                                               device=device,
                                                               parameters={'object_px_padding':(0,0 )}).preprocess()

reconstruction_parameters = {}
reconstruction_parameters['alpha'] = alpha
reconstruction_parameters['beta'] = beta
if probe_position_correction:
    reconstruction_parameters['pre_position_correction_update_steps'] = Mixed_operator._num_diffraction_patterns * int(pre_position_correction_update_steps)
    reconstruction_parameters['position_step_size'] = position_step_size
    reconstruction_parameters['step_size_damping_rate'] = step_size_damping_rate

exp_objects, exp_probes, exp_positions, exp_sse  = Mixed_operator.reconstruct(
    max_iterations = n_iter,
    random_seed=1,
    return_iterations=True,
    parameters=reconstruction_parameters)

In [None]:
print(exp_objects[-1].shape)
for i in range(exp_objects[-1].dimensions):
    print(exp_objects[-1].calibrations[i].name, exp_objects[-1].calibrations[i].units,
          exp_objects[-1].calibrations[i].sampling)
    
plot_every = 5

fig, axes = plt.subplots(1, int(np.ceil(len(exp_objects) / plot_every))+1, figsize=(20, 8))

for i, j in enumerate(range(0,len(exp_objects), plot_every)):
    axes[i].imshow(np.angle(exp_objects[j].array).T, origin='lower', cmap='gray')
    axes[i].set_title('iteration: %d, SSE: %.2e'%(j+1, exp_sse[j]))
    axes[i].axis("off")

axes[-1].imshow(np.angle(exp_objects[-1].array).T, origin='lower', cmap='gray')
axes[-1].set_title('iteration: %d, SSE: %.2e'%(n_iter, exp_sse[-1]))
axes[-1].axis("off")

fig.tight_layout()
plt.show()

errs = []
for i in range(n_iter):
    errs.append(exp_sse[i].get())

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(np.arange(2,n_iter)+1, errs[2:], 'k-')
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.angle(exp_objects[-1].array).T[:int(recon_shape[1]*recon_extent_y_ratio), :int(recon_shape[0]*recon_extent_x_ratio)], 
          cmap='inferno', origin="lower",
          extent=[0, experimental_measurement.calibration_limits[0][1], 0, experimental_measurement.calibration_limits[1][1]])
fig.tight_layout()
plt.show()

fig, ax = plt.subplots(1, n_probe, figsize=(5*n_probe, 3))
for i in range(n_probe):
    ax[i].imshow(np.abs(exp_probes[-1][i].array).T**2, origin="lower", cmap="gray")
    ax[i].axis("off")
fig.tight_layout()
plt.show()

In [None]:
objects = []
probes = []

for i in range(n_iter):
    objects.append(np.angle(exp_objects[i].array).T[:int(recon_shape[1]*recon_extent_y_ratio), 
                                                   :int(recon_shape[0]*recon_extent_x_ratio)])
    probe_iter = []
    for j in range(n_probe):
        probe_iter.append(exp_probes[i][j].array)
    probes.append(probe_iter)

tifffile.imwrite("mixed_rpie_object.tif", np.asarray(objects))
tifffile.imwrite("mixed_rpie_probe.tif", np.asarray(probes).astype(np.complex64))

In [None]:
# multislice ptychography
Multislice_operator = MultislicePtychographicOperator(experimental_measurement,
                                                        semiangle_cutoff=semiangle,
                                                        energy=energy,
                                                        num_slices = n_slice,
                                                        slice_thicknesses = slice_thickness,
                                                        device=device,
                                                        parameters={'object_px_padding':(0,0)}).preprocess()

reconstruction_parameters = {}
reconstruction_parameters['alpha'] = alpha
reconstruction_parameters['beta'] = beta
if probe_position_correction:
    reconstruction_parameters['pre_position_correction_update_steps'] = Mixed_operator._num_diffraction_patterns * int(pre_position_correction_update_steps)
    reconstruction_parameters['position_step_size'] = position_step_size
    reconstruction_parameters['step_size_damping_rate'] = step_size_damping_rate

mspie_objects, mspie_probes, mspie_positions, mspie_sse = Multislice_operator.reconstruct(
    max_iterations = n_iter,
    verbose=True,
    random_seed=1,
    return_iterations=True,
    parameters=reconstruction_parameters)

In [None]:
plot_every = int(n_iter/5)

fig, axes = plt.subplots(2, int(np.ceil(len(mspie_objects) / plot_every))+1, figsize=(20, 8))

for i, j in enumerate(range(0,len(mspie_objects), plot_every)):
    axes[0,i].imshow(np.sum(np.angle(mspie_objects[j].array), axis=0).T, origin='lower', cmap='gray')
    axes[0,i].set_title('iteration: %d, SSE: %.2e'%(j+1, mspie_sse[j]))
    axes[1,i].imshow(np.sum(np.abs(mspie_probes[j].array), axis=0).T**2, origin='lower', cmap='gray')
    axes[0,i].axis("off")
    axes[1,i].axis("off")

axes[0,-1].imshow(np.angle(mspie_objects[-1].array).sum(axis=0).T, origin='lower', cmap='gray')
axes[0,-1].set_title('iteration: %d, SSE: %.2e'%(n_iter, mspie_sse[-1]))
axes[1,-1].imshow(np.sum(np.abs(mspie_probes[-1].array), axis=0).T**2, origin='lower', cmap='gray')
axes[0,-1].axis("off")
axes[1,-1].axis("off")

fig.tight_layout()
plt.show()

errs = []
for i in range(n_iter):
    errs.append(mspie_sse[i].get())

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(np.arange(2,n_iter)+1, errs[2:], 'k-')
fig.tight_layout()
plt.show()

In [None]:
fig, axd = plt.subplots(1, 2, figsize=(10, 5))
mspie_objects[-1].angle().sum(0).show(ax=axd[0], title=f"SSE = {float(mspie_sse[-1]):.3e}", cmap='inferno')
mspie_probes[-1][0].intensity().show(ax=axd[1], cmap="gray", power=0.5)
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.angle(mspie_objects[-1].array).sum(0).T[:int(recon_shape[1]*recon_extent_y_ratio), :int(recon_shape[0]*recon_extent_x_ratio)], 
          cmap='inferno', origin="lower",
          extent=[0, experimental_measurement.calibration_limits[0][1], 0, experimental_measurement.calibration_limits[1][1]])
fig.tight_layout()
plt.show()

In [None]:
for j in range(n_iter):
    fig, ax = plt.subplots(1, n_slice, figsize=(5*n_slice, 5))
    for i in range(n_slice):
        ax[i].imshow(np.angle(mspie_objects[j].array[i]).T[:int(recon_shape[1]*recon_extent_y_ratio), :int(recon_shape[0]*recon_extent_x_ratio)], 
                     cmap="inferno", origin="lower",
                     extent=[0, experimental_measurement.calibration_limits[0][1], 0, experimental_measurement.calibration_limits[1][1]])
        
    fig.tight_layout()
    plt.show()

In [None]:
objects = []
probes = []
for i in range(n_iter):
    object_iter = []
    probe_iter = []
    for j in range(n_slice):
        object_iter.append(np.angle(mspie_objects[i][j].array).T[:int(recon_shape[1]*recon_extent_y_ratio), 
                                                   :int(recon_shape[0]*recon_extent_x_ratio)])
        probe_iter.append(mspie_probes[i][j].array)
    objects.append(object_iter)
    probes.append(probe_iter)

tifffile.imwrite("multislice_object.tif", np.asarray(objects))
tifffile.imwrite("multislice_probe.tif", np.asarray(probes).astype(np.complex64))