In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
import numpy as np
from scipy import optimize
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from matplotlib.colors import ListedColormap
import tkinter.filedialog as tkf
import hyperspy.api as hys
import ipywidgets as pyw

In [None]:
%matplotlib inline

In [None]:
def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    print(stack.shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    
    print(stack.shape) 
    return stack

In [None]:
# refer to https://scipy-cookbook.readthedocs.io/items/FittingData.html

def gaussian(height, center_x, center_y, width_x, width_y):
    """Returns a gaussian function with the given parameters"""
    width_x = float(width_x)
    width_y = float(width_y)
    return lambda x,y: height*np.exp(
                -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)

def moments(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution by calculating its
    moments """
    total = data.sum()
    X, Y = np.indices(data.shape) # row, col
    x = (X*data).sum()/total # row
    y = (Y*data).sum()/total # col
    col = data[:, int(y)]
    width_x = np.sqrt(np.abs((np.arange(col.size)-y)**2*col).sum()/col.sum()) # row
    row = data[int(x), :]
    width_y = np.sqrt(np.abs((np.arange(row.size)-x)**2*row).sum()/row.sum()) # col
    height = data.max()
    return height, x, y, width_x, width_y

def fitgaussian(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution found by a fit"""
    params = moments(data)
    errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
                                 data)
    p, success = optimize.leastsq(errorfunction, params)
    return p

In [None]:
def gaussian_center(image, cbox_edge=0):
    y, x = np.indices(image.shape)
    if not cbox_edge:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    else:
        cbox_outy = int(image.shape[0]/2 - cbox_edge/2)
        cbox_outx = int(image.shape[1]/2 - cbox_edge/2)
        center_box = image[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
        fit_params = fitgaussian(center_box)
        (_, center_y, center_x, _, _) = fit_params
        center = [center_y+cbox_outy, center_x+cbox_outx]
        
    return center

In [None]:
raw_adr = tkf.askopenfilename()
print(raw_adr)

In [None]:
datatype = np.float32
o_shape = (256, 256, 130, 128)
f_shape = (256, 256, 128, 128)

In [None]:
datatype = np.float32
o_shape = (128, 128, 130, 128)
f_shape = (128, 128, 128, 128)

In [None]:
# load a data
f_stack = load_binary_4D_stack(raw_adr, datatype, o_shape, f_shape, log_scale=False)
print(np.max(f_stack))
print(np.min(f_stack))
print(np.mean(f_stack))
f_stack = f_stack / np.max(f_stack)
#print(np.median(stack_4d))
f_shape = f_stack.shape

In [None]:
fourd = hys.load(raw_adr)
print(fourd)
f_stack = fourd.data.copy()
print(f_stack.shape)
f_stack = f_stack / np.max(f_stack)
f_shape = f_stack.shape
print(f_shape)

In [None]:
%matplotlib qt
fourd.plot()

In [None]:
# select an interesting area
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.sum(f_stack, axis=(2, 3)), cmap="gray")

def onselect(eclick, erelease):
    print('startposition: (%f, %f)' % (eclick.xdata, eclick.ydata))
    print('endposition  : (%f, %f)' % (erelease.xdata, erelease.ydata))

box = RectangleSelector(ax, onselect)
plt.show()

In [None]:
# crop the data
f_stack = f_stack[int(box.corners[1][0]):int(box.corners[1][2]), 
                                           int(box.corners[0][0]):int(box.corners[0][1])].copy()

f_shape = f_stack.shape
print(f_stack.shape)

In [None]:
magni = 300
f_stack = f_stack[:, :, magni:-magni, magni:-magni]
f_shape = f_stack.shape
print(f_shape)

In [None]:
pacbed = np.mean(f_stack, axis=(0,1))
print(pacbed.shape)

In [None]:
%matplotlib inline

In [None]:
# find center position
center_pos = []
for i in range(f_shape[0]):
    for j in range(f_shape[1]):
        center_pos.append(gaussian_center(f_stack[i, j], cbox_edge=30))
        
center_pos = np.asarray(center_pos)
center_pos = np.reshape(center_pos, (f_shape[0], f_shape[1], -1))
print(center_pos.shape)
center_mean = np.mean(center_pos, axis=(0, 1))
print(center_mean)

# center distibution
plt.figure()
plt.hist(center_pos[:, :, 0].flatten(), bins=100, density=True, color="orange", label="center y position")
plt.hist(center_pos[:, :, 1].flatten(), bins=100, density=True, color="gray", alpha=0.5, label="center x position")
plt.grid()
plt.legend()
plt.show()

plt.figure()
plt.scatter(center_pos[:, :, 1], center_pos[:, :, 0], s=10.0, alpha=0.5)
plt.grid()
plt.scatter(center_mean[1], center_mean[0], s=20, c="red")
plt.xlabel("center x position", fontsize=20)
plt.ylabel("center y position", fontsize=20)
plt.show()

In [None]:
ct = center_mean.tolist()

In [None]:
y, x = np.indices(f_shape[2:])
ct = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0]).tolist()

