## Reconstruction of photoemission band structure using Markov Random Field model
### Model setup

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import fuller
from fuller.mrfRec import MrfRec
from tqdm import tnrange

%matplotlib inline
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
# Load preprocessed data
data_path = '../../data/pes/3_smooth.h5'
data = fuller.utils.loadHDF(data_path)

E = data['E'][:470]
kx = data['kx']
ky = data['ky']
I = data['V'][...,:470]

In [None]:
# Create MRF model
mrf = MrfRec(E=E, kx=kx, ky=ky, I=I, eta=.12)
mrf.I_normalized = True

### Reconstruction

In [None]:
# Initialize parameters for loop
path_dft = ['../../data/theory/WSe2_LDA_bands.mat',
            '../../data/theory/WSe2_PBE_bands.mat',
            '../../data/theory/WSe2_PBEsol_bands.mat',
            '../../data/theory/WSe2_HSE06_bands.mat']
path_hyperparam = ['../../data/hyperparameter/LDA.csv',
                   '../../data/hyperparameter/PBE.csv',
                   '../../data/hyperparameter/PBEsol.csv',
                   '../../data/hyperparameter/HSE06.csv']
num_dft = 1 # Number of DFTs to consider, can be up to 4 here, but set to 1 to save computation
recon = np.zeros((num_dft, 14, len(kx), len(ky)))

for ind_dft in tnrange(num_dft, desc='Initialization'):
    # Load hyperparameter and DFT
    hyperparam = np.loadtxt(path_hyperparam[ind_dft], delimiter=',', skiprows=1)
    kx_dft, ky_dft, E_dft = mrf.loadBandsMat(path_dft[ind_dft])
    
    for ind_band in tnrange(14, desc='Band'):
        # Set eta and initialization
        mrf.eta = hyperparam[ind_band, 1]
        mrf.initializeBand(kx=kx_dft, ky=ky_dft, Eb=E_dft[2 * ind_band,...], kScale=hyperparam[ind_band, 3],
                           offset=hyperparam[ind_band, 2] + 0.65, flipKAxes=True)
        
        # Perform optimization
        mrf.iter_para(150, disable_tqdm=True)
        
        # Store result
        recon[ind_dft, ind_band, ...] = mrf.getEb()

### Results

In [None]:
# Plot slices ky slice
dft_name = ['LDA', 'PBE', 'PBEsol', 'HSE06']

# Mask to only plot Brillouin zone
mask = np.load('../../data/processed/WSe2_Brillouin_Zone_Mask.npy')
mrf.I = mrf.I * mask[:, :, None]
ky_val = 0
ind_ky = np.argmin(np.abs(mrf.ky - ky_val))

# Loop over initializations and bands
for ind_dft in range(num_dft):
    mrf.plotI(ky=ky_val, cmapName='coolwarm')
    plt.title(dft_name[ind_dft], fontsize=26)
    plt.xlim((-1.35, 1.3))
    kx_dft, ky_dft, E_dft = mrf.loadBandsMat(path_dft[ind_dft])
    for ind_band in range(14):
        #mrf.initializeBand(kx=kx_dft, ky=ky_dft, Eb=E_dft[2 * ind_band,...], kScale=1,
        #                   offset=0.65, flipKAxes=True)
        #E0 = mrf.E[mrf.indE0[:, ind_ky]]
        #plt.plot(mrf.kx, E0 * mask[:, ind_ky], 'k', linewidth=2.0, 
        #         label='DFT' if ind_band==0 else None, zorder=3)
        mrf.initializeBand(kx=kx_dft, ky=ky_dft, Eb=E_dft[2 * ind_band,...], kScale=hyperparam[ind_band, 3],
                           offset=hyperparam[ind_band, 2] + 0.65, flipKAxes=True)
        E0 = mrf.E[mrf.indE0[:, ind_ky]]
        plt.plot(mrf.kx, E0 * mask[:, ind_ky], 'c--', linewidth=2.0, 
                 label='Initialization' if ind_band==0 else None, zorder=2)
        plt.plot(mrf.kx, recon[ind_dft, ind_band, :, ind_ky] * mask[:, ind_ky], 'r', linewidth=2.0,
                 label='Reconstruction' if ind_band==0 else None, zorder=1)
    plt.legend(loc=4, prop={'size': 14}, framealpha=1)
        