In [None]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Times New Roman'

import multiprocessing
import typing
import numpy as np
import tifffile

from STEM4D_modified import *
import pixstem.api as ps

In [None]:
parfile ='parameters_wdd.txt'
params = np.genfromtxt(parfile,delimiter='\t', dtype=str)

In [None]:
data_4D = Data4D(parfile)
data_4D.center_ronchigrams()
data_4D.estimate_aperture_size()

In [None]:
%matplotlib widget

In [None]:
data_4D.plot_4D()

In [None]:
data_4D.plot_aperture()

In [None]:
# leave only BF disks
data_4D.truncate_ronchigram(expansion_ratio=True) # crops ronchigram to the BF disk

# Transformation of 4D-STEM into G-set
![alt text](image/G_set.jpg "practice")

In [None]:
# I(u, R) -> G(u, U) Fourier transform
data_4D.apply_FT()

In [None]:
data_4D.plot_4D_reciprocal(signal='amplitude')

In [None]:
data_4D.plot_4D_reciprocal(signal='phase')

# Trotters
![alt text](image/trotters.jpg "practice")

In [None]:
rotation = 0
data_4D.plot_trotters(rotation, plot_constrains=True, skip=1)# value that fits

# Wigner distribution deconvolution (WDD)
![alt text](image/wdd.jpg "practice")

In [None]:
wdd = WDD(data_4D)
wdd.run()

In [None]:
wdd.plot_result()

# Aberration correction
![alt text](image/aberration_correction.jpg "practice")

In [None]:
svd = SVD_AC(data_4D, trotters_nb=5)
svd.build_omnimatrix()
svd.run_SVD()

In [None]:
svd.print_aberration_coefficients()

In [None]:
svd.calc_aberrationfunction()
svd.calc_aperturefunction()
svd.calc_probefunction()
svd.plot_corrected_trotters(data_4D.selected_frames, -svd.aberration_coeffs)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(np.abs(svd.probe))
ax[1].imshow(np.angle(svd.func_transfer),
             extent=(svd.theta_x.min(),svd.theta_x.max(),svd.theta_y.min(),svd.theta_y.max()),
             cmap='jet')

ax[0].axis("off")
ax[1].axis("off")

fig.tight_layout()
plt.show()

In [None]:
wdd_ac = WDD(data_4D)
wdd_ac.run(aberrations = -svd.aberration_coeffs)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0,0].imshow(wdd_ac.phase, cmap='inferno')
ax[1,0].imshow(wdd.phase, cmap='inferno')
ax[0,1].imshow(wdd_ac.amplitude, cmap='inferno')
ax[1,1].imshow(wdd.amplitude, cmap='inferno')

for i in range(4):
    ax[int(i/2),i%2].set_yticks([])
    ax[int(i/2),i%2].set_xticks([])
ax[0,0].set_ylabel('AC corrected')
ax[1,0].set_ylabel('uncorrected SSB')
ax[0,0].set_title('phase')
ax[0,1].set_title('amplitude')

fig.tight_layout()
plt.show()