In [None]:
def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

In [None]:
def max_rad(shape, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    
    return np.max(r)

In [None]:
def radial_indices(shape, radial_range, scale, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1]) * scale
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

In [None]:
def angle_indices(shape, angle_range, center=None):
    y, x = np.indices(shape)
    y = y.astype("float64")
    x = x.astype("float64")
    
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    y -= center[0]
    x -= center[1]
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j], y[i, j]), deg=True)
            
    angle_arr = angle_arr + 180
    ri = np.ones(angle_arr.shape)
    
    if angle_range[1] > angle_range[0]:
        ri[np.where(angle_arr < angle_range[0])] = 0
        ri[np.where(angle_arr >= angle_range[1])] = 0
        
    elif angle_range[0] > angle_range[1]:
        ri *= 0.0
        ri[np.where(angle_arr < angle_range[1])] = 1
        ri[np.where(angle_arr >= angle_range[0])] = 1
        
    else:
        angle_arr = np.round(r)
        ri[np.where(angle_arr != round(angle_range[0]))] = 0
    
    return ri

In [None]:
mrad_per_pixel = 0.27045890769499215
radii = np.arange(max_rad(f_stack.shape[2:], center=ct)) * mrad_per_pixel
print("maximum angle = %.2f"%(radii[-1]))

In [None]:
%matplotlib qt
fig, ax = plt.subplots(3, 3, figsize=(10, 10))

def virtual_detector(det1_in, det1_out, det2_in, det2_out, det3_in, det3_out):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
    
    if det1_in > det1_out:
        det1 = [det1_out, det1_in]
        print("Warning! detector 1 (inner angle > outer angle)")
    else:
        det1 = [det1_in, det1_out]
        
    if det2_in > det2_out:
        det2 = [det2_out, det2_in]
        print("Warning! detector 2 (inner angle > outer angle)")
    else:
        det2 = [det2_in, det2_out]    
        
    if det3_in > det3_out:
        det3 = [det3_out, det3_in]
        print("Warning! detector 3 (inner angle > outer angle)")
    else:
        det3 = [det3_in, det3_out]
    
    det = []
    det.append(det1)
    det.append(det2)
    det.append(det3)
    
    ri = []
    for i in range(3):
        ri.append(radial_indices(f_shape[2:], det[i], mrad_per_pixel, center=ct))
    
    for i in range(3):
        ax[i][0].imshow(ri[i], cmap="gray")
        ax[i][0].axis("off")
        ax[i][0].set_title("(%.2f mrad, %.2f mrad)"%(det[i][0], det[i][1]), fontsize=10)
        
    for i in range(3):
        ax[i][1].imshow(pacbed, cmap="gray")
        ax[i][1].imshow(ri[i], cmap="Reds", alpha=0.3)
        ax[i][1].axis("off")
        
    for i in range(3):
        img_temp = np.sum(np.multiply(f_stack, ri[i]), axis=(2, 3))
        ax[i][2].imshow(img_temp, cmap="afmhot")
        ax[i][2].axis("off")
    
    fig.canvas.draw()
    fig.tight_layout()    

