In [None]:
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit
from scipy.integrate import simpson
from sklearn.linear_model import Ridge
import numpy as np
import matplotlib.pyplot as plt
import tkinter.filedialog as tkf
import tifffile
from scipy import ndimage
import ipywidgets as pyw
from py4DSTEM.process.utils import single_atom_scatter

plt.rcParams['font.family'] = 'Times New Roman'
color_rep = ["black", "orange", "purple", "blue", "red", "green", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]

In [None]:
k_list = np.arange(0, 0.0023322*1024, 0.0023322)

Fst_AF = single_atom_scatter(elements=[31], composition=[1.0], q_coords=k_list, units='A')
Fst_AF.get_scattering_factor()
Fst_AFF = Fst_AF.fe

Snd_AF = single_atom_scatter(elements=[8], composition=[1.0], q_coords=k_list, units='A')
Snd_AF.get_scattering_factor()
Snd_AFF = Snd_AF.fe

composition = [0.5, 0.5]
AF_mean_square = (Fst_AFF*composition[0] + Snd_AFF*composition[1])**2
AF_square_mean = Fst_AFF**2*composition[0] + Snd_AFF**2*composition[1]

#max_nor = np.max(AF_square_mean)
#AF_mean_square /= max_nor
#AF_square_mean /= max_nor

In [None]:
damp = tifffile.imread("./setting/filters/damping_filter.tif")
print(damp.shape)

filter_types = ["boxcar", "trangular", "trapezoidal", "Happ-Genzel", "3TEM BH", "4TEM BH"]

In [None]:
raw_adr = tkf.askopenfilenames()
print(*raw_adr, sep="\n")

In [None]:
# Load a data
rot_dp = tifffile.imread(raw_adr)
print(rot_dp.shape)
rot_dp_data = rot_dp.copy()
print(rot_dp_data.shape)
n_dim = len(rot_dp_data.shape)
print(n_dim)

In [None]:
scale = [1.0, 1.0, 0.0023322]
origin = [0, 0, 0]
unit = ['nm', 'nm', '1/A']
size = rot_dp_data.shape

In [None]:
fit_range = [int(0.9*len(k_list)), len(k_list)-1]
cor_term_ctb = np.arange(-0.5, 0.5, 0.01)
fit_dif = np.mean(np.mean(rot_dp_data, axis=(0, 1))[fit_range[0]:fit_range[1]]) - np.mean(AF_mean_square[fit_range[0]:fit_range[1]])
fit_ratios = []
errs = []
for i in range(len(cor_term_ctb)):
    fit_ratio = np.mean(np.mean(rot_dp_data, axis=(0, 1))[fit_range[0]:fit_range[1]]) / (np.mean(AF_mean_square[fit_range[0]:fit_range[1]])+cor_term_ctb[i]*fit_dif)
    fit_ratios.append(fit_ratio)
    error = np.linalg.norm(np.mean(rot_dp_data, axis=(0, 1))[fit_range[0]:fit_range[1]]-(AF_mean_square[fit_range[0]:fit_range[1]]+cor_term_ctb[i]*fit_dif)*fit_ratio)
    errs.append(error)

opt_ind = np.argmin(errs)
print(opt_ind)
print(cor_term_ctb[opt_ind], fit_dif, fit_ratios[opt_ind])

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(cor_term_ctb, errs, 'k*')
ax.grid()
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(k_list, np.mean(rot_dp_data, axis=(0, 1)), 'k-')
ax.plot(k_list, (AF_mean_square+cor_term_ctb[opt_ind]*fit_dif)*fit_ratios[opt_ind], 'r-')
ax.plot(k_list, AF_square_mean*fit_ratios[opt_ind], 'g+')
ax.grid()
fig.tight_layout()
plt.show()

In [None]:
def rif_to_rdf(rif, r_list, k_list):
    gr = []
    for i in range(len(r_list)):
        sin_rk = np.sin(2*np.pi*r_list[i]*k_list)
        rif_sin = rif * sin_rk
        gr_tmp = 8 * np.pi * simpson(rif_sin, dx=(k_list[1]-k_list[0]))
        gr.append(gr_tmp)
    
    return np.asarray(gr)

In [None]:
r_size = 500
r_scale = 0.01 # Angstrom
r_list = np.arange(r_size) * r_scale
r_unit = unit[-1][-1]
print(r_list.shape)
print(r_scale)
print(r_unit)
#print(r_list)

In [None]:
%matplotlib widget
fig, ax = plt.subplots(2, 2, figsize=(8, 5))
int_img = np.sum(rot_dp_data, axis=2)

