# This notebook is to prepare a config file to run ptyrex recon

- read metadata 

- Detect BF dic diameter and form VADF image

- estimate CL and probe convergence semi-angle

- Help user decide on region to reconstruct

```yaml
global_template_json:
    value: '/dls_sw/e02/PtyREX_templates/template_MD.json'
    explanation: 'This is the template json file used as the starting point to populate based on the metadata.'
global_write_json: 
    value: '1'
    explanation: 'Leave at 0 if exploring parameters. Set to 1 to write a json file.'
global_json_path: 
    value: '/dls/e02/data/2023/cm33902-1/processing/Merlin/ptyrex_json_files'
    explanation: 'A global path to hold the json files. If None and json file write requested defaults to data folder.'
global_tag_name: 
    value: 'full'
    explanation: 'This prefix tag will be added to the json file and the reconstruction.'
    prompt: 'recosntruction prefix'
global_thresholds: 
    value: '0.01, 0.7'
    explanation: 'Provide two values to be used for lower and upper threshold when detecting the BF disc.'
    prompt: 'thresholds'
global_scan_rot:
    value: '-55'
    explanation: 'This is the scan rotation used in the PtyREX json file. It can be determined using CoM_DPC_ABF_ouputs notebook.'
    prompt: 'scan rotation'
global_recon_reg:
    value: '0,1,0,1,1,1'
    explanation: 'The first four values determine x and y ranges in fraction (0 to 1). the last two values should be integers and correspond to skips.'
    prompt: 'x/y ranges and skips'
remex_hints:
    value: 'standard_science_cluster'
    explanation: 'Hints to the workflow engine regarding cluster processing. Please consult the beamline staff before changing this.'
    prompt: 'remote execution hints'
```

In [None]:
# Leave empty!

In [None]:


# # dataset name
# data_label = 'cam_length_cal_data/20221130_114122'
# # notebook name
# notebook = 'ptyrex_calibrations'
# global_template_json = '/dls/e02/data/2022/mg31702-1/processing/template.json'
# global_write_json = '1'
# global_json_path = '/dls/e02/data/2022/mg31702-1/processing/Merlin/json_files'
# global_tag_name = 'test'
# global_thresholds = '0.01, 0.6'

# BEAMLINE = 'e02'
# YEAR = '2022'
# VISIT = 'cm31101-5'



In [None]:
%%capture --no-display
%matplotlib notebook
import sys
sys.path.append('/dls_sw/apps/pycho/ptyrex_latest')
import ptyrex
from ptyrex import np
from ptyrex import tb
from ptyrex import h5py
import json
sys.path.append('/dls/science/groups/e02/Mohsen/code/Git_Repos/Merlin-Medipix/')
import epsic_tools.api as epsic
import h5py
import os
import py4DSTEM
import hyperspy.api as hs
import matplotlib.pyplot as plt

In [None]:
# This cell to be deleted when running live with bxflow
# BEAMLINE = 'e02'
# YEAR = '2022'
# VISIT = 'cm31101-4'
# data_label = 'JM_RE252/20221006_145211'

In [None]:
# # dataset name
# data_label = 'Au_test/20221011_162517'
# # notebook name
# notebook = 'ptyrex_calibrations'
# global_template_json = '/dls/e02/data/2022/cm31101-4/processing/Merlin/graphene_example/template.json'
# global_write_json = '1'
# global_json_path = 'None'
# global_tag_name = 'test'

# BEAMLINE = 'e02'
# YEAR = '2022'
# VISIT = 'cm31101-4'

In [None]:
path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/'
timestamp = data_label.split('/')[-1]
ibf_path = f'{path}/{data_label}/{timestamp}_ibf.hspy'
meta_path = f'{path}/{data_label}/{timestamp}.hdf'
full_path = f'{path}/{data_label}/{timestamp}_data.hdf5'

In [None]:
save_path = os.path.dirname(ibf_path)
print(save_path)
base_name = os.path.dirname(save_path)
print(base_name)
print(os.getcwd())

In [None]:
scan_rot = float(global_scan_rot)
with h5py.File(meta_path, 'r') as f:
    print(f['metadata'].keys())
    mag = f['metadata/magnification'][()]
    fov = f['metadata/field_of_view(m)'][()]
    sh = f['metadata/4D_shape'][()]
    print(f['metadata/aperture_size'][()])
    print(f['metadata/nominal_camera_length(m)'][()])
    print(f['data/mask'].shape)
