In [None]:
# Jinseok Ryu, PhD
# jinseuk56@gmail.com

import numpy as np
import matplotlib.pyplot as plt
import tkinter.filedialog as tkf
import time
import tifffile
import sys
sys.path.append("")
from FDSTEM_process import *
plt.rcParams['font.family'] = 'Cambria'

# load data

In [None]:
raw_adr = tkf.askopenfilename()
print(raw_adr)

In [None]:
fd = FourDSTEM_process(raw_adr)

In [None]:
fd.spike_remove(percent_thresh=0.01, mode="lower", apply_remove=True)

In [None]:
%matplotlib widget
fd.show_4d_viewer(fd.original_stack)

In [None]:
%matplotlib inline

In [None]:
c_pos = fd.find_center(cbox_edge=15)
print(c_pos)
mean_dp = fd.original_mean_dp

In [None]:
mean_dp_radial_avg, _ = radial_stats(mean_dp, center=c_pos, var=False)

# obtain variance map dpending on k-vector
square_avg = np.mean(np.square(fd.original_stack), axis=(0,1))
avg_square = np.square(np.mean(fd.original_stack, axis=(0,1)))
mask = avg_square.copy()
mask[np.where(avg_square == 0)] = 1.0
var_map = (square_avg - avg_square) / mask

_, mean_dp_radial_var = radial_stats(var_map, center=c_pos, var=True)

fig, ax = plt.subplots(1, 4, figsize=(12, 3))
ax[0].imshow(np.log(mean_dp), cmap="gray")
ax[0].axis("off")
ax[1].plot(mean_dp_radial_avg, "k-")
ax[1].grid()
ax[2].imshow(var_map, cmap="inferno")
ax[2].axis("off")
ax[3].plot(mean_dp_radial_var, "k-")
ax[3].grid()
fig.tight_layout()
plt.show()

# rotational average & variance profile

In [None]:
fd.rotational_average(rot_variance=True)

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(10, 10))
ax[0].plot(np.sum(fd.radial_avg_stack, axis=(0, 1)), "k-")
ax[1].plot(np.sum(fd.radial_var_stack, axis=(0, 1)), "k-")
fig.tight_layout()
plt.show()

In [None]:
%matplotlib widget
fd.show_3d_viewer(fd.radial_avg_stack)

In [None]:
# save (radial average, 3D)
tifffile.imsave(raw_adr[:-4]+"_radial_avg.tif", fd.radial_avg_stack)

In [None]:
# save (radial variance, 3D)
tifffile.imsave(raw_adr[:-4]+"_radial_var.tif", fd.radial_var_stack)

# local similarity

In [None]:
%matplotlib inline

In [None]:
%matplotlib widget

In [None]:
radial_var = fd.radial_var_stack.copy()
radial_var_spectrum = np.sum(fd.radial_var_stack, axis=(0, 1))

In [None]:
tmp_radius = [21, 22, 23]

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(np.log(stack_4d[10, 10]))
ax.axis("off")
for r in tmp_radius:
    tmp = radial_indices(f_shape[2:], [r], center=c_pos)
    ax.imshow(tmp, alpha=0.1, cmap="gray")
fig.tight_layout()
plt.show()

In [None]:
plt.close('all')

In [None]:
win_size = 3
stride = 1
k_selected = 22

In [None]:
k_var_map = radial_var[:, :, k_selected].copy()
#k_var_map = k_var_map.clip(max=np.percentile(k_var_map, 99))

local_avg, local_std, local_dif, bin_shape = local_var_similarity(k_var_map, win_size, stride)
print(local_avg.shape)
print(local_std.shape)
print(local_dif.shape)

mask = np.zeros(k_var_map.shape)
mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)] = 1

print(mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)].shape)
print(bin_shape)

