## This notebook is to apply calibrations to experimental data

here we are assuming that the same camera length was used!

- Apply mask

- Apply affine transform to correct for diffration roundness

- Apply reciprocal space pixel size

```yaml
global_align_BF:
    value: '1'
    explanation: 'Leave at 1 to align the BF disc, otherwise set to 0.'
    prompt: 'align BF'
global_cal_json_path: 
    value: 'None'
    explanation: 'If None used, it will look at the default path: /dls/e02/data/<YEAR>/<VISIT>/processing/Merlin/au_xgrating/calibrations_diff.json, otherwise an arbitrary location can be passed here.'
global_crop_window_size:
    value: '0.05'
    explanation: 'This is the relative crop size around the BF disc to use for alignment. If the BF disc gitters too much, e.g. low MAG data, a larger value should be used. Takes longer with larger crop size!'
    prompt: 'crop window'
```

In [None]:
# This cell will be replaced by bxflow

In [None]:
%%capture --no-display
%matplotlib notebook
import numpy as np
import h5py
import json
import matplotlib.pyplot as plt
import hyperspy.api as hs
import os
import pyxem as pxm
import py4DSTEM
import logging

In [None]:
path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/'
timestamp = data_label.split('/')[-1]
ibf_path = f'{path}/{data_label}/{timestamp}_ibf.hspy'
meta_path = f'{path}/{data_label}/{timestamp}.hdf'
full_path = f'{path}/{data_label}/{timestamp}_data.hdf5'

In [None]:
# Check to see if json file with diff cals existing
if global_cal_json_path=='None':
    try:
        with open(os.path.join(path, 'au_xgrating/calibrations_diff.json')) as json_file:
            cals = json.load(json_file)
    except OSError:
        print('au_xgrating/calibrations_diff.json does not exist!')
else:
    try:
        with open(global_cal_json_path) as json_file:
                    cals = json.load(json_file)
    except OSError:
        print(f'{global_cal_json_path} is not a valid path')

In [None]:
recip_pix = cals['reciprocal_space_pix(1/A)']

In [None]:
affine_matrix = np.asarray(cals['affine_transform'])
print(affine_matrix)

In [None]:
with h5py.File(meta_path, 'r') as f:
    print(f['metadata'].keys())
    print(f['metadata/magnification'][()])
    mag = f['metadata/magnification'][()]
    print(f['metadata/field_of_view(m)'][()])
    fov = f['metadata/field_of_view(m)'][()]
    sh = f['metadata/4D_shape'][()]
    print(f['metadata/aperture_size'][()])
    print(f['metadata/nominal_camera_length(m)'][()])
    print(f['data/mask'].shape)
    mask = f['data/mask'][()]


In [None]:
# This should not be needed in future with updated mask
with h5py.File('/dls_sw/e02/medipix_mask/Merlin_12bit_mask.h5', 'r') as f:
    print(f.keys())
    mask = f['data/mask'][()]
mask = mask.astype('bool')
mask = np.invert(mask)
mask = hs.signals.Signal2D(mask)
mask.plot()

In [None]:
# Load data and Apply mask
d = hs.load(full_path)
d.axes_manager[2].offset = 0.
d.axes_manager[3].offset = 0.
d.axes_manager[2].scale = 1
d.axes_manager[3].scale = 1
d_mask = d * mask

In [None]:
# This can be removed once py4DSTEM gets updated in env
import numpy as np
from py4DSTEM.process.utils import get_CoM
def get_probe_size(DP, thresh_lower=0.01, thresh_upper=0.99, N=100):
    """
    Gets the center and radius of the probe in the diffraction plane.
    The algorithm is as follows:
    First, create a series of N binary masks, by thresholding the diffraction pattern
    DP with a linspace of N thresholds from thresh_lower to thresh_upper, measured
    relative to the maximum intensity in DP.
    Using the area of each binary mask, calculate the radius r of a circular probe.
    Because the central disk is typically very intense relative to the rest of the DP, r
    should change very little over a wide range of intermediate values of the threshold.
    The range in which r is trustworthy is found by taking the derivative of r(thresh)
    and finding identifying where it is small.  The radius is taken to be the mean of
    these r values. Using the threshold corresponding to this r, a mask is created and
    the CoM of the DP times this mask it taken.  This is taken to be the origin x0,y0.
    Args:
        DP (2D array): the diffraction pattern in which to find the central disk.
            A position averaged, or shift-corrected and averaged, DP works best.
        thresh_lower (float, 0 to 1): the lower limit of threshold values
        thresh_upper (float, 0 to 1): the upper limit of threshold values
        N (int): the number of thresholds / masks to use
    Returns:
        (3-tuple): A 3-tuple containing:
            * **r**: *(float)* the central disk radius, in pixels
            * **x0**: *(float)* the x position of the central disk center
            * **y0**: *(float)* the y position of the central disk center
    """
    thresh_vals = np.linspace(thresh_lower, thresh_upper, N)
    r_vals = np.zeros(N)

    # Get r for each mask
    DPmax = np.max(DP)
    for i in range(len(thresh_vals)):
        thresh = thresh_vals[i]
        mask = DP > DPmax * thresh
        r_vals[i] = np.sqrt(np.sum(mask) / np.pi)

    # Get derivative and determine trustworthy r-values
    dr_dtheta = np.gradient(r_vals)
    mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta))
    r = np.mean(r_vals[mask])

    # Get origin
    thresh = np.mean(thresh_vals[mask])
    mask = DP > DPmax * thresh
    x0, y0 = get_CoM(DP * mask)
    
    return r, x0, y0