st = {"description_width": "initial"}
d1in = pyw.FloatText(value=1.0, description="D1 inner angle: ", style=st)
d1out = pyw.FloatText(value=5.0, description="D1 outer angle: ", style=st)

d2in = pyw.FloatText(value=5.0, description="D2 inner angle: ", style=st)
d2out = pyw.FloatText(value=10.0, description="D2 outer angle: ", style=st)

d3in = pyw.FloatText(value=10.0, description="D3 inner angle: ", style=st)
d3out = pyw.FloatText(value=15.0, description="D3 outer angle: ", style=st)

pyw.interact(virtual_detector, det1_in=d1in, det1_out=d1out, det2_in=d2in, det2_out=d2out, det3_in=d3in, det3_out=d3out)
fig.show()

In [None]:
# not a virtual segmented detector

%matplotlib qt
fig, ax = plt.subplots(3, 3, figsize=(10, 10))

cm = matplotlib.cm.jet
norm = matplotlib.colors.Normalize()
#norm.autoscale(np.arange(f_shape[2]*f_shape[3]) / (f_shape[2]*f_shape[3]))

sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
sm.set_array([])

def DPC(inner, outer):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
        
    ri = radial_indices(f_shape[2:], [inner, outer], mrad_per_pixel, center=ct)
    
    ax[0][0].imshow(pacbed, cmap="jet")
    ax[0][0].axis("off")
    
    ax[0][1].imshow(pacbed, cmap="gray")
    ax[0][1].imshow(ri, cmap="Reds", alpha=0.3)
    ax[0][1].axis("off")
    
    
    selected = np.multiply(f_stack, ri)
    STEM_img = np.sum(selected, axis=(2, 3))
    ax[0][2].imshow(STEM_img, cmap="gray")
    ax[0][2].axis("off")
    
    
    Y, X = np.indices(f_shape[2:])
    y_vector = np.sum(selected * Y, axis=(2, 3)) / np.sum(selected, axis=(2, 3)) - ct[0]
    x_vector = np.sum(selected * X, axis=(2, 3)) / np.sum(selected, axis=(2, 3)) - ct[1]
    
    vector_length = np.hypot(y_vector, x_vector)
    
    y_vector_flat = y_vector.flatten()
    x_vector_flat = x_vector.flatten()
    dummy = np.zeros(f_shape[:2])
    
    y, x = np.indices(f_shape[:2])
    y = np.flip(y.flatten())
    x = x.flatten()
    
    ax[1][0].imshow(y_vector, cmap="gray")
    ax[1][0].axis("off")
    
    ax[1][1].imshow(x_vector, cmap="gray")
    ax[1][1].axis("off")
    
    ax[1][2].imshow(vector_length, cmap="gray")
    ax[1][2].axis("off")
    
    vector_length_flat = vector_length.flatten()    
    #norm_length = vector_length_flat / np.max(vector_length_flat)
    
    ax[2][0].quiver(x, y, -x_vector_flat, -y_vector_flat, color=cm(norm(vector_length_flat)))
    ax[2][0].axis("off")
    
    gy = np.gradient(y_vector)[0]
    gx = np.gradient(x_vector)[1]
    div_vector = -(gy+gx)
    #print(div_vector)
    
    gradient = ax[2][1].imshow(div_vector, vmin=-np.amax(np.abs(div_vector)),vmax=np.amax(np.abs(div_vector)), cmap="seismic")
    ax[2][1].axis("off")

    fCX = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(x_vector)))
    fCY = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(y_vector)))
    KX = fCX.shape[1]
    KY = fCY.shape[0]
    kxran = np.linspace(-1, 1, KX, endpoint=True)
    kyran = np.linspace(-1, 1, KY, endpoint=True)
    kx, ky = np.meshgrid(kxran, kyran)
    fCKX = fCX * kx
    fCKY = fCY * ky
    fnum = (fCKX + fCKY)
    hpass, lpass = 0.005, 0.0
    fdenom = np.pi * 2 * (0 + 1j) * (hpass + (kx ** 2 + ky ** 2) + lpass * (kx ** 2 + ky ** 2) ** 2)
    fK = np.divide(fnum, fdenom)

    ax[2][2].imshow(np.real(np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(fK)))), cmap="jet")
    ax[2][2].axis("off")
    
    
    fig.canvas.draw()
    fig.tight_layout()

