## Reconstruction of synthetic 3D multiband photoemission data

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
import fuller
from fuller.mrfRec import MrfRec
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm_notebook as tqdm
%matplotlib inline

mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
# Import synthetic data and axes values
fdir = r'../../data/synthetic'
data = fuller.utils.loadHDF(fdir + r'/synth_data_WSe2_LDA_top8.h5', hierarchy='nested')
E0 = data['params']['E']
kx = data['params']['kx']
ky = data['params']['ky']
I = np.moveaxis(np.nan_to_num(data['data']['mpes_padded']), 0, -1)
I.shape

In [None]:
# Import initial conditions
datab = fuller.utils.loadHDF(r'../../data/theory/bands_padded/wse2_hse_bands_padded.h5')
datab['bands_padded'].shape

In [None]:
# Compare ground truth with coefficient-tuned band structure
for i in range(6):
    if i < 5:
        plt.plot(ky, data['data']['bands_padded'][i, :, 150], c='k')
        plt.plot(ky, datab['bands_padded'][i, :, 128].T, ls='--', c='b')
    elif i == 5:
        plt.plot(ky, data['data']['bands_padded'][i, :, 150], c='k', label='ground truth (LDA)')
        plt.plot(ky, datab['bands_padded'][i, :, 128].T, ls='--', c='b', label='initialization (PBE)')

plt.tick_params(axis='both', length=10, labelsize=15)
plt.ylabel('Energy (eV)', fontsize=15)
plt.legend(bbox_to_anchor=(1,0.2,0.2,0.3), fontsize=15, frameon=False);

In [None]:
# Create MRF model
mrf = MrfRec(E=E0, kx=kx, ky=ky, I=I, eta=.12)
mrf.I_normalized = False

In [None]:
mrf.normalizeI(kernel_size=(20, 20, 20), clip_limit=0.01)

In [None]:
# These hyperparameters are already tuned
etas = [0.08, 0.1, 0.08, 0.1,  0.1, 0.14, 0.08, 0.08]
ofs =  [0.3,  0.1, 0.26, 0.14, 0.3, 0.24, 0.34, 0.14]

In [None]:
# Demonstration for reconstructing one band
mrf.eta = etas[1]
offset = ofs[1]
mrf.initializeBand(kx, ky, datab['bands_padded'][1,...], offset=offset, kScale=1., flipKAxes=False)
mrf.iter_para(100, use_gpu=True, disable_tqdm=False, graph_reset=True)

In [None]:
# Illustration of outcome (black line = initialization, red line = reconstruction)
mrf.plotBands()
mrf.plotI(ky=0, plotBand=True, plotBandInit=True, plotSliceInBand=False, cmapName='coolwarm')
mrf.plotI(ky=0.4, plotBand=True, plotBandInit=True, plotSliceInBand=False, cmapName='coolwarm')
mrf.plotI(kx=0, plotBand=True, plotBandInit=True, plotSliceInBand=False, cmapName='coolwarm')
mrf.plotI(kx=0.4, plotBand=True, plotBandInit=True, plotSliceInBand=False, cmapName='coolwarm')

### Reconstruct all bands and save the results

In [None]:
if not os.path.exists(r'../../results/hse_lda'):
    os.mkdir(r'../../results/hse_lda')

In [None]:
# Reconstruct band by band
for idx, (eta, offset) in enumerate(zip(tqdm(etas), ofs)):

        mrf.eta = eta
        iband = idx + 1
        mrf.initializeBand(kx, ky, datab['bands_padded'][idx,...], offset=offset, kScale=1., flipKAxes=False)
        mrf.iter_para(100, use_gpu=True, disable_tqdm=True, graph_reset=True)
        mrf.saveBand(r'../../results/hse_lda/mrf_rec_band='+str(iband).zfill(2)+'_ofs='+str(offset)+'_eta='+str(eta)+'.h5',
                      index=iband)