def RDF_4DSTEM(yp, xp, aph, ln, f_n, f_lc, f_c, full_check):
    
    for a in ax.flat:
        a.cla()
        
    ax[0][0].imshow(int_img, cmap="gray")
    ax[0][0].axis("off")
    ax[0][0].scatter(xp, yp, c="red", s=15)
    
    filt = damp[f_n].copy()
    filt[:int(f_lc*len(filt))] = 0
    filt[int(f_c*len(filt)):] = 0
    print(filter_types[f_n])
    ind = np.linspace(0, len(filt)-1, size[n_dim-1]).astype(np.int16)
    filt = filt[ind]
    rot_dp_tmp = rot_dp_data[yp, xp]
    
    Nfs_tmp = ln*AF_mean_square
    rif_tmp = ((rot_dp_tmp + aph - Nfs_tmp) / (ln*AF_square_mean)) * k_list
    rif_tmp_filtered = (((rot_dp_tmp + aph - Nfs_tmp) / (ln*AF_square_mean)) * k_list) * filt
    
    ax[0][1].plot(k_list, rot_dp_tmp, 'k-')
    ax[0][1].plot(k_list, Nfs_tmp, 'r-')
    ax[0][1].set_ylim(ymin=0.0)
    ax[0][1].fill_between([k_list[0], k_list[int(f_lc*len(filt))]], np.max(rot_dp_tmp), alpha=0.5, color="green")
    ax[0][1].fill_between([k_list[int(f_c*len(filt))], k_list[-1]], np.max(rot_dp_tmp), alpha=0.5, color="green")
    ax[0][1].grid()
    ax[1][0].plot(k_list, rif_tmp, 'k-')
    ax[1][0].plot(k_list, rif_tmp_filtered, 'g-')
    ax[1][0].fill_between([k_list[0], k_list[int(f_lc*len(filt))]], np.max(rif_tmp), alpha=0.5, color="green")
    ax[1][0].fill_between([k_list[int(f_c*len(filt))], k_list[-1]], np.max(rif_tmp), alpha=0.5, color="green")
    ax[1][0].grid()
    
    tmp_Gr = rif_to_rdf(rif_tmp_filtered, r_list, k_list)
    ax[1][1].plot(r_list, tmp_Gr, 'k-')
    if not full_check:
        ax[1][1].set_xlim(xmin=1.0, xmax=2.5)
    ax[1][1].grid()
    
    fig.tight_layout()
    
st = {"description_width": "initial"}
y_wg = pyw.BoundedIntText(value=0, min=0, max=size[0]-1, description="y position", style=st)
x_wg = pyw.BoundedIntText(value=0, min=0, max=size[1]-1, description="x position", style=st)
N_wg = pyw.FloatText(value=1.0, step=0.1, description="N", style=st)
alpha_wg = pyw.FloatText(value=0.0, step=0.01, description="alpha", style=st)
filter_wg = pyw.BoundedIntText(value=0, min=0, max=len(filter_types)-1, description="filter", style=st)
filter_low_wg = pyw.BoundedFloatText(value=0.05, min=0.0, max=1.0, step=0.01, description="filter low cut", style=st)
filter_cut_wg = pyw.BoundedFloatText(value=0.8, min=0.0, max=1.0, step=0.01, description="filter high cut", style=st)
full_wg = pyw.Checkbox(value=True, description="show the full range of RDF")

pyw.interact(RDF_4DSTEM, yp=y_wg, xp=x_wg, aph=alpha_wg, fit_range=fit_range_wg, 
             ln=N_wg, f_n=filter_wg, f_lc=filter_low_wg, f_c=filter_cut_wg, full_check=full_wg)
plt.show()

In [None]:
alpha = alpha_wg.value
print(alpha)
N = N_wg.value
print(N)
filter_select = filter_wg.value
print(filter_select)
high_cut = filter_cut_wg.value
print(high_cut)

In [None]:
RIF_data = []
for i in range(size[1]):
    for j in range(size[0]):
        filt = damp[filter_select].copy()
        filt[int(high_cut*len(filt)):] = 0
        ind = np.linspace(0, len(filt)-1, size[n_dim-1]).astype(np.int16)
        filt = filt[ind]
        rif = (((rot_dp_data[i, j] + alpha - N*AF_mean_square) / (N*AF_square_mean)) * k_list) * filt
        RIF_data.append(rif)
               
RIF_data = np.asarray(RIF_data).reshape(size[1], size[0], -1)

In [None]:
%matplotlib inline

In [None]:
Gr = []
for i in range(size[1]):
    for j in range(size[0]):
        tmp_Gr = rif_to_rdf(RIF_data[i, j], r_list, k_list)
        Gr.append(tmp_Gr)

Gr = np.asarray(Gr).reshape(size[1], size[0], -1)
print(Gr.shape)    

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
for i in range(size[0]):
    ax.plot(r_list, Gr[0, i])
ax.grid()
ax.set_xlim(xmin=1.0, xmax=3.0)
fig.tight_layout()
plt.show()

In [None]:
tifffile.imwrite(raw_adr[:-4]+"_Gr_upto_%dA_size_%d_2.tif"%(int(r_size*r_scale), r_size), Gr)

In [None]:
tifffile.imwrite(raw_adr[:-4]+"_RIF.tif", RIF_data)