### Guided Computation of Fused Multi-Modal Electron Microscopy

```{admonition} Step 1: Python Imports
First, we import standard python packages alongside our custom class of functions for data fusion called [fusion_utils](https://github.com/jtschwar/Multi-Modal-2D-Data-Fusion/blob/170fea3292da7e6390bfff7236610eb0c8077ff7/EDX/fusion_utils.py). The fusion_utils package contains 3 functions and a class of TV functions: save_data, plot_convergence, create_weighted_measurement_matrix, and tvlib. The save_data function saves the fused images into .tif files and all the parameters and matrix data into a .h5 file. Plot_convergence plots the convergence of the cost functions. Create_weighted_measurement_matrix creates the summation matrix relating the elemental weights $Z$ and $\gamma$ as described by the first term in the cost function. Finally, the class of functions called tvlib performs the Fast Gradient Project (FGP) method of TV image denoising, which deblurs output images while preserving edges.
```
:::{code-cell} ipython3
from scipy.sparse import spdiags
import matplotlib.pyplot as plt
import fusion_utils as utils
from tqdm import tqdm 
import numpy as np
:::

```{admonition} Step 2: Load your data
Load your inelastic and elastic data. Define the element names and their corresponding atomic weights. For the sake of this tutorial I use a .npz file, but for .dm3,.dm4,.emd, or another EM file format, just extract the 2D image data into matrices.  Your elastic data should be stored in the HAADF variable, and your inelastic data (EDX/EELS) should be stored in the edsMap variable.
```

:::{code-cell} ipython3
data = np.load('PTO_Trilayer_dataset.npz')
# Define element names and their atomic weights
elem_names=['Sc', 'Dy', 'O']
elem_weights=[21,66,8]
# Parse elastic HAADF data and inelastic chemical maps based on element index from line above
HAADF = data['HAADF']
xx = np.array([],dtype=np.float32)
for ee in elem_names:

	# Read Chemical Map for Element "ee"
	edsMap = data[ee]

	# Set Noise Floor to Zero and Normalize Chemical Maps
	edsMap -= np.min(edsMap); edsMap /= np.max(edsMap)

	# Concatenate Chemical Map to Variable of Interest
	xx = np.concatenate([xx,edsMap.flatten()])
:::

```{admonition} Step 3: Reshape your data
Run the following cell without changing anything.  This cell reshapes your data and peforms the summation operation outlined in the cost function which is described in the Mathematical Methods section.
```

```{danger} Danger!
Do not change the code below.
```

:::{code-cell} ipython3
# Make Copy of Raw Measurements for Poisson Maximum Likelihood Term 
xx0 = xx.copy()

gamma = 1.6 # Don't change this parameter from 1.6

# Image Dimensions
(nx, ny) = edsMap.shape; nPix = nx * ny
nz = len(elem_names)

# C++ TV Min Regularizers
reg = utils.tvlib(nx,ny)

# Data Subtraction and Normalization 
HAADF -= np.min(HAADF); HAADF /= np.max(HAADF)
HAADF=HAADF.flatten()

# Create Summation Matrix
A = utils.create_weighted_measurement_matrix(nx,ny,nz,elem_weights,gamma,1)
:::

```{admonition} Optional
:class: tip 
Plot your raw elastic/inelastic data
```

:::{code-cell} ipython3
fig, ax = plt.subplots(2,len(elem_names)+1,figsize=(12,8))
ax = ax.flatten()
ax[0].imshow(HAADF.reshape(nx,ny),cmap='gray'); ax[0].set_title('HAADF'); ax[0].axis('off')
ax[1+len(elem_names)].imshow(HAADF.reshape(nx,ny)[70:130,25:85],cmap='gray'); ax[1+len(elem_names)].set_title('HAADF Cropped'); ax[1+len(elem_names)].axis('off')

for ii in range(len(elem_names)):
    ax[ii+1].imshow(xx0[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx,ny),cmap='gray'); ax[ii+1].set_title(elem_names[ii]); ax[ii+1].axis('off')
    ax[ii+2+len(elem_names)].imshow(xx0[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx,ny)[70:130,25:85],cmap='gray'); ax[ii+2+len(elem_names)].set_title(elem_names[ii]+' Cropped'); ax[ii+2+len(elem_names)].axis('off')
plt.show()
:::

```{admonition} Step 4: Fine tune your weights for each of the three parts of the cost function. The first weight, labeled as lambdaHAADF, luckily does not need tuning since it is ideal as the inverse of the number of elements.  The second weight, labeled as lambdaChem, is our data consistency term and we typically find to be ideal between 0.05 and 0.3 although the total range for this term is 0 to 1. The final weight is lambdaTV and it dictates how strong our Total Variation denoising is. In order to preserve fine features while removing noise, we find that a <0.2 lambdaTV value is ideal. nIter is the number of iterations the cost function will run for. although we typically see convergence within 10 iterations, we recommend using 30 iterations to make sure convergence is met long term. the variable bkg represents the background and we use this number to perform minor background subtraction to try and reduce noise; we recommend keeping this variable small. Finally, we give the user the option to turn the TV regularization off, and define a nIter_TV variable to define the number of iterations the FGP TV algorithm should run for.
```

