# Plot regions-of-interest maps from XRD-CT reconstructions

### Load data and plot average XRD for the map

In [None]:
%matplotlib widget
import h5py as h5
import numpy as np
import matplotlib.pyplot as plt
#To import DanMAX from the folder above:
import sys
sys.path.append('../')
import DanMAX as DM
style = DM.darkMode(style_dic={'figure.figsize':'large'})

In [None]:
# Define scan location:
fname = DM.findScan()

# find reconstructed file name
rname = fname.replace('raw', 'process/xrd_ct').replace('.h5', '_recon.h5')

# load the reconstructed data
with h5.File(rname,'r') as f:
    recon = f['/reconstructed/gridrec'][:] # (n, m, radial)
    if 'q' in f.keys():
        Q = True
        x = f['q'][:]
    else:
        Q = False
        x = f['2th'][:]
    if 'micrometer_per_px' in f.keys():
        um_per_px = f['micrometer_per_px'][()]
    else:
        um_per_px = np.nan
        
I_avg = np.mean(recon,axis=(0,1))

### Create a mask based on a standard deviation threshold  
#### Estimate threshold from an interactive plot

In [None]:
# calculate standard deviation (naormalized to mean)
im_std = np.std(recon,axis=2)/np.mean(recon)
DM.interactiveImageHist(im_std,ignore_zero=True)

#### Create and evaluate mask
Set the threshold based on the interactive plot

In [None]:
threshold = 2
## initialize figure
# grid spec keywords
gs_kw = dict(width_ratios=[1, 1], 
             height_ratios=[1, 4, 4])

fig, axd = plt.subplot_mosaic([['N' ,  'N'],
                               ['W' ,  'E'],
                               ['SW', 'SE']],
                              gridspec_kw=gs_kw,
                              layout='constrained',
                              figsize=(8, 8)
                             )
# calculate mean reconstruction patterns
im = np.mean(recon,axis=2)
# calculate standard deviation (naormalized to mean)
im_std = np.std(recon,axis=2)/np.mean(recon)

# calculate histogram
val, edges = np.histogram(im_std[im_std>0],bins=im_std.shape[0], density=True)
bins = edges[:-1]+np.mean(np.diff(edges))/2

# create mask based on threshold
mask = (im_std>threshold).astype(float)
mask[mask<1.] = np.nan

axd['N'].set_title('Threshold histogram')
axd['N'].plot(bins,val)
axd['N'].axvline(threshold,ls='--',c='grey')
#axd['N'].set_xscale('log')

axd['W'].set_title('Original')
axd['W'].imshow(im)
axd['W'].set_xticks([])
axd['W'].set_yticks([])

axd['E'].set_title('Standard deviation')
axd['E'].imshow(im_std)
axd['E'].set_xticks([])
axd['E'].set_yticks([])

axd['SW'].set_title('Masked')
axd['SW'].imshow(im*mask)
axd['SW'].set_xticks([])
axd['SW'].set_yticks([])

axd['SE'].set_title('Mask')
axd['SE'].imshow(mask,vmax=1,vmin=0)
axd['SE'].set_xticks([])
axd['SE'].set_yticks([])


### Select region-of-interest

In [None]:
# define the approximate region of interest in scattering units
#            label    :     roi
regions = {'region_1' : [4.8,5.1],
           'region_2' : [9.4,11.4],
          }

# plot the region of interest for the average pattern
plt.figure()
plt.title(DM.getScan_id(fname))
plt.plot(x,I_avg,label='average pattern')
# loop through all peaks
for region in regions:
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])
    plt.plot(x[roi],I_avg[roi],'.',label=region)
plt.ylabel('I [a.u.]')
if Q:
    plt.xlabel(r'Q [$\AA^{-1}$]')
else:
    plt.xlabel(r'2$\theta$ [$\deg$]')
plt.legend()

#### Plot ROI integrals