In [None]:
fig, ax = plt.subplots(4, 1, figsize=(5, 8))
ax[0].imshow(k_var_map[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)], cmap="afmhot")
#ax[0].imshow(mask, cmap="gray", alpha=0.5)
ax[0].axis("off")
ax[1].imshow(local_avg, cmap="viridis")
ax[1].axis("off")
ax[2].imshow(-local_dif, cmap="viridis")
ax[2].axis("off")
ax[3].imshow(-local_std, cmap="viridis")
ax[3].axis("off")
fig.tight_layout()
plt.show()

In [None]:
k_ind, a_ind = indices_at_r(f_shape[2:], k_selected, c_pos)

f_flat = stack_4d[:, :, k_ind[0], k_ind[1]]
print(f_flat.shape)

dp_mse, dp_ssim, bin_shape = local_DP_similarity(f_flat, win_size, stride)
print(dp_mse.shape)
print(dp_ssim.shape)

mask = np.zeros(f_flat.shape[:2])
mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)] = 1

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(5, 4))
ax[0].imshow(-dp_mse, cmap="viridis")
ax[0].axis("off")
ax[1].imshow(dp_ssim, cmap="viridis")
ax[1].axis("off")
fig.tight_layout()
plt.show()

In [None]:
th_sigma = [0.6, 0.6, 0.6, 0.6, 0.6]
high_var = k_var_map.clip(min=(np.mean(k_var_map)+th_sigma[0]*np.std(k_var_map)))[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)]
print(high_var.shape)
low_dif = local_dif.clip(max=(np.mean(local_dif)-th_sigma[1]*np.std(local_dif)))
print(low_dif.shape)
low_std = local_std.clip(max=(np.mean(local_std)-th_sigma[2]*np.std(local_std)))
print(low_std.shape)
high_ssim = dp_ssim.clip(min=(np.mean(dp_ssim)+th_sigma[3]*np.std(dp_ssim)))
print(high_ssim.shape)
low_mse = dp_mse.clip(max=(np.mean(dp_mse)-th_sigma[4]*np.std(dp_mse)))
print(low_mse.shape)

fig, ax = plt.subplots(5, 1, figsize=(5, 10))
ax[0].imshow(high_var, cmap="viridis")
ax[0].axis("off")
ax[1].imshow(-low_dif, cmap="viridis")
ax[1].axis("off")
ax[2].imshow(-low_std, cmap="viridis")
ax[2].axis("off")
ax[3].imshow(high_ssim, cmap="viridis")
ax[3].axis("off")
ax[4].imshow(-low_mse, cmap="viridis")
ax[4].axis("off")
fig.tight_layout()
plt.show()

In [None]:
plt.close("all")

In [None]:
k_selected = 22
k_var_map = radial_var[:, :, k_selected].copy()
k_ind, a_ind = indices_at_r(f_shape[2:], k_selected, c_pos)
f_flat = stack_4d[:, :, k_ind[0], k_ind[1]]

win_sizes = np.array([3, 5, 7, 9, 11])
stride = 1
rows = range(0, f_shape[0]-np.max(win_sizes)+1, stride)
cols = range(0, f_shape[1]-np.max(win_sizes)+1, stride)

In [None]:
var_dif_stack = []
var_std_stack = []
dp_ssim_stack = []
dp_mse_stack = []
for i in range(len(win_sizes)):
    local_avg, local_std, local_dif, dp_mse, dp_ssim, bin_shape = local_similarity(k_var_map, f_flat, win_sizes[i], rows, cols)
    
    var_dif_stack.append(-local_dif)
    var_std_stack.append(-local_std)
    dp_ssim_stack.append(dp_ssim)
    dp_mse_stack.append(-dp_mse)
    
var_dif_stack = np.asarray(var_dif_stack)
print(var_dif_stack.shape)
var_std_stack = np.asarray(var_std_stack)
print(var_std_stack.shape)
dp_ssim_stack = np.asarray(dp_ssim_stack)
print(dp_ssim_stack.shape)
dp_mse_stack = np.asarray(dp_mse_stack)
print(dp_mse_stack.shape)

