In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
import hyperspy.api as hys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import RectangleSelector
from tabulate import tabulate
from sklearn.cluster import DBSCAN
import tkinter.filedialog as tkf
from scipy import optimize
import ipywidgets as pyw

In [None]:
# refer to https://scipy-cookbook.readthedocs.io/items/FittingData.html

def gaussian(height, center_x, center_y, width_x, width_y):
    """Returns a gaussian function with the given parameters"""
    width_x = float(width_x)
    width_y = float(width_y)
    return lambda x,y: height*np.exp(
                -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)

def moments(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution by calculating its
    moments """
    total = data.sum()
    X, Y = np.indices(data.shape) # row, col
    x = (X*data).sum()/total # row
    y = (Y*data).sum()/total # col
    col = data[:, int(y)]
    width_x = np.sqrt(np.abs((np.arange(col.size)-y)**2*col).sum()/col.sum()) # row
    row = data[int(x), :]
    width_y = np.sqrt(np.abs((np.arange(row.size)-x)**2*row).sum()/row.sum()) # col
    height = data.max()
    return height, x, y, width_x, width_y

def fitgaussian(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution found by a fit"""
    params = moments(data)
    errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
                                 data)
    p, success = optimize.leastsq(errorfunction, params)
    return p

In [None]:
def gaussian_center(image, cbox_edge=0):
    y, x = np.indices(image.shape)
    if not cbox_edge:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    else:
        cbox_outy = int(image.shape[0]/2 - cbox_edge/2)
        cbox_outx = int(image.shape[1]/2 - cbox_edge/2)
        center_box = image[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
        fit_params = fitgaussian(center_box)
        (_, center_y, center_x, _, _) = fit_params
        center = [center_y+cbox_outy, center_x+cbox_outx]
        
    return center

In [None]:
def remove_center_beam(image, center=None, cb_rad=0):
    y, x = np.indices(image.shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    ri = np.where(r<=cb_rad)
    #print(ri[0].shape)
    
    image[ri] = 0
    
    return image

In [None]:
def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    print(stack.shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    
    print(stack.shape) 
    return stack

In [None]:
def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

In [None]:
raw_adr = tkf.askopenfilename()
print(raw_adr)

In [None]:
datatype = "float32"
#o_shape = (128, 128, 130, 128)
#f_shape = (128, 128, 128, 128)
o_shape = (256, 256, 130, 128)
f_shape = (256, 256, 128, 128)

# load a data (RAW)
stack_4d = load_binary_4D_stack(raw_adr, datatype, o_shape, f_shape, log_scale=False)
stack_4d = stack_4d / np.max(stack_4d) # normalize absolutely
print(np.max(stack_4d))
print(np.min(stack_4d))
print(np.mean(stack_4d))

In [None]:
# Load a data (DM)
stack_4d = hys.load(raw_adr).data
print(stack_4d.shape)
stack_4d = fourd_roll_axis(stack_4d)
f_shape = stack_4d.shape
print(f_shape)

stack_4d = stack_4d / np.max(stack_4d)
print(np.max(stack_4d))
print(np.min(stack_4d))
print(np.mean(stack_4d))

In [None]:
# rotate
stack_4d = np.rot90(stack_4d, 2, (0, 1))
f_shape = stack_4d.shape
print(f_shape.shape)

In [None]:
# select an interesting area
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.sum(stack_4d, axis=(2, 3)), cmap="gray")

def onselect(eclick, erelease):
    print('startposition: (%f, %f)' % (eclick.xdata, eclick.ydata))
    print('endposition  : (%f, %f)' % (erelease.xdata, erelease.ydata))

box = RectangleSelector(ax, onselect)
plt.show()

In [None]:
# crop the data
stack_4d_cropped = stack_4d[int(box.corners[1][0]):int(box.corners[1][2]), 
                                           int(box.corners[0][0]):int(box.corners[0][1])].copy()
cr_shape = stack_4d_cropped.shape
print(cr_shape)
print(np.max(stack_4d_cropped))
print(np.min(stack_4d_cropped))
print(np.mean(stack_4d_cropped))

In [None]:
%matplotlib inline

In [None]:
# maximum intensity distribution
max_int = np.max(stack_4d_cropped, axis=(2, 3))
print(max_int.shape)
print(np.max(max_int))
print(np.min(max_int))
print(np.mean(max_int))
print(np.median(max_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(max_int, cmap="afmhot")
ax[0].axis("off")
ax[1].hist(max_int.flatten(), bins=len(max_int))
ax[1].grid()
plt.show()

In [None]:
# total intensity distribution
tot_int = np.sum(stack_4d_cropped, axis=(2, 3))
print(max_int.shape)
print(np.max(tot_int))
print(np.min(tot_int))
print(np.mean(tot_int))
print(np.median(tot_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(tot_int, cmap="afmhot")
ax[0].axis("off")
ax[1].hist(tot_int.flatten(), bins=len(max_int))
ax[1].grid()
plt.show()

In [None]:
# find center position
center_pos = []
for i in range(cr_shape[0]):
    for j in range(cr_shape[1]):
        center_pos.append(gaussian_center(stack_4d_cropped[i, j], cbox_edge=30))
        
center_pos = np.asarray(center_pos)
center_pos = np.reshape(center_pos, (cr_shape[0], cr_shape[1], -1))
print(center_pos.shape)
center_mean = np.mean(center_pos, axis=(0, 1))
print(center_mean)

# center distibution
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(center_pos[:, :, 0].flatten(), bins=100, density=True, color="orange", label="center y position")
ax[0].hist(center_pos[:, :, 1].flatten(), bins=100, density=True, color="gray", alpha=0.5, label="center x position")
ax[0].grid()
ax[0].legend()

ax[1].scatter(center_pos[:, :, 1], center_pos[:, :, 0], s=10.0, alpha=0.5)
ax[1].grid()
ax[1].scatter(center_mean[1], center_mean[0], s=20, c="red")
ax[1].set_xlabel("center x position", fontsize=10)
ax[1].set_ylabel("center y position", fontsize=10)
fig.tight_layout()
plt.show()

In [None]:
# remove center beam
center_radius = 10
for i in range(cr_shape[0]):
    for j in range(cr_shape[1]):
        stack_4d_cropped[i, j] = remove_center_beam(stack_4d_cropped[i, j], 
                                                    center=center_mean.tolist(), cb_rad=center_radius)
print(np.max(stack_4d_cropped))
print(np.min(stack_4d_cropped))
print(np.mean(stack_4d_cropped))

In [None]:
def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

In [None]:
def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[0][ai], ri[1][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

In [None]:
cr_shape = stack_4d_cropped.shape
rs_img = np.sum(stack_4d_cropped, axis=(2, 3))
print(cr_shape)

%matplotlib qt
def flat_k(yp, xp, ki, ko, show):
    if yp < 0 or yp >= cr_shape[0] or xp < 0 or xp >= cr_shape[1]:
        print("Error in selecting the position of a pixel")
        return
    
    if ki >= ko:
        print("Inner k must be larger than outer k")
        return
    
    num_k = ko - ki
    k_range = np.arange(ki, ko)
    k_donut = radial_indices(f_shape[2:], [k_range[0], k_range[-1]], center=center_mean.tolist())
    
    if show:
        figi = plt.figure(figsize=(15, 10))
        G = gridspec.GridSpec(num_k, num_k*3)
        ax1 = figi.add_subplot(G[:num_k, :num_k])
        ax2 = figi.add_subplot(G[:num_k, num_k:num_k*2])

        ax1.imshow(rs_img, cmap="gray")
        ax1.scatter(xp, yp, marker="s", facecolor="none", edgecolors="aqua", linewidths=1.0)
        ax1.axis("off")

        ax2.imshow(np.log(stack_4d_cropped[yp, xp]), cmap="afmhot")
        ax2.imshow(k_donut, cmap="gray", alpha=0.20)
        ax2.axis("off")

        for i in range(num_k):
            ax_tmp = figi.add_subplot(G[i, num_k*2:num_k*3])
            k_ind, a_ind = indices_at_r(cr_shape[2:], radius=k_range[i], center=center_mean.tolist())
            k_line = stack_4d_cropped[yp, xp][k_ind]
            ax_tmp.plot(a_ind, k_line, "k-", linewidth=1)
            ax_tmp.set_title("radial flat at k = %d"%k_range[i])
            ax_tmp.grid()
            ax_tmp.set_yticks([])
            ax_tmp.tick_params(axis="x", labelsize=5)

        figi.tight_layout()
        
    else:
        return

st = {"description_width": "initial"}
y_pos_widget = pyw.IntText(value=0, description="y position", style=st)
x_pos_widget = pyw.IntText(value=0, description="x position", style=st)
ki_widget = pyw.IntText(value=20, description="inner k", style=st)
ko_widget = pyw.IntText(value=30, description="outer k", style=st)
show_box = pyw.Checkbox(value=True, description="create a new figure")


pyw.interact(flat_k, yp=y_pos_widget, xp=x_pos_widget, ki=ki_widget, ko=ko_widget, show=show_box)
plt.show()

In [None]:
#fk_range = np.arange(ki_widget.value, ko_widget.value)
fk_range = [16]

k_flat_4d = np.zeros((cr_shape[0]*cr_shape[1], 1))
print(k_flat_4d.shape)

for k in fk_range:
    k_ind, a_ind = indices_at_r(cr_shape[2:], radius=k, center=center_mean.tolist())
    tmp = []
    for i in range(cr_shape[0]):
        for j in range(cr_shape[1]):
            k_line = stack_4d_cropped[i, j][k_ind]
            tmp.append(k_line)
    tmp = np.asarray(tmp)
    print(tmp.shape)
    k_flat_4d = np.append(k_flat_4d, tmp, axis=1)
    print("k = %d completed"%(k))

#print(k_flat_4d[0])
k_flat_4d = np.asarray(k_flat_4d)[:, 1:].reshape(cr_shape[0], cr_shape[1], -1)
print(k_flat_4d.shape)

In [None]:
k_flat = hys.signals.Signal1D(k_flat_4d)
k_flat.save(raw_adr[:-4]+"_k_flat_%d_%d.hdf5"%(fk_range[0], fk_range[-1]))

In [None]:
%matplotlib qt
k_flat.plot()