# Figure - Iteration time comparison


In [None]:
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ptyrad.data_io import load_hdf5, load_pt
from ptyrad.utils import center_crop
import h5py
from natsort import natsorted
from skimage.metrics import structural_similarity as ssim
from scipy.ndimage import fourier_shift


## Use SSIM with ground truth

In [None]:
gt_path = "data/paper/simu_tBL_WSe2/phonon_temporal_spatial_N16384_dp128.hdf5"

# These two seems too noisy, orblur needs to be stronger, or we need to add sparsity
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_ozblur1_oathr0.98_opos_sng1.0_1e6/"
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.2_ozblur1_oathr0.98_opos_sng1.0_1e6/"
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.3_ozblur1_oathr0.98_opos_sng1.0_1e6/"

# All the below ones show good SSIM
## 0.918, smooth, good, no sparsity
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.4_ozblur1_oathr0.98_opos_sng1.0_1e6/"

## 0.93, kink, very good
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.4_ozblur1_oathr0.98_opos_sng1.0_spr0.01_1e6/"

## 0.924, smooth, very good
dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.4_ozblur1_oathr0.98_opos_sng1.0_spr0.03_1e6/"

## 0.907, smooth, good but a bit overlapping
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.3_ozblur1_oathr0.98_opos_sng1.0_spr0.03_1e6/"

## 0.91, wavy, okay, no positivity
# dir_ptyrad = "output/paper/simu_tBL_WSe2/20250219_ptyrad_convergence/full_N16384_dp128_flipT001_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.4_ozblur1_oathr0.98_sng1.0_1e6/"

dir_ptyshv = "data/paper/simu_tBL_WSe2/6/roi6_Ndp128_step128\MLs_L1_p12_g16_pc1_noModel_mm_Ns6_dz2_reg0.1"
dir_py4dstem = "output/paper/simu_tBL_WSe2/20250219_py4dstem_convergence/20250219_N16384_dp128_flipT001_random16_p12_6slice_dz2_update0.5_kzf0.1_1e6/"

# Get file names sorted
files_ptyrad = []
for file in os.listdir(dir_ptyrad):
    if file.startswith('model'):
        files_ptyrad.append(file)
files_ptyrad.sort()
print(f"Found {len(files_ptyrad)} ptyrad files")

files_ptyshv = []
for file in os.listdir(dir_ptyshv):
    if file.startswith('Niter'):
        files_ptyshv.append(file)
files_ptyshv = natsorted(files_ptyshv)
print(f"Found {len(files_ptyshv)} ptyshv files")
        
files_py4dstem = []
for file in os.listdir(dir_py4dstem):
    if file.startswith('model'):
        files_py4dstem.append(file)
files_py4dstem.sort()
print(f"Found {len(files_py4dstem)} py4dstem files")

In [None]:
# Get ptyrad results
objs_ptyrad = []
for file in files_ptyrad:
    ckpt = load_pt(os.path.join(dir_ptyrad, file))
    obj = ckpt['optimizable_tensors']['objp'].detach().cpu().numpy().squeeze().sum(0)
    objs_ptyrad.append(obj)
iter_time_ptyrad = ckpt['avg_iter_t']

# Get ptyshv results
objs_ptyshv = []
for file in files_ptyshv:
    with h5py.File(os.path.join(dir_ptyshv, file), "r") as hdf_file:
        obj = np.angle(hdf_file['object'][()].view('complex128')).sum(0)
        obj = obj.T
        iter_time_ptyshv = hdf_file['outputs']['avgTimePerIter'][()].squeeze()[()]
        objs_ptyshv.append(obj)
        
# Get py4dstem results
objs_py4dstem = []
for file in files_py4dstem:
    with h5py.File(os.path.join(dir_py4dstem, file), "r") as hdf_file:
        obj = np.angle(hdf_file['object'][()].view('complex64')).sum(0)
        iter_time_py4dstem = hdf_file['iter_times'][()].mean(0)
        objs_py4dstem.append(obj)

