In [1]:
%matplotlib widget
import numpy as np
import cupy as cp
import cupyx.scipy.fft as cpfft
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.axes_grid1 import make_axes_locatable

import time
from argparse import ArgumentParser
import os
from pathlib import Path

from fastssb import data4d as d4
from fastssb.optics import wavelength, get_qx_qy_1D, disk_overlap_function, single_sideband_reconstruction
from fastssb import plotting

import tifffile

plt.rcParams['figure.dpi'] = 50

In [2]:
def plot_virtual_images(d, metadata, radius, scan_number):
    abf = d.virtual_annular_image(radius/2, radius, d.frame_dimensions/2)
    bf = d.virtual_annular_image(0, radius/2, d.frame_dimensions/2)
    eabf = abf - bf
    adf = d.virtual_annular_image(radius, d.frame_dimensions[0]/2, d.frame_dimensions/2)

    bf[bf==0] = bf.mean()
    abf[abf==0] = abf.mean()

    fig, ax = plt.subplots(dpi=150)
    im = ax.imshow(abf, cmap= plt.cm.get_cmap('bone'))
    ax.set_title(f'Scan {scan_number} ABF')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.tight_layout()
    
    fig, ax = plt.subplots(dpi=150)
    im = ax.imshow(bf, cmap= plt.cm.get_cmap('bone'))
    ax.set_title(f'Scan {scan_number} BF')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.tight_layout()
    
    fig, ax = plt.subplots(dpi=150)
    im = ax.imshow(adf, cmap= plt.cm.get_cmap('bone'))
    ax.set_title(f'Scan {scan_number} ADF')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.tight_layout()

In [None]:
# load my data set

data_dir = '/global/cfs/projectdirs/ncemhub/distiller'
base_path = Path(data_dir)

scan_number, scan_id = 57, 3323
date = Path('2022.08.30')
sparse_path = base_path / Path('counted') / Path(date)

adfpath = base_path / Path('dm4') / Path(date)
results_path = base_path / Path('results')

if not results_path.exists():
    results_path.mkdir()

filename4d = sparse_path / f'data_scan{scan_number}_id{scan_id}_electrons.h5'
# filename4d = sparse_path / f'data_scan{scan_num}_th{th}_electrons.h5'
# filename4d = sparse_path / f'data_scan{scan_num}_electrons.h5'
filenameadf = adfpath / f'scan{scan_number}.dm4'

alpha_max_factor = 1.2
alpha_max_factor = 1.05

print('1: data loading')
d = d4.Sparse4DData.from_4Dcamera_file(filename4d) # load the entire data set
# d = d4.Sparse4DData.from_4Dcamera_file(filename4d, (0, 0), 512)  # load a 512 x 512 region of the data set
metadata = d4.Metadata4D.from_dm4_file(filenameadf)

metadata.alpha_rad = 17.1e-3 # also have used 30.0e-3
metadata.rotation_deg = -93.1 # TEAM 0.5 at 80 kV with -30.0 deg scan rotation in TIA
metadata.wavelength =  wavelength(metadata.E_ev)

In [None]:
# center, radius = d.determine_center_and_radius(manual=False, size=200) 
center, radius = d.determine_center_and_radius(manual=True, size=70)
print(f'center: {center}')
print(f'radius: {radius}')
print('2: cropping')
d.crop_symmetric_center_(center, radius*alpha_max_factor)
print('3: sum diffraction pattern')
s = d.sum_diffraction() # stempy
print('4: plotting')

f,ax = plt.subplots(1,2,figsize=(15,8))
imax = ax[0].imshow(s)
ax[0].set_title(f'Scan {scan_number} sum after cropping')
imax = ax[1].imshow(np.log10(s+1))
ax[1].set_title(f'Scan {scan_number} log10(sum) after cropping')
plt.colorbar(imax)
plt.tight_layout()

if False:
    plot_virtual_images(d, metadata, radius, scan_number)

In [None]:
dwell_time = 1/87e3
detector_to_real_fluence_80kv = 1 

fluence = d.fluence(metadata.dr[0]) * detector_to_real_fluence_80kv
flux = d.flux(metadata.dr[0], dwell_time) * detector_to_real_fluence_80kv

