In [None]:
# Jinseok Ryu, Ph.D.
# jinseuk56@gmail.com
# 20230927

import py4DSTEM
import tkinter.filedialog as tkf
import numpy as np
import tifffile
import matplotlib.pyplot as plt
py4DSTEM.__version__


def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def spike_remove(data, percent_thresh, mode):

    pacbed = np.mean(data, axis=(0, 1))
    intensity_integration_map = np.sum(data, axis=(2, 3))

    threshold = np.percentile(intensity_integration_map, percent_thresh)
    if mode == "upper":
        spike_ind = np.where(intensity_integration_map > threshold)
    elif mode == "lower":
        spike_ind = np.where(intensity_integration_map < threshold)
    else:
        print("Wrong mode!")
        return

    print("threshold value = %f"%threshold)
    print("number of abnormal pixels = %d"%len(spike_ind[0]))

    data[spike_ind] = pacbed.copy()

    return data


device = "cpu"
#device = "gpu"

In [None]:
# select the file you want to load
file_adr = tkf.askopenfilename()
print(file_adr)

In [None]:
# load a raw file and specify the calibration info
"""
dataset = py4DSTEM.io.import_file(file_adr)

Rx, Ry = 0.30038461538461536, 0.30038461538461536
R_unit = "A"
Qx, Qy = 1.07, 1.07
Q_unit = "mrad"

dataset.calibration._params['Q_pixel_size'] = Qx
dataset.calibration._params['Q_pixel_units'] = Q_unit
dataset.calibration._params['R_pixel_size'] = Rx
dataset.calibration._params['R_pixel_units'] = R_unit

print(dataset)
print(dataset.calibration)

HT = 200E3 # [V]
"""

In [None]:
# load a tif file and specify the calibration info
_data = tifffile.imread(file_adr)
print(_data.shape)

dataset = py4DSTEM.DataCube(data=_data)

Rx, Ry = 0.3, 0.3
R_unit = "A"
Qx, Qy = 1.0, 1.0
Q_unit = "mrad"

dataset.calibration._params['Q_pixel_size'] = Qx
dataset.calibration._params['Q_pixel_units'] = Q_unit
dataset.calibration._params['R_pixel_size'] = Rx
dataset.calibration._params['R_pixel_units'] = R_unit

print(dataset)
print(dataset.calibration)

HT = 200E3 # [V]

In [None]:
# (optional) invert the dimensions (a, b, c, d) -> (c, d, a, b)

dataset.data = fourd_roll_axis(dataset.data)

print(dataset)
print(dataset.calibration)

In [None]:
# (optional)
dataset.data = np.nan_to_num(dataset.data) # NaN -> 0
#dataset.data = spike_remove(dataset.data, percent_thresh=0.01, mode="lower") # remove spike pixels (replace the spike pixels with the pacbed) -> optional stopgap

In [None]:
dataset.get_dp_mean()
dataset.get_dp_max()

py4DSTEM.show(
    dataset.tree('dp_mean'),
    scaling = 'log',
    cmap = 'jet',
)

# Estimate the radius of the BF disk, and the center coordinates
# Get probe radius in pixels
probe_radius_pixels, probe_qx0, probe_qy0 = dataset.get_probe_size(plot = False)
print(probe_radius_pixels, probe_qx0, probe_qy0)

In [None]:
# Make a virtual bright field and dark field image
expand_BF = 2.0  # expand radius by 2 pixels to encompass the full center disk

center = (probe_qx0, probe_qy0)
radius_BF = probe_radius_pixels + expand_BF
radii_DF = (probe_radius_pixels + expand_BF, 1e3)

py4DSTEM.show(
    dataset.tree('dp_mean'),
    scaling = 'log',
    cmap = 'gray',
    circle = {
    'center':(probe_qx0, probe_qy0),
    'R': probe_radius_pixels + expand_BF,
    'alpha':0.3,
    'fill':True
    }
)

dataset.get_virtual_image(
    mode = 'circle',
    geometry = (center,radius_BF),
    name = 'bright_field',
    shift_center = False,
)
dataset.get_virtual_image(
    mode = 'annulus',
    geometry = (center,radii_DF),
    name = 'dark_field',
    shift_center = False,
)