#     mask = f['data/mask'][()]
    def_val = f['metadata/defocus(nm)'][()] * 1e-9
    step_size = fov / sh[0]
    acc_volt = f['metadata/ht_value(V)'][()]
    e_wave_len = epsic.sim_utils.e_lambda(acc_volt)
    probe_conv = f['metadata/convergence_semi-angle(rad)'][()] * 2

In [None]:
with h5py.File('/dls_sw/e02/medipix_mask/Merlin_12bit_mask.h5', 'r') as f:
    mask = f['data/mask'][()]

In [None]:
plt.figure()
plt.imshow(mask)

In [None]:
# This can be removed once py4DSTEM gets updated in env
import numpy as np
from py4DSTEM.process.utils import get_CoM
# def get_probe_size(DP, thresh_lower=0.01, thresh_upper=0.6, N=100):
#     """
#     Gets the center and radius of the probe in the diffraction plane.
#     The algorithm is as follows:
#     First, create a series of N binary masks, by thresholding the diffraction pattern
#     DP with a linspace of N thresholds from thresh_lower to thresh_upper, measured
#     relative to the maximum intensity in DP.
#     Using the area of each binary mask, calculate the radius r of a circular probe.
#     Because the central disk is typically very intense relative to the rest of the DP, r
#     should change very little over a wide range of intermediate values of the threshold.
#     The range in which r is trustworthy is found by taking the derivative of r(thresh)
#     and finding identifying where it is small.  The radius is taken to be the mean of
#     these r values. Using the threshold corresponding to this r, a mask is created and
#     the CoM of the DP times this mask it taken.  This is taken to be the origin x0,y0.
#     Args:
#         DP (2D array): the diffraction pattern in which to find the central disk.
#             A position averaged, or shift-corrected and averaged, DP works best.
#         thresh_lower (float, 0 to 1): the lower limit of threshold values
#         thresh_upper (float, 0 to 1): the upper limit of threshold values
#         N (int): the number of thresholds / masks to use
#     Returns:
#         (3-tuple): A 3-tuple containing:
#             * **r**: *(float)* the central disk radius, in pixels
#             * **x0**: *(float)* the x position of the central disk center
#             * **y0**: *(float)* the y position of the central disk center
#     """
#     thresh_vals = np.linspace(thresh_lower, thresh_upper, N)
#     r_vals = np.zeros(N)

#     # Get r for each mask
#     DPmax = np.max(DP)
#     for i in range(len(thresh_vals)):
#         thresh = thresh_vals[i]
#         mask = DP > DPmax * thresh
#         r_vals[i] = np.sqrt(np.sum(mask) / np.pi)

#     # Get derivative and determine trustworthy r-values
#     dr_dtheta = np.gradient(r_vals)
#     mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta))
#     r = np.mean(r_vals[mask])

#     # Get origin
#     thresh = np.mean(thresh_vals[mask])
#     mask = DP > DPmax * thresh
#     x0, y0 = get_CoM(DP * mask)
    
#     return r, x0, y0

In [None]:
mask = mask.astype('bool')
d = hs.load(full_path, lazy=True)
d.axes_manager[2].offset = 0.
d.axes_manager[3].offset = 0.
d.axes_manager[2].scale = 1
d.axes_manager[3].scale = 1
d_mask = d * np.invert(mask)

In [None]:
# with h5py.File('/dls_sw/e02/medipix_mask/Merlin_12bit_mask_2.h5', 'w') as f:
#     f.create_dataset('data/mask', data=mask)


In [None]:
input_vals = global_thresholds.split(',')
v_min = float(input_vals[0])
v_max = float(input_vals[1])

In [None]:
d_mean = d_mask.mean()
d_mean.compute()
# find good value for r for arbitrary data
# Estimate the radius of the BF disk, and the center coordinates
rad, x0, y0  = py4DSTEM.process.calibration.get_probe_size(d_mean.data, v_min, v_max)
# rad, x0, y0 = get_probe_size(d_mean.data, v_min, v_max)
# rad, x0, y0 = get_probe_size(data)
print('BF disc radius in pixels:', int(rad))
# Here rad is radius of BF disc
print(f'Optical axis coordinates are {int(x0)} and {int(y0)}')

In [None]:
d_mean.plot()

In [None]:
cam_len = (55e-6 * int(rad)) / (probe_conv / 2)
print(cam_len)

In [None]:
# plot mean pattern with deteced BF disc
d_mean.plot(vmax = 50)
circ_ROI = hs.roi.CircleROI(cx = int(y0), cy = int(x0), r = int(rad))
circ_ROI.interactive(d_mean)
# plt.savefig(f'{save_path}/BF_disc_detected.png')
plt.savefig(f'{os.getcwd()}/BF_disc_detected.png')

