In [None]:
%matplotlib widget

import numpy as np
import cupy as cp
import cupyx.scipy.fft as cpfft
import scipy.fft as spfft
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib_scalebar.scalebar import ScaleBar
import time
from pathlib import Path

from fastssb import data4d as d4
from fastssb.data4d import MetadataEMPAD
from fastssb.optics import wavelength, Q_freq_to_d_spacing, d_spacing_to_Q_freq, repr_d_spacings, get_qx_qy_1D, disk_overlap_function, single_sideband_reconstruction
from fastssb.plotting import fft_doverlap_figure, cx_to_hsv_img, manual_aberration_ui

import tifffile

plt.rcParams['figure.dpi'] = 100
size12 = (12.8,4.8)

In [None]:
meta = MetadataEMPAD()
scan_name, npos, meta.nominal_mag, estimated_FOV, meta.E_ev, meta.alpha_rad = 'my_folder_name', 128, '28.5Mx', 50.18, 200e3, 10.5e-3
estimated_dr = estimated_FOV/npos
meta.wavelength = wavelength(meta.E_ev)
meta.dr = np.array([estimated_dr, estimated_dr])
meta.rotation_deg = 0

my_dir = Path('my_path')  # Define the folder that contains the EMPAD data sets (which are also themselves folders)

data_dir = my_dir / Path(scan_name)
fname4d = data_dir / f'scan_x{npos}_y{npos}.raw'

In [None]:
# load EMPAD data
data = np.fromfile(fname4d, '<f4')
x_cbed, y_cbed = 128, 130
data = data.reshape((npos, npos, y_cbed, x_cbed))
data = data.transpose((1, 0, 2, 3))
data = data[:, :, 2:126, 2:126]

sumcbed = np.sum(data, (1, 0))
cbed_totals = np.sum(data, (2,3))

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(cbed_totals)
cb = fig.colorbar(im, ax=ax)

In [None]:
# sanity check for EMPAD data - the first frame often has too much dose somehow
fig, axs = plt.subplots(1,2, figsize=size12)
for k in range(2):
    im = axs[k].imshow(data[0, k, :, :])
    cb = fig.colorbar(im, ax=axs[k])

In [None]:
radius, center = d4.locate_BF_disk(sumcbed, threshold=0.5)

In [None]:
# dc is a cupy array, stored on the GPU
dc, frame_radius = d4.shift_and_crop(data, radius, center, use_gpu=True)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=size12)
im1 = ax1.imshow(dc.sum(axis=(0,1)).get())
cb1 = fig.colorbar(im1, ax=ax1)
bf_tot_image = dc.sum(axis=(2,3)).get()
im2 = ax2.imshow(bf_tot_image)
im2.set_clim(np.percentile(bf_tot_image, 1), np.percentile(bf_tot_image, 99))
cb2 = fig.colorbar(im2, ax=ax2)

plt.show()

In [None]:
# Calculate FFT of 4D data w.r.t. probe position
do_fft_on_gpu = False

t1 = time.perf_counter()

if do_fft_on_gpu:
    dc = dc.astype(cp.float32, copy=False)  # I don't think the copy=False here will actually avoid copying the array
    G = cpfft.fft2(dc, axes=(0, 1), overwrite_x=True)
    del dc
    Gabs = cp.sum(cp.abs(G), (2, 3)).get()
else:
    dc_cpu = cp.asnumpy(dc).astype(np.float32, copy=False)
    del dc
    G_cpu = spfft.fft2(dc_cpu, axes=(0, 1), overwrite_x=True, workers=-1)
    del dc_cpu
    G = cp.asarray(G_cpu)
    Gabs = np.sum(np.abs(G_cpu), (2, 3))
G /= cp.sqrt(np.prod(G.shape[:2]))
sh = np.array(Gabs.shape)

t2 = time.perf_counter()
print(f"FFT along scan coordinate took {t2-t1:.3f}s")