print(f"E               = {metadata.E_ev/1e3}             keV")
print(f"λ               = {metadata.wavelength * 1e2:2.2}   pm")
print(f"dR              = {metadata.dr} Å")  # real space pixel size
print(f"scan       size = {d.scan_dimensions}")
print(f"detector   size = {d.frame_dimensions}")
print(f"scan       FOV  = {d.scan_dimensions*metadata.dr/10} nm")
print(f"fluence         ~ {fluence} e/Å^2")
print(f"flux            ~ {flux} e/Å^2/s")

In [None]:
dssb = d
metadata.k_max = metadata.alpha_rad * alpha_max_factor / metadata.wavelength
s = dssb.sum_diffraction()

if False:
    f,ax = plt.subplots(figsize=(4,4))
    imax = ax.imshow(s)
    ax.set_title('Sum after cropping for SSB')
    plt.colorbar(imax)

slic = np.s_[:,:]
data = dssb.slice(slic)

ssb_size = np.array([25,25])  # how many pixels in cropped diffraction pattern
bin_factor = int(np.min(np.floor(data.frame_dimensions/ssb_size)))
radius2 = radius/bin_factor
meta = metadata
verbose = True

start = time.perf_counter()
dc = d4.sparse_to_dense_datacube(data.indices, data.counts, data.scan_dimensions, data.frame_dimensions, data.frame_dimensions/2, data.frame_dimensions[0]/2, data.frame_dimensions[0]/2, binning=bin_factor, fftshift=False)
print(f"Bin by {bin_factor} for ssb took {time.perf_counter() - start:.3f}s")

rmax = dc.shape[-1] // 2
alpha_max = rmax / radius2 * meta.alpha_rad

r_min = meta.wavelength / (2 * alpha_max)
r_min = [r_min, r_min]
k_max = [alpha_max / meta.wavelength, alpha_max / meta.wavelength]
r_min1 = np.array(r_min)
dxy1 = np.array(meta.dr)

M = cp.array(dc).astype(cp.float32)
xp = cp.get_array_module(M)
ny, nx, nky, nkx = M.shape

Qx1d, Qy1d = get_qx_qy_1D([nx, ny], meta.dr, M.dtype, fft_shifted=False)

start = time.perf_counter()
G = cpfft.fft2(M, axes=(0, 1), overwrite_x=True)
G /= cp.sqrt(np.prod(G.shape[:2]))
print(f"FFT along scan coordinate took {time.perf_counter() - start:.3f}s")

In [None]:
manual_frequencies = None  # [[20, 62, 490], [454, 12, 57]]