In [None]:
var_dif_stack = np.rollaxis(np.rollaxis(var_dif_stack, 2, 0), 2, 0)
print(var_dif_stack.shape)
var_std_stack = np.rollaxis(np.rollaxis(var_std_stack, 2, 0), 2, 0)
print(var_std_stack.shape)
dp_ssim_stack = np.rollaxis(np.rollaxis(dp_ssim_stack, 2, 0), 2, 0)
print(dp_ssim_stack.shape)
dp_mse_stack = np.rollaxis(np.rollaxis(dp_mse_stack, 2, 0), 2, 0)
print(dp_mse_stack.shape)

In [None]:
fig, ax = plt.subplots(4, 1, figsize=(5, 8))
ax[0].plot(win_sizes, np.mean(var_dif_stack, axis=(0, 1)))
ax[0].grid()
ax[1].plot(win_sizes, np.mean(var_std_stack, axis=(0, 1)))
ax[1].grid()
ax[2].plot(win_sizes, np.mean(dp_ssim_stack, axis=(0, 1)))
ax[2].grid()
ax[3].plot(win_sizes, np.mean(dp_mse_stack, axis=(0, 1)))
ax[3].grid()
fig.tight_layout()
plt.show()

In [None]:
local_size = win_size
row_, col_ = 5-int(local_size/2), 5-int(local_size/2)

selected_region = f_flat[row_:row_+local_size, col_:col_+local_size]
print(selected_region.shape)
selected_region = selected_region.reshape(local_size**2, -1)

fig, ax = plt.subplots(local_size, local_size, figsize=(20, 20))
for i, axs in enumerate(ax.flat):
    axs.plot(selected_region[i])
    axs.grid()
    
plt.show()

In [None]:
th_sigma = 1.35
high_ssim = dp_ssim.clip(min=(np.mean(dp_ssim)+th_sigma*np.std(dp_ssim)))
high_ind = np.where(high_ssim > (np.mean(dp_ssim)+th_sigma*np.std(dp_ssim)))

print(len(high_ind[0]))

In [None]:
for i in range(len(high_ind[0])):
    y_pos, x_pos = high_ind[0][i], high_ind[1][i]
    print(y_pos, x_pos)

    ref_dp = f_flat[y_pos, x_pos]
    ssim_result = []
    mse_result = []

    for i in range(f_shape[0]):
        for j in range(f_shape[1]):
            tmp_dp = f_flat[i, j]
            mse_result.append(mean_squared_error(ref_dp/np.max(ref_dp), tmp_dp/np.max(tmp_dp)))
            ssim_result.append(ssim(ref_dp/np.max(ref_dp), tmp_dp/np.max(tmp_dp)))

    ssim_result = np.asarray(ssim_result).reshape(f_shape[:2])
    ssim_result = ssim_result / np.max(ssim_result)
    mse_result = np.asarray(mse_result).reshape(f_shape[:2])
    mse_result = mse_result / np.max(mse_result)

    ssim_result[y_pos, x_pos] = 0.0
    mse_result[y_pos, x_pos] = 1.0

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(-mse_result, cmap="viridis")
    ax[0].scatter(x_pos, y_pos, c="red")
    ax[0].axis("off")
    ax[1].imshow(ssim_result, cmap="viridis")
    ax[1].scatter(x_pos, y_pos, c="red")
    ax[1].axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
k_range = np.arange(19, 24, 1)
start_time = time.process_time()
ac_spectra = []
ac_fft_stack = []
angle_sampling = 361
angles = np.arange(angle_sampling)
tril_mask = np.ones((angle_sampling, angle_sampling))
tril_mask = np.triu(tril_mask, 0)
tril_mask[np.where(tril_mask==0)] = np.nan

In [None]:
k_ind, a_ind = indices_at_r(f_shape[2:], 22, c_pos)

value_sel = stack_4d[49, 100, k_ind[0], k_ind[1]]
values = np.zeros(angle_sampling)
values[a_ind.astype(int)] = value_sel

In [None]:
fig, ax = plt.subplots(2, 1)
ax[0].plot(values)
ax[1].plot(ndimage.gaussian_filter(values, sigma=1.0))
fig.tight_layout()
plt.show()

