In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
from scipy import stats
from scipy.signal import periodogram
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import RectangleSelector
import tkinter.filedialog as tkf
import hyperspy.api as hys
from scipy import optimize
import ipywidgets as pyw
import time
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
from sklearn.cluster import DBSCAN
plt.rcParams['font.family'] = 'Times New Roman'

In [None]:
# refer to https://scipy-cookbook.readthedocs.io/items/FittingData.html

def gaussian(height, center_x, center_y, width_x, width_y):
    """Returns a gaussian function with the given parameters"""
    width_x = float(width_x)
    width_y = float(width_y)
    return lambda x,y: height*np.exp(
                -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)

def moments(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution by calculating its
    moments """
    total = data.sum()
    X, Y = np.indices(data.shape) # row, col
    x = (X*data).sum()/total # row
    y = (Y*data).sum()/total # col
    col = data[:, int(y)]
    width_x = np.sqrt(np.abs((np.arange(col.size)-y)**2*col).sum()/col.sum()) # row
    row = data[int(x), :]
    width_y = np.sqrt(np.abs((np.arange(row.size)-x)**2*row).sum()/row.sum()) # col
    height = data.max()
    return height, x, y, width_x, width_y

def fitgaussian(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution found by a fit"""
    params = moments(data)
    errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
                                 data)
    p, success = optimize.leastsq(errorfunction, params)
    return p

# refer to "github.com/mkolopanis/python/blob/master/radialProfile.py"