:::{code-cell} ipython3
# Convergence Parameters
lambdaHAADF = 1/nz # Do not modify this
lambdaChem = 0.08
lambdaTV = 0.15; #Typically between 0.001 and 1
nIter = 30 # Typically 10-15 will suffice
bkg = 2.4e-1

# FGP TV Parameters
regularize = True; nIter_TV = 3; 
:::

```{admonition} Step 5: Run the Fused Multi-Modal algorithm
Run the algorithm and use the output cost functions to determine if the weights you chose previously allow for convergence.  If you do not see convergence for one or more of the cost functions, tune the associated lambda value for that function.  For example, if the lambdaCHEM cost function does not show convergence, try change the lambdaCHEM value to something lower.  Repeat until all 3 cost functions show convergence.
```

```{danger} Danger!
Do not change the code below.
```

:::{code-cell} ipython3
xx = xx0.copy()

# Auxiliary Functions
lsqFun = lambda inData : 0.5 * np.linalg.norm(A.dot(inData**gamma) - HAADF) **2
poissonFun = lambda inData : np.sum(xx0 * np.log(inData + 1e-8) - inData)

# Main Loop
costHAADF = np.zeros(nIter,dtype=np.float32); costChem = np.zeros(nIter, dtype=np.float32); costTV = np.zeros(nIter, dtype=np.float32);
for kk in tqdm(range(nIter)):
	# HAADF Update
	xx -=  gamma * spdiags(xx**(gamma - 1), [0], nz*nx*ny, nz*nx*ny) * lambdaHAADF * A.transpose() * (A.dot(xx**gamma) - HAADF) + lambdaChem * (1 - xx0 / (xx + bkg))
	xx[xx<0] = 0

	# Regularization 
	if regularize:
		for zz in range(nz):
			xx[zz*nPix:(zz+1)*nPix] = reg.fgp_tv( xx[zz*nPix:(zz+1)*nPix].reshape(nx,ny), lambdaTV, nIter_TV).flatten()
			costTV[kk] += reg.tv( xx[zz*nPix:(zz+1)*nPix].reshape(nx,ny) )
			
	# Measure Cost Function
	costHAADF[kk] = lsqFun(xx); costChem[kk] = poissonFun(xx)
:::

```{admonition} Step 6: Assess Convergence
Assess Convergence by confirming that all 3 cost functions adequately converge like the example plot show below.  You should see all three curves start high and exponentially decay into a plateau where no change is seen.
```

:::{figure} ./figs/Figure_3_Convergence.png
:name: convergence
:width: 700px
Convergence of 3 subparts of the multi-modal cost function.  The top plot represents the first term that is dictated by $\lambda_{HAADF}$.  The middle plot represents the second term that is dictated by $\lambda_{Chem}$. The bottom plot represents the third TV term that is dictated by $\lambda_{TV}$. 
:::

```{admonition} Be careful with $\lambda_{TV}$!
:class: attention
You may see convergence with an over or under weighted TV term.  Under weighting the TV term results in noisy reconstructions, while over weighting results in blurring and loss of features in the image.  Below is an example of the output images from under, over, and just right TV weights.
```

:::{figure} ./figs/Figure_4_TV.png
:name: TV_weights
:width: 700px
Comparison of TV weighting across different chemistries and HAADF. Too low of a TV results in noise and artifacts across images.  Proper TV preserves fine features like the dumbell shape of the Dy particles, while reducing noise.  High TV oversmoothes the image resulting in loss of important features for analysis.
:::

:::{code-cell} ipython3
# Display Cost Functions and Descent Parameters
utils.plot_convergence(costHAADF, lambdaHAADF, costChem, lambdaChem, costTV, lambdaTV)
# Show Reconstructed Signal
fig, ax = plt.subplots(2,len(elem_names)+1,figsize=(12,8))
ax = ax.flatten()
ax[0].imshow((A.dot(xx**gamma)).reshape(nx,ny),cmap='gray'); ax[0].set_title('HAADF'); ax[0].axis('off')
ax[1+len(elem_names)].imshow((A.dot(xx**gamma)).reshape(nx,ny)[70:130,25:85],cmap='gray'); ax[1+len(elem_names)].set_title('HAADF Cropped'); ax[1+len(elem_names)].axis('off')

for ii in range(len(elem_names)):
    ax[ii+1].imshow(xx[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx,ny),cmap='gray'); ax[ii+1].set_title(elem_names[ii]); ax[ii+1].axis('off')
    ax[ii+2+len(elem_names)].imshow(xx[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx,ny)[70:130,25:85],cmap='gray'); ax[ii+2+len(elem_names)].set_title(elem_names[ii]+' Cropped'); ax[ii+2+len(elem_names)].axis('off')
plt.show()
:::

```{admonition} Step 7: Save your data
Define the folder you would like to save to. The output images and data will be saved to .tif and .h5 file formats
```

:::{code-cell} ipython3
save_folder_name='test'
utils.save_data(save_folder_name, xx0, xx, HAADF, A.dot(xx**gamma), elem_names, nx, ny, costHAADF, costChem, costTV, lambdaHAADF, lambdaChem, lambdaTV, gamma)
:::