In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
# 4D-STEM viewer
# it supports EMPAD data (.raw) and tiff stack (.tif or .tiff)
# 20220414
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
import matplotlib.patches as pch
import tifffile
import tkinter.filedialog as tkf

In [None]:
def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    return stack

In [None]:
img_adr = tkf.askopenfilename()
print(img_adr)

In [None]:
datatype = "float32"
f_shape = [128, 128, 128, 128] # the shape of the 4D-STEM data [scanning_y, scanning_x, DP_y, DP_x]
o_shape = [f_shape[0], f_shape[1], f_shape[2]+2, f_shape[3]]

In [None]:
if img_adr[-3:] == "raw":
    stack_4d = load_binary_4D_stack(img_adr, datatype, o_shape, f_shape, log_scale=False)
    stack_4d = np.transpose(stack_4d, axes=(1, 0, 2, 3))
    stack_4d = np.nan_to_num(stack_4d)
    
elif img_adr[-3:] == "tif" or img_adr[:-4] == "tiff":
    stack_4d = tifffile.imread(img_adr)
    
else:
    print("The format of the file is not supported here")
    
print(stack_4d.shape)
print(stack_4d.min(), stack_4d.max())
print(stack_4d.mean())

stack_4d = stack_4d.clip(min=0.0)
stack_4d_original = stack_4d.copy()

In [None]:
# select an interesting area
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(np.sum(stack_4d_original, axis=(2, 3)), cmap="gray")

def onselect(eclick, erelease):
    print('startposition: (%d, %d)'%(int(eclick.xdata), int(eclick.ydata)))
    print('endposition  : (%d, %d)'%(int(erelease.xdata), int(erelease.ydata)))
    print('width : %d'%(int(erelease.xdata)-int(eclick.xdata)))
    print('height : %d'%(int(erelease.ydata)-int(eclick.ydata)))

box = RectangleSelector(ax, onselect)
fig.tight_layout()
plt.show()

In [None]:
# crop the data
stack_4d = stack_4d_original[int(box.corners[1][0]):int(box.corners[1][2]), 
                            int(box.corners[0][0]):int(box.corners[0][1])].copy()
print(stack_4d.shape)

In [None]:
class fourd_viewer:
    def __init__(self, fig, ax, fdata):
        self.fig = fig
        self.ax = ax
        self.fdata = fdata
        self.ind = np.zeros(2).astype(np.int8)
        
        self.sy, self.sx, self.dsy, self.dsx = fdata.shape
        self.log_scale = -1
        self.log_scale_message = "False"
        self.ax[0].set_title("[x, y]=[%d, %d]"%(self.ind[1], self.ind[0]))
        
        self.int_img = np.sum(fdata, axis=(2, 3))
        mask = np.zeros((self.sy, self.sx))
        mask[self.ind[0], self.ind[1]] = 1
        
        self.ax[0].imshow(self.int_img, cmap="gray")
        self.ax[0].imshow(mask, cmap="Reds", alpha=mask)
        self.ax[0].axis("off")
        
        self.box = RectangleSelector(self.ax[2], self.onselect)
        self.by, self.bx, self.height, self.width = 0, 0, self.dsy, self.dsx
        self.df_img = np.sum(self.fdata[:, :, self.by:self.by+self.height, self.bx:self.bx+self.width], axis=(2,3))
        self.ax[1].imshow(self.df_img, cmap="gray")
        self.ax[1].axis("off")
        
        
        self.ax[2].set_title("log scale: %s"%(self.log_scale_message))
        if self.log_scale == -1:
            self.ax[2].imshow(self.fdata[self.ind[0], self.ind[1]], cmap="gray")
        else:
            self.ax[2].imshow(np.log(self.fdata[self.ind[0], self.ind[1]]), cmap="gray")
        self.ax[2].axis("off")
        

    def on_press(self, event):
        if event.key == "up":
            if self.ind[0] != 0:
                self.ind[0] -= 1
        elif event.key == "down":
            if self.ind[0] != self.sy:
                self.ind[0] += 1
        elif event.key == "right":
            if self.ind[1] != self.sx:
                self.ind[1] += 1
        elif event.key == "left":
            if self.ind[1] != 0:
                self.ind[1] -= 1
        elif event.key == "l":
            self.log_scale *= -1
            
        self.update()
        
    def on_pick(self, event):
        if event.inaxes == self.ax[0]:
            my, mx = int(event.ydata), int(event.xdata)
            self.ind[0] = my
            self.ind[1] = mx
            self.update()
            
        else:
            return True
        
    def onselect(self, eclick, erelease):
        self.by, self.bx  = int(eclick.ydata), int(eclick.xdata)
        self.height, self.width = int(erelease.ydata)-int(eclick.ydata), int(erelease.xdata)-int(eclick.xdata)
        
        self.update()
        
    def update(self):
        self.ax[0].cla()
        self.ax[1].cla()
        self.ax[2].cla()
        
        self.ax[0].set_title("[x, y]=[%d, %d]"%(self.ind[1], self.ind[0]))
        
        mask = np.zeros((self.sy, self.sx))
        mask[self.ind[0], self.ind[1]] = 1
        
        self.ax[0].imshow(self.int_img, cmap="gray")
        self.ax[0].imshow(mask, cmap="Reds", alpha=mask)
        self.ax[0].axis("off")
        
        df_img = np.sum(self.fdata[:, :, self.by:self.by+self.height, self.bx:self.bx+self.width], axis=(2,3))
        self.ax[1].imshow(df_img, cmap="gray")

        self.ax[1].axis("off")
        
        if self.log_scale == -1:
            self.log_scale_message = "False"
            self.ax[2].imshow(self.fdata[self.ind[0], self.ind[1]], cmap="gray")
            self.ax[2].set_title("log scale: %s"%(self.log_scale_message))
        else:
            self.log_scale_message = "True"
            self.ax[2].imshow(np.log(self.fdata[self.ind[0], self.ind[1]]), cmap="gray")
            self.ax[2].set_title("log scale: %s"%(self.log_scale_message))
        
        self.ax[2].add_patch(pch.Rectangle((self.bx, self.by), self.width, self.height, 
                           linewidth=1, edgecolor="r", facecolor="none"))
        self.ax[2].axis("off")
        
        self.fig.canvas.draw()

In [None]:
%matplotlib qt
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("""1st figure (intensity map) : arrow keys or mouse left button to move the position
2nd figure (virtual DF image)
3rd figure (diffraction image) : press 'l' key to turn on or off log-scaling / drag to make a ROI (virtual obj aperture)""")

tracker = fourd_viewer(fig, ax, stack_4d)

fig.canvas.mpl_connect("key_press_event", tracker.on_press)
fig.canvas.mpl_connect("button_press_event", tracker.on_pick)
fig.tight_layout()
plt.show()

In [None]:
#tifffile.imwrite(tkf.asksaveasfilename(), stack_4d_original) # save the original data
tifffile.imwrite(tkf.asksaveasfilename(), stack_4d) # save the cropped data