st = {"description_width": "initial"}
din = pyw.FloatText(value=2.0, description="inner angle: ", style=st)
dout = pyw.FloatText(value=20.0, description="outer angle: ", style=st)

pyw.interact(DPC, inner=din, outer=dout)
fig.show()

In [None]:
# segmented detector 1

%matplotlib qt
fig, ax = plt.subplots(3, 3, figsize=(10, 10))

cm = matplotlib.cm.jet
norm = matplotlib.colors.Normalize()
#norm.autoscale(np.arange(f_shape[2]*f_shape[3]) / (f_shape[2]*f_shape[3]))

sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
sm.set_array([])

def DPC(inner, outer, Ai, Af, Bi, Bf, Ci, Cf, Di, Df):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
        
    ri = radial_indices(f_shape[2:], [inner, outer], mrad_per_pixel, center=ct)
    rA = angle_indices(f_shape[2:], [Ai, Af], center=ct)
    rB = angle_indices(f_shape[2:], [Bi, Bf], center=ct)
    rC = angle_indices(f_shape[2:], [Ci, Cf], center=ct)
    rD = angle_indices(f_shape[2:], [Di, Df], center=ct)
    
    As = np.multiply(ri, rA)
    Bs = np.multiply(ri, rB)
    Cs = np.multiply(ri, rC)
    Ds = np.multiply(ri, rD)
    
    ax[0][0].imshow(pacbed, cmap="jet")
    ax[0][0].axis("off")
    
    ax[0][1].imshow(pacbed, cmap="gray", alpha=0.7)
    ax[0][1].imshow(As, cmap=ListedColormap(["white", "blue"]), alpha=0.1)
    ax[0][1].imshow(Bs, cmap=ListedColormap(["white", "green"]), alpha=0.1)
    ax[0][1].imshow(Cs, cmap=ListedColormap(["white", "orange"]), alpha=0.1)
    ax[0][1].imshow(Ds, cmap=ListedColormap(["white", "red"]), alpha=0.1)
    ax[0][1].axis("off")
    
    STEM_img = np.sum(np.multiply(f_stack, ri), axis=(2, 3))
    ax[0][2].imshow(STEM_img, cmap="gray")
    ax[0][2].axis("off")
    
    Asum = np.sum(np.multiply(f_stack, As), axis=(2, 3))
    Bsum = np.sum(np.multiply(f_stack, Bs), axis=(2, 3))
    Csum = np.sum(np.multiply(f_stack, Cs), axis=(2, 3))
    Dsum = np.sum(np.multiply(f_stack, Ds), axis=(2, 3))
    
    y_vector = (Asum+Bsum)-(Csum+Dsum)
    x_vector = (Bsum+Csum)-(Asum+Dsum)
    vector_length = np.hypot(y_vector, x_vector)
    
    y_vector_flat = y_vector.flatten()
    x_vector_flat = x_vector.flatten()
    dummy = np.zeros(f_shape[:2])
    
    y, x = np.indices(f_shape[:2])
    y = np.flip(y.flatten())
    x = x.flatten()
    
    ax[1][0].imshow(y_vector, cmap="gray")
    ax[1][0].axis("off")
    
    ax[1][1].imshow(x_vector, cmap="gray")
    ax[1][1].axis("off")
    
    ax[1][2].imshow(vector_length, cmap="gray")
    ax[1][2].axis("off")
    
    vector_length_flat = vector_length.flatten()    
    #norm_length = vector_length_flat / np.max(vector_length_flat)
    
    ax[2][0].quiver(x, y, -x_vector_flat, -y_vector_flat, color=cm(norm(vector_length_flat)))
    ax[2][0].axis("off")
    
    gy = np.gradient(y_vector)[0]
    gx = np.gradient(x_vector)[1]
    div_vector = -(gy+gx)
    #print(div_vector)
    
    gradient = ax[2][1].imshow(div_vector, vmin=-np.amax(np.abs(div_vector)),vmax=np.amax(np.abs(div_vector)), cmap="seismic")
    ax[2][1].axis("off")

    fCX = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(x_vector)))
    fCY = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(y_vector)))
    KX = fCX.shape[1]
    KY = fCY.shape[0]
    kxran = np.linspace(-1, 1, KX, endpoint=True)
    kyran = np.linspace(-1, 1, KY, endpoint=True)
    kx, ky = np.meshgrid(kxran, kyran)
    fCKX = fCX * kx
    fCKY = fCY * ky
    fnum = (fCKX + fCKY)
    hpass, lpass = 0.005, 0.0
    fdenom = np.pi * 2 * (0 + 1j) * (hpass + (kx ** 2 + ky ** 2) + lpass * (kx ** 2 + ky ** 2) ** 2)
    fK = np.divide(fnum, fdenom)

    ax[2][2].imshow(np.real(np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(fK)))), cmap="jet")
    ax[2][2].axis("off")
    
    
    fig.canvas.draw()
    fig.tight_layout()