def radial_stats(image, center=None, var=True):
   
    y, x = np.indices(image.shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    r = np.hypot(y - center[0], x - center[1])
    #plt.imshow(r, cmap="Accent")
    #plt.show()

    # Get sorted radii
    ind = np.argsort(r.flat)
    r_sorted = r.flat[ind]
    i_sorted = image.flat[ind]

    # Get the integer part of the radii (bin size = 1)
    r_int = np.around(r_sorted)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = np.where(deltar)[0]       # location of changed radius
    nr = rind[1:] - rind[:-1]        # number of radius bin
    #print(nr)
    
    csim = np.cumsum(i_sorted, dtype=float)
    sq_csim = np.cumsum(np.square(i_sorted), dtype=float)
    radial_avg  = (csim[rind[1:]] - csim[rind[:-1]]) / nr
    
    if var:    
        avg_square = np.square(radial_avg)
        square_avg = (sq_csim[rind[1:]] - sq_csim[rind[:-1]]) / nr
        mask = avg_square.copy()
        mask[np.where(avg_square==0)] = 1.0
        radial_var = (square_avg - avg_square) / mask
        return r, radial_avg, radial_var, (square_avg - avg_square)
    
    else:
        return r, radial_avg
    
def gaussian_center(image, cbox_edge=0):
    y, x = np.indices(image.shape)
    if not cbox_edge:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    else:
        cbox_outy = int(image.shape[0]/2 - cbox_edge/2)
        cbox_outx = int(image.shape[1]/2 - cbox_edge/2)
        center_box = image[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
        fit_params = fitgaussian(center_box)
        (_, center_y, center_x, _, _) = fit_params
        center = [center_y+cbox_outy, center_x+cbox_outx]
        
    return center

def remove_center_beam(image, center=None, cb_rad=0):
    y, x = np.indices(image.shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    ri = np.where(r<=cb_rad)
    #print(ri[0].shape)
    
    image[ri] = 0
    
    return image

def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    print(stack.shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    
    print(stack.shape) 
    return stack

def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[0][ai], ri[1][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

def round_step(arr, step_size=1):
    
    divide = np.around(arr / step_size)
    rounded = divide * step_size
    
    return rounded

def point_circle(image, radius, center=None):
    
    y, x = np.indices(image.shape)
    y = y.astype("float64")
    x = x.astype("float64")
    
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    y -= center[0]
    x -= center[1]
    
    angle_arr = np.zeros(image.shape)
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j], y[i, j]), deg=True)
            
    angle_arr = angle_arr + 180
    #print(angle_arr)
    r = np.hypot(y, x)
    r = np.around(r)
    ri = np.where(r == radius)
    angle_sel = angle_arr[ri]
    value_sel = image[ri]

    return angle_arr, angle_sel, value_sel, ri

# angle unit [degree]
def angular_correlation(angles, values):
    angle_diff = []
    corr_values = []
    angle_diff_candidates = np.arange(0, 361, 1)
    corr_val_total = np.zeros(angle_diff_candidates.shape)
    
    int_sqavg = np.square(np.mean(values))
    
    for i in range(len(angles)):
        temp_diff = np.abs(angles[i:] - angles[i])
        temp_diff = round_step(temp_diff)
        corr_temp = values[i:] * values[i]

        angle_diff.extend(temp_diff.tolist())
        corr_values.extend(corr_temp.tolist())
        
    angle_diff = np.asarray(angle_diff)
    corr_values = np.asarray(corr_values)
    
    uni_angle = np.unique(angle_diff)
    #print(uni_angle)
    
    for ang in uni_angle:
        temp_ind = np.where(angle_diff==ang)
        corr_val_total[int(ang)] = np.mean(corr_values[temp_ind]) - int_sqavg

    return (corr_val_total / int_sqavg), corr_val_total

# load data

In [None]:
raw_adr = tkf.askopenfilename()
print(raw_adr)

In [None]:
# Load a data (DM)
stack_4d = hys.load(raw_adr).data
print(stack_4d.shape)
stack_4d = fourd_roll_axis(stack_4d)
f_shape = stack_4d.shape
print(f_shape)

stack_4d = stack_4d / np.max(stack_4d)
if np.isnan(np.max(stack_4d)):
    print("NaN exists")
    stack_4d = np.nan_to_num(stack_4d)
stack_4d = stack_4d.clip(min=0.0)
print(np.max(stack_4d))
print(np.min(stack_4d))
print(np.mean(stack_4d))

In [None]:
# just copy
stack_4d_cropped = stack_4d.copy()
cr_shape = stack_4d_cropped.shape
print(cr_shape)
print(np.max(stack_4d_cropped))
print(np.min(stack_4d_cropped))
print(np.mean(stack_4d_cropped))

# image quick check

In [None]:
%matplotlib inline

In [None]:
# maximum intensity distribution
max_int = np.max(stack_4d_cropped, axis=(2, 3))
print(max_int.shape)
print(np.max(max_int))
print(np.min(max_int))
print(np.mean(max_int))
print(np.median(max_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(max_int, cmap="afmhot")
ax[0].axis("off")
ax[1].hist(max_int.flatten(), bins=len(max_int))
ax[1].grid()
plt.show()

In [None]:
# total intensity distribution
tot_int = np.sum(stack_4d_cropped, axis=(2, 3))
print(max_int.shape)
print(np.max(tot_int))
print(np.min(tot_int))
print(np.mean(tot_int))
print(np.median(tot_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(tot_int, cmap="afmhot")
ax[0].axis("off")
ax[1].hist(tot_int.flatten(), bins=len(max_int))
ax[1].grid()
plt.show()

# center beam positions

In [None]:
%matplotlib inline

In [None]:
# find center position
center_pos = []
for i in range(cr_shape[0]):
    for j in range(cr_shape[1]):
        center_pos.append(gaussian_center(stack_4d_cropped[i, j], cbox_edge=8))
        
center_pos = np.asarray(center_pos)
center_pos = np.reshape(center_pos, (cr_shape[0], cr_shape[1], -1))
print(center_pos.shape)
center_mean = np.mean(center_pos, axis=(0, 1))
print(center_mean)

# center distibution
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(center_pos[:, :, 0].flatten(), bins=100, density=True, color="orange", label="center y position")
ax[0].hist(center_pos[:, :, 1].flatten(), bins=100, density=True, color="gray", alpha=0.5, label="center x position")
ax[0].grid()
ax[0].legend()

ax[1].scatter(center_pos[:, :, 1], center_pos[:, :, 0], s=10.0, alpha=0.5)
ax[1].grid()
ax[1].scatter(center_mean[1], center_mean[0], s=20, c="red")
ax[1].set_xlabel("center x position", fontsize=10)
ax[1].set_ylabel("center y position", fontsize=10)
fig.tight_layout()
plt.show()

# k-point FEM

In [None]:
%matplotlib qt

In [None]:
%matplotlib inline

In [None]:
# obtain variance map dpending on k-vector
square_avg = np.mean(np.square(stack_4d_cropped), axis=(0,1))
avg_square = np.square(np.mean(stack_4d_cropped, axis=(0,1)))
mask = avg_square.copy()
mask[np.where(avg_square == 0)] = 1.0
var_map = (square_avg - avg_square) / mask
var_nonor = square_avg - avg_square

# obtain radial average prifile for variance map above
radial_map, kp_var = radial_stats(var_map, center=center_mean.tolist(), var=False)
print(kp_var.shape)
print(np.argmax(kp_var))
_, kp_nonor = radial_stats(var_nonor, center=center_mean.tolist(), var=False)
print(kp_nonor.shape)
print(np.argmax(kp_nonor))

fig, ax = plt.subplots(1, 5, figsize=(15, 3))
ax[0].imshow(radial_map, cmap="gray")
ax[0].axis("off")
ax[1].imshow(var_map, cmap="afmhot")
ax[1].axis("off")
ax[2].plot(kp_var, "k-")
ax[2].grid()
ax[3].imshow(var_nonor, cmap="afmhot")
ax[3].axis("off")
ax[4].plot(kp_nonor, "k-")
ax[4].grid()
plt.show()

# k-radial FEM

In [None]:
# calculate variance with all angles at a certain k
radial_var_stack = []
len_profile = []
for i in range(cr_shape[0]):
    for j in range(cr_shape[1]):
        _, _, radial_temp, _ = radial_stats(stack_4d_cropped[i, j], center=center_mean.tolist(), var=True)
        len_profile.append(len(radial_temp))
        radial_var_stack.append(radial_temp)

if len(np.unique(len_profile)) > 1:
    print(np.unique(len_profile))
    shortest = np.min(len_profile)
    for i in range(len(len_profile)):
        radial_var_stack[i] = radial_var_stack[i][:shortest]

radial_var_stack = np.asarray(radial_var_stack).reshape(stack_4d_cropped.shape[0], stack_4d_cropped.shape[1], -1)
print(radial_var_stack.shape)

radial_var_sum = np.sum(radial_var_stack, axis=(0, 1))
print(radial_var_sum.shape)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(radial_var_sum)
fig.tight_layout()
plt.show()

In [None]:
# save (radial variance, 3D)
radial_var_stack_save = hys.signals.Signal1D(radial_var_stack)
radial_var_stack_save.metadata.General.set_item("notes", center_mean.tolist())
radial_var_stack_save.save(raw_adr[:-4]+"_radial_var_radial.hdf5")

In [None]:
radial_var = radial_var_stack.copy()
radial_var_spectrum = radial_var_sum.copy()

In [None]:
def local_var_similarity(var_map, w_size, stride):

    var_map = np.asarray(var_map)
    rows = range(0, var_map.shape[0]-w_size+1, stride)
    cols = range(0, var_map.shape[1]-w_size+1, stride)
    new_shape = (len(rows), len(cols))
    
    surr_avg = []
    surr_std = []
    surr_dif = []
    for i in rows:
        for j in cols:
            local_region = var_map[i:i+w_size, j:j+w_size].flatten()
            local_region = local_region / np.max(local_region)
            
            temp_avg = np.mean(local_region)
            temp_std = np.std(local_region)
            surr_avg.append(temp_avg)
            surr_std.append(temp_std)
            diff_mse = np.mean(np.square(local_region - local_region[int(w_size**2/2)]))
            surr_dif.append(diff_mse)
            
    surr_avg = np.asarray(surr_avg).reshape(new_shape)
    surr_std = np.asarray(surr_std).reshape(new_shape)
    surr_dif = np.asarray(surr_dif).reshape(new_shape)
    
    return surr_avg, surr_std, surr_dif, new_shape

def local_DP_similarity(f_stack, w_size, stride):

    f_stack = np.asarray(f_stack)
    rows = range(0, f_stack.shape[0]-w_size+1, stride)
    cols = range(0, f_stack.shape[1]-w_size+1, stride)
    new_shape = (len(rows), len(cols))
    
    dp_mse = []
    dp_ssim = []
    for i in rows:
        for j in cols:        
            local_region = f_stack[i:i+w_size, j:j+w_size]
            ref_dp = local_region[int(w_size/2), int(w_size/2)]
            local_region = local_region.reshape(w_size**2, -1)
            tmp_mse = []
            tmp_ssim = []
            for fdp in local_region:
                tmp_mse.append(mean_squared_error(ref_dp/np.max(ref_dp), fdp/np.max(fdp)))
                tmp_ssim.append(ssim(ref_dp/np.max(ref_dp), fdp/np.max(fdp)))
                
            dp_mse.append(np.mean(tmp_mse))
            dp_ssim.append(np.mean(tmp_ssim))
            
    dp_mse = np.asarray(dp_mse).reshape(new_shape)
    dp_ssim = np.asarray(dp_ssim).reshape(new_shape)
    
    return dp_mse, dp_ssim, new_shape

In [None]:
plt.figure(figsize=(10, 10))
tmp_radius = 22
tmp = radial_indices(cr_shape[2:], [tmp_radius], center=center_mean.tolist())
plt.imshow(tmp)
plt.imshow(np.log(stack_4d_cropped[3, 3]), alpha=0.9)
tmp_ind, _ = indices_at_r(cr_shape[2:], tmp_radius, center=center_mean.tolist())
#plt.scatter(tmp_ind[1], tmp_ind[0], alpha=0.9)
plt.show()

In [None]:
win_size = 3
stride = 1
k_selected = 22

In [None]:
k_var_map = radial_var[:, :, k_selected].copy()
#k_var_map = k_var_map.clip(max=np.percentile(k_var_map, 99))

local_avg, local_std, local_dif, bin_shape = local_var_similarity(k_var_map, win_size, stride)
print(local_avg.shape)
print(local_std.shape)
print(local_dif.shape)

mask = np.zeros(k_var_map.shape)
mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)] = 1

print(mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)].shape)
print(bin_shape)

fig, ax = plt.subplots(1, 4, figsize=(7, 15))
ax[0].imshow(k_var_map[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)], cmap="afmhot")
#ax[0].imshow(mask, cmap="gray", alpha=0.5)
ax[0].axis("off")
ax[1].imshow(local_avg, cmap="viridis")
ax[1].axis("off")
ax[2].imshow(-local_dif, cmap="viridis")
ax[2].axis("off")
ax[3].imshow(-local_std, cmap="viridis")
ax[3].axis("off")
fig.tight_layout()
plt.show()

In [None]:
k_ind, a_ind = indices_at_r(cr_shape[2:], k_selected, center_mean.tolist())

f_flat = stack_4d_cropped[:, :, k_ind[1], k_ind[0]]
print(f_flat.shape)

dp_mse, dp_ssim, bin_shape = local_DP_similarity(f_flat, win_size, stride)
print(dp_mse.shape)
print(dp_ssim.shape)

mask = np.zeros(f_flat.shape[:2])
mask[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)] = 1

