## Please select epsic 3.10 env for this notebook 

This notebook does the following:

- loads the two files generated by the JEOL AnalysisStation (.pts and .APB files)
- Measures the shift in the image stack
- Saves a zspy format of the EDX stack - binned / cropped version if needed
- Runs an alignmnet routine on the EDX stack using the shifts measured on the image stack above
- Re-applies the metadata tp the aligned sum spectrum
  

In [None]:
%matplotlib widget
import hyperspy.api as hs
import numpy as np
import h5py
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import gc

In [None]:
hs.__version__

In [None]:
pts_path = '/dls/e01/data/2023/mg37141-1/EDX/PtNi_1495_graphene/Sample/00_View000/View000_0000012.pts'
apb_path = '/dls/e01/data/2023/mg37141-1/EDX/PtNi_1495_graphene/Sample/00_View000/View000_0000013.APB'

In [None]:
d = hs.load(pts_path, sum_frames=False, lazy=True)

In [None]:
d.axes_manager

In [None]:
def load_apb(apb_filename,frames):
    fd = open(apb_filename, "br")
    file_magic = np.fromfile(fd, "uint8")
    offset = 16668
    data_array = np.zeros((frames,128,128))
    for i in range(frames):
        im_test = file_magic[(i*16384)+offset:(i*16384)+offset+16384]
        reshaped = im_test.reshape(128,128)
        data_array[i,:,:] = reshaped
    hs_data = hs.signals.Signal2D(data_array)
    return(hs_data)

In [None]:
d.axes_manager[2].size

In [None]:
im_stack = load_apb(apb_path,  d.axes_manager[2].size)

In [None]:
im_stack

In [None]:
im_stack.plot()

In [None]:
d

In [None]:
d.change_dtype('uint8')

In [None]:
im_stack.sum().plot()

In [None]:
# im_stack_crop = im_stack.inav[:600]

In [None]:
shifts = im_stack.align2D(crop=False)

In [None]:
im_stack.sum().plot()

In [None]:
im_stack_sum = im_stack.sum()

In [None]:
zspy_path = '/dls/e01/data/2023/mg37141-1/processing/EDX/PtNi_1495_graphene/View_00_12_13_bin4/View_00_12_13_bin4.zspy'
if not os.path.exists(os.path.dirname(zspy_path)):
    os.makedirs(os.path.dirname(zspy_path))
output_path = os.path.dirname(zspy_path)

In [None]:
plt.savefig(os.path.join(output_path, 'sum_images_stack.png'))

In [None]:
shifts

## Adding elements and calibrate

In [None]:
d.add_elements(['Pt', 'Ni'])

In [None]:
d.axes_manager

# Align EDX frames

In [None]:
binned_eds = d.rebin(scale=(4,4,1,2))

In [None]:
binned_eds.axes_manager

In [None]:
# binned_eds_crop = binned_eds.inav[:,:,:600]

In [None]:
binned_eds.save(zspy_path)

In [None]:
data_zarr = hs.load(zspy_path, lazy=True)

In [None]:
data_zarr

In [None]:
# si_sum = np.zeros((256,256,2048))
si_sum = np.zeros((128,128,2048))

In [None]:
from scipy import ndimage
def shift_image(im, shift=0, interpolation_order=1, fill_value=np.nan):
    if not np.any(shift):
        return im
    else:
        fractional, integral = np.modf(shift)
        if fractional.any():
            order = interpolation_order
        else:
            # Disable interpolation
            order = 0
        return ndimage.shift(im, shift, cval=fill_value, order=order)
    

def shift_si(si, shift):
    """
    si is a hyperspy EDX object
    """
    from functools import partial
    mapfunc = partial(shift_image, shift=shift)
    si_t = si.T
    si_shift = map(mapfunc, si_t.data)
    si_shift = list(si_shift)
    si_shift = np.asarray(si_shift)
    si_shift = si_shift.astype('uint8')
    si_shift = hs.signals.Signal2D(si_shift)
    return si_shift.T
    


In [None]:
# batch_size = 100
# batch_num = binned_eds_crop.data.shape[0] // batch_size

In [None]:
data_zarr.metadata

In [None]:
# binned_eds_crop.add_elements(elements_list)
# lines_list = ['C', 'Co', 'Mn', 'Ni', 'O']
data_zarr.add_lines()

In [None]:
data_zarr.axes_manager[2].size

In [None]:
for i in range(data_zarr.axes_manager[2].size):
    si_to_add = data_zarr.inav[:,:,i]
    si_to_add.compute()
    si_aligned = shift_si(si_to_add, -1 * shifts[i])
    si_sum = si_sum + si_aligned.data
    del si_to_add
    gc.collect()  
#     batch_to_compute = binned_eds_crop.inav[:,:,int(i*batch_size):int(i*batch_size + batch_size)]
#     batch_to_compute.compute(parallel=True, max_workers=8)
        
    print(f'Computed the binned version of the EDX stack_batch number {i}')



In [None]:
si_sum = hs.signals.Signal2D(si_sum)
edx_hs = hs.signals.EDSTEMSpectrum(si_sum)

In [None]:
edx_hs

In [None]:
binned_eds.axes_manager

In [None]:

edx_hs.axes_manager[0].name = binned_eds.axes_manager[0].name
edx_hs.axes_manager[1].name = binned_eds.axes_manager[1].name
edx_hs.axes_manager[2].name = binned_eds.axes_manager[3].name

edx_hs.axes_manager[0].scale = binned_eds.axes_manager[0].scale
edx_hs.axes_manager[1].scale = binned_eds.axes_manager[1].scale
edx_hs.axes_manager[0].offset = binned_eds.axes_manager[0].offset
edx_hs.axes_manager[1].offset = binned_eds.axes_manager[1].offset
edx_hs.axes_manager[0].units = binned_eds.axes_manager[0].units
edx_hs.axes_manager[1].units = binned_eds.axes_manager[1].units

edx_hs.axes_manager[2].scale = binned_eds.axes_manager[3].scale
edx_hs.axes_manager[2].offset = binned_eds.axes_manager[3].offset
edx_hs.axes_manager[2].units = binned_eds.axes_manager[3].units

edx_hs.add_elements(['Pt' ,'Ni'])

edx_hs.save(os.path.join(output_path , f'SI_sum'))
    


In [None]:
edx_hs

In [None]:
edx_hs.axes_manager

In [None]:
edx_hs.metadata

In [None]:
edx_hs.sum().plot(True)

In [None]:
plt.savefig(os.path.join(output_path, 'sum_spectrum.png'))

In [None]:
eds_maps = edx_hs.get_lines_intensity();

hs.plot.plot_images(eds_maps, axes_decor = 'off', scalebar = 'all',
    tight_layout=True, cmap=  'viridis',
    colorbar='single', 
    scalebar_color='black', suptitle_fontsize=8,
    padding={'top':0.8, 'bottom':0.10, 'left':0.05,
            'right':0.85, 'wspace':0.20, 'hspace':0.20});


In [None]:
plt.savefig(os.path.join(output_path, 'maps.png'))

In [None]:
eds_maps_with_ADF = eds_maps.copy()

In [None]:
im_stack_sum

In [None]:
eds_maps

In [None]:
eds_maps_with_ADF.append(im_stack_sum)

In [None]:
im_stack_sum.plot()

In [None]:
# si_EDS = hs.load("core_shell.hdf5")
# im = si_EDS.get_lines_intensity()
hs.plot.plot_images(eds_maps_with_ADF, 
                    axes_decor = 'off', 
                    # scalebar = 'all',
                    tight_layout=True, 
                    cmap=  ['inferno', 'inferno', 'viridis'],
                    # colorbar='single', 
                    scalebar_color='black', 
                    suptitle_fontsize=8,
                    padding={'top':0.8, 'bottom':0.10, 'left':0.05,
                            'right':0.85, 'wspace':0.20, 'hspace':0.20});


In [None]:
plt.savefig(os.path.join(output_path, 'maps_with_ADF.png'))

In [None]:
eds_maps_2 = edx_hs.get_lines_intensity(['Pt_Ma', 'Ni_La']);

hs.plot.plot_images(eds_maps_2, axes_decor = 'off', scalebar = 'all',
    tight_layout=True, cmap=  'viridis',
    colorbar='single', 
    scalebar_color='black', suptitle_fontsize=8,
    padding={'top':0.8, 'bottom':0.10, 'left':0.05,
            'right':0.85, 'wspace':0.20, 'hspace':0.20});