In [None]:
# Set the number of columns for the figure
cols = 3


# initialize figure
rows = int(len(regions)/cols) + (len(regions)%cols!=0)
fig, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
#fig.set_size_inches(12,8)
fig.suptitle(DM.getScan_id(fname))
axes = axes.flatten()

# calculate scale bar values
scale_500um = 500./(um_per_px)
offset = recon.shape[1]*0.025

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    # integral
    im = np.trapz(recon[:,:,roi],x=x[roi],axis=-1)
    
    # integral breadth
    breadth = im/np.max(recon[:,:,roi],axis=-1)
    breadth[im<im.max()*0.05]=np.nan
    #im = breadth

    
    # plot heatmap
    ax = axes[k]
    ax.set_title(f'{region}')
    ax.grid(False)
    
    vmin = max(im.min(),0)
    ax.imshow(im*mask,vmin=vmin)
    
    ax.plot([offset,offset+scale_500um],
            [offset,offset],
            'w',
            lw=10,
           )
    ax.annotate('500 μm',
                (offset+scale_500um/2,offset*2),
                horizontalalignment='center',
                verticalalignment='top',
               )

# delete surplus plots
for i in range(1,cols*rows-len(regions)+1):
    fig.delaxes(axes[-i])
fig.tight_layout()
    

### Perform simple single-peak fit

In [None]:
# flattened shape
flat_shape = (recon.shape[0]*recon.shape[1],recon.shape[2])
# flat reconstruction array
flat_recon = (recon.transpose(2,0,1)*mask).transpose(1,2,0).reshape(flat_shape)
# fit parameters
param = ['amplitude', 'position', 'FWHM', 'background']
# empty results dictionary
res = {reg:{p:np.full(flat_shape[0],np.nan) for p in param} for reg in regions}
# iterate through ROIs
for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    print(f'Region {k+1} of {len(regions)}')
    # Use a try statement to allow interrupting the fit useful for debugging 
    try:
        for i,y in enumerate(flat_recon):
            if np.mean(y[roi])>0.:
                amplitude, position, FWHM, background, y_calc = DM.singlePeakFit(x[roi],y[roi],verbose=False)
                for p,r in zip(param,[amplitude, position, FWHM, background]):
                    res[region][p][i]=r
            print(f'{(i+1)/(flat_shape[0])*100:.2f} %',end='\r')
    except KeyboardInterrupt:
        print('Fitting interrupted')
    finally:
        pass
    for p in param:
        res[region][p]=res[region][p].reshape((recon.shape[0],recon.shape[1]))

#### Plot peak fit results

In [None]:
# set the active parameter to plot ('amplitude', 'position', 'FWHM', 'background')
active_parameter = 'amplitude'
# Set the number of columns for the figure
cols = 3


# initialize figure
rows = int(len(regions)/cols) + (len(regions)%cols!=0)
fig, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
#fig.set_size_inches(12,8)
fig.suptitle(f'{DM.getScan_id(fname)} - {active_parameter}')
axes = axes.flatten()

# calculate scale bar values
scale_500um = 500./(um_per_px)
offset = recon.shape[1]*0.025

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

for k,region in enumerate(regions):
    roi = regions[region]
    roi = (x>roi[0]) & (x<roi[1])

    im = res[region][active_parameter]
    # plot heatmap
    ax = axes[k]
    ax.set_title(f'{region}')
    ax.grid(False)
    
    #vmin = max(im.min(),0)
    ax.imshow(im,
              #vmin=vmin,
              norm='log',
             )
    
    ax.plot([offset,offset+scale_500um],
            [offset,offset],
            'w',
            lw=10,
           )
    ax.annotate('500 μm',
                (offset+scale_500um/2,offset*2),
                horizontalalignment='center',
                verticalalignment='top',
               )

# delete surplus plots
for i in range(1,cols*rows-len(regions)+1):
    fig.delaxes(axes[-i])
fig.tight_layout()
    