## This notebook is to automatically measure real space pixel size for SEND

- We assume au-xgrating sample is being used at low MAG (less than 100 kX) to measure real space pixel size

- This workflow will not work if the thresholding is not done correctly or if there are non-square features in the thresholded image.

```yaml
global_ADF_det_range: 
    value: '5, 20'
    explanation: 'Inner / outer angle range for ADF det as a factor of BF disc radius.'
    prompt: 'ADF_det_range'
global_im_threshold:
    value: 'None'
    explanation: 'From the image histrogram, this should be the intensity corresponding to the region between two main peaks.'
    prompt: 'image_threshold'
global_probe_thresholds: 
    value: '0.01, 0.1'
    explanation: 'Provide two values to be used for lower and upper threshold when detecting the BF disc.'
    prompt: 'probe_thresholds'
global_write_to_file: 
    value: '0'
    explanation: 'Set to 0 at start. See the image histogram output and provide threshold above as described. Set to 1 when satisfied with square detection to write pixel size to json file.'
remex_hints:
    value: 'standard_science_cluster'
    explanation: 'Hints to the workflow engine regarding cluster processing. Please consult the beamline staff before changing this. Valid choices are: standard_science_cluster, ptypy_mpi, ptyrex_mpi, local'
    prompt: 'remote execution hints'
```

In [None]:
# This cell will be replaced by bxflow

In [None]:
%%capture --no-display
%matplotlib notebook
import numpy as np
import h5py
import matplotlib.pyplot as plt
import hyperspy.api as hs
import os
from sklearn.neighbors import NearestNeighbors
import pyxem as pxm
import py4DSTEM
print(py4DSTEM.__file__)

In [None]:
# # dataset name
# data_label = 'au_xgrating/20221111_132756'
# # notebook name
# notebook = 'Calibration_real_SED'
# global_write_to_file = '1'
# global_threshold = '144'

# BEAMLINE = 'e02'
# YEAR = '2022'
# VISIT = 'mg32007-1'



In [None]:
path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/'
timestamp = data_label.split('/')[-1]
ibf_path = f'{path}/{data_label}/{timestamp}_ibf.hspy'
meta_path = f'{path}/{data_label}/{timestamp}.hdf'
full_path = f'{path}/{data_label}/{timestamp}_data.hdf5'

In [None]:
timestamp

In [None]:
def print_attrs(name, obj):
    print(name)
    for key, val in obj.attrs.items():
        print('%s %s' % (key, val))
with h5py.File(full_path, 'r') as f:
    f.visititems(print_attrs)


In [None]:
save_path = os.path.dirname(ibf_path)
print(save_path)

In [None]:
# Compute mean pattern
d = hs.load(full_path)
with h5py.File('/dls_sw/e02/medipix_mask/Merlin_12bit_mask.h5', 'r') as f:
    mask = f['data/mask'][()]
mask = mask.astype('bool')
d_mask = d * np.invert(mask)
d_mean = d_mask.mean()

# Estimate the radius of the BF disk, and the center coordinates
v_min = float(global_probe_thresholds.split(',')[0])
v_max = float(global_probe_thresholds.split(',')[1])
probe_semiangle, qx0, qy0 = py4DSTEM.process.calibration.get_probe_size(d_mean.data, v_min,v_max)

# plot the mean diffraction pattern, with the estimated probe radius overlaid as a circle
py4DSTEM.visualize.show_circles(d_mean.data, (qx0, qy0), probe_semiangle)

# Print the estimate probe radius
print('Estimated probe radius =', '%.2f' % probe_semiangle, 'pixels')
plt.savefig(f'{os.getcwd()}/deteted_BF_disc_radius_{int(probe_semiangle)}_pix.png')

In [None]:
probe_semiangle, qx0, qy0
min_ang_factor = int(global_ADF_det_range.split(',')[0])
max_ang_factor = int(global_ADF_det_range.split(',')[1])
min_ang = int(probe_semiangle * min_ang_factor) # in pix
max_ang = int(probe_semiangle * max_ang_factor) # in pix


d_mean.plot(vmax=30)
d_T = d_mask.T

# defining the detector
adf_det = hs.roi.CircleROI(cx=int(qy0), cy=int(qx0), 
                           r=max_ang, 
                           r_inner=min_ang)
