In [None]:
#| label: app:tutorialtwo
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from tqdm.notebook import tqdm
import data.fusion_utils as utils
from scipy.sparse import spdiags
from tqdm import tqdm 
import h5py
data = 'Co3O4_Mn3O4.h5'
# Define element names and their atomic weights
elem_names=['Co', 'Mn', 'O']
elem_weights=[27,25,8]
# Parse elastic HAADF data and inelastic chemical maps based on element index from line above
with h5py.File(data, 'r') as h5_file:
    HAADF = np.array(h5_file['HAADF'])
xx = np.array([],dtype=np.float32)
for ee in elem_names:

	  # Read chemical maps
    with h5py.File(data, 'r') as h5_file:
        chemMap = np.array(h5_file[ee])
        
    # Check if chemMap has the same dimensions as HAADF
    if chemMap.shape != HAADF.shape:
        raise ValueError(f"The dimensions of {ee} chemical map do not match HAADF dimensions.")
	
	  # Set Noise Floor to Zero and Normalize Chemical Maps
    chemMap -= np.min(chemMap); chemMap /= np.max(chemMap)

    # Concatenate Chemical Map to Variable of Interest
    xx = np.concatenate([xx,chemMap.flatten()])
# Make Copy of Raw Measurements for Poisson Maximum Likelihood Term 
xx0 = xx.copy()

# Incoherent linear imaging for elastic scattering scales with atomic number Z raised to γ  ∈ [1.4, 2]
gamma = 1.6 

# Image Dimensions
(nx, ny) = chemMap.shape; nPix = nx * ny
nz = len(elem_names)

# C++ TV Min Regularizers
reg = utils.tvlib(nx,ny)

# Data Subtraction and Normalization 
HAADF -= np.min(HAADF); HAADF /= np.max(HAADF)
HAADF=HAADF.flatten()

# Create Summation Matrix
A = utils.create_weighted_measurement_matrix(nx,ny,nz,elem_weights,gamma,1)
# Convergence Parameters
lambdaHAADF = 1/nz # Do not modify this
lambdaChem_default = 1e-3
nIter_default = 30 # Typically 10-15 will suffice
lambdaTV_default = 0.006; #Typically between 0.001 and 1
bkg = 1e-8
regularize = True; nIter_TV_default = 5
# Widgets for the parameters
lambdaChem_slider = widgets.FloatSlider(value=lambdaChem_default, min=0.0001, max=0.01, step=0.0001, description='lambdaChem',style={'description_width': 'initial'}, layout=widgets.Layout(width='400px'),readout_format='.3f')
lambdaTV_slider = widgets.FloatSlider(value=lambdaTV_default, min=0.0001, max=0.1, step=0.0001, description='lambdaTV',style={'description_width': 'initial'}, layout=widgets.Layout(width='400px'),readout_format='.3f')
nIter_slider = widgets.IntSlider(value=nIter_default, min=10, max=50, step=1, description='# Cost Function Iterations',style={'description_width': 'initial'}, layout=widgets.Layout(width='400px'),readout_format='.3f')
nIter_TV_slider = widgets.IntSlider(value=nIter_TV_default, min=1, max=10, step=1, description=' # TV Iterations',style={'description_width': 'initial'}, layout=widgets.Layout(width='400px'),readout_format='.3f')

# Function to update plots
def update_plots(lambdaChem, lambdaTV, nIter, nIter_TV):
    xx = xx0.copy()

    # Background noise subtraction for improved convergence
    xx = np.where((xx < .2), 0, xx)

    # Auxiliary Functions for measuring the cost functions
    lsqFun = lambda inData : 0.5 * np.linalg.norm(A.dot(inData**gamma) - HAADF) **2
    poissonFun = lambda inData : np.sum(xx0 * np.log(inData + 1e-8) - inData)
    # Initialize the three cost functions components 
    costHAADF = np.zeros(nIter,dtype=np.float32); costChem = np.zeros(nIter, dtype=np.float32); costTV = np.zeros(nIter, dtype=np.float32);

    for kk in tqdm(range(nIter)):
        # Solve for the first two optimization functions $\Psi_1$ and $\Psi_2$
        xx -=  gamma * spdiags(xx**(gamma - 1), [0], nz*nx*ny, nz*nx*ny) * lambdaHAADF * A.transpose() * (A.dot(xx**gamma) - HAADF) + lambdaChem * (1 - xx0 / (xx + bkg))

        # Enforce positivity constraint
        xx[xx<0] = 0

        # FGP Regularization if turned on
        if regularize:
            for zz in range(nz):
                xx[zz*nPix:(zz+1)*nPix] = reg.fgp_tv( xx[zz*nPix:(zz+1)*nPix].reshape(nx,ny), lambdaTV, nIter_TV).flatten()

                # Measure TV Cost Function
                costTV[kk] += reg.tv( xx[zz*nPix:(zz+1)*nPix].reshape(nx,ny) )
                
        # Measure $\Psi_1$ and $\Psi_2$ Cost Functions
        costHAADF[kk] = lsqFun(xx); costChem[kk] = poissonFun(xx)
    # Display Cost Functions and Descent Parameters
    utils.plot_convergence(costHAADF, lambdaHAADF, costChem, lambdaChem, costTV, lambdaTV)
    
    # Show Reconstructed Signal
    fig, ax = plt.subplots(2, len(elem_names), figsize=(12, 8))
    ax = ax.flatten()

    for ii in range(len(elem_names)):
        ax[ii].imshow(xx[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx, ny), cmap='gray')
        ax[ii].set_title(elem_names[ii])
        ax[ii].axis('off')
        
        ax[ii + len(elem_names)].imshow(xx[ii*(nx*ny):(ii+1)*(nx*ny)].reshape(nx, ny)[40:100, 50:110], cmap='gray')
        ax[ii + len(elem_names)].set_title(elem_names[ii] + ' Cropped')
        ax[ii + len(elem_names)].axis('off')

    plt.show()
    
widgets.interact(update_plots, lambdaChem=lambdaChem_slider, lambdaTV=lambdaTV_slider, nIter=nIter_slider, nIter_TV=nIter_TV_slider)
plt.show()