# Notebook to submit a slurm job for ptychography
## PtyREX / abTEM (latest or legacy) / py4DSTEM
### singleslice / mixed-state / multislice (unavailable for abTEM latest) / mixed-state multislice (py4DSTEM only)
Created by Jinseok Ryu (jinseok.ryu@diamond.ac.uk)

In [None]:
import os
import sys
import time
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import h5py as h5
import json
import py4DSTEM
import hyperspy.api as hs

print(py4DSTEM.__version__)

device = "cpu"

data_path = '/dls/e02/data/2024/cm37231-3/processing/TiO2_jryu/imaging/mg38695-2/Merlin/sample_5/20240803_101410/20240803_101410_data.hdf5'
# meta_path = ''
meta_path = data_path[:-10]+".hdf"
data_name = data_path.split("/")[-1].split(".")[0]
mask_path = '/dls/science/groups/e02/Ryu/RYU_at_ePSIC/python_ptycho/mask/29042024_12bitmask.h5'
save_dir = os.path.dirname(data_path) # directory for saving the results
print(meta_path)
print(data_path)
print(data_name)
print(mask_path)

In [None]:
def Meta2Config(acc,nCL,aps):
    if acc == 80e3:
        rot_angle = 238.5
        if aps == 1:
            conv_angle = 41.65
        elif aps == 2:
            conv_angle = 31.74
        elif aps == 3:
            conv_angle = 24.80
        elif aps == 4:
            conv_angle =15.44
        else:
            print('the aperture being used has unknwon convergence semi angle please consult confluence page or collect calibration data')
    elif acc == 200e3:
        rot_angle = -75
        if aps == 1:
            conv_angle = 37.7
        elif aps == 2:
            conv_angle = 28.8
        elif aps == 3:
            conv_angle = 22.4
        elif aps == 4:
            conv_angle = 14.0
        elif aps == 5:
            conv_angle = 6.4
    elif acc == 300e3:
        rot_angle = -85.5
        if aps == 1:
            conv_angle = 44.7
        elif aps == 2:
            conv_angle = 34.1
        elif aps == 3:
            conv_angle = 26.7
        elif aps == 4:
            conv_angle =16.7
        else:
            print('the aperture being used has unknwon convergence semi angle please consult confluence page or collect calibration data')
    else:
        print('Rotation angle for this acceleration voltage is unknown, please collect calibration data. Rotation angle being set to zero')
        rot_angle = 0

    return rot_angle, conv_angle


if meta_path != '':
    with h5.File(meta_path,'r') as f:
        print("----------------------------------------------------------")
        print(f['metadata']["defocus(nm)"])
        print(f['metadata']["defocus(nm)"][()])
        defocus_exp = f['metadata']["defocus(nm)"][()]*10 # Angstrom
        print("----------------------------------------------------------")
        print(f['metadata']["ht_value(V)"])
        print(f['metadata']["ht_value(V)"][()])
        HT = f['metadata']["ht_value(V)"][()]
        print("----------------------------------------------------------")
        print(f['metadata']["step_size(m)"])
        print(f['metadata']["step_size(m)"][()])
        scan_step = f['metadata']["step_size(m)"][()] * 1E10 # Angstrom
        print("----------------------------------------------------------")
        print(f['metadata']['nominal_camera_length(m)'])
        print(f['metadata']['nominal_camera_length(m)'][()])
        camera_length = f['metadata']['nominal_camera_length(m)'][()]
        print("----------------------------------------------------------")
        print(f['metadata']['aperture_size'])
        print(f['metadata']['aperture_size'][()])
        aperture = f['metadata']['aperture_size'][()]
        print("----------------------------------------------------------")
        
        rotation_angle_exp, semiangle = Meta2Config(HT, camera_length, aperture)

In [None]:
if mask_path != '':
    try:
        with h5.File(mask_path,'r') as f:
            mask = f['data']['mask'][()]
        mask = np.bool_(mask)
        mask = np.invert(mask)
        mask = mask.astype(np.float32)

    except:
        with h5.File(mask_path,'r') as f:
            mask = f['root']['np.array']['data'][()]
        mask = np.bool_(mask)
        mask = np.invert(mask)
        mask = mask.astype(np.float32)        
    
    print(type(mask))
    print(mask.dtype)
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(mask)
    fig.tight_layout()
    plt.show()