adf_sig = adf_det.interactive(d_T, navigation_signal=d_mean)
fig = plt.gcf()
fig.savefig(f'{os.getcwd()}/ADF_detector.png')

In [None]:
adf_sig = adf_sig.data.astype('uint16')
adf_sig[np.isnan(adf_sig)] = 0
adf_sig = hs.signals.Signal2D(adf_sig) #.as_lazy()
# abf_sig.compute()
adf_im = adf_sig.sum()
adf_im = 255 * adf_im.data / np.max(adf_im.data)
adf_im = hs.signals.Signal2D(adf_im)
adf_im.plot()

In [None]:
with h5py.File(meta_path, 'r') as f:
    print(f['metadata'].keys())
    print(f['metadata/magnification'][()])
    mag = f['metadata/magnification'][()]
    print(f['metadata/field_of_view(m)'][()])
    fov = f['metadata/field_of_view(m)'][()]
    sh = f['metadata/4D_shape'][()]

In [None]:
est_pix = fov * 1e9 / sh[0]
est_square = int(500 / est_pix)
print(est_square)

lower_b = int(est_square - 5)
upper_b = int(est_square + 5)
print(lower_b, upper_b)

In [None]:
# Make the functions

from scipy.ndimage import gaussian_filter, maximum_filter

def threshold_image(
    im, 
    thresh, 
    sigma=0, 
    plot_result=False
    ):
    
    if sigma == 0:
        im_sm = im
    else:
        im_sm = gaussian_filter(im,sigma)

    
    im_thresh = (im_sm > thresh).astype('float') * 2 - 1
    
    if plot_result:
        fig,ax = plt.subplots(1,1,figsize=(6,6))
        ax.imshow(im_thresh)
        plt.show()
    
    return im_thresh


def hough_squares(
    im,
    side_lengths = np.arange(50,70,2),
    angle_degrees = np.arange(0,90,2),
    edge_width = 8,
    min_thresh = 0.5,  # relative to maximum filter output
    padding = 64,
    min_dist_maxima = 10,
    plot_result = False,
    return_hough_sig = False,
    ):
    
    # Padding
    im_pad = np.pad(
        im, 
        (0,padding), 
        mode='constant',
        constant_values=0)
    #     im_fft = np.conj(np.fft.fft2(im_pad))
    im_fft = np.fft.fft2(im_pad)
    
    # Coordinates
    x = np.fft.fftfreq(im_pad.shape[0],1/im_pad.shape[0])
    y = np.fft.fftfreq(im_pad.shape[0],1/im_pad.shape[0])
    ya,xa = np.meshgrid(y,x)
    
    # init
    hough_sig = np.zeros_like(im)
    inds_side = np.zeros_like(im, dtype='int')
    inds_angle = np.zeros_like(im, dtype='int')
    #     ds = side_lengths[1] - side_lengths[0]
    
    # Main loop over shapes
    for a0 in range(angle_degrees.shape[0]):
        theta = np.deg2rad(angle_degrees[a0])
        
        xp = xa*np.cos(theta) - ya*np.sin(theta)
        yp = ya*np.cos(theta) + xa*np.sin(theta)
        rp = np.maximum(np.abs(xp), np.abs(yp))
        
        for a1 in range(side_lengths.shape[0]):
            s = side_lengths[a1]
            
            sub1 = rp <= s/2 - edge_width/2
            sub2 = np.logical_and(rp > s/2 - edge_width/2, rp <= s/2 + edge_width/2)
            kernel = np.zeros(im_pad.shape)
            kernel[sub1] = -1*np.sum(sub1)
            kernel[sub2] = np.sum(sub2) 
            
            
            im_corr = np.real(np.fft.ifft2(
                np.fft.fft2(kernel) * im_fft
            ))
            sub_corr = im_corr[:im.shape[0],:im.shape[1]] > hough_sig
            hough_sig[sub_corr] = im_corr[:im.shape[0],:im.shape[1]][sub_corr]
            inds_side[sub_corr] = a1
            inds_angle[sub_corr] = a0
            
    
    # Get the Hough results
    hough_sig /= np.max(hough_sig)
    im_max = np.logical_and.reduce((
        maximum_filter(hough_sig, min_dist_maxima) == hough_sig,
        hough_sig > min_thresh,
        inds_side > 0,
    #         inds_side < side_lengths.shape[0],
    ))
    xy_all = np.argwhere(im_max)
    
    inds = np.ravel_multi_index((xy_all[:,0],xy_all[:,0]),hough_sig.shape)
    sides_all = side_lengths[inds_side[im_max]]
    angles_all = angle_degrees[inds_angle[im_max]]
    hough_data = np.vstack((xy_all.T, sides_all, angles_all)).T
    
    if plot_result:
        fig,ax = plt.subplots(1,1,figsize=(6,6))
        ax.imshow(hough_data)
        plt.show()
    
    if return_hough_sig:
        return hough_data, hough_sig      
    else:
        return hough_data
    
    