def find_origin_using_cropped_signal(data, rad, x0, y0, crop_window_size = float(global_crop_window_size)):
    '''
    Args:
        data (DataCube instance): 4D-STEM dataset for which the origins will be found
        rad (float): approximate radius of the central beam
        x0 (float): approximate x coordinate of the central beam
        y0 (float): approximate y coordinate of the central beam
        crop_window_size (float, 0 to 1) proportion of the diffraction plane to use as a search window around the approximate centre
    Returns:
        (np.array): 
            A numpy array of shape (2, Rx,Ry) containing the central coordinates at each real space location
    '''    
    
    #get a search window around the approximate centre based on a proportion of the total diffraction imaging plane
    central_search_width = int(np.ceil(rad + data.data.shape[2]*crop_window_size))
    
    #calculate the bounds of this search window
    x0r, y0r = int(x0//1), int(y0//1)
    x_lower, x_upper =  x0r-central_search_width, x0r+central_search_width
    y_lower, y_upper =  y0r-central_search_width, y0r+central_search_width
    
    #create a cropped DataCube instance
    d_cent = py4DSTEM.io.DataCube(data.data[:,:,x_lower:x_upper,y_lower:y_upper])

    #Find the origin of these cropped patterns
    origins = py4DSTEM.process.calibration.origin.get_origin(d_cent, r = rad, rscale = 1.1) 
    cent_coords = list(origins)
    cent_coords = np.asarray(cent_coords)
    
    #translate these central coordinates to uncropped image 
    cent_coords += np.array((x_lower, y_lower))[:,None,None]
    
    return cent_coords


def center_using_pyxem(dp, mask):
    
    mean_dp = dp.mean()
    # Get mean diffraction pattern to centre from there
    mean_dp *= mask

    centre = mean_dp.get_direct_beam_position(method = 'cross_correlate', radius_start = 1, radius_finish = 10)
    shifts = [[centre.data[1], centre.data[0]]]
    mean_dp.change_dtype('float32')
    # Create shifts array and align and centre
    n_shifts = shifts*(len(dp.data[0])*len(dp.data[:,1]))
    n_shifts=np.array(n_shifts)
    n_shifts = n_shifts.reshape(dp.data.shape[0],dp.data.shape[1],2)
    dp.align2D(shifts = -n_shifts, crop = False)
    dp.center_direct_beam(method='interpolate', sigma=5, upsample_factor=4, kind='linear', half_square_width=10)
    return dp, shifts

In [None]:
if global_align_BF=='1':
    # shift function
    logging.info("Aligning the BF disc in the data.")
    d_before = d_mask.mean()
    from scipy import ndimage
    def shift_image(im, shift=0, interpolation_order=1, fill_value=0):
        if not np.any(shift):
            return im
        else:
            fractional, integral = np.modf(shift)
            if fractional.any():
                order = interpolation_order
            else:
                # Disable interpolation
                order = 0
            return ndimage.shift(im, shift, cval=fill_value, order=order)
    # Getting the origin coordinates
    data = py4DSTEM.io.DataCube(d_mask.data)
    
    # find good value for r for arbitrary data
    rad, x0, y0 = get_probe_size(data.data[0,0,:,:])
    print('BF disc radius in pixels:', int(rad))
    # Here rad is radius of BF disc
    
    cent_coords = find_origin_using_cropped_signal(data, rad, x0, y0)
    
    # expressing them as shift values
    cent_coords = cent_coords.reshape(2, d.data.shape[0]*d.data.shape[1])
    shifts = -1 * (cent_coords - 515 // 2)
    
    # Apply the above shifts
    data_resh = np.reshape(data.data, (d.data.shape[0]*d.data.shape[1],515,515))
    for i in range(data_resh.shape[0]):
        data_resh[i,:,:] = shift_image(data_resh[i,:,:], shift= [shifts[0][i], shifts[1][i]])
        
    # replacing d_mask here
    d_mask = pxm.signals.ElectronDiffraction2D(np.reshape(data_resh, (d.data.shape[0],d.data.shape[1],515,515)))
    d_mean = d_mask.mean()
    #Plotting mean signal with the estimate of the centre positions
    fig, axs = plt.subplots(1,3)
    axs[0].imshow(np.log10(1 + d_before.data), vmax=0.4)
    axs[0].set_title('before align')
    axs[1].plot(cent_coords[1], cent_coords[0])
    axs[1].set_xlim([0,515])
    axs[1].set_ylim([0,515])
    axs[1].set_aspect(1)
    axs[1].invert_yaxis()
    axs[1].set_title('estimated centre pos')
    axs[2].imshow(np.log10(1 + d_mean.data), vmax=0.4)
    axs[2].set_title('after align')
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(f'{os.getcwd()}/aligning_BF_disc.png')
    
else:
    logging.info("No BF disc alignment performed.")
    d_mean = d_mask.mean()
    fig, axs = plt.subplots(1,1)
    axs[0].imshow(np.log10(1 + d_mean.data), vmax=0.1)
    axs[0].set_title('without aligning')
    plt.savefig(f'{os.getcwd()}/without_BF_disc_alignment.png')

In [None]:
base_name = os.path.dirname(full_path)
print(base_name)
print(os.getcwd())

In [None]:
# saving a version in cwd before applying any changes
d_mask.save(f'{os.getcwd()}/masked_signal_before_cal')
d_mask.apply_affine_transformation(affine_matrix,keep_dtype=True)
d_mask_mean = d_mask.mean()
rad, x0, y0 = get_probe_size(d_mask_mean.data)
print(f'estimate of BF rad: {rad} and position {x0}, {y0}')
d_mask.set_diffraction_calibration(recip_pix, center=(recip_pix * x0,recip_pix * y0))

In [None]:
# saving a calibrated version
d_mask.save(f'{os.getcwd()}/{timestamp}_calibrated_data')
d_mask.save(f'{base_name}/{timestamp}_calibrated_data', overwrite=True)

In [None]:
# d_mask.axes_manager

In [None]:
# # Save a vdf and bdf
# # rad used from before = radius of d beam
# recip_pix = d_mask.axes_manager[3].scale
# cropping_radius = 2*rad*recip_pix # set boundary between bright and dark feild to be 2*direct beam radius

# roi_vbf = hs.roi.CircleROI(cx=0,cy=0, r=cropping_radius)
# roi_vdf = hs.roi.CircleROI(cx=0,cy=0, r_inner=cropping_radius,r=300*recip_pix)

# integrated_vbf = d_mask.get_integrated_intensity(roi_vbf)
# integrated_vdf = d_mask.get_integrated_intensity(roi_vdf)

In [None]:
# from PIL import Image
# arr_vbf = integrated_vbf.data
# arr_vbf = arr_vbf-np.amin(arr_vbf)
# arr_vbf = arr_vbf/np.amax(arr_vbf)
# arr_vbf =arr_vbf*255

# im_vbf = Image.fromarray(arr_vbf)
# im_vbf = im_vbf.convert('RGB')
# im_vbf.save(f'{os.getcwd()}_Integrated_vbf.png',quality=100, subsampling=0)

# arr_vdf = integrated_vdf.data
# arr_vdf = arr_vdf-np.amin(arr_vdf)
# arr_vdf = arr_vdf/np.amax(arr_vdf)
# arr_vdf =arr_vdf*255

# im_vdf = Image.fromarray(arr_vdf)
# im_vdf = im_vdf.convert('RGB')
# im_vdf.save(f'{os.getcwd()}_Integrated_vdf.png',quality=100, subsampling=0)

In [None]:
# scale = [1,1,2,2]
# rebin = d_mask.isig[:-1,:-1].rebin(scale=scale)
# rebin.save(f'{os.getcwd()}/{timestamp}_calibrated_data_rebinned')