Gabs = xp.log10(xp.sum(xp.abs(G), (2, 3)))
sh = np.array(Gabs.shape)
mask = ~np.array(np.fft.fftshift(d4.sector_mask(sh, sh // 2, 5, (0, 360))))
mask[:,-1] = 0
mask[:,0] = 0
mask[:,1] = 0

# might want to save mask as numpy array

gg = Gabs.get()
gg[~mask] = gg.mean()


show_mask = True
if show_mask:
    fig, ax = plt.subplots(1,2,figsize=(15,8))
    ax[0].imshow(np.fft.fftshift(mask))
    ax[0].set_title('FFT mask')
    ax[1].imshow(np.fft.fftshift(gg), cmap=plt.cm.get_cmap('inferno'))
    ax[1].set_title('Masked absolute values of G')

In [None]:
# show double overlap for strongest object frequencies

Gabs = xp.sum(xp.abs(G), (2, 3))
sh = np.array(Gabs.shape)

n_fit=25
# meta.rotation_deg = 0.0 # +/-99.0?? 0.0???
best_angle = meta.rotation_deg
aberrations = xp.zeros((12))

gg = Gabs.get() * mask
gg[gg==0] = gg.mean()

inds = xp.flip(xp.argsort((gg).ravel()))
strongest_inds = np.unravel_index(inds[:n_fit], G.shape[:2])
G_max = G[strongest_inds]

r_min1 = np.array(r_min)
dxy1 = np.array(meta.dr)

r_min1 *= 1
dxy1 *= 1.0
Kx, Ky = get_qx_qy_1D([nkx, nky], r_min1, G[0, 0, 0, 0].real.dtype, fft_shifted=True)
# print(strongest_inds[0])
# print(strongest_inds[1])
# print([nx, ny], dxy1)
Qx1d, Qy1d = get_qx_qy_1D([nx, ny], dxy1, G[0, 0, 0, 0].real.dtype, fft_shifted=False)
print(Qx1d.max(), Qy1d.max())
print('strongest object frequencies')
Qy_max1d = Qy1d[strongest_inds[0]]
Qx_max1d = Qx1d[strongest_inds[1]]
# print(Qx_max1d, Qy_max1d)


def Q_freq_to_d_spacing(qx, qy):
    '''convert these Q_p frequencies into something more human-readable'''
    ds = 1/np.sqrt(qx**2 + qy**2)
    angles = np.arctan2(qy, qx) * 180 / np.pi
    return ds, angles


def d_spacing_to_Q_freq(ds, angles):
    q = 1/ds
    angle_rad = angles * np.pi / 180
    qx = q * np.cos(angle_rad)
    qy = q * np.sin(angle_rad)
    return qx, qy


def repr_d_spacings(ds, angles):
    out = ''
    for d, angle in zip(ds, angles):
        out += f'({d:.3f} Å, {angle:.1f}°) '
    return out

dtest, angletest = Q_freq_to_d_spacing(Qx_max1d.get(), Qy_max1d.get())
print(repr_d_spacings(dtest, angletest))
# print(gg[strongest_inds])

Gamma = disk_overlap_function(Qx_max1d, Qy_max1d, Kx, Ky, aberrations, best_angle * np.pi/180, meta.alpha_rad, meta.wavelength)
plt.ion()

fig, ax = plt.subplots(2,3,figsize=(18,12))
ax = ax.flatten()

im = ax[0].imshow(np.log10(np.fft.fftshift(gg)+1), cmap= plt.cm.get_cmap('bone'))
ax[0].set_title(f'Scan {scan_number} fft')
ax[0].set_xticks([])
ax[0].set_yticks([])

ax[1].plot(Qx_max1d.get(), Qy_max1d.get(), '.')
ax[1].axis('equal')
ax[1].invert_yaxis()

ax[2].plot(np.arange(n_fit), gg[strongest_inds]/gg[strongest_inds][0], '.-')
ax[2].set_title('strength of object frequencies')

im = ax[3].imshow(plotting.imsave(plotting.mosaic(G_max.get() * Gamma.conjugate().get())), cmap= plt.cm.get_cmap('bone'))
ax[3].set_title(f'Scan {scan_number} fft * overlap')
ax[3].set_xticks([])
ax[3].set_yticks([])
divider = make_axes_locatable(ax[1])

im = ax[4].imshow(plotting.imsave(plotting.mosaic(G_max.get())), cmap= plt.cm.get_cmap('bone'))
ax[4].set_title(f'Scan {scan_number} data fft')
ax[4].set_xticks([])
ax[4].set_yticks([])
divider = make_axes_locatable(ax[1])

im = ax[5].imshow(plotting.imsave(plotting.mosaic(Gamma.get())), cmap= plt.cm.get_cmap('bone'))
ax[5].set_title('overlap function')
ax[5].set_xticks([])
ax[5].set_yticks([])
divider = make_axes_locatable(ax[1])

selected_inds = list(range(9))
selected_inds = [0, 1, 2, 8, 9, 10, 16, 17, 18]
G_sel = G_max[selected_inds]
Qx_sel = Qx_max1d[selected_inds]
Qy_sel = Qy_max1d[selected_inds]

plt.tight_layout()
# fig.savefig(results_path /f'scan{1}_fft.png')

In [None]:
aberrations[0] = 0.0
print(f'defocus: {aberrations[0]}')

Psi_Qp = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)

start = time.perf_counter()
eps = 1e-3

single_sideband_reconstruction( # want this function
    G,
    Qx1d,
    Qy1d,
    Kx,
    Ky,
    aberrations,
    best_angle * np.pi/180,
    meta.alpha_rad,
    Psi_Qp,
    Psi_Qp_left_sb,
    Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

Psi_Rp_left_sb = cpfft.ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb = cpfft.ifft2(Psi_Qp_right_sb, norm="ortho")
Psi_Rp = cpfft.ifft2(Psi_Qp, norm="ortho")

ssb_defocal = Psi_Rp.get()
ssb_defocal_right = Psi_Rp_right_sb.get()
ssb_defocal_left = Psi_Rp_left_sb.get()

print(f"SSB took {time.perf_counter() - start}")

my_ssb_img = np.angle(ssb_defocal_left)
my_fft = np.fft.fftshift(np.fft.fft2(my_ssb_img))
f_crop = my_fft.shape[0]//5
fft_show = np.log(np.abs(my_fft[f_crop:-f_crop, f_crop:-f_crop])+1)

fig, ax = plt.subplots(1,2,figsize=(18, 9))
# fig, ax = plt.subplots(1,2,figsize=(7,4))
im1 = ax[0].imshow(my_ssb_img, cmap=plt.cm.get_cmap('bone'))
ax[0].set_title(f'Scan {scan_number} SSB ptychography')
ax[0].set_xticks([])
ax[0].set_yticks([])
# fig.colorbar(im1, ax=ax[0])
ax[0].add_artist(ScaleBar(metadata.dr[0]/10,'nm'))

ax[1].imshow(fft_show)


# fig, ax = plt.subplots(dpi=100) # 300
# im1 = ax.imshow(my_ssb_img, cmap= plt.cm.get_cmap('bone'))
# ax.set_title(f'Scan {scan_number} SSB ptychography')
# ax.set_xticks([])
# ax.set_yticks([])
# fig.colorbar(im1, ax=ax)
# ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))

plt.show()

# tifffile.imwrite(results_path /f'scan{scan_number}_ssb_ptycho_best_right.tif',my_ssb_img.astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)), metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'YX'})

In [None]:
# Interactively correcting aberrations with auto-updating single side band reconstruction

plt.close('all')
from ipywidgets import AppLayout, FloatSlider, GridspecLayout, VBox, HBox
import ipywidgets as widgets
from cupyx.scipy.fft import fft2, ifft2

# aberration correction second attempt
ab_coeffs  = ['C1', 'C12', 'C21', 'C23', 'C3', 'C32', 'C34']  # aberrations according to Krivanek
ab_uhlemann= ['C1', 'A1', '3 B2*', 'A2', 'C3','4 S3*', 'A3']  # aberrations according to Uhlemann & Haider
ab_isreal  = [True, False, False, False, True, False, False]
C_max      = [20,   10,    20,    10,    20,   30,    30]
C_multiplier=[1e1,  1e1,   1e3,   1e3,   1e4,  1e4,   1e4]

ab_indices = []
ab_i = 0
for isreal in ab_isreal:
    ab_indices.append(ab_i)
    ab_i += 1 if isreal else 2
c_indices = [0] * ab_i
for idx, isreal in enumerate(ab_isreal):
    if isreal:
        c_indices[ab_indices[idx]] = idx
    else:
        c_indices[ab_indices[idx]] = idx
        c_indices[ab_indices[idx]+1] = idx

C_gui = np.zeros((12,))
C = xp.zeros((12,))
# (ab_indices, c_indices)

# dp_angle = xp.zeros((1,))
dp_angle = best_angle/180 * np.pi

plt.rcParams['figure.dpi'] = 42
plt.ioff()

gs = GridspecLayout(8,9)
Cslider_box = VBox(width=50)
# scale_slider_box = VBox()
children= []
sliders =  []

text = widgets.HTML(
    value="1",
    placeholder='',
    description='',
)

overlaps_output = widgets.Output()
# recon_output = widgets.Output()

Psi_Rp[:] = ifft2(Psi_Qp, norm="ortho")
Psi_Rp_left_sb[:] = ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb[:] = ifft2(Psi_Qp_right_sb, norm="ortho")

Gamma = disk_overlap_function(Qx_sel, Qy_sel, Kx, Ky, aberrations, dp_angle, meta.alpha_rad, meta.wavelength)
gg = Gamma.conjugate() * G_sel

overlap_figure_axes = []
overlap_figure2_axes = []
Gmax_figure_axes = []

with overlaps_output:
    overlap_figure = plt.figure(constrained_layout=True,figsize=(7,7))
    gs1 = overlap_figure.add_gridspec(3, 3, wspace=0.05,hspace=0.05)
    for dd, ggs in zip(gg[:9], gs1):
        f3_ax1 = overlap_figure.add_subplot(ggs)
        imax2 = f3_ax1.imshow(plotting.imsave(dd.get()))
        f3_ax1.set_xticks([])
        f3_ax1.set_yticks([])
        overlap_figure_axes.append(imax2)

    overlap_figure2 = plt.figure(constrained_layout=True,figsize=(15,7))
    gs2 = overlap_figure2.add_gridspec(3, 6, wspace=0.05,hspace=0.05)
    for i, (dd, doverlap) in enumerate(zip(Gamma[:9], G_sel[:9])):
        f3_ax1 = overlap_figure2.add_subplot(gs2[2*i])
        imax2 = f3_ax1.imshow(plotting.imsave(dd.get()))
        f3_ax1.set_xticks([])
        f3_ax1.set_yticks([])
        overlap_figure2_axes.append(imax2)
        
        f3_ax1 = overlap_figure2.add_subplot(gs2[2*i+1])
        imax2 = f3_ax1.imshow(plotting.imsave(doverlap.get()))
        f3_ax1.set_xticks([])
        f3_ax1.set_yticks([])
        Gmax_figure_axes.append(imax2)
        

plot_box = VBox(children =[overlap_figure.canvas])
plot_box2 = VBox(children =[overlap_figure2.canvas])

recon_fig, recon_axes = plt.subplots(figsize=(8,8))
m = 5
img = np.angle(Psi_Rp_left_sb.get()[m:-m,m:-m])
recon_img = recon_axes.imshow(img, cmap=plt.get_cmap('bone'))
recon_axes.set_xticks([])
recon_axes.set_yticks([])
scalebar = ScaleBar(meta.dr[0]/10,'nm') # 1 pixel = 0.2 meter
recon_axes.add_artist(scalebar)
plt.tight_layout()


def update_everything():
    Psi_Qp[:] = 0
    Psi_Qp_left_sb[:] = 0
    Psi_Qp_right_sb[:] = 0
    single_sideband_reconstruction(
        G,
        Qx1d,
        Qy1d,
        Kx,
        Ky,
        C,
        dp_angle,
        meta.alpha_rad,
        Psi_Qp,
        Psi_Qp_left_sb,
        Psi_Qp_right_sb,
        eps,
        meta.wavelength,
    )
    m = 5

    Psi_Rp[:] = ifft2(Psi_Qp, norm="ortho")
    Psi_Rp_left_sb[:] = ifft2(Psi_Qp_left_sb, norm="ortho")
    Psi_Rp_right_sb[:] = ifft2(Psi_Qp_right_sb, norm="ortho")

    img = np.angle(Psi_Rp_left_sb.get()[m:-m,m:-m])
    # img = np.angle(Psi_Rp.get()[m:-m,m:-m])
    recon_img.set_data(img)
    recon_img.set_clim(img.min(),img.max())
    recon_fig.canvas.draw()
    recon_fig.canvas.flush_events()

    Gamma = disk_overlap_function(Qx_sel, Qy_sel, Kx, Ky, C, dp_angle, meta.alpha_rad, meta.wavelength)
    gg = Gamma.conjugate() * G_sel
    for ax, ggg in zip(overlap_figure_axes,gg):
        # ax.set_data(cx_to_img(ggg.get(), max_l=0.9))
        ax.set_data(plotting.imsave(ggg.get()))
    for ax, ggg in zip(overlap_figure2_axes,Gamma):
        # ax.set_data(cx_to_img(ggg.get(), max_l=0.9))
        ax.set_data(plotting.imsave(ggg.get()))

    overlap_figure.canvas.draw()
    overlap_figure2.canvas.draw()
    overlap_figure.canvas.flush_events()
    overlap_figure2.canvas.flush_events()



def create_ab_function(name, i, is_real, multiplier):
    def func1(change):
        C_gui[i] = change['new']
        update_str = ''
        if is_real:
            C[i] = C_gui[i] * multiplier
            update_str = f'{C_gui[i]}, {C[i]}'
        else:
            c_idx = c_indices[i]
            ab_idx = ab_indices[c_idx]
            cx_val = multiplier * C_gui[ab_idx] * np.exp(1j * np.pi / 180 * C_gui[ab_idx+1])
            C[ab_idx] = np.real(cx_val)
            C[ab_idx+1] = np.imag(cx_val)
            update_str = f'{C_gui[i]}, {cx_val:.3f}'
        update_everything()
        text.value = update_str
    func1.__name__ = name
    return func1


def create_angle_function():
    def func1(change):
        global dp_angle
        dp_angle = change['new']/180 * np.pi
        update_everything()
        text.value = f'{dp_angle}'
    func1.__name__ = 'dp_angle_slider_changed'
    return func1


ab_i = 0
for k, (ab_name, isreal, maxs, multiplier) in enumerate(zip(ab_uhlemann, ab_isreal, C_max, C_multiplier)):
    if isreal:
        s = FloatSlider(description=ab_name, value=C_gui[ab_i], min=-maxs, max=maxs, readout_format='.1f',
                        layout=widgets.Layout(width='90%'))
        s.observe(create_ab_function(f'ab_slider_changed_{ab_i}', ab_i, True, multiplier), names='value')
        sliders.append(s)
        ab_i += 1
    else:
        s = FloatSlider(description=ab_name, value=C_gui[ab_i], min=-maxs, max=maxs, readout_format='.1f',
                        layout=widgets.Layout(width='90%'))
        s.observe(create_ab_function(f'ab_slider_changed_{ab_i}', ab_i, False, multiplier), names='value')
        sliders.append(s)
        ab_i += 1
        s = FloatSlider(description='angle', value=C_gui[ab_i], min=-180, max=180, step=1.0, readout_format='.0f',
                        layout=widgets.Layout(width='90%'))
        s.observe(create_ab_function(f'ab_slider_changed_{ab_i}', ab_i, False, multiplier), names='value')
        sliders.append(s)
        ab_i += 1


sdp = FloatSlider(description='dp angle', value=best_angle, min=-180, max=180, readout_format='.1f',
                  layout=widgets.Layout(width='90%'))
sdp.observe(create_angle_function(), names='value')
sliders.append(sdp)
        

Cslider_box.children = sliders + [text]

gs[:,0:3] = Cslider_box
gs[:4,3:6] = plot_box
gs[4:8,3:] = plot_box2
gs[:4,6:9] = recon_fig.canvas

AppLayout(center=gs)

In [None]:
print(C_gui)
print(C)
print(f'{dp_angle * 180 / np.pi:.1f}')

In [None]:
# Do reconstruction with previously selected aberrations

Psi_Qp = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)

start = time.perf_counter()
eps = 1e-3
single_sideband_reconstruction( # want this function
    G,
    Qx1d,
    Qy1d,
    Kx,
    Ky,
    C,
    dp_angle,
    meta.alpha_rad,
    Psi_Qp,
    Psi_Qp_left_sb,
    Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

Psi_Rp_left_sb = cpfft.ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb = cpfft.ifft2(Psi_Qp_right_sb, norm="ortho")
Psi_Rp = cpfft.ifft2(Psi_Qp, norm="ortho")

ssb_defocal = Psi_Rp.get()
ssb_defocal_right = Psi_Rp_right_sb.get()
ssb_defocal_left = Psi_Rp_left_sb.get()

fig, ax = plt.subplots(dpi=220) # 300
im1 = ax.imshow(np.angle(ssb_defocal_left), cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} SSB ptychography')
ax.set_xticks([])
ax.set_yticks([])
fig.colorbar(im1, ax=ax)
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))

plt.tight_layout()
plt.show()

In [None]:
my_ssb_image = np.angle(ssb_defocal_right)
my_fft = np.fft.fftshift(np.fft.fft2(my_ssb_image))

fig, ax = plt.subplots(dpi=100)
im1 = ax.imshow(np.log(np.abs(my_fft)+1))
plt.show()

In [28]:
tifffile.imwrite(results_path /f'scan{scan_number}_ssb_ptycho_best_right.tif',np.angle(ssb_defocal_right).astype('float32'), 
        imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)), 
        metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'YX'})