# angular correlation

In [None]:
k_range = np.arange(19, 24, 1)
start_time = time.process_time()
ac_spectra = []
ac_fft_stack = []
angle_sampling = 361
angles = np.arange(angle_sampling)
tril_mask = np.ones((angle_sampling, angle_sampling))
tril_mask = np.triu(tril_mask, 0)
tril_mask[np.where(tril_mask==0)] = np.nan

for k in k_range:
    k_ind, a_ind = indices_at_r(f_shape[2:], k, c_pos)
    temp_spectra = []
    temp_fft_stack = []
    for i in range(f_shape[0]):
        for j in range(f_shape[1]):
            value_sel = stack_4d[i, j, k_ind[0], k_ind[1]]
            values = np.zeros(angle_sampling)
            values[a_ind.astype(int)] = value_sel
            values = ndimage.gaussian_filter(values, sigma=2.0)

            dummy = np.roll(values, 1)
            value_stack = np.vstack((values, dummy))
            for l in range(len(values)-2):
                dummy = np.roll(dummy, 1)
                value_stack = np.vstack((value_stack, dummy))

            ang_corr = np.multiply(value_stack, values[np.newaxis, :])
            ang_corr = np.multiply(np.triu(ang_corr, 0), tril_mask)

            value_avgsq = np.mean(value_sel)**2
            ac_spectrum = np.nanmean(ang_corr, axis=1)
            ac_spectrum = (ac_spectrum / value_avgsq) - 1
            ac_fft = np.abs(np.fft.fft(ac_spectrum))
            
            temp_spectra.append(ac_spectrum)
            temp_fft_stack.append(ac_fft)
            
    temp_spectra = np.asarray(temp_spectra).reshape(f_shape[0], f_shape[1], -1)
    temp_fft_stack = np.asarray(temp_fft_stack).reshape(f_shape[0], f_shape[1], -1)
    ac_spectra.append(temp_spectra)
    ac_fft_stack.append(temp_fft_stack)
    print("%d radius completed"%(k))
    print("%d seconds have passed"%(time.process_time()-start_time))
print("all done")

In [None]:
ac_spectra = np.asarray(ac_spectra)
print(ac_spectra.shape)
ac_fft_stack = np.asarray(ac_fft_stack)
print(ac_fft_stack.shape)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 10))
ax[0].imshow(np.flip(np.mean(ac_spectra, axis=(1, 2)).T, 0), cmap="viridis", 
             extent=[k_range[0], k_range[-1], angles[0]/10, angles[-1]/10])
ax[1].imshow(np.flip(np.mean(ac_fft_stack, axis=(1, 2)).T[1:11], 0), cmap="viridis", 
             extent=[k_range[0], k_range[-1], 0.5, 10.5])
fig.tight_layout()
plt.show()

In [None]:
k_selected = 19
k_ind = np.where(k_range==k_selected)[0][0]
print(k_ind)
rot_sym = [2, 3, 4, 5, 6, 10]

ang_corr_rot = []
for r in rot_sym:
    ang_corr_rot.append(ac_fft_stack[k_ind, :, :, r]/np.max(ac_fft_stack[k_ind, :, :, r]))
ang_corr_rot = np.asarray(ang_corr_rot)
print(ang_corr_rot.shape)

ang_max_val = np.max(ang_corr_rot)
ang_min_val = np.min(ang_corr_rot)

k_var_map = radial_var[:, :, k_selected]

for i, r in enumerate(rot_sym):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(ang_corr_rot[i], cmap="inferno", vmin=ang_min_val, vmax=ang_max_val)
    ax[0].axis("off")
    ax[0].set_title("rotation symmetry %d"%r)
    ax[1].imshow(k_var_map, cmap="inferno")
    ax[1].axis("off")
    ax[2].imshow(ang_corr_rot[i], cmap="inferno", alpha=0.8)
    ax[2].contour(k_var_map, colors="k", alpha=1.0, levels=5)
    ax[2].axis("off")
    fig.tight_layout()
    plt.show()