In [None]:
min_ang = int(rad) + 20 # in pix
max_ang = 250 # in pix


d_mean.plot(vmax=30)
d_T = d_mask.T

In [None]:
adf_det = hs.roi.CircleROI(cx=int(y0), cy=int(x0), 
                           r=max_ang, 
                           r_inner=min_ang)


In [None]:
adf_sig = adf_det.interactive(d_T, navigation_signal=d_mean)
fig = plt.gcf()
# fig.savefig(f'{save_path}/ADF_detector.png')
fig.savefig(f'{os.getcwd()}/ADF_detector.png')

In [None]:
adf_sig = adf_sig.data.astype('uint16')
adf_sig[np.isnan(adf_sig)] = 0
adf_sig = hs.signals.Signal2D(adf_sig).as_lazy()
adf_sig.compute()
adf_im = adf_sig.sum()
adf_im = 255 * adf_im.data / np.max(adf_im.data)
adf_im = hs.signals.Signal2D(adf_im)

In [None]:
adf_im.plot()

In [None]:
# plt.savefig(f'{save_path}/ADF_im.png')
plt.savefig(f'{os.getcwd()}/ADF_im.png')

In [None]:
from skimage.transform import rotate
adf_rot = np.flip(rotate(adf_im.data, -1 * scan_rot, resize=True), axis = 1)
fig, axs = plt.subplots(1,1)
axs.imshow(adf_rot, cmap = 'gray')
# Major ticks
axs.set_xticks(np.arange(0, adf_rot.shape[0], adf_rot.shape[0]//5))
axs.set_yticks(np.arange(0, adf_rot.shape[0], adf_rot.shape[0]//5))

# # Labels for major ticks
axs.set_xticklabels([0. , 0.2, 0.4, 0.6, 0.8, 1. ])
axs.set_yticklabels([0. , 0.2, 0.4, 0.6, 0.8, 1. ])

# Gridlines based on minor ticks
axs.grid(which='major', color='w', linestyle='-', linewidth=1)
# plt.savefig(f'{save_path}/ADF_im_rotated_grid.png')
plt.savefig(f'{os.getcwd()}/ADF_im_rotated_grid.png')

In [None]:
recon_reg_x0 = float(global_recon_reg.split(',')[0])
recon_reg_x1 = float(global_recon_reg.split(',')[1])
recon_reg_y0 = float(global_recon_reg.split(',')[2])
recon_reg_y1 = float(global_recon_reg.split(',')[3])
recon_reg_skipx = int(global_recon_reg.split(',')[4])
recon_reg_skipy = int(global_recon_reg.split(',')[5])


In [None]:
if not os.path.exists(f'{save_path}/ptyrex_recon'):
    os.makedirs(f'{save_path}/ptyrex_recon')

In [None]:
try:
    with open(global_template_json,'r') as template_file:
        pty_expt = json.load(template_file)
    # modify the above:
    pty_expt['base_dir'] = f'{save_path}/ptyrex_recon'
    pty_expt['process']['save_dir'] = f'{save_path}/ptyrex_recon'
    pty_expt['experiment']['data']['data_path'] = meta_path
    pty_expt['process']['common']['scan']['rotation'] = scan_rot
    pty_expt['process']['common']['scan']['N'] = [int(sh[0]), int(sh[1])]
    pty_expt['experiment']['detector']['position'] = [0, 0, float(cam_len)]
    pty_expt['experiment']['optics']['lens']['alpha'] = float(probe_conv)
    pty_expt['process']['common']['source']['energy'] = int(acc_volt)
    pty_expt['process']['common']['scan']['dR'] = [float(step_size), float(step_size)]
    pty_expt['experiment']['optics']['lens']['defocus'] = [float(def_val),float(def_val)]
    pty_expt['process']['save_prefix'] = global_tag_name
    pty_expt['process']['common']['scan']['region']  = [recon_reg_x0, recon_reg_x1, recon_reg_y0 ,recon_reg_y1,recon_reg_skipx,recon_reg_skipy]
except OSError: 
    print('no valid template path')

In [None]:
if global_json_path is not 'None':
    if not os.path.exists(f'{global_json_path}'):
        os.makedirs(f'{global_json_path}')

In [None]:
if global_write_json == '1':
    if global_json_path is not 'None':
        with open(f'{global_json_path}/{timestamp}_{global_tag_name}.json','w') as f:
            json.dump(pty_expt, f, indent=4)
    else:
        with open(f'{path}/{data_label}/{timestamp}_{global_tag_name}.json','w') as f:
            json.dump(pty_expt, f, indent=4)
