## Reconstruction for synthetic data with scaled theoretical band structure (LDA-DFT) as initialization

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
from numpy import nan_to_num as n2n
import fuller
from mpes import fprocessing as fp, analysis as aly
import matplotlib.pyplot as plt
import matplotlib as mpl
from natsort import natsorted
import glob as g
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import AutoMinorLocator
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

In [None]:
# Load synthetic data and ground truth band structure (gt)
data = fuller.utils.loadHDF(r'../data/synthetic/synth_data_WSe2_LDA_top8.h5')
gtbands = data['bands_padded']
kxvals, kyvals = data['kx'], data['ky']
msk = data['mask_tight']

### Supplementary Figure 9d

In [None]:
# Load reconstructions and corresponding initializations
scales = ['0.8', '1.2']

recons, inits = {}, {}
for isc, sc in enumerate(scales):
    scalestr = str(sc)
    files = fuller.utils.findFiles(r'../data/synthetic/sc='+scalestr+'_lda', fstring=r'/*')
    recon = []
    for f in files:
        recon.append(fuller.utils.loadH5Parts(f, ['bands/Eb'], outtype='vals'))
    
    recons[scalestr] = np.squeeze(np.array(recon))
    inits[scalestr] = data[scalestr]

In [None]:
# Calculate errors in initialization (einit) and reconstruction (erec)
bands_tight = fuller.utils.trim_2d_edge(gtbands, edges=24, axes=(1, 2))
erec = fuller.metrics.abserror(recons, bands_tight, [0.8, 1.2], ofs=24, mask=msk, outkeys=[0.8, 1.2], ret='dict')
einit = fuller.metrics.abserror(inits, bands_tight, [0.8, 1.2], ofs=24, mask=msk, outkeys=[0.8, 1.2], ret='dict')

In [None]:
# Plot comparison between reconstruction using scaled DFT calculations
dt=0.08 # horizontal jitter amplitude in visualization (to separate overlapping points)
f, ax = plt.subplots(figsize=(8.5, 5.5))
for i in range(8):
    ax.axvline(x=i+1, ls='--', lw=1, c='g', zorder=0)
    
    if i < 7:
        ax.scatter(i+1-dt, einit['0.8'][i]*1000, s=100, facecolors='b', edgecolors='b', lw=2, zorder=1)
        ax.scatter(i+1-dt, erec['0.8'][i]*1000, s=100, facecolors='w', edgecolors='b', lw=2, zorder=1)

        ax.scatter(i+1+dt, einit['1.2'][i]*1000, s=100, facecolors='k', edgecolors='k', lw=2, zorder=1)
        ax.scatter(i+1+dt, erec['1.2'][i]*1000, s=100, facecolors='w', edgecolors='k', lw=2, zorder=1)
    
    if i == 7:
        ax.scatter(i+1-dt, einit['0.8'][i]*1000, s=100, facecolors='b', edgecolors='b', lw=2, zorder=1,
                   label=r'Scaled LDA (0.8$\times$)')
        ax.scatter(i+1-dt, erec['0.8'][i]*1000, s=100, facecolors='w', edgecolors='b', lw=2, zorder=1,
                   label=r'Recon. with 0.8$\times$')

        ax.scatter(i+1+dt, einit['1.2'][i]*1000, s=100, facecolors='k', edgecolors='k', lw=2, zorder=1,
                   label=r'Scaled LDA (1.2$\times$)')
        ax.scatter(i+1+dt, erec['1.2'][i]*1000, s=100, facecolors='w', edgecolors='k', lw=2, zorder=1,
                   label=r'Recon. with 1.2$\times$')
    