st = {"description_width": "initial"}
din = pyw.FloatText(value=5.0, description="inner angle: ", style=st)
dout = pyw.FloatText(value=25.0, description="outer angle: ", style=st)

A_i = pyw.FloatText(value=0.0, description="A initial angle: ", style=st)
A_f = pyw.FloatText(value=90.0, description="A final angle: ", style=st)

B_i = pyw.FloatText(value=90.0, description="B initial angle: ", style=st)
B_f = pyw.FloatText(value=180.0, description="B final angle: ", style=st)

C_i = pyw.FloatText(value=180.0, description="C initial angle: ", style=st)
C_f = pyw.FloatText(value=270.0, description="C final angle: ", style=st)

D_i = pyw.FloatText(value=270.0, description="D initial angle: ", style=st)
D_f = pyw.FloatText(value=360.0, description="D final angle: ", style=st)

pyw.interact(DPC, inner=din, outer=dout, Ai=A_i, Af=A_f, Bi=B_i, Bf=B_f, Ci=C_i, Cf=C_f, Di=D_i, Df=D_f)
fig.show()

In [None]:
# segmented detector 2

%matplotlib qt
fig, ax = plt.subplots(3, 3, figsize=(10, 10))

cm = matplotlib.cm.jet
norm = matplotlib.colors.Normalize()
#norm.autoscale(np.arange(f_shape[2]*f_shape[3]) / (f_shape[2]*f_shape[3]))

sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
sm.set_array([])