# plot the virtual images
py4DSTEM.show(
    [
        dataset.tree('bright_field'),
        dataset.tree('dark_field'),               
    ],
    cmap='viridis',
    ticks = False,
    axsize=(4,4),
    title=['Bright Field','Dark Field'],
)

In [None]:
# Initialize the DPC reconstruction
dpc = py4DSTEM.process.phase.DPCReconstruction(
    datacube=dataset,
    energy = HT,
).preprocess()

In [None]:
dpc.reconstruct(
    max_iter=8,
    store_iterations=True,
    reset=True,
).visualize(
    iterations_grid='auto',
    figsize=(12,8)
)

In [None]:
# initialize the reconstruction class
parallax = py4DSTEM.process.phase.ParallaxReconstruction(
    datacube=dataset,
    energy = HT,
    device = device, 
    verbose = True
).preprocess(
    normalize_images=True,
    plot_average_bf=False,
    edge_blend=8,
)

In [None]:
parallax = parallax.reconstruct(
    reset=True,
    regularizer_matrix_size=(1,1),
    regularize_shifts=True,
    running_average=True,
    min_alignment_bin = 2,
    max_iter_at_min_bin = 6,
)

In [None]:
parallax.aberration_fit(
    plot_CTF_compare = True,
)

In [None]:
parallax.aberration_correct(figsize=(10, 10))

In [None]:
depth_sections = parallax.depth_section(depth_angstroms=np.arange(-300, 310, 100), figsize=(20, 10))

In [None]:
# Get the probe convergence semiangle from the pixel size and estimated radius in pixels
semiangle_cutoff = dataset.calibration.get_Q_pixel_size() * probe_radius_pixels
print('semiangle cutoff estimate = ' + str(np.round(semiangle_cutoff,decimals=1)) + ' mrads')

# Get the estimated defocus from the parallax reconstruction - note that defocus dF has the opposite sign as the C1 aberration!
defocus = parallax.aberration_C1
print('estimated defocus         = ' + str(np.round(defocus*-1)) + ' Angstroms')

rotation_degrees = np.rad2deg(parallax.rotation_Q_to_R_rads)
print('estimated rotation        = ' + str(np.round(rotation_degrees)) + ' deg')

In [None]:
ptycho = py4DSTEM.process.phase.SingleslicePtychographicReconstruction(
    datacube=dataset,
    device = device,
    energy = HT,
    semiangle_cutoff = semiangle_cutoff,
    defocus = defocus,
    object_type='potential',
).preprocess(
    plot_center_of_mass = False, 
    plot_rotation = False, 
    plot_probe_overlaps = True, 
    force_com_rotation = rotation_degrees, 
    force_com_transpose = False,
    fit_function = "constant",
)

In [None]:
ptycho = ptycho.reconstruct(
    reset = True,
    store_iterations = True,
    max_iter = 20,
    #step_size = 0.5,
    #gaussian_filter_sigma = 0.3,
    normalization_min=1,
).visualize(
    iterations_grid= 'auto',
    figsize= (16,8),
)

In [None]:
ptycho.visualize(figsize=(16, 8))

In [None]:
rotated_crop = ptycho._crop_rotate_object_fov(ptycho.object)
rotated_shape = rotated_crop.shape

rotate_extent = [
    0,
    ptycho.sampling[1] * rotated_shape[1],
    ptycho.sampling[0] * rotated_shape[0],
    0,
]


fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(np.abs(ptycho.object), cmap='inferno')
ax[1].imshow(np.abs(rotated_crop), cmap='inferno', extent=rotate_extent)
fig.tight_layout()
plt.show()

In [None]:
ptycho_mix = py4DSTEM.process.phase.MixedstatePtychographicReconstruction(
    datacube=dataset,
    verbose=True,
    energy=HT,
    num_probes=2,
    semiangle_cutoff=semiangle_cutoff,
    defocus=defocus,
    device=device,
    object_type='potential',
).preprocess(
    plot_center_of_mass = False,
    plot_rotation = False,
)

