## This notebook does the follwoing:
- loads an SED dataset (lazily)
- User defines diffraction calibration, min angle / max angle / virtual detector width 
- We generate a series of virtual ADF images (summing in diff plane over ranges defined by above params)
- Saves the images in a sub-dir VDF_images to the path where the data is located

```yaml
global_min_angle: 
    value: '50'
global_max_angle: 
    value: '200'
global_detector_width:
    value: '10'
global_estimated_probe_radius_px:
    value: '8'
cal_json_path:
    value: '/dls/e02/data/2022/mg31953-1/processing/Merlin/Au_grating/calibrations_diff_CL_0.4.json'
```

In [None]:
# # dataset name
# data_label = 'Winchcombe_site9/20221205_172448'
# # notebook name
# notebook = 'VADF_series'
# global_min_angle = '50'
# global_max_angle = '200'
# global_detector_width = '10'
# global_estimated_probe_radius_px = '8'
# cal_json_path = '/dls/e02/data/2022/mg31953-1/processing/Merlin/Au_grating/calibrations_diff_CL_0.4.json'

# BEAMLINE = 'e02'
# YEAR = '2022'
# VISIT = 'mg31953-1'

In [None]:
%%capture --no-display
%matplotlib notebook
import numpy as np
import h5py
import json
import matplotlib.pyplot as plt
import hyperspy.api as hs
import os
import pyxem as pxm
import logging
import py4DSTEM 


In [None]:
path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/'
timestamp = data_label.split('/')[-1]
ibf_path = f'{path}/{data_label}/{timestamp}_ibf.hspy'
meta_path = f'{path}/{data_label}/{timestamp}.hdf'
full_path = f'{path}/{data_label}/{timestamp}_data.hdf5'
cal_data_path = f'{path}/{data_label}/{timestamp}_calibrated_data.hspy'

with open(os.path.join(cal_json_path)) as json_file:
    cals = json.load(json_file)
recip_pix = cals['reciprocal_space_pix(1/A)']
print(recip_pix)


In [None]:
d = hs.load(cal_data_path, lazy=True)
print(d)

In [None]:
min_ang = int(global_min_angle) # in pix
max_ang = int(global_max_angle) # in pix
detector_width = int(global_detector_width) # in pix

In [None]:
d.axes_manager

In [None]:
d_mean = d.mean()

In [None]:
d_mean.compute()

In [None]:
step_num = int((max_ang - min_ang) / detector_width)
print(step_num)

In [None]:
data_path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/{data_label}'
data_path

In [None]:
save_dir = os.path.join(data_path, 'VADF_series_images')

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
import gc
d_mean.plot(vmax=0.1, norm='log')
d_T = d.T
# intensities = []
adf_images = []
index_key = []


class ADF_image(hs.signals.Signal2D):
    def det_position(self, det):
        self.det_position = list((det.r, det.r_inner))
#     def bragg_vector_len(self, ap):
#         self.bragg_vector_len = np.sqrt(ap.cx ** 2 + ap.cy ** 2)
    def sum_intensity(self):
        self.sum_intensity = np.sum(self.data)
    def label(self):
        self.label = None
        
    
for i, index in enumerate(range(step_num)):
    adf_det = hs.roi.CircleROI(cx=0.0, cy=0.0, 
                           r=(min_ang + (i + 1) *detector_width) * recip_pix, 
                           r_inner=(min_ang + i * detector_width) * recip_pix)
    adf_sig = adf_det.interactive(d_T, navigation_signal=d_mean)
    adf_sig = adf_sig.data.astype('uint16')
    adf_sig[np.isnan(adf_sig)] = 0
    adf_sig = hs.signals.Signal2D(adf_sig).as_lazy()

    adf_sig.compute()
    adf_im = adf_sig.sum()
#     intensities.append(int(np.sum(adf_im.data)))
    adf_ = ADF_image(adf_im)
    adf_.det_position(adf_det)
    adf_.label = str(index)
    adf_images.append(adf_)
    
    
    adf_im = 255 * adf_im.data / np.max(adf_im.data)
    adf_im = hs.signals.Signal2D(adf_im)
    adf_im.save(f'{save_dir}/vadf_{i}_{(min_ang + i * detector_width)}_to_{(min_ang + (i + 1) *detector_width)}_px.jpg')
    adf_.save(f'{save_dir}/vadf_{i}_{(min_ang + i * detector_width)}_to_{(min_ang + (i + 1) *detector_width)}_px.hspy')
    
    del(adf_sig)
    del(adf_im)
    gc.collect()

plt.savefig(os.path.join(os.getcwd(), 'mean_diff_pattern_VADF_detectors.png'))

In [None]:
reshaped_adf = [x.data.flatten() for x in adf_images]
len(reshaped_adf)

In [None]:
data_array = np.zeros((len(adf_images), reshaped_adf[0].shape[0]))
for i in range(data_array.shape[0]):
    for j in range(data_array.shape[1]):
        data_array[i,j]= reshaped_adf[i][j]

In [None]:
max_index = np.zeros((reshaped_adf[0].shape[0],))
for i in range(reshaped_adf[0].shape[0]):
#     max_index[i] = np.argmax(data_array, axis=0)[0]  
    max_index[i] = np.argmax(data_array[:,i]) 

In [None]:
segment = max_index.reshape((d.axes_manager[0].size,d.axes_manager[1].size))
plt.figure()
plt.imshow(segment, cmap = 'turbo_r')
plt.colorbar()
plt.savefig(os.path.join(os.getcwd(), 'segmentation_based_on_max_ADF_signal.png'))

In [None]:
segment = hs.signals.Signal2D(segment)
segment.save(os.path.join(os.getcwd(), 'segmentation_based_on_max_ADF_signal.hspy'))