def DPC(inner, outer, Ai, Af, Bi, Bf, Ci, Cf, Di, Df):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
        
    ri = radial_indices(f_shape[2:], [inner, outer], mrad_per_pixel, center=ct)
    rA = angle_indices(f_shape[2:], [Ai, Af], center=ct)
    rB = angle_indices(f_shape[2:], [Bi, Bf], center=ct)
    rC = angle_indices(f_shape[2:], [Ci, Cf], center=ct)
    rD = angle_indices(f_shape[2:], [Di, Df], center=ct)
    
    As = np.multiply(ri, rA)
    Bs = np.multiply(ri, rB)
    Cs = np.multiply(ri, rC)
    Ds = np.multiply(ri, rD)
    
    ax[0][0].imshow(pacbed, cmap="jet")
    ax[0][0].axis("off")
    
    ax[0][1].imshow(pacbed, cmap="gray", alpha=0.7)
    ax[0][1].imshow(As, cmap=ListedColormap(["white", "blue"]), alpha=0.1)
    ax[0][1].imshow(Bs, cmap=ListedColormap(["white", "green"]), alpha=0.1)
    ax[0][1].imshow(Cs, cmap=ListedColormap(["white", "orange"]), alpha=0.1)
    ax[0][1].imshow(Ds, cmap=ListedColormap(["white", "red"]), alpha=0.1)
    ax[0][1].axis("off")
    
    STEM_img = np.sum(np.multiply(f_stack, ri), axis=(2, 3))
    ax[0][2].imshow(STEM_img, cmap="gray")
    ax[0][2].axis("off")
    
    Asum = np.sum(np.multiply(f_stack, As), axis=(2, 3))
    Bsum = np.sum(np.multiply(f_stack, Bs), axis=(2, 3))
    Csum = np.sum(np.multiply(f_stack, Cs), axis=(2, 3))
    Dsum = np.sum(np.multiply(f_stack, Ds), axis=(2, 3))
    
    y_vector = Asum-Csum
    x_vector = Bsum-Dsum
    vector_length = np.hypot(y_vector, x_vector)
    
    y_vector_flat = y_vector.flatten()
    x_vector_flat = x_vector.flatten()
    dummy = np.zeros(f_shape[:2])
    
    y, x = np.indices(f_shape[:2])
    y = np.flip(y.flatten())
    x = x.flatten()
    
    ax[1][0].imshow(y_vector, cmap="gray")
    ax[1][0].axis("off")
    
    ax[1][1].imshow(x_vector, cmap="gray")
    ax[1][1].axis("off")
    
    ax[1][2].imshow(vector_length, cmap="gray")
    ax[1][2].axis("off")
    
    vector_length_flat = vector_length.flatten()    
    #norm_length = vector_length_flat / np.max(vector_length_flat)
    
    ax[2][0].quiver(x, y, -x_vector_flat, -y_vector_flat, color=cm(norm(vector_length_flat)))
    ax[2][0].axis("off")
    
    gy = np.gradient(y_vector)[0]
    gx = np.gradient(x_vector)[1]
    div_vector = -(gy+gx)
    #print(div_vector)
    
    gradient = ax[2][1].imshow(div_vector, vmin=-np.amax(np.abs(div_vector)),vmax=np.amax(np.abs(div_vector)), cmap="seismic")
    ax[2][1].axis("off")

    fCX = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(x_vector)))
    fCY = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(y_vector)))
    KX = fCX.shape[1]
    KY = fCY.shape[0]
    kxran = np.linspace(-1, 1, KX, endpoint=True)
    kyran = np.linspace(-1, 1, KY, endpoint=True)
    kx, ky = np.meshgrid(kxran, kyran)
    fCKX = fCX * kx
    fCKY = fCY * ky
    fnum = (fCKX + fCKY)
    hpass, lpass = 0.005, 0.0
    fdenom = np.pi * 2 * (0 + 1j) * (hpass + (kx ** 2 + ky ** 2) + lpass * (kx ** 2 + ky ** 2) ** 2)
    fK = np.divide(fnum, fdenom)

    ax[2][2].imshow(np.real(np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(fK)))), cmap="jet")
    ax[2][2].axis("off")
    
    
    fig.canvas.draw()
    fig.tight_layout()
    
st = {"description_width": "initial"}    
din = pyw.FloatText(value=5.0, description="inner angle: ", style=st)
dout = pyw.FloatText(value=25.0, description="outer angle: ", style=st)

A_i = pyw.FloatText(value=45.0, description="A initial angle: ", style=st)
A_f = pyw.FloatText(value=135.0, description="A final angle: ", style=st)

B_i = pyw.FloatText(value=135.0, description="B initial angle: ", style=st)
B_f = pyw.FloatText(value=225.0, description="B final angle: ", style=st)

C_i = pyw.FloatText(value=225.0, description="C initial angle: ", style=st)
C_f = pyw.FloatText(value=315.0, description="C final angle: ", style=st)

D_i = pyw.FloatText(value=315.0, description="D initial angle: ", style=st)
D_f = pyw.FloatText(value=45.0, description="D final angle: ", style=st)

pyw.interact(DPC, inner=din, outer=dout, Ai=A_i, Af=A_f, Bi=B_i, Bf=B_f, Ci=C_i, Cf=C_f, Di=D_i, Df=D_f)
fig.show()

In [None]:
%matplotlib qt
mrad_per_pixel = 0.2
fig, ax = plt.subplots(3, 3, figsize=(10, 10))
radii = np.arange(max_rad(f_stack.shape[2:])) * mrad_per_pixel
print("maximum angle = %.2f"%(radii[-1]))