In [None]:
ptycho_mix = ptycho_mix.reconstruct(
    reset=True,
    store_iterations=True,
    max_iter = 20,
    normalization_min= 1,
    #gaussian_filter_sigma=0.02,
    #step_size=0.5,
).visualize(
    iterations_grid= 'auto',
    figsize= (16,8)
)

In [None]:
ptycho_mix.visualize(figsize=(16, 8))

In [None]:
rotated_crop = ptycho_mix._crop_rotate_object_fov(ptycho_mix.object)
rotated_shape = rotated_crop.shape

rotate_extent = [
    0,
    ptycho_mix.sampling[1] * rotated_shape[1],
    ptycho_mix.sampling[0] * rotated_shape[0],
    0,
]


fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(np.abs(ptycho_mix.object), cmap='inferno')
ax[1].imshow(np.abs(rotated_crop), cmap='inferno', extent=rotate_extent)
fig.tight_layout()
plt.show()

In [None]:
intensity_arrays = np.array([np.abs(probe)**2 for probe in ptycho_mix.probe])
probe_ratio      = [np.sum(intensity_array)/intensity_arrays.sum() for intensity_array in intensity_arrays]

py4DSTEM.show(
    [
        np.abs(ptycho_mix.probe_fourier[0])**2,
        np.abs(ptycho_mix.probe_fourier[1])**2
    ],
    scalebar=True,
    pixelsize=ptycho_mix.angular_sampling[0],
    pixelunits="mrad",
    ticks=False,
    title=[
        f"Probe 0 intensity: {probe_ratio[0]*100:.1f}%",
        f"Probe 1 intensity: {probe_ratio[1]*100:.1f}%"
    ]
)

In [None]:
num_slice = 8
slice_thickness = 2.23

ms_ptycho = py4DSTEM.process.phase.MultislicePtychographicReconstruction(
    datacube=dataset,
    num_slices=num_slice,
    slice_thicknesses=slice_thickness,
    verbose=True,
    energy=HT,
    defocus=defocus,
    semiangle_cutoff=semiangle_cutoff,
    device=device,
).preprocess(
    plot_center_of_mass = False,
    plot_rotation=False,
)

In [None]:
ms_ptycho = ms_ptycho.reconstruct(
    reset=True,
    store_iterations=True,
    max_iter = 20,
    normalization_min=1,
).visualize(
    iterations_grid = 'auto'
)

In [None]:
ms_ptycho.visualize(figsize=(16, 8))

In [None]:
fig, ax = plt.subplots(1, num_slice+1, figsize=(4*(num_slice+1), 4))

slice_object = ms_ptycho.object_iterations[29]

for i in range(num_slice):
    rotated_crop = ms_ptycho._crop_rotate_object_fov(slice_object[i])
    rotated_shape = rotated_crop.shape

    rotate_extent = [
        0,
        ms_ptycho.sampling[1] * rotated_shape[1],
        ms_ptycho.sampling[0] * rotated_shape[0],
        0,
    ]
    ax[i].imshow(np.angle(rotated_crop), cmap='inferno', extent=rotate_extent)

sum_object = ms_ptycho._crop_rotate_object_fov(np.sum(slice_object, axis=0))
ax[-1].imshow(np.angle(sum_object), cmap='inferno', extent=rotate_extent)
fig.tight_layout()
plt.show()

In [None]:
ms_ptycho_tune = py4DSTEM.process.phase.MultislicePtychographicReconstruction(
    datacube=dataset,
    num_slices=num_slice,
    slice_thicknesses=slice_thickness,
    verbose=True,
    energy=HT,
    defocus=defocus,
    semiangle_cutoff=semiangle_cutoff,
    device=device,
).preprocess(
    plot_center_of_mass = False,
    plot_rotation=False,
)

In [None]:
ms_ptycho_tune = ms_ptycho_tune.tune_num_slices_and_thicknesses(
        num_slices_guess=num_slice,
        thicknesses_guess=slice_thickness,
        num_slices_step_size=1,
        thicknesses_step_size=20,
        num_slices_values=3,
        num_thicknesses_values=3,
        update_defocus=False,
        max_iter=5,
        plot_reconstructions=True,
        plot_convergence=True,
        return_values=True,
)