In [None]:
crop_height, crop_width = 170, 170 # Note that the window size and alignment would also impact the SSIM value
# The image shifts are obtained via SIFT registration using imageJ

# Get images
gt_phase = center_crop(load_hdf5(gt_path, dataset_key = 'gt_phase').sum(0), crop_height, crop_width, offset = [1,1])

ssims_ptyrad = []
for img_ptyrad in objs_ptyrad:
    img_ptyrad = center_crop(img_ptyrad, crop_height, crop_width, offset = [0,0])
    shift = [-0.205,-0.2675]
    img_ptyrad = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img_ptyrad), shift)))
    img_ptyrad -= img_ptyrad.min()
    
    ssim_ptyrad = ssim(img_ptyrad, gt_phase, data_range=(gt_phase.max() - gt_phase.min()))
    print(ssim_ptyrad)
    ssims_ptyrad.append(ssim_ptyrad)
    
ssims_ptyshv = []
for img_ptyshv in objs_ptyshv:
    img_ptyshv = center_crop(img_ptyshv, crop_height, crop_width, offset = [-1,-1])
    shift = [-0.35,-0.416]
    img_ptyshv = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img_ptyshv), shift)))
    img_ptyshv -= img_ptyshv.min()
    
    ssim_ptyshv = ssim(img_ptyshv, gt_phase, data_range=(gt_phase.max() - gt_phase.min()))
    ssims_ptyshv.append(ssim_ptyshv)
        
ssims_py4dstem = []
for img_py4dstem in objs_py4dstem:
    img_py4dstem = center_crop(img_py4dstem, crop_height, crop_width, offset = [1,1])
    shift = [0.32, 0.22]
    img_py4dstem = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img_py4dstem), shift)))
    img_py4dstem -= img_py4dstem.min()
    
    ssim_py4dstem = ssim(img_py4dstem, gt_phase, data_range=(gt_phase.max() - gt_phase.min()))
    ssims_py4dstem.append(ssim_py4dstem)

In [None]:
# from tifffile import imwrite
# imwrite('ptyrad.tif', np.float32(objs_ptyrad))
# imwrite('ptyshv.tif', np.float32(objs_ptyshv))
# imwrite('py4dstem.tif', np.float32(objs_py4dstem))
# imwrite('stack.tif', np.stack([np.float32(gt_phase), np.float32(img_ptyrad), np.float32(img_ptyshv), np.float32(img_py4dstem)]).astype('float32'))

In [None]:
iterations = np.array([i for i in range(1,20)] + [i for i in range(20,210,10)])

plt.figure()
plt.plot(iterations[:len(ssims_ptyrad)]*iter_time_ptyrad, ssims_ptyrad)
plt.plot(iterations[:len(ssims_ptyshv)]*iter_time_ptyshv, np.array(ssims_ptyshv))
plt.plot(iterations[:len(ssims_py4dstem)]*iter_time_py4dstem, np.array(ssims_py4dstem))
plt.xscale('log')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.ticker as tck

mpl.rc('xtick', direction='in')
mpl.rc('xtick.major', width=1, size=3.5)
mpl.rc('xtick.minor', width=1, size=2)
mpl.rc('ytick', direction='in')
mpl.rc('ytick.major', width=1, size=3.5)
mpl.rc('ytick.minor', width=1, size=2)

# Plot Data
# Note that these data are all computed on 20GB slice

iterations = np.array([i for i in range(1,20)] + [i for i in range(20,210,10)])

batch_sizes = np.array([16, 32, 64, 128, 256, 512, 1024])
time_ptyrad_batch = np.array([19.0, 16.45, 15.18, 14.37, 13.91, 13.84, 13.71])
time_ptyshv_batch = np.array([84.08, 42.31, 23.47, 17.18, 13.02, 11.87, 11.15])
time_py4dstem_batch = np.array([145.57, 75.85, 43.40, 23.20, 15.47, 12.32, 11.51])