# Mask out some parts of the FFT which are prone to artifacts or noisy background
mask = ~np.array(np.fft.fftshift(d4.sector_mask(sh, sh // 2, 5, (0, 360))))
# mask[:,-1] = 0
# mask[:,0] = 0
# mask[:,1] = 0

mask_half = np.zeros(sh)
m_center = sh // 2
mask_half[:m_center[0]+1, :m_center[1]+1] = 1
if sh[1] % 2 == 0:
    mask_half[m_center[0]+1:, 1:m_center[1]] = 1
else:
    mask_half[m_center[0]+1:, :m_center[1]] = 1
mask_half = np.fft.ifftshift(mask_half)

gg = np.log10(Gabs)
gg[~mask] = gg.mean()

fft_show = np.fft.fftshift(gg)
fig, ax = plt.subplots(1,2,figsize=size12)
ax[0].imshow(np.fft.fftshift(mask))
ax[0].set_title('FFT mask')
ax[1].imshow(fft_show, norm=colors.LogNorm(vmin=np.percentile(fft_show, 10)))
ax[1].set_title('Masked absolute values of G')
plt.show()

In [None]:
ny, nx, nky, nkx = G.shape

# Spatial spacings and frequencies, assuming everything is square
# scanned positions
rp_min = meta.dr[0] # probe step size
rp_max = rp_min * ny  # this is just FOV
qp_max = 1/rp_min/2 # Nyquist limit
qp_min = 1/rp_max
# cropped/binned CBED
kf_min = meta.alpha_rad/meta.wavelength / radius
rf_max = 1/kf_min # largest "spatial period" in cropped CBED
rf_min = rf_max/nky # real space step size corresponding to how much the CBED was cropped
kf_max = 1/rf_min/2 # Nyquist limit
kf_ssb = meta.alpha_rad/meta.wavelength * 2
rf_ssb = 1/kf_ssb

print('Spatial frequencies, probe positions:')
print(f'Rp: {rp_min:.3f} .. {rp_max:.3f} Å')
print(f'Qp: {qp_max:.3f} .. {rp_min:.3f} Å^-1')
print('Spatial frequencies, cropped/binned CBED:')
print(f'Rf: ({rf_ssb:.3f}) .. {rf_min:.3f} .. {rf_max:.3f} Å')
print(f'Kf: ({kf_ssb:.3f}) .. {kf_max:.3f} .. {kf_min:.3f} Å^-1')

rf_min_arr = np.array([rf_max/nkx, rf_max/nky])
rp_min_arr = np.array(meta.dr)

rf_min_arr *= 1
rp_min_arr *= 1.0
print(rf_min_arr, rp_min_arr)
Kx, Ky = get_qx_qy_1D([nkx, nky], rf_min_arr, dtype=np.float32, fft_shifted=True)
Qx1d, Qy1d = get_qx_qy_1D([nx, ny], rp_min_arr, dtype=np.float32, fft_shifted=False)

n_fit=25
best_angle = meta.rotation_deg
aberrations = cp.zeros((12))

gg = Gabs * mask
gg[gg==0] = gg.mean()

gg_half = gg * mask_half

inds = cp.flip(cp.argsort(gg_half.ravel()))
strongest_inds = np.unravel_index(inds[:n_fit], G.shape[:2])
G_max = G[strongest_inds]

Qy_max1d = Qy1d[strongest_inds[0]]
Qx_max1d = Qx1d[strongest_inds[1]]

dtest, angletest = Q_freq_to_d_spacing(Qx_max1d.get(), Qy_max1d.get())
print('strongest object frequencies')
print(repr_d_spacings(dtest, angletest))
# print(gg[strongest_inds])

Gamma = disk_overlap_function(Qx_max1d, Qy_max1d, Kx, Ky, aberrations, best_angle * np.pi/180, meta.alpha_rad, meta.wavelength)

_, _ = fft_doverlap_figure(np.fft.fftshift(gg), (Qx_max1d.get(), Qy_max1d.get()), gg[strongest_inds], 
    G_max.get() * Gamma.conjugate().get(), G_max.get(), Gamma.get(), figsize=(10,6.5))

In [None]:
selected_inds = list(range(9))
# order11, order12, order13, order2 = [0, 4], [1, 8], [2, 3], [10, 12, 13]
# selected_inds = order11 + [order2[0]] + order12 + [order2[1]] + order13 + [order2[2]]
G_sel = G_max[selected_inds]
Qx_sel = Qx_max1d[selected_inds]
Qy_sel = Qy_max1d[selected_inds]

In [None]:
# allocate output arrays for SSB
Psi_Qp = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)

eps = 1e-3

In [None]:
# SSB with some defocus value

aberrations[0] = 0.0
print(f'defocus: {aberrations[0]}')

start = time.perf_counter()

single_sideband_reconstruction(
    G,
    Qx1d, Qy1d, Kx, Ky,
    aberrations,
    best_angle * np.pi/180,
    meta.alpha_rad,
    Psi_Qp, Psi_Qp_left_sb, Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

Psi_Rp_left_sb = cpfft.ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb = cpfft.ifft2(Psi_Qp_right_sb, norm="ortho")
Psi_Rp = cpfft.ifft2(Psi_Qp, norm="ortho")

ssb_defocal = Psi_Rp.get()
ssb_defocal_right = Psi_Rp_right_sb.get()
ssb_defocal_left = Psi_Rp_left_sb.get()

print(f"SSB took {time.perf_counter() - start}")

my_ssb_img = np.angle(ssb_defocal_left)
my_fft = np.fft.fftshift(np.fft.fft2(my_ssb_img))
f_crop = my_fft.shape[0]//6
fft_show = np.abs(my_fft[f_crop:-f_crop, f_crop:-f_crop])

fig, ax = plt.subplots(1,2,figsize=size12)
im1 = ax[0].imshow(my_ssb_img, cmap=plt.cm.get_cmap('bone'))
ax[0].set_title(f'Scan SSB ptychography')
ax[0].set_xticks([])
ax[0].set_yticks([])
# fig.colorbar(im1, ax=ax[0])
ax[0].add_artist(ScaleBar(meta.dr[0]/10,'nm'))

ax[1].imshow(fft_show, norm=colors.LogNorm(vmin=np.percentile(fft_show, 30)))
plt.show()

In [None]:
from ipywidgets import AppLayout

C_gui = np.zeros((12,))
C = cp.zeros((12,))
C_exp = np.array([2,1, 3,3, 6,5,5])  # what power of 10 to scale the aberration coefficents by

# Using a numpy array becauase this somehow allows it to be modified from the UI
dp_angle_arr = np.array([best_angle/180 * np.pi])

# If the UI elements don't fit well, try changing the size of your window or the figure DPI
# plt.rcParams['figure.dpi'] = 90

AppLayout(center=manual_aberration_ui(G, G_sel, C, C_gui, dp_angle_arr, 
        Psi_Qp, Psi_Qp_left_sb, Psi_Qp_right_sb, 
        Qx1d, Qy1d, Qx_sel, Qy_sel, Kx, Ky, eps,
        aberrations, meta.alpha_rad, meta.wavelength, meta.dr[0],
        C_exp=C_exp))

In [None]:
print(C_gui)
print(C_exp)
print(C)
print(f'{dp_angle_arr[0] * 180 / np.pi:.1f}')

In [None]:
# Do reconstruction with previously selected aberrations

start = time.perf_counter()
eps = 1e-3
single_sideband_reconstruction( # want this function
    G,
    Qx1d,
    Qy1d,
    Kx,
    Ky,
    C,
    dp_angle_arr[0],
    meta.alpha_rad,
    Psi_Qp,
    Psi_Qp_left_sb,
    Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

Psi_Rp_left_sb = cpfft.ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb = cpfft.ifft2(Psi_Qp_right_sb, norm="ortho")
Psi_Rp = cpfft.ifft2(Psi_Qp, norm="ortho")

ssb_defocal = Psi_Rp.get()
ssb_defocal_right = Psi_Rp_right_sb.get()
ssb_defocal_left = Psi_Rp_left_sb.get()

fig, axs = plt.subplots(1,2, figsize=size12) # 300
im1 = axs[0].imshow(np.angle(ssb_defocal_right), cmap= plt.cm.get_cmap('bone'))
axs[0].set_title(f'SSB ptychography')
axs[0].set_xticks([])
axs[0].set_yticks([])
fig.colorbar(im1, ax=axs[0])
axs[0].add_artist(ScaleBar(meta.dr[0]/10,'nm'))

my_ssb_image = np.angle(ssb_defocal_right)
my_fft = np.fft.fftshift(np.fft.fft2(my_ssb_image))
fft_show = np.abs(my_fft)
im2 = axs[1].imshow(fft_show, norm=colors.LogNorm(vmin=np.percentile(fft_show, 50)))

plt.tight_layout()
plt.show()

In [None]:
tifffile.imwrite(data_dir / f'{scan_name}_ssb_ptycho_best_right.tif',np.angle(ssb_defocal_right).astype('float32'), 
        imagej=True, resolution=(1./(meta.dr[0]/10), 1./(meta.dr[1]/10)), 
        metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'YX'})