fig, ax = plt.subplots(1, 2, figsize=(5, 15))
ax[0].imshow(-dp_mse, cmap="viridis")
ax[0].axis("off")
ax[1].imshow(dp_ssim, cmap="viridis")
ax[1].axis("off")
fig.tight_layout()
plt.show()

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
th_sigma = 0.5
high_var = k_var_map.clip(min=(np.mean(k_var_map)+th_sigma*np.std(k_var_map)))[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)]
low_dif = local_dif.clip(max=(np.mean(local_dif)-0.5*np.std(local_dif)))
low_std = local_std.clip(max=(np.mean(local_std)-th_sigma*np.std(local_std)))
high_ssim = dp_ssim.clip(min=(np.mean(dp_ssim)+th_sigma*np.std(dp_ssim)))
low_mse = dp_mse.clip(max=(np.mean(dp_mse)-th_sigma*np.std(dp_mse)))

fig, ax = plt.subplots(1, 5, figsize=(5, 15))
ax[0].imshow(high_var, cmap="viridis")
ax[0].axis("off")
ax[1].imshow(-low_dif, cmap="viridis")
ax[1].axis("off")
ax[2].imshow(-low_std, cmap="viridis")
ax[2].axis("off")
ax[3].imshow(high_ssim, cmap="viridis")
ax[3].axis("off")
ax[4].imshow(-low_mse, cmap="viridis")
ax[4].axis("off")
fig.tight_layout()
plt.show()