In [None]:
%%time
### VERY IMPORTANT VARIABLE ###
# semiangle = 34.1 # mrad
# defocus_exp = 0.0
### ####################### ###

if data_path.split(".")[-1] == "hspy": 
# This is for the simulated 4DSTEM data using 'submit_abTEM_4DSTEM_simulation.ipynb'
# stored in /dls/science/groups/e02/Ryu/RYU_at_ePSIC/multislice_simulation/submit_abtem/submit_abtem_4DSTEM_simulation.ipynb
    original_stack = hs.load(data_path)
    print(original_stack)
    n_dim = len(original_stack.data.shape)
    scale = []
    origin = []
    unit = []
    size = []
    
    for i in range(n_dim):
        print(original_stack.axes_manager[i].scale, original_stack.axes_manager[i].offset, original_stack.axes_manager[i].units, original_stack.axes_manager[i].size)
        scale.append(original_stack.axes_manager[i].scale)
        origin.append(original_stack.axes_manager[i].offset)
        unit.append(original_stack.axes_manager[i].units)
        size.append(original_stack.axes_manager[i].size)
    
    HT = eval(original_stack.metadata["HT"])
    if HT < 1000:
        HT *= 1000
    try:
        defocus_exp = eval(original_stack.metadata["defocus"])
        semiangle = eval(original_stack.metadata["semiangle"])
    except:
        print("No metadata found")
    scan_step = scale[0]
    print("HT: ", HT)
    print("experimental defocus: ", defocus_exp)
    print("semiangle: ", semiangle)
    print("scan step: ", scan_step)
    original_stack = original_stack.data
    det_name = 'ePSIC_EDX'
    data_key = 'Experiments/__unnamed__/data'


elif data_path.split(".")[-1] == "hdf" or data_path.split(".")[-1] == "hdf5" or data_path.split(".")[-1] == "h5":
    # try:
    #     original_stack = hs.load(data_path, reader="HSPY", lazy=True)
    #     print(original_stack)
    #     original_stack = original_stack.data
    try:    
        f = h5.File(data_path,'r')
        print(f)
        original_stack = f['Experiments']['__unnamed__']['data'][:]
        f.close()
        det_name = 'ePSIC_EDX'
        data_key = 'Experiments/__unnamed__/data'
    
    except:
        f = h5.File(data_path,'r')
        print(f)
        original_stack = f['data']['frames'][:]
        f.close()
        det_name = 'pty_data'
        data_key = "data/frames"

elif data_path.split(".")[-1] == "dm4":
    original_stack = hs.load(data_path)
    print(original_stack)
    n_dim = len(original_stack.data.shape)
    scale = []
    origin = []
    unit = []
    size = []
    
    for i in range(n_dim):
        print(original_stack.axes_manager[i].scale, original_stack.axes_manager[i].offset, original_stack.axes_manager[i].units, original_stack.axes_manager[i].size)
        scale.append(original_stack.axes_manager[i].scale)
        origin.append(original_stack.axes_manager[i].offset)
        unit.append(original_stack.axes_manager[i].units)
        size.append(original_stack.axes_manager[i].size)
    
    HT = 1000 * original_stack.metadata['Acquisition_instrument']['TEM']['beam_energy']
    scan_step = scale[0] * 10
    print("HT: ", HT)
    print("experimental defocus: ", defocus_exp)
    print("semiangle: ", semiangle)
    print("scan step: ", scan_step)
    original_stack = original_stack.data
    original_stack = original_stack.astype(np.float32)
    original_stack -= np.min(original_stack)
    original_stack /= np.max(original_stack)
    original_stack *= 128.0
    # det_name = 'ePSIC_EDX'
    # data_key = 'Experiments/__unnamed__/data'    

else:
    print("Wrong data format!")

original_stack = original_stack.astype(np.float32)
print(original_stack.dtype)
print(original_stack.shape)
print(np.min(original_stack), np.max(original_stack))

In [None]:
%%time
# masking

if mask_path != '' and type(mask) == np.ndarray:
    for i in range(original_stack.shape[0]):
        for j in range(original_stack.shape[1]):
            original_stack[i, j] = np.multiply(original_stack[i, j], mask)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(original_stack.sum(axis=(2,3)), cmap='inferno')
ax[1].imshow(original_stack[0, 0], cmap='jet')
fig.tight_layout()
plt.show()

In [None]:
# scan region cropping
crop_R = False
crop_R_region = (16,-16,16,-16)