probe_modes = np.array([1, 3, 6, 12])
time_ptyrad_modes = np.array([12.87, 14.18, 19.00, 26.46])
time_ptyshv_modes = np.array([19.07, 45.46, 84.08, 161.49])
time_py4dstem_modes = np.array([29.70, 74.71, 145.57, 282.91])

num_slices = np.array([1, 3, 6])
time_ptyrad_slices = np.array([7.09, 11.34, 19.00])
time_ptyshv_slices = np.array([22.24, 44.38, 84.08])
time_py4dstem_slices = np.array([24.26, 82.01, 145.57])

# Speedup factor
speedup_factor_iter = iter_time_py4dstem/iter_time_ptyrad
speedup_factor_batch = time_py4dstem_batch[0] / time_ptyrad_batch[0]
speedup_factor_modes = time_py4dstem_modes[-1] / time_ptyrad_modes[-1]
speedup_factor_slices = time_py4dstem_slices[-1] / time_ptyrad_slices[-1]

# Global font/line control
linewidth = 0.8
markersize = 3
fontsize_title = 9
fontsize_subtitle = 7
fontsize_label = 9
fontsize_legend = 5
annotate_size = 7

# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=600)

# Panel 1: Error vs. Iteration
axes[0, 0].plot(iterations[:len(ssims_ptyrad)]  *iter_time_ptyrad, ssims_ptyrad, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize, zorder=10)
axes[0, 0].plot(iterations[:len(ssims_ptyshv)]  *iter_time_ptyshv, ssims_ptyshv, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[0, 0].plot(iterations[:len(ssims_py4dstem)]*iter_time_py4dstem, ssims_py4dstem, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)

axes[0, 0].set_title('SSIM vs. Reconstruction Time', fontsize=fontsize_title)
axes[0, 0].text(0.70, 0.20, '200 iterations, batch size 16 \n12 probes, 6 slices', 
                transform=axes[0, 0].transAxes, ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[0, 0].set_xlabel('Reconstruction Time (sec)', fontsize=fontsize_label)
axes[0, 0].set_ylabel('SSIM', fontsize=fontsize_label)
axes[0, 0].text(-0.2, 1.08, 'a', transform=axes[0, 0].transAxes, fontsize=16, fontweight='bold')  # Label "a"
# axes[0, 0].text(0.58, 0.725, f'{np.round(speedup_factor_iter,1)}x faster', transform=axes[0, 0].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "24x faster"
axes[0, 0].annotate(f'{iter_time_ptyrad:.0f} sec', (iter_time_ptyrad, ssims_ptyrad[0]), textcoords="offset points", xytext=(20,0), ha='center', color='C0', fontsize=annotate_size)
axes[0, 0].annotate(f'{iter_time_ptyshv:.0f} sec', (iter_time_ptyshv, ssims_ptyshv[0]), textcoords="offset points", xytext=(3,-10), ha='center', color='C1', fontsize=annotate_size)
axes[0, 0].annotate(f'{iter_time_py4dstem:.0f} sec', (iter_time_py4dstem, ssims_py4dstem[0]), textcoords="offset points", xytext=(18,-6), ha='center', color='C2', fontsize=annotate_size)

axes[0, 0].legend(fontsize=fontsize_legend)
# axes[0, 0].set_ylim(0.2,1.2)
axes[0, 0].set_xscale('log')
axes[0, 0].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 2: Iter Time vs. Batch Sizes
axes[0, 1].plot(np.arange(len(batch_sizes)), time_ptyrad_batch, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[0, 1].plot(np.arange(len(batch_sizes)), time_ptyshv_batch, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[0, 1].plot(np.arange(len(batch_sizes)), time_py4dstem_batch, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[0, 1].set_title('Iteration Time vs. Batch Sizes', fontsize=fontsize_title)
axes[0, 1].text(0.5, 0.95, '6 probes, 6 slices', transform=axes[0, 1].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[0, 1].set_xlabel('Batch Sizes', fontsize=fontsize_label)
axes[0, 1].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[0, 1].set_xticks(np.arange(len(batch_sizes)))
axes[0, 1].set_xticklabels([str(int(b)) for b in batch_sizes], fontsize=fontsize_label)
axes[0, 1].text(-0.1, 1.08, 'b', transform=axes[0, 1].transAxes, fontsize=16, fontweight='bold')  # Label "b"
# axes[0, 1].text(0.01, 0.08, f'{np.round(speedup_factor_batch,1)}x faster', transform=axes[0, 1].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "7.7x faster"
axes[0, 1].annotate(f'{time_ptyrad_batch[0]:.0f} sec', (0, time_ptyrad_batch[0]), textcoords="offset points", xytext=(10,5), ha='center', color='C0', fontsize=annotate_size)
axes[0, 1].annotate(f'{time_ptyshv_batch[0]:.0f} sec', (0, time_ptyshv_batch[0]), textcoords="offset points", xytext=(9,4), ha='center', color='C1', fontsize=annotate_size)
axes[0, 1].annotate(f'{time_py4dstem_batch[0]:.0f} sec', (0, time_py4dstem_batch[0]), textcoords="offset points", xytext=(10,5), ha='center', color='C2', fontsize=annotate_size)

axes[0, 1].legend(fontsize=fontsize_legend)
axes[0, 1].set_ylim(-10,165)
axes[0, 1].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 3: Iter Time vs. Probe Modes
axes[1, 0].plot(probe_modes, time_ptyrad_modes, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[1, 0].plot(probe_modes, time_ptyshv_modes, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[1, 0].plot(probe_modes, time_py4dstem_modes, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[1, 0].set_title('Iteration Time vs. Probe Modes', fontsize=fontsize_title)
axes[1, 0].text(0.5, 0.95, 'Batch size 16, 6 slices', transform=axes[1, 0].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[1, 0].set_xlabel('Number of Probe Modes', fontsize=fontsize_label)
axes[1, 0].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[1, 0].set_xticks(probe_modes)
axes[1, 0].set_xticklabels([str(int(p)) for p in probe_modes], fontsize=fontsize_label)
axes[1, 0].text(-0.2, 1.08, 'c', transform=axes[1, 0].transAxes, fontsize=16, fontweight='bold')  # Label "c"
# axes[1, 0].text(0.75, 0.15, f'{np.round(speedup_factor_modes,1)}x faster', transform=axes[1, 0].transAxes, fontsize=7, fontweight='bold', c='C0')  # Label "10.7x faster"
axes[1, 0].annotate(f'{time_ptyrad_modes[-2]:.0f} sec', (probe_modes[-2], time_ptyrad_modes[-2]), textcoords="offset points", xytext=(15,-7), ha='center', color='C0', fontsize=annotate_size)
axes[1, 0].annotate(f'{time_ptyshv_modes[-2]:.0f} sec', (probe_modes[-2], time_ptyshv_modes[-2]), textcoords="offset points", xytext=(14,-6), ha='center', color='C1', fontsize=annotate_size)
axes[1, 0].annotate(f'{time_py4dstem_modes[-2]:.0f} sec', (probe_modes[-2], time_py4dstem_modes[-2]), textcoords="offset points", xytext=(14,-9), ha='center', color='C2', fontsize=annotate_size)

axes[1, 0].annotate(f'{time_ptyrad_modes[-1]:.0f} sec', (probe_modes[-1], time_ptyrad_modes[-1]), textcoords="offset points", xytext=(-5,5), ha='center', color='C0', fontsize=annotate_size)
axes[1, 0].annotate(f'{time_ptyshv_modes[-1]:.0f} sec', (probe_modes[-1], time_ptyshv_modes[-1]), textcoords="offset points", xytext=(-5,4), ha='center', color='C1', fontsize=annotate_size)
axes[1, 0].annotate(f'{time_py4dstem_modes[-1]:.0f} sec', (probe_modes[-1], time_py4dstem_modes[-1]), textcoords="offset points", xytext=(-5,5), ha='center', color='C2', fontsize=annotate_size)

axes[1, 0].legend(fontsize=fontsize_legend)
axes[1, 0].set_ylim(-15,325)
axes[1, 0].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Panel 4: Iter Time vs. Slices
axes[1, 1].plot(num_slices, time_ptyrad_slices, label='PtyRAD', marker='o', linewidth=linewidth, markersize=markersize)
axes[1, 1].plot(num_slices, time_ptyshv_slices, label='PtyShv', marker='s', linewidth=linewidth, markersize=markersize)
axes[1, 1].plot(num_slices, time_py4dstem_slices, label='py4DSTEM', marker='^', linewidth=linewidth, markersize=markersize)
axes[1, 1].set_title('Iteration Time vs. Slices', fontsize=fontsize_title)
axes[1, 1].text(0.5, 0.95, 'Batch size 16, 6 probes', transform=axes[1, 1].transAxes, 
                ha='center', va='top', fontsize=fontsize_subtitle, color='k')
axes[1, 1].set_xlabel('Number of Slices', fontsize=fontsize_label)
axes[1, 1].set_ylabel('Iteration Time (sec)', fontsize=fontsize_label)
axes[1, 1].set_xticks(num_slices)
axes[1, 1].set_xticklabels(num_slices, fontsize=fontsize_label)
axes[1, 1].text(-0.1, 1.08, 'd', transform=axes[1, 1].transAxes, fontsize=16, fontweight='bold')  # Label "d"

axes[1, 1].annotate(f'{time_ptyrad_slices[-2]:.0f} sec', (num_slices[-2], time_ptyrad_slices[-2]), textcoords="offset points", xytext=(15,-7), ha='center', color='C0', fontsize=annotate_size)
axes[1, 1].annotate(f'{time_ptyshv_slices[-2]:.0f} sec', (num_slices[-2], time_ptyshv_slices[-2]), textcoords="offset points", xytext=(14,-6), ha='center', color='C1', fontsize=annotate_size)
axes[1, 1].annotate(f'{time_py4dstem_slices[-2]:.0f} sec', (num_slices[-2], time_py4dstem_slices[-2]), textcoords="offset points", xytext=(14,-9), ha='center', color='C2', fontsize=annotate_size)

axes[1, 1].annotate(f'{time_ptyrad_slices[-1]:.0f} sec', (num_slices[-1], time_ptyrad_slices[-1]), textcoords="offset points", xytext=(-5,5), ha='center', color='C0', fontsize=annotate_size)
axes[1, 1].annotate(f'{time_ptyshv_slices[-1]:.0f} sec', (num_slices[-1], time_ptyshv_slices[-1]), textcoords="offset points", xytext=(-5,4), ha='center', color='C1', fontsize=annotate_size)
axes[1, 1].annotate(f'{time_py4dstem_slices[-1]:.0f} sec', (num_slices[-1], time_py4dstem_slices[-1]), textcoords="offset points", xytext=(-5,5), ha='center', color='C2', fontsize=annotate_size)

axes[1, 1].legend(fontsize=fontsize_legend)
axes[1, 1].set_ylim(-10,165)
axes[1, 1].yaxis.set_minor_locator(tck.AutoMinorLocator())

# Adjust layout
plt.tight_layout()
plt.savefig("Fig_03_iter_time_comparison.pdf", bbox_inches="tight")
plt.savefig("Fig_03_iter_time_comparison.png", bbox_inches="tight")

plt.show()