In [None]:
th_sigma = 0.5
high_var = k_var_map.clip(min=(np.mean(k_var_map)+th_sigma*np.std(k_var_map)))[int((win_size-1)/2):-int(win_size/2), int((win_size-1)/2):-int(win_size/2)]
low_dif = local_dif.clip(max=(np.mean(local_dif)-0.5*np.std(local_dif)))
low_std = local_std.clip(max=(np.mean(local_std)-th_sigma*np.std(local_std)))
high_ssim = dp_ssim.clip(min=(np.mean(dp_ssim)+th_sigma*np.std(dp_ssim)))
low_mse = dp_mse.clip(max=(np.mean(dp_mse)-th_sigma*np.std(dp_mse)))

fig, ax = plt.subplots(1, 5, figsize=(5, 15))
ax[0].imshow(high_var, cmap="viridis")
ax[0].axis("off")
ax[1].imshow(-low_dif, cmap="viridis")
ax[1].axis("off")
ax[2].imshow(-low_std, cmap="viridis")
ax[2].axis("off")
ax[3].imshow(high_ssim, cmap="viridis")
ax[3].axis("off")
ax[4].imshow(-low_mse, cmap="viridis")
ax[4].axis("off")
fig.tight_layout()
plt.show()

In [None]:
local_size = win_size
#local_size = 3
row_, col_ = 83-int(local_size/2), 2-int(local_size/2)

selected_region = f_flat[row_:row_+local_size, col_:col_+local_size]
print(selected_region.shape)
selected_region = selected_region.reshape(local_size**2, -1)

fig, ax = plt.subplots(local_size, local_size, figsize=(20, 20))
for i, axs in enumerate(ax.flat):
    axs.plot(selected_region[i])
    axs.grid()
    
plt.show()

In [None]:
local_size = win_size
#local_size = 3
row_, col_ = 46-int(local_size/2), 3-int(local_size/2)

selected_region = f_flat[row_:row_+local_size, col_:col_+local_size]
print(selected_region.shape)
selected_region = selected_region.reshape(local_size**2, -1)

fig, ax = plt.subplots(local_size, local_size, figsize=(20, 20))
for i, axs in enumerate(ax.flat):
    axs.plot(selected_region[i])
    axs.grid()
    
plt.show()