print(original_stack[crop_R_region[0]:crop_R_region[1], crop_R_region[2]:crop_R_region[3]].shape)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(original_stack[crop_R_region[0]:crop_R_region[1], crop_R_region[2]:crop_R_region[3]].sum(axis=(2,3)), cmap='inferno')
fig.tight_layout()
plt.show()

In [None]:
if crop_R:
    being_processed = original_stack[crop_R_region[0]:crop_R_region[1], crop_R_region[2]:crop_R_region[3]]
    bp_shape = being_processed.shape

else:
    being_processed = original_stack
    bp_shape = being_processed.shape

print(bp_shape)

In [None]:
# DP region cropping
crop_Q = False
crop_Q_region = (64,-31,64,-31)

print(original_stack[0, 0, crop_Q_region[0]:crop_Q_region[1], crop_Q_region[2]:crop_Q_region[3]].shape)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(original_stack[0, 0, crop_Q_region[0]:crop_Q_region[1], crop_Q_region[2]:crop_Q_region[3]], cmap='jet')
fig.tight_layout()
plt.show()

In [None]:
if crop_Q:
    being_processed = being_processed[:, :, crop_Q_region[0]:crop_Q_region[1], crop_Q_region[2]:crop_Q_region[3]]
    bp_shape = being_processed.shape

else:
    bp_shape = being_processed.shape

print(bp_shape)

In [None]:
# Q binning information

binsize = 1
if bp_shape[3] % binsize != 0:
    remove_ = int(bp_shape[3] % binsize)
    being_processed = being_processed[:, :, :-remove_, :-remove_]
    bp_shape = being_processed.shape
print(bp_shape)

In [None]:
dataset = py4DSTEM.DataCube(data=being_processed)
print("original dataset")
print(dataset)

del being_processed # to reduce the memory usage
# del original_stack # to reduce the memory usage

if binsize > 1:
    dataset.bin_Q(binsize)
    print("after binning")
    print(dataset)

dataset.get_dp_mean()
dataset.get_dp_max()

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(dataset.tree('dp_mean')[:, :], cmap='jet')
fig.tight_layout()
plt.show()

In [None]:
probe_radius_pixels, probe_qx0, probe_qy0 = dataset.get_probe_size(thresh_lower=0.065, thresh_upper=0.15, N=100, plot=True)
plt.show()

dataset.calibration._params['Q_pixel_size'] = semiangle / probe_radius_pixels
dataset.calibration._params['Q_pixel_units'] = "mrad"
dataset.calibration._params['R_pixel_size'] = scan_step
dataset.calibration._params['R_pixel_units'] = "A"

print(dataset)
print(dataset.calibration)

In [None]:
light_speed = 299792458 # speed of light [m/s]
m0 = 9.1093837E-31 # mass of an electron [kg]
planck = 6.62607015E-34 # h [m^2*kg/s]
e_volt = 1.602176634E-19 # eV [m^2*kg/s^2]
wavelength = planck/np.sqrt(2*m0*HT*e_volt*(1+HT*e_volt/(2*m0*light_speed**2)))*1E10

R_extent = dataset.calibration.R_pixel_size * dataset.shape[0]
k_extent = dataset.calibration.Q_pixel_size * dataset.shape[2]
recon_R_pixel_size = wavelength / k_extent * 1000
recon_R_extent = wavelength / dataset.calibration.Q_pixel_size * 1000

print("scan step size: %f Å"%dataset.calibration.R_pixel_size)
print("reconstructed pixel size: %f Å"%recon_R_pixel_size)
print("scan extent: %f Å"%R_extent)
print("reconstructed extent: %f Å"%recon_R_extent)

In [None]:
# Make a virtual bright field and dark field image
center = (probe_qx0, probe_qy0)
radius_BF = probe_radius_pixels
radii_DF = (probe_radius_pixels, int(dataset.Q_Nx/2))

dataset.get_virtual_image(
    mode = 'circle',
    geometry = (center,radius_BF),
    name = 'bright_field',
    shift_center = False,
)
dataset.get_virtual_image(
    mode = 'annulus',
    geometry = (center,radii_DF),
    name = 'dark_field',
    shift_center = False,
)

py4DSTEM.show([dataset.tree('bright_field'),
                dataset.tree('dark_field')],
            cmap='inferno',
            figsize=(10, 10))
plt.show()

# fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# ax[0].imshow(dataset.tree('bright_field')[:, :], cmap="inferno")
# ax[0].set_title("BF image")
# ax[1].imshow(dataset.tree('dark_field')[:, :], cmap="inferno")
# ax[1].set_title("ADF image [%.1f, %.1f] mrad"%(radii_DF[0]*dataset.Q_pixel_size, radii_DF[1]*dataset.Q_pixel_size))
# fig.tight_layout()
# plt.show()

In [None]:
dpc = py4DSTEM.process.phase.DPC(
    datacube=dataset,
    energy = HT,
).preprocess()
plt.show()

In [None]:
dpc.reconstruct(
    max_iter=8,
    store_iterations=True,
    reset=True,
    gaussian_filter_sigma=0.1,
    gaussian_filter=True,
    q_lowpass=None,
    q_highpass=0.001
).visualize(
    iterations_grid='auto',
    figsize=(16, 10)
)
plt.show()

In [None]:
dpc_cor = py4DSTEM.process.phase.DPC(
    datacube=dataset,
    energy=HT,
    verbose=False,
).preprocess(
    force_com_rotation = np.rad2deg(dpc._rotation_best_rad),
    force_com_transpose = False,
)
plt.show()

In [None]:
dpc_cor.reconstruct(
    max_iter=8,
    store_iterations=True,
    reset=True,
    gaussian_filter_sigma=0.1,
    gaussian_filter=True,
    q_lowpass=None,
    q_highpass=0.001
).visualize(
    iterations_grid='auto',
    figsize=(16, 10)
)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(dpc._com_normalized_y, cmap="bwr")
ax[0].set_title("CoMx")
ax[1].imshow(dpc._com_normalized_x, cmap="bwr")
ax[1].set_title("CoMy")
ax[2].imshow(np.sqrt(dpc._com_normalized_y**2 + dpc._com_normalized_x**2), cmap="inferno")
ax[2].set_title("Magnitude of CoM")
fig.tight_layout()
plt.show()

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(dpc_cor._com_normalized_y, cmap="bwr")
ax[0].set_title("CoMx - rotation corrected")
ax[1].imshow(dpc_cor._com_normalized_x, cmap="bwr")
ax[1].set_title("CoMy - rotation corrected")
ax[2].imshow(np.sqrt(dpc_cor._com_normalized_y**2 + dpc_cor._com_normalized_x**2), cmap="inferno")
ax[2].set_title("Magnitude of CoM - rotation corrected")
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(dpc.object_phase, cmap="inferno")
ax[0].set_title("iCoM")
ax[1].imshow(dpc_cor.object_phase, cmap="inferno")
ax[1].set_title("iCoM - rotation corrected")
fig.tight_layout()
plt.show()

In [None]:
parallax = py4DSTEM.process.phase.Parallax(
    datacube=dataset,
    energy = HT,
    device = device, 
    verbose = True
).preprocess(
    normalize_images=True,
    plot_average_bf=False,
    edge_blend=8,
)

In [None]:
parallax = parallax.reconstruct(
    reset=True,
    regularizer_matrix_size=(1,1),
    regularize_shifts=True,
    running_average=True,
    min_alignment_bin = 2,
    num_iter_at_min_bin = 4,
)
plt.show()

In [None]:
parallax.show_shifts()
plt.show()

parallax.subpixel_alignment(
    #kde_upsample_factor=2,
    kde_sigma_px=0.125,
    plot_upsampled_BF_comparison=True,
    plot_upsampled_FFT_comparison=True,
)
plt.show()

In [None]:
parallax.aberration_fit(
    plot_CTF_comparison=True,
)
plt.show()

parallax.aberration_correct(figsize=(5, 5))
plt.show()

In [None]:
depth_sections = parallax.depth_section(depth_angstroms=np.arange(0, 200, 20), figsize=(12, 10))
plt.show()

In [None]:
# Get the probe convergence semiangle from the pixel size and estimated radius in pixels
semiangle_cutoff_estimated = dataset.calibration.get_Q_pixel_size() * probe_radius_pixels
print('semiangle cutoff estimate = ' + str(np.round(semiangle_cutoff_estimated, decimals=1)) + ' mrads')

# Get the estimated defocus from the parallax reconstruction - note that defocus dF has the opposite sign as the C1 aberration!
defocus_estimated = -parallax.aberration_C1
print('estimated defocus         = ' + str(np.round(defocus_estimated)) + ' Angstroms')