def virtual_detector(det1_in, det1_out, det2_in, det2_out, det3_in, det3_out, ypos, xpos):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
    
    if det1_in > det1_out:
        det1 = [det1_out, det1_in]
        print("Warning! detector 1 (inner angle > outer angle)")
    else:
        det1 = [det1_in, det1_out]
        
    if det2_in > det2_out:
        det2 = [det2_out, det2_in]
        print("Warning! detector 2 (inner angle > outer angle)")
    else:
        det2 = [det2_in, det2_out]    
        
    if det3_in > det3_out:
        det3 = [det3_out, det3_in]
        print("Warning! detector 3 (inner angle > outer angle)")
    else:
        det3 = [det3_in, det3_out]
    
    det = []
    det.append(det1)
    det.append(det2)
    det.append(det3)
    
    ri = []
    for i in range(3):
        ri.append(radial_indices(f_shape[2:], det[i], mrad_per_pixel, center=ct))
    
    for i in range(3):
        ax[i][0].imshow(ri[i], cmap="gray")
        ax[i][0].axis("off")
        ax[i][0].set_title("(%.2f mrad, %.2f mrad)"%(det[i][0]*mrad_per_pixel, det[i][1]*mrad_per_pixel), fontsize=20)
        
    for i in range(3):
        ax[i][1].imshow(f_stack[ypos, xpos], cmap="gray")
        ax[i][1].imshow(ri[i], cmap="Reds", alpha=0.3)
        ax[i][1].set_title("(%d, %d)"%(xpos, ypos), fontsize=20)
        ax[i][1].axis("off")
        
    for i in range(3):
        img_temp = np.sum(np.multiply(f_stack, ri[i]), axis=(2, 3))
        ax[i][2].imshow(img_temp, cmap="gray")
        ax[i][2].axis("off")
    
    fig.canvas.draw()
    fig.tight_layout()    
    
d1in = pyw.FloatText(value=0.0, description="D1 inner angle: ")
d1out = pyw.FloatText(value=radii[-1]*0.2, description="D1 outer angle: ")

d2in = pyw.FloatText(value=radii[-1]*0.2, description="D2 inner angle: ")
d2out = pyw.FloatText(value=radii[-1]*0.5, description="D2 outer angle: ")

d3in = pyw.FloatText(value=radii[-1]*0.5, description="D3 inner angle: ")
d3out = pyw.FloatText(value=radii[-1]*0.7, description="D3 outer angle: ")
    
y_selector = pyw.Dropdown(options=np.arange(f_stack.shape[0]), value=0, description="y position: ")
x_selector = pyw.Dropdown(options=np.arange(f_stack.shape[1]), value=0, description="x position: ")

pyw.interact(virtual_detector, det1_in=d1in, det1_out=d1out, det2_in=d2in, det2_out=d2out, det3_in=d3in, det3_out=d3out, ypos=y_selector, xpos=x_selector)
fig.show()

In [None]:
%matplotlib qt
mrad_per_pixel = 0.2
fig, ax = plt.subplots(3, 3, figsize=(10, 10))
radii = np.arange(max_rad(f_stack.shape[2:])) * mrad_per_pixel
print("maximum angle = %.2f"%(radii[-1]))

cm = matplotlib.cm.jet
norm = matplotlib.colors.Normalize()
norm.autoscale(np.arange(f_shape[2]*f_shape[3]) / (f_shape[2]*f_shape[3]))

sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
sm.set_array([])