def hough_plot_results(
    im,
    hough_data,
    int_range = np.array((-3,3)),
    save_path = None
    ):
    
    im_scale = im - np.mean(im)
    im_scale = im_scale / np.std(im_scale)
    
    fig,ax = plt.subplots(1,1,figsize=(6,6))
    ax.imshow(
        im_scale,
        vmin = int_range[0],
        vmax = int_range[1],
        cmap = 'gray',
    )
    
    b = np.array([
        [0.5,0.5],
        [-0.5,0.5],
        [-0.5,-0.5],
        [0.5,-0.5],
        [0.5,0.5],
    ])
    
    for a0 in range(hough_data.shape[0]):
        x0 = hough_data[a0,0]
        y0 = hough_data[a0,1]
        s = hough_data[a0,2]
        a = np.deg2rad(hough_data[a0,3])
        
        bx = (b[:,0] * np.cos(a) + b[:,1] * np.sin(a)) * s
        by = (b[:,1] * np.cos(a) - b[:,0] * np.sin(a)) * s
        
        
        
        ax.plot(
            by + y0,
            bx + x0,
            c='r',
            linewidth=2,
        )
        ax.plot(y0, x0, 'o')
    
    
    plt.show()
    if save_path is not None:
        plt.savefig(save_path)

In [None]:
if global_im_threshold == 'None':
    h = np.histogram(adf_im.data, bins = 'auto')
    fig, axs = plt.subplots(1,2)
    axs[0].plot(h[1][1:], h[0])
    axs[1].imshow(adf_im.data)
    save_path=os.path.join(os.getcwd(), 'adf_and_histogram.png')
    plt.savefig(save_path)

else:
    try:
        global_im_threshold = int(global_im_threshold)
    except ValueError:
        print('This needs to be an integer value!')
    im_thresh = threshold_image(
        adf_im.data, 
        global_im_threshold, 
        sigma=2, 
    )
    plt.figure()
    plt.imshow(im_thresh)
    plt.savefig(os.path.join(os.getcwd(), 'adf_threshold_applied.png'))

    hough_data, hough_sig = hough_squares(
        im_thresh, 
        side_lengths = np.arange(lower_b,upper_b,2),
        return_hough_sig=True
    )

    hough_plot_results(
        adf_im.data,
        hough_data,
        save_path=os.path.join(os.getcwd(), 'real_space_fit.png')
    )
    
    X = [[a[0], a[1]] for a in hough_data]
    X = [[a[0], a[1]] for a in hough_data]
    nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)
    distances, indices = nbrs.kneighbors(X)
    dists = [a[1] for a in distances]
    dists = np.asarray(dists)

    print(np.mean(dists), np.std(dists))
    
    real_pix = 500 / np.mean(dists)
    print(real_pix) # in nm
    
    base_name = os.path.dirname(save_path)
    
    
    import json
    cal_dict = {}
    cal_dict[f'real_space_pix(m)_at_{mag}_MAG'] = real_pix * 1e-9

    # Check if write_to_file flag is on if so write to root, otherwise write in cwd

    if global_write_to_file=='1':
        with open(os.path.join(base_name, f'real_space_cal_at_{mag}_MAG_{timestamp}.json'), 'w') as fp:
            json.dump(cal_dict, fp)

    else:
        with open(os.path.join(os.getcwd(), f'real_space_cal_at_{mag}_MAG_{timestamp}.json'), 'w') as fp:
            json.dump(cal_dict, fp)