rotation_degrees_estimated = np.rad2deg(parallax.rotation_Q_to_R_rads)
print('estimated rotation        = ' + str(np.round(rotation_degrees_estimated)) + ' deg')

# use the calculated calibration information
semiangle_cutoff = semiangle_cutoff_estimated
defocus = defocus_estimated
rotation_degrees = rotation_degrees_estimated

In [None]:
# use the nominal calibration information
semiangle_cutoff = semiangle
defocus = defocus_exp
rotation_degrees = rotation_angle_exp

In [None]:
# generate the information file
# Please note that
# pyfftw (0.13.1) causes annoying warnings in the reconstruction process using abtem.
# pyfftw (0.12.0) does not cause any warnings -> I use my own python environment and you can access too.
# The reconstructed object and probe will be saved as .hspy files - able to see via DAWN

base_dir = '/dls/science/groups/e02/Ryu/RYU_at_ePSIC/python_ptycho' # directory storing python scripts for slurm job submission
package = "py4dstem" # "ptyrex", "abtem_latest", "abtem_legacy", or "py4dstem"

if package == 'ptyrex':
    device = "gpu" # "cpu" or "gpu"
    gpu_type = "pascal" # "pascal" or "volta"
    gpu_node = 4
else:
    device = "gpu" # "cpu" or "gpu"
    gpu_type = "volta" # "pascal" or "volta"
    gpu_node = 1

data_name = data_path.split("/")[-1].split(".")[0]
data_name = time.strftime("%Y%m%d_%H%M%S") + "_" + data_name

script_path = base_dir + '/python_ptycho.py'
    
save_path = save_dir + '/%s_ptycho/'%package
if not os.path.exists(save_path):
    os.makedirs(save_path)
    print("Save directory created: "+save_path)
if package == "ptyrex":
    cal_name = data_name+"_calibration_info.json"
else:
    cal_name = data_name+"_calibration_info.txt"
sub_name = data_name+"_submit.sh"
log_path = save_path + data_name + "_"

ptycho_type = "multislice" # "singleslice", "mixed-state", "multislice" or "mixed-state-multislice"
num_iteration = 30
num_probe = 4 # for mixed-state ptychography or mixed-state multislice ptychography
num_slice = 12 # for multislice ptychography or mixed-state multislice ptychography
slice_thickness = 20 # for multislice ptychography or mixed-state multislice ptychography

# PtyRex reconstruction parameters
camera_length = (1000 * 0.000055 / (semiangle/(binsize*probe_radius_pixels)))
ptyrex_crop = list(dataset.shape[2:]) # []
shift_radius = 0.5
shift_trial = 3

if ptycho_type == "multislice" or ptycho_type == "mixed-state-multislice":
    ptyrex_template = '/dls/science/groups/e02/Ryu/RYU_at_ePSIC/python_ptycho/ptyrex_template/multislice_example.json'
elif ptycho_type == "singleslice":
    ptyrex_template = '/dls/science/groups/e02/Ryu/RYU_at_ePSIC/python_ptycho/ptyrex_template/singleslice_example.json'
else:
    print("Wrong ptychography type!")
    
# abTEM recontruction parameters
alpha = 0.5 # also used in PtyREX
beta = 0.5 # also used in PtyREX
probe_position_correction = False
pre_position_correction_update_steps = 1
position_step_size = 0.5
step_size_damping_rate = 0.995

# py4DSTEM reconstruction parameters
max_batch_size = 256
reconstruction_method = "GD"
reconstruction_parameter = 1.0
normalization_min = 1.0
identical_slices = False
object_positivity = False
tv_denoise = False
tv_denoise_weights = [0.1,0.1]
tv_denoise_inner_iter = 40
tv_denoise_chambolle = False
tv_denoise_weight_chambolle = 0.01

# Expected memory necessary for reconstruction
print("data shape = ", dataset.shape)
num_element = np.prod(dataset.shape)
dp_element = np.prod(dataset.shape[2:])

print("memory for the data array (complex64) = %.3f Bytes"%(num_element*8))
print("memory for the data array (complex64) = %.3f Gb"%(num_element*8/2**30))

print("memory for the data array (complex128) = %.3f Bytes"%(num_element*16))
print("memory for the data array (complex128) = %.3f Gb"%(num_element*16/2**30))