def DPC(inner, outer, Ai, Af, Bi, Bf, Ci, Cf, Di, Df, ypos, xpos):
    for i in range(3):
        for j in range(3):
            ax[i][j].cla()
        
    ri = radial_indices(f_shape[2:], [inner, outer], mrad_per_pixel, center=ct)
    rA = angle_indices(f_shape[2:], [Ai, Af], center=ct)
    rB = angle_indices(f_shape[2:], [Bi, Bf], center=ct)
    rC = angle_indices(f_shape[2:], [Ci, Cf], center=ct)
    rD = angle_indices(f_shape[2:], [Di, Df], center=ct)
    
    As = np.multiply(ri, rA)
    Bs = np.multiply(ri, rB)
    Cs = np.multiply(ri, rC)
    Ds = np.multiply(ri, rD)
    
    ax[0][0].imshow(As, cmap=ListedColormap(["lightgray", "blue"]), alpha=0.3)
    ax[0][0].imshow(Bs, cmap=ListedColormap(["lightgray", "green"]), alpha=0.3)
    ax[0][0].imshow(Cs, cmap=ListedColormap(["lightgray", "orange"]), alpha=0.3)
    ax[0][0].imshow(Ds, cmap=ListedColormap(["lightgray", "red"]), alpha=0.3)
    ax[0][0].axis("off")
    
    ax[0][1].imshow(f_stack[ypos, xpos], cmap="gray")
    ax[0][1].imshow(As, cmap=ListedColormap(["white", "blue"]), alpha=0.1)
    ax[0][1].imshow(Bs, cmap=ListedColormap(["white", "green"]), alpha=0.1)
    ax[0][1].imshow(Cs, cmap=ListedColormap(["white", "orange"]), alpha=0.1)
    ax[0][1].imshow(Ds, cmap=ListedColormap(["white", "red"]), alpha=0.1)
    ax[0][1].axis("off")
    
    STEM_img = np.sum(np.multiply(f_stack, ri), axis=(2, 3))
    ax[0][2].imshow(STEM_img, cmap="gray")
    ax[0][2].axis("off")
    
    Asum = np.sum(np.multiply(f_stack, As), axis=(2, 3))
    Bsum = np.sum(np.multiply(f_stack, Bs), axis=(2, 3))
    Csum = np.sum(np.multiply(f_stack, Cs), axis=(2, 3))
    Dsum = np.sum(np.multiply(f_stack, Ds), axis=(2, 3))
    
    y_vector = (Asum+Bsum)-(Csum+Dsum)
    x_vector = (Bsum+Csum)-(Asum+Dsum)
    vector_length = np.hypot(y_vector, x_vector)
    
    y_vector_flat = y_vector.flatten()
    x_vector_flat = x_vector.flatten()
    dummy = np.zeros(f_shape[:2])
    
    y, x = np.indices(f_shape[:2])
    y = np.flip(y.flatten())
    x = x.flatten()
    
    ax[1][0].imshow(y_vector, cmap="gray")
    ax[1][0].axis("off")
    
    ax[1][1].imshow(x_vector, cmap="gray")
    ax[1][1].axis("off")
    
    ax[1][2].imshow(vector_length, cmap="gray")
    ax[1][2].axis("off")
    
    ax[2][0].quiver(x, y, dummy, y_vector_flat)
    ax[2][0].axis("off")
    ax[2][1].quiver(x, y, x_vector_flat, dummy)
    ax[2][1].axis("off")
    
    vector_length_flat = vector_length.flatten()    
    norm_length = vector_length_flat / np.max(vector_length_flat)
    
    fig.canvas.draw()
    fig.tight_layout()
    
din = pyw.FloatText(value=2.0, description="inner angle: ")
dout = pyw.FloatText(value=10.0, description="outer angle: ")

A_i = pyw.FloatText(value=0.0, description="A initial angle: ")
A_f = pyw.FloatText(value=90.0, description="A final angle: ")

B_i = pyw.FloatText(value=90.0, description="B initial angle: ")
B_f = pyw.FloatText(value=180.0, description="B final angle: ")

C_i = pyw.FloatText(value=180.0, description="C initial angle: ")
C_f = pyw.FloatText(value=270.0, description="C final angle: ")

D_i = pyw.FloatText(value=270.0, description="D initial angle: ")
D_f = pyw.FloatText(value=360.0, description="D final angle: ")

y_selector = pyw.Dropdown(options=np.arange(f_stack.shape[0]), value=0, description="y position: ")
x_selector = pyw.Dropdown(options=np.arange(f_stack.shape[1]), value=0, description="x position: ")

pyw.interact(DPC, inner=din, outer=dout, Ai=A_i, Af=A_f, Bi=B_i, Bf=B_f, Ci=C_i, Cf=C_f, Di=D_i, Df=D_f, ypos=y_selector, xpos=x_selector)
fig.colorbar(sm)
fig.show()