ax.set_ylabel('Average error $\eta_{\mathrm{avg}}$ wrt ground truth (meV)', fontsize=18)
ax.set_yticks(range(0, 181, 20))
ax.tick_params(axis='both', length=8, width=2, labelsize=15)
ax.set_xticks(range(1, 9))
ax.set_xlabel('Band index', fontsize=18)
ax.set_ylim([0, 180])
ax.set_title('Reconstruction from scaled LDA calculations', fontsize=18)
lg = ax.legend(bbox_transform=ax.transAxes, bbox_to_anchor=(0.45, 0.93), frameon=True, fontsize=15,
               facecolor='w', labelspacing=0.2, handletextpad=0.3)
frame = lg.get_frame()
frame.set_facecolor('w')
frame.set_edgecolor('k')
frame.set_linewidth(2)
plt.savefig(r'../results/figures/sfig_9d.png', bbox_inches='tight', transparent=True, dpi=300)

### Supplementary Figure 9f

In [None]:
# Using know positions (ways to obtain these see notebooks in /code/extra/)
G = np.array([127.0, 127.27828129766911])
K = np.array([ 23.83002655, 127.        ])
M = np.array([ 49.38033047, 171.8133136 ])

pathPoints = np.asarray([G, M, K, G])
nGM, nMK, nKG = 70, 39, 79
segPoints = [nGM, nMK, nKG]
rowInds, colInds, pathInds = aly.points2path(pathPoints[:,0], pathPoints[:,1], npoints=segPoints)
nSegPoints = len(rowInds)

pdGT = aly.bandpath_map(np.moveaxis(gtbands, 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
pdInit = aly.bandpath_map(np.moveaxis(inits['0.8'], 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
pdMPES = aly.bandpath_map(np.moveaxis(data['mpes_padded'], 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
Emin, Emax = data['E'].min(), data['E'].max()

# Symmetrize the reconstructed bands
symrecbands = []
for i in range(8):
    symmed = fuller.generator.rotosymmetrize(recons['0.8'][i,...], (127.5, 127.5), rotsym=6)[0]
    symrecbands.append(fuller.generator.refsym(symmed[None,...], op='nanmean', pbar=False)[0,...])
symrecbands = np.asarray(symrecbands)
pdRecon = aly.bandpath_map(np.moveaxis(symrecbands, 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)

In [None]:
# Plot comparison between initialization, reconstruction and ground truth along high-symmetry lines
xaxis = np.array(range(rowInds.size))
pos = pathInds.copy()
pos[-1] -= 1

f, ax = plt.subplots(figsize=(8.3, 6))
imax = ax.imshow(pdMPES, cmap='Blues', origin='lower', extent=[0, nSegPoints, Emin, Emax], aspect=22, vmax=7, zorder=0)
ax.plot(pdGT.T, c='k', zorder=2)
ax.plot(pdRecon.T, c='r', zorder=3)
ax.plot(pdInit.T + 0.12, '--', c='g', zorder=1)

ax.plot(xaxis, pdGT[-1, :], c='k', zorder=2, label='Ground truth (LDA)')
ax.plot(xaxis, pdInit[-1, :] + 0.12, '--', c='g', zorder=1, label=r'Initial. (0.8$\times$)')
ax.plot(xaxis, pdRecon[-1, :], c='r', zorder=3, label='Reconstruction')

ax.tick_params(axis='y', length=8, width=2, labelsize=15)
ax.tick_params(axis='x', length=0, width=0, labelsize=15, pad=8)
ax.set_xlim([pos[0], pos[-1]])
ax.set_xticks(pos)
ax.set_xticklabels(['$\overline{\Gamma}$', '$\overline{\mathrm{M}}$',
                       '$\overline{\mathrm{K}}$', '$\overline{\Gamma}$'])
ax.set_ylabel('Energy (eV)', fontsize=18)
# ax.set_ylim([])
for p in pos[:-1]:
        ax.axvline(x=p, c='k', ls='--', lw=2, dashes=[4, 2])
ax.legend(loc='lower left', frameon=False, fontsize=15,
               facecolor='None', labelspacing=0.2, handletextpad=0.3, borderpad=0)
plt.savefig('../results/figures/sfig_9f.png', bbox_inches='tight', transparent=True, dpi=300)