# Plot regions-of-interest maps from XRD-CT reconstructions

### Load data and plot average XRD for the map

In [None]:
%matplotlib widget
import h5py as h5
import numpy as np
import matplotlib.pyplot as plt
#To import DanMAX from the folder above:
import sys
sys.path.append('../')
import DanMAX as DM

style = DM.darkMode(style_dic={'figure.figsize':'large'})

In [None]:
# Define scan location:
fname = DM.findScan()
maps = DM.mapping.load_maps(fname)

Q = maps['Q']
x = maps['x_xrd']
recon = maps['xrd_map'].T # (x,y,radial)
x_map = maps['x_map']
y_map = maps['y_map']

# calculate the mean resolution
um_per_px = np.mean(np.diff(x_map[:,0]))*1e3

# check if the first pattern is the diode absorption data
if recon.shape[2] > x.shape[0]:
    A_map = recon[:,:,0]
    recon = recon[:,:,1:]
else:
    # no absorption data 
    A_map = np.ones(recon.shape[:2])

### Create a mask based on a standard deviation threshold  
#### Estimate threshold from an interactive plot

In [None]:
interactive_mask = DM.InteractiveMask(recon,reduction_mode='std')

In [None]:
# get the result from the interactive plot and convert to a nan-mask
mask = interactive_mask.getResult().astype(float)
mask[mask<1.] = np.nan
I_avg = np.nanmean(recon.T*mask,axis=(1,2))

### Select region-of-interest

In [None]:
# define the approximate region of interest in scattering units
#            label    :     roi
regions = {
            'peak1' : [4,5],
            'peak2'  : [7.5,9],
            'peak3' : [9.5,11.5]
            }

# plot the region of interest for the average pattern
plt.figure()
plt.title(DM.getScan_id(fname))
plt.plot(x,I_avg,label='average pattern')
# loop through all peaks
for region in regions:
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])
    plt.plot(x[roi],I_avg[roi],'.',label=region)
plt.ylabel('I [a.u.]')
if Q:
    plt.xlabel(r'Q [$\AA^{-1}$]')
else:
    plt.xlabel(r'2$\theta$ [$\deg$]')
plt.legend()

#### Plot ROI integrals

In [None]:
# Set the number of columns for the figure
cols = 3


# initialize figure
rows = int(len(regions)/cols) + (len(regions)%cols!=0)
fig, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
#fig.set_size_inches(12,8)
fig.suptitle(DM.getScan_id(fname))
axes = axes.flatten()

# calculate scale bar values
scale_500um = 500./(um_per_px)
offset = recon.shape[1]*0.05

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    # integral
    im = np.trapz(recon[:,:,roi],x=x[roi],axis=-1)
    
    # integral breadth
    breadth = im/np.max(recon[:,:,roi],axis=-1)
    breadth[im<im.max()*0.05]=np.nan
    #im = breadth

    
    # plot heatmap
    ax = axes[k]
    ax.set_title(f'{region}')
    ax.grid(False)
    
    vmin = max(im.min(),0)
    ax.imshow((im*mask).T,vmin=vmin)
    
    ax.plot([offset,offset+scale_500um],
            [offset,offset],
            'k',
            lw=10,
           )
    ax.annotate('500 μm',
                (offset+scale_500um/2,offset*2),
                horizontalalignment='center',
                verticalalignment='top',
               )

# delete surplus plots
for i in range(1,cols*rows-len(regions)+1):
    fig.delaxes(axes[-i])
fig.tight_layout()
    

### Perform simple single-peak fit

In [None]:
# flattened shape
flat_shape = (recon.shape[0]*recon.shape[1],recon.shape[2])
# flat reconstruction array
flat_recon = (recon.transpose(2,0,1)*mask).transpose(1,2,0).reshape(flat_shape)
# fit parameters
param = ['amplitude', 'position', 'FWHM', 'background']
# empty results dictionary
res = {reg:{p:np.full(flat_shape[0],np.nan) for p in param} for reg in regions}
# iterate through ROIs
for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    print(f'Region {k+1} of {len(regions)}')
    # Use a try statement to allow interrupting the fit useful for debugging 
    try:
        for i,y in enumerate(flat_recon):
            if np.mean(y[roi])>0.:
                amplitude, position, FWHM, background, y_calc = DM.fitting.singlePeakFit(x[roi],y[roi],verbose=False)
                for p,r in zip(param,[amplitude, position, FWHM, background]):
                    res[region][p][i]=r
            print(f'{(i+1)/(flat_shape[0])*100:.2f} %',end='\r')
    except KeyboardInterrupt:
        print('Fitting interrupted')
    finally:
        pass
    for p in param:
        res[region][p]=res[region][p].reshape((recon.shape[0],recon.shape[1]))

#### Plot peak fit results

In [None]:
# set the active parameter to plot ('amplitude', 'position', 'FWHM', 'background')
active_parameter = 'amplitude'
# Set the number of columns for the figure
cols = 3


# initialize figure
rows = int(len(regions)/cols) + (len(regions)%cols!=0)
fig, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
#fig.set_size_inches(12,8)
fig.suptitle(f'{DM.getScan_id(fname)} - {active_parameter}')
axes = axes.flatten()

# calculate scale bar values
scale_500um = 500./(um_per_px)
offset = recon.shape[1]*0.025

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    im = res[region][active_parameter]
    # plot heatmap
    ax = axes[k]
    ax.set_title(f'{region}')
    ax.grid(False)
    
    #vmin = max(im.min(),0)
    ax.imshow(im,
              #vmin=vmin,
              norm='log',
             )
    
    ax.plot([offset,offset+scale_500um],
            [offset,offset],
            'w',
            lw=10,
           )
    ax.annotate('500 μm',
                (offset+scale_500um/2,offset*2),
                horizontalalignment='center',
                verticalalignment='top',
               )

# delete surplus plots
for i in range(1,cols*rows-len(regions)+1):
    fig.delaxes(axes[-i])
fig.tight_layout()
    