## This notebook does the follwoing:
- Loads the outcome of segmentation from VADF_series_SED notebook
- Applies a user-defined size threshold to the segmentation and separates them into spatially distinct domains 
- Generates a series sum difraction signal and their azimuthal integrations with the applied masks defined above
- Saves the images / signal outputs in a sub-dir **Segment_sums** to the path where the data is located

```yaml
global_segmentation_path: 
    value: 'None'
    explanation: 'Path to the hspy file with the desired segmentation. Leave to None if there is only one outcome of the VADF_Series_SED notebook'
global_use_binned_data:
    value: 'True'
    explanation: 'For the thick samples we may opt to bin the data by 2 as the outcome of the Apply_Cal_to_SED_data notebook, if so set to True'
global_DBSCAN_min_threshold: 
    value: '5'
    explanation: 'Min size of the segmentation.'
global_DBSCAN_eps:
    value: '2.8'
```

In [None]:
# Leave empty!

In [None]:
# # # Testing
# # # dataset name
# data_label = 'pct0_FIB/20230126_133423'
# # data_label = 'pct0_FIB/20230126_131428'
# # notebook name
# # notebook = 'Apply_Cal_to_SED_data'
# global_DBSCAN_min_threshold = '5'
# global_segmentation_path = 'None'
# global_use_binned_data = 'True'
# global_DBSCAN_eps = '2.8'
# # global_cal_json_path = '/dls/e02/data/2023/mg31973-1/processing/Merlin/au_xgrating/calibrations_diff_20230125_093528.json'
# # global_crop_window_size = '0.01'

# BEAMLINE = 'e02'
# YEAR = '2023'
# VISIT = 'mg31973-1'

In [None]:
%%capture --no-display
%matplotlib notebook
import numpy as np
import h5py
import json
import matplotlib.pyplot as plt
import hyperspy.api as hs
import os
import pyxem as pxm
import py4DSTEM
import logging
from sklearn.cluster import DBSCAN
import glob
import gc

In [None]:
path = f'/dls/{BEAMLINE}/data/{YEAR}/{VISIT}/processing/Merlin/'
timestamp = data_label.split('/')[-1]
ibf_path = f'{path}/{data_label}/{timestamp}_ibf.hspy'
meta_path = f'{path}/{data_label}/{timestamp}.hdf'
full_path = f'{path}/{data_label}/{timestamp}_data.hdf5'

if global_use_binned_data == 'True':
    d = hs.load(f'{path}/{data_label}/{timestamp}_calibrated_data_bin2.hspy', lazy=True)
else:
    d = hs.load(f'{path}/{data_label}/{timestamp}_calibrated_data.hspy', lazy=True)

# Check how many segmentations outcomes are there in the dataset
segment_path = glob.glob(f'{path}/{data_label}/vadf_series_sed.*')

if len(segment_path) > 1:
    if global_segmentation_path == 'None':
        print('Multiple segmentations results found. We are using one of the outcomes. You can provide the path to your favourite one as an input parameter!')
        d_seg = hs.load(f'{segment_path[-1]}/vadf_series_sed/segmentation_based_on_max_ADF_signal.hspy')
    else:
        d_seg = hs.load(global_segmentation_path)
else:
    d_seg = hs.load(f'{segment_path[0]}/vadf_series_sed/segmentation_based_on_max_ADF_signal.hspy')

In [None]:
print(segment_path)

In [None]:
d_seg.plot()

In [None]:
save_path = f'{path}/{data_label}/Segment_sums'
if not os.path.exists(save_path):
    os.makedirs(save_path)
    

In [None]:
d.data.shape[2:]

In [None]:
eps_val = float(global_DBSCAN_eps)
min_sample = int(global_DBSCAN_min_threshold)


# masking the BF disc
bf_mask = np.ones(d.data.shape[2:])
bf_mask[515//2 - 30:515//2 + 30, 515//2 - 30:515//2 + 30] = 0
# plt.figure()
# plt.imshow(bf_mask)

for ind in np.arange(np.max(d_seg.data)):
    mask = d_seg.data == int(ind)
    mask_ = mask.astype('int')
    
    # separating into individual domains
    _coords = np.asarray(np.where(mask_ ==1)).T
    individual_reg = np.zeros_like(mask_)
    individual_reg[np.where(mask_ ==1)] = DBSCAN(eps_val,min_samples = min_sample).fit_predict(_coords)+1
    
    # save the outcome of DBSCAN
    plt.figure()
    plt.imshow(individual_reg, cmap = 'turbo')
    plt.savefig(f'{save_path}/label_{int(ind)}_individual_domains.jpg')
    
    # loop through these domains 
    
    for clust_ind in np.arange(np.max(individual_reg)):
        mask_cluster = np.where(individual_reg==clust_ind, 1,0)
        mask_cluster = hs.signals.Signal2D(mask_cluster)
        d_mask = d * mask_cluster.T
        d_mask_sum = d_mask.sum()
        # computing sum signal over masked region
        d_mask_sum.compute()
        # radial integration
        d_int = d_mask_sum.radial_average()
        
        d_int.axes_manager[0].units = d_mask_sum.axes_manager[0].units
        d_int.axes_manager[0].scale = d_mask_sum.axes_manager[0].scale
        d_int.axes_manager[0].name = 'Scattering Angle'
        
        x_axis_ticks = d_int.axes_manager[0].scale * np.arange(d_int.data.shape[0])
        
        
        fig, axs = plt.subplots(1,3, figsize=(9,3))
        axs[0].imshow(mask_cluster.data, cmap = 'binary')
        axs[0].set_xticks([])
        axs[0].set_yticks([])
        axs[0].set_title('mask')
        axs[1].imshow(d_mask_sum.data, vmax = 0.1 * np.max(d_mask_sum.data * bf_mask), cmap='inferno')
        axs[1].set_xticks([])
        axs[1].set_yticks([])
        axs[1].set_title('sum signal')
        axs[2].plot(x_axis_ticks[30:], d_int.isig[30:].data)
        axs[2].set(xlabel=f'Scattering Angle {d_mask_sum.axes_manager[0].units}')
        axs[2].set(yticks = [])
        axs[2].set_title('radial average')
        fig.tight_layout()
        plt.savefig(f'{save_path}/label_{int(ind)}_cluster_ind_{clust_ind}.jpg');
        
        d_mask_sum.save(f'{save_path}/label_{int(ind)}_cluster_ind_{clust_ind}_sum_diff.hspy', overwrite=True)
        mask_cluster.save(f'{save_path}/label_{int(ind)}_cluster_ind_{clust_ind}_mask.hspy', overwrite=True)
        
        del(d_mask_sum)
        gc.collect()