print("memory for the data array (complex64) - number of slices considered = %.3f Bytes"%(num_element*8*num_slice))
print("memory for the data array (complex64) - number of slices considered = %.3f Gb"%(num_element*8*num_slice/2**30))

print("memory for the data array (complex128) - number of slices considered = %.3f Bytes"%(num_element*16*num_slice))
print("memory for the data array (complex128) - number of slices considered = %.3f Gb"%(num_element*16*num_slice/2**30))

print("***************************************")
print("abTEM needs more than %.3f Gb memory"%(num_element*8/2**30))
print("py4DSTEM needs more than %.3f Gb memory without max batch"%(num_element*8*num_slice/2**30))
print("py4DSTEM needs more than %.3f Gb memory with max batch"%(max_batch_size*dp_element*8*num_slice/2**30))

In [None]:
if package == "ptyrex":
    with open(ptyrex_template,'r') as f:
        pty_expt = json.load(f)

    pty_expt['base_dir'] = save_path
    pty_expt['process']['save_dir'] = save_path
    pty_expt['process']['common']['scan']['rotation'] = rotation_degrees

    pty_expt['experiment']['experiment_ID'] = data_name
    pty_expt['experiment']['data']['data_path'] = data_path
    pty_expt['experiment']['data']['data_key'] = data_key
    if mask_path != '':
        pty_expt['experiment']['data']['dead_pixel_flag'] = 1
        pty_expt['experiment']['data']['dead_pixel_path'] = mask_path
    else:
        pty_expt['experiment']['data']['dead_pixel_flag'] = 0
    pty_expt['experiment']['detector']['position'] = [0, 0, camera_length]
    pty_expt['experiment']['optics']['lens']['alpha'] = 2*semiangle_cutoff*1E-3
    pty_expt['experiment']['optics']['lens']['defocus'] = [defocus*1E-10, defocus*1E-10]
    
    pty_expt['process']['common']['scan']['N'] = [bp_shape[0], bp_shape[1]]
    pty_expt['process']['common']['scan']['dR'] = [scan_step*1E-10, scan_step*1E-10]
    if mask_path != '':
        pty_expt['process']['common']['detector']['mask_flag'] = 1
    else:
        pty_expt['process']['common']['detector']['mask_flag'] = 0
    pty_expt['process']['common']['detector']['bin'] = [binsize, binsize]
    if ptyrex_crop != []:
        pty_expt['process']['common']['detector']['crop'] = ptyrex_crop
    else:
        pty_expt['process']['common']['detector']['crop'] = [bp_shape[2], bp_shape[3]]
    pty_expt['process']['common']['detector']['name'] = det_name
    pty_expt['process']['common']['probe']['convergence'] = 2*semiangle_cutoff*1E-3
    pty_expt['process']['common']['source']['energy'] = [HT]
    pty_expt['process']['save_prefix'] = data_name
    pty_expt['process']['PIE']['iterations'] = num_iteration
    pty_expt['process']['PIE']['object']['alpha'] = alpha
    pty_expt['process']['PIE']['probe']['alpha'] = beta
    pty_expt['process']['PIE']['scan']['shift radius'] = shift_radius
    pty_expt['process']['PIE']['scan']['shift trials'] = shift_trial
    if ptycho_type == "mixed-state-multislice" or ptycho_type == "mixed-state":
        pty_expt['process']['PIE']['source']['sx'] = num_probe
    else:
        pty_expt['process']['PIE']['source']['sx'] = 1

    if ptycho_type == "multislice" or ptycho_type == "mixed-state-multislice":
        pty_expt['process']['PIE']['MultiSlice']['S_distance'] = slice_thickness * 1E-10
        pty_expt['process']['PIE']['MultiSlice']['slices'] = num_slice
    
    
    with open(save_path+cal_name, 'w') as f:
        json.dump(pty_expt, f, indent=4)

else:
    with open(save_path+cal_name, 'w') as f:
        f.write("package : "+package+"\n")
        f.write("data_path : "+data_path+"\n")
        f.write("data_name : "+data_name+"\n")
        f.write("mask_path : "+mask_path+"\n")
        f.write("save_path : "+save_path+"\n")
        f.write("device : "+device+"\n")
        f.write("HT : %f\n"%HT)
        f.write("scan_step : %f"%scan_step+"\n")
        if crop_R:
            f.write("crop_R : "+"True"+"\n")
            f.write("crop_R_region : "+"(%d,%d,%d,%d)"%(crop_R_region[0], crop_R_region[1], crop_R_region[2], crop_R_region[3])+"\n")
        else:
            f.write("crop_R : "+"False"+"\n")
        if crop_Q:
            f.write("crop_Q : "+"True"+"\n")
            f.write("crop_Q_region : "+"(%d,%d,%d,%d)"%(crop_Q_region[0], crop_Q_region[1], crop_Q_region[2], crop_Q_region[3])+"\n")
        else:
            f.write("crop_Q : "+"False"+"\n")    
        if fill:
            f.write("fill : "+"True"+"\n")
            f.write("fill_region : "+"(%d,%d,%d,%d)"%(fill_region[0], fill_region[1], fill_region[2], fill_region[3])+"\n")
            f.write("num_neighbor : %d\n"%num_neighbor)
        else:
            f.write("fill : "+"False"+"\n")
        f.write("center : "+"(%f,%f)"%(probe_qx0,probe_qy0)+"\n")
        f.write("reciprocal_pixel_size : "+"%f"%(semiangle*1E-3/probe_radius_pixels)+"\n")
        f.write("binsize : %d\n"%binsize)
        f.write("semiangle : %f\n"%(semiangle_cutoff*1E-3))
        f.write("defocus : %f\n"%defocus)
        f.write("rotation : %f\n"%rotation_degrees)
        f.write("ptycho_type : "+ptycho_type+"\n")
        f.write("num_iteration : %d\n"%num_iteration)
        
        if ptycho_type == "singleslice":
            print("singleslice ptychography")
        elif ptycho_type == "mixed-state":
            print("mixed-state ptychography")
            f.write("num_probe : %d\n"%num_probe)
        elif ptycho_type == "multislice":
            print("multislice ptychography")
            f.write("num_slice : %d\n"%num_slice)
            f.write("slice_thickness : %f\n"%slice_thickness)
        elif ptycho_type == "mixed-state-multislice":
            print("mixed-state-multislice ptychography")
            f.write("num_probe : %d\n"%num_probe)
            f.write("num_slice : %d\n"%num_slice)
            f.write("slice_thickness : %f\n"%slice_thickness)
        else:
            print("Wrong type!")
    
        if package == "abtem_latest" or package == "abtem_legacy":
            f.write("alpha : %f\n"%alpha)
            f.write("beta : %f\n"%beta)
            f.write("step_size_damping_rate : %f\n"%step_size_damping_rate)
            if probe_position_correction:
                f.write("probe_position_correction : True\n")
                f.write("pre_position_correction_update_steps : %d\n"%pre_position_correction_update_steps)
                f.write("position_step_size : %f\n"%position_step_size)
            else:
                f.write("probe_position_correction : False\n")
            
        elif package == "py4dstem":
            f.write("max_batch_size : %d\n"%max_batch_size)
            f.write("reconstruction_method : %s\n"%reconstruction_method)
            f.write("reconstruction_parameter : %f\n"%reconstruction_parameter)
            f.write("normalization_min : %f\n"%normalization_min)
            if identical_slices:
                f.write("identical_slices : True\n")
            else:
                f.write("identical_slices : False\n")
            if object_positivity:
                f.write("object_positivity : True\n")
            else:
                f.write("object_positivity : False\n")
                
            if tv_denoise:
                f.write("tv_denoise : True\n")
                f.write("tv_denoise_weights : [%f,%f]\n"%(tv_denoise_weights[0], tv_denoise_weights[1]))
                f.write("tv_denoise_inner_iter : %d\n"%tv_denoise_inner_iter)
            else:
                f.write("tv_denoise : False\n")
                f.write("tv_denoise_weights : None\n")
                f.write("tv_denoise_inner_iter : None\n")

            if ptycho_type == "multislice" or ptycho_type == "mixed-state-multislice":
                if tv_denoise_chambolle:
                    f.write("tv_denoise_chambolle : True\n")
                    f.write("tv_denoise_weight_chambolle : %f\n"%tv_denoise_weight_chambolle)
                else:
                    f.write("tv_denoise_chambolle : False\n")
                    f.write("tv_denoise_weight_chambolle : None\n")
        else:
            print("Wrong package!")

In [None]:
# generate batch file
if package == "ptyrex":
    with open(save_path+sub_name, 'w') as f:
        f.write("#!/usr/bin/env bash\n")
        f.write("#SBATCH --partition=cs05r\n")
        f.write("#SBATCH --job-name=ptyrex_recon\n")
        f.write("#SBATCH --nodes=1\n")
        f.write("#SBATCH --ntasks-per-node=4\n")
        f.write("#SBATCH --cpus-per-task=1\n")
        f.write("#SBATCH --time=12:00:00\n")
        f.write("#SBATCH --output=%s%%j.out\n"%log_path)
        f.write("#SBATCH --error=%s%%j.error\n\n"%log_path)
        if gpu_type == "pascal":
            f.write("#SBATCH --constraint=NVIDIA_Pascal\n")
        elif gpu_type == "volta":
            f.write("#SBATCH --constraint=NVIDIA_Volta\n")
        f.write("#SBATCH --gpus-per-node=%d\n"%gpu_node)
        f.write("#SBATCH --mem=0G\n\n")
        
        f.write("cd /dls/science/groups/e02/Ryu/python_library/ptyrex_temp_5_test/PtyREX\n")
        f.write("module load python/cuda11.7\n")
        f.write("module load hdf5-plugin/1.12\n")
        
        f.write("mpirun -np %d ptyrex_recon -c $1"%gpu_node)

else:
    with open(save_path+sub_name, 'w') as f:
        f.write("#!/usr/bin/env bash\n")
        if device == "gpu":
            f.write("#SBATCH --partition=cs05r\n")
            f.write("#SBATCH --gpus-per-node=%d\n"%gpu_node)
            if gpu_type == "pascal":
                f.write("#SBATCH --constraint=NVIDIA_Pascal\n")
            elif gpu_type == "volta":
                f.write("#SBATCH --constraint=NVIDIA_Volta\n")
            else:
                print("Wrong gpu setting!")
        elif device == "cpu":
            f.write("#SBATCH --partition=cs04r\n")
        else:
            print("Wrong device!\n")
    
        if package == "py4dstem":
            f.write("#SBATCH --job-name=py4dstem_recon\n")
        elif package == "abtem_latest":
            f.write("#SBATCH --job-name=abtem_latest_recon\n")
        elif package == "abtem_legacy":
            f.write("#SBATCH --job-name=abtem_legacy_recon\n")
        else:
            print("Wrong package!\n")
            
        f.write("#SBATCH --nodes=1\n")
        f.write("#SBATCH --ntasks-per-node=4\n")
        f.write("#SBATCH --cpus-per-task=1\n")
        f.write("#SBATCH --time=48:00:00\n")
        f.write("#SBATCH --mem=0G\n")
        f.write("#SBATCH --output=%s%%j.out\n"%log_path)
        f.write("#SBATCH --error=%s%%j.error\n\n"%log_path)
    
        if package == "py4dstem":
            f.write("module load python/3\n")
            f.write("conda activate /dls/science/groups/e02/Ryu/py_env/python_ptycho\n")
        elif package == "abtem_latest":
            f.write("module load python/3\n")
            f.write("conda activate /dls/science/groups/e02/Ryu/py_env/python_ptycho\n")
        elif package == "abtem_legacy":
            f.write("module load python/3\n")
            f.write("conda activate /dls/science/groups/e02/Ryu/py_env/abtem_multi\n")
        else:
            print("Wrong package!\n")
    
        f.write("python %s %s%s"%(script_path, save_path, cal_name))

In [None]:
if package == "ptyrex":
    sshProcess = subprocess.Popen(['ssh',
                                   '-tt',
                                   'wilson'],
                                   stdin=subprocess.PIPE, 
                                   stdout = subprocess.PIPE,
                                   universal_newlines=True,
                                   bufsize=0)
    sshProcess.stdin.write("echo END\n")
    sshProcess.stdin.write("sbatch "+save_path+sub_name+' '+save_path+cal_name+"\n")
    sshProcess.stdin.write("uptime\n")
    sshProcess.stdin.write("logout\n")
    sshProcess.stdin.close()    

else:
    sshProcess = subprocess.Popen(['ssh',
                                   '-tt',
                                   'wilson'],
                                   stdin=subprocess.PIPE, 
                                   stdout = subprocess.PIPE,
                                   universal_newlines=True,
                                   bufsize=0)
    sshProcess.stdin.write("echo END\n")
    sshProcess.stdin.write("sbatch "+save_path+sub_name+"\n")
    sshProcess.stdin.write("uptime\n")
    sshProcess.stdin.write("logout\n")
    sshProcess.stdin.close()