## Reconstruction for synthetic data with different DFT band structures (PBE, PBEsol, HSE06) as initialization 

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
from numpy import nan_to_num as n2n
import fuller
from mpes import fprocessing as fp, analysis as aly
import matplotlib.pyplot as plt
import matplotlib as mpl
from natsort import natsorted
import glob as g
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import AutoMinorLocator
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

In [None]:
# Load synthetic data and ground truth band structure (gt)
data = fuller.utils.loadHDF(r'../data/synthetic/synth_data_WSe2_LDA_top8.h5')
gtbands = data['bands_padded']
kxvals, kyvals = data['kx'], data['ky']
msk = data['mask_tight']

### Supplementary Figure 9e

In [None]:
# Load initializations
xcfs = ['pbe', 'pbesol', 'hse'] # Names of the exchange-correlation functionals (XCF) used in the DFT calculations
eshifts = {'pbe':0.33063, 'pbesol':0.49865, 'hse':0.10955} # Energy values to zero the band structure at K points

inits = {}
for xc in xcfs:
    inits[xc] = fuller.utils.loadHDF(r'../data/theory/bands_padded/wse2_'+xc+'_bands_padded.h5')['bands_padded'] + eshifts[xc]

In [None]:
# Calculate errors in initialization (einit) and reconstruction (erec)
reconbands, erec, einit = {}, {}, {}
nk = np.sum(~np.isnan(msk))

for xc in xcfs:
    
    folderstr = xc + '_lda'
    recons = {}
    bandidx = list(range(0, 8))
    
    for istp, stp in enumerate(bandidx):
        stepstr = str(istp).zfill(2)
        files = fuller.utils.findFiles(r'../data/synthetic' + '//' + folderstr + r'/mrf_rec_band='+stepstr, fstring=r'*')
        recon = []
        recon.append(fuller.utils.loadH5Parts(files[0], ['bands/Eb'], outtype='vals'))
        recons[stepstr] = np.squeeze(np.array(recon))

    brecons = [v for k, v in recons.items() if k != 'init']
    brecons = np.asarray(brecons)
    reconbands[xc] = brecons

    errinit, errrecon = [], []
    for i in range(8):
        ediff = (inits[xc][i,...] - gtbands[i,...])**2
        ediffrec = (recons[str(i).zfill(2)] - gtbands[i,...])**2
        ediff = fuller.utils.trim_2d_edge(ediff, edges=24)
        ediffrec = fuller.utils.trim_2d_edge(ediffrec, edges=24)
        errinit.append(np.sqrt(np.sum(n2n(msk*ediff) / nk)))
        errrecon.append(np.sqrt(np.sum(n2n(msk*ediffrec) / nk)))

    einit[xc] = np.array(errinit)
    erec[xc] = np.array(errrecon)

In [None]:
# Plot comparison between reconstruction using different DFT theories
dt=0.12 # horizontal jitter amplitude in visualization (to separate overlapping points)
f, ax = plt.subplots(figsize=(7, 10))
for i in range(8):
    ax.axvline(x=i+1, ls='--', lw=1, c='g', zorder=0)
    
    if i < 7:
        ax.scatter(i+1, einit['pbe'][i]*1000, s=100, facecolors='b', edgecolors='b', lw=2, zorder=1)
        ax.scatter(i+1-dt, erec['pbe'][i]*1000, s=100, facecolors='w', edgecolors='b', lw=2, zorder=1)

        ax.scatter(i+1, einit['pbesol'][i]*1000, s=100, facecolors='k', edgecolors='k', lw=2, zorder=1)
        ax.scatter(i+1, erec['pbesol'][i]*1000, s=100, facecolors='w', edgecolors='k', lw=2, zorder=1)

        ax.scatter(i+1, einit['hse'][i]*1000, s=100, facecolors='m', edgecolors='m', lw=2, zorder=1)
        ax.scatter(i+1+dt, erec['hse'][i]*1000, s=100, facecolors='w', edgecolors='m', lw=2, zorder=1)
    
    if i == 7:
        ax.scatter(i+1, einit['pbe'][i]*1000, s=100, facecolors='b', edgecolors='b', lw=2, zorder=1, label='PBE calc.')
        ax.scatter(i+1-dt, erec['pbe'][i]*1000, s=100, facecolors='w', edgecolors='b', lw=2, zorder=1, label='PBE recon.')

        ax.scatter(i+1, einit['pbesol'][i]*1000, s=100, facecolors='k', edgecolors='k', lw=2, zorder=1, label='PBEsol calc.')
        ax.scatter(i+1, erec['pbesol'][i]*1000, s=100, facecolors='w', edgecolors='k', lw=2, zorder=1, label='PBEsol recon.')

        ax.scatter(i+1, einit['hse'][i]*1000, s=100, facecolors='m', edgecolors='m', lw=2, zorder=1, label='HSE06 calc.')
        ax.scatter(i+1+dt, erec['hse'][i]*1000, s=100, facecolors='w', edgecolors='m', lw=2, zorder=1, label='HSE06 recon.')
    
ax.set_ylabel('Average error $\eta_{\mathrm{avg}}$ wrt ground truth (meV)', fontsize=18)
ax.set_yticks(range(0, 551, 50))
ax.tick_params(axis='both', length=8, width=2, labelsize=15)
ax.set_xticks(range(1, 9))
ax.set_xlabel('Band index', fontsize=18)
# ax.set_ylim([0, 200])
ax.set_title('Reconstruction from other DFT calculations', fontsize=18)
lg = ax.legend(loc='best', bbox_to_anchor=(0.56, 0.42), bbox_transform=ax.transAxes, frameon=True, fontsize=15,
               facecolor='w', labelspacing=0.2, handletextpad=0.3)
frame = lg.get_frame()
frame.set_facecolor('w')
frame.set_edgecolor('k')
frame.set_linewidth(2)
plt.savefig(r'../results/figures/sfig_9e.png', bbox_inches='tight', transparent=True, dpi=300)

### Supplementary Figure 9g

In [None]:
# Using know positions (ways to obtain these see notebooks in /code/extra/)
G = np.array([127.0, 127.27828129766911])
K = np.array([ 23.83002655, 127.        ])
M = np.array([ 49.38033047, 171.8133136 ])

pathPoints = np.asarray([G, M, K, G])
nGM, nMK, nKG = 70, 39, 79
segPoints = [nGM, nMK, nKG]
rowInds, colInds, pathInds = aly.points2path(pathPoints[:,0], pathPoints[:,1], npoints=segPoints)
nSegPoints = len(rowInds)

pdGT = aly.bandpath_map(np.moveaxis(gtbands, 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
pdInit = aly.bandpath_map(np.moveaxis(inits['pbe'], 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
pdMPES = aly.bandpath_map(np.moveaxis(data['mpes_padded'], 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)
Emin, Emax = data['E'].min(), data['E'].max()

# Symmetrize the reconstructed bands
symrecbands = []
for i in range(8):
    symmed = fuller.generator.rotosymmetrize(reconbands['pbe'][i,...], (127.5, 127.5), rotsym=6)[0]
    symrecbands.append(fuller.generator.refsym(symmed[None,...], op='nanmean', pbar=False)[0,...])
symrecbands = np.asarray(symrecbands)
pdRecon = aly.bandpath_map(np.moveaxis(symrecbands, 0, 2), pathr=rowInds, pathc=colInds, eaxis=2)

In [None]:
# Plot comparison between initialization, reconstruction and ground truth along high-symmetry lines
xaxis = np.array(range(rowInds.size))
pos = pathInds.copy()
pos[-1] -= 1

f, ax = plt.subplots(figsize=(8.3, 6))
imax = ax.imshow(pdMPES, cmap='Blues', origin='lower', extent=[0, nSegPoints, Emin, Emax], aspect=22, vmax=7, zorder=0)
ax.plot(pdGT.T, c='k', zorder=2)
ax.plot(pdRecon.T, c='r', zorder=3)
ax.plot(pdInit.T - eshifts['pbe'], '--', c='g', zorder=1)

ax.plot(xaxis, pdGT[-1, :], c='k', zorder=2, label='Ground truth (LDA)')
ax.plot(xaxis, pdInit[-1, :] - eshifts['pbe'], '--', c='g', zorder=1, label='Initial. (PBE)')
ax.plot(xaxis, pdRecon[-1, :], c='r', zorder=3, label='Reconstruction')

ax.tick_params(which='both', axis='y', length=8, width=2, labelsize=15)
ax.tick_params(axis='x', length=0, width=0, labelsize=15, pad=8)
ax.set_xlim([pos[0], pos[-1]])
ax.set_xticks(pos)
ax.set_xticklabels(['$\overline{\Gamma}$', '$\overline{\mathrm{M}}$',
                       '$\overline{\mathrm{K}}$', '$\overline{\Gamma}$'])
ax.set_ylabel('Energy (eV)', fontsize=18)
ax.set_yticks(np.arange(-4, 1, 1))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
# ax.set_yticks(np.arange(-4.5, 0.6, 0.5))
# ax.set_ylim([])
for p in pos[:-1]:
        ax.axvline(x=p, c='k', ls='--', lw=2, dashes=[4, 2])
ax.legend(loc='lower left', frameon=False, fontsize=15, borderpad=0,
               facecolor='None', labelspacing=0.2, handletextpad=0.3)
cax = inset_axes(ax, width="3%", height="30%", bbox_to_anchor=(70, -30, 440, 200))
cb = plt.colorbar(imax, cax=cax, ticks=[])
cb.ax.set_ylabel('Intensity', fontsize=15, rotation=-90, labelpad=17)
plt.savefig(r'../results/figures/sfig_9g.png', bbox_inches='tight', transparent=True, dpi=300)

### Supplementary Figure 9h

In [None]:
cmsk = aly.circmask(np.ones((255, 255)), 127.5, 127.5, 115, sign='xnan', method='algebraic')

In [None]:
# Retrieve reconstructions with PBE-DFT as inititalization
recons = {}
bandidx = list(range(0, 8))
for istp, stp in enumerate(bandidx):
    stepstr = str(istp).zfill(2)
    files = fuller.utils.findFiles(r'../data/synthetic/pbe_lda/mrf_rec_band='+stepstr, fstring=r'*')
    recon = []
    recon.append(fuller.utils.loadH5Parts(files[0], ['bands/Eb'], outtype='vals'))
    recons[stepstr] = np.squeeze(np.array(recon))
    
brecons = [v for k, v in recons.items() if k != 'init']
brecons = np.asarray(brecons)

In [None]:
f, ax = plt.subplots(8, 2, figsize=(3, 25))
kxminl, kxmaxl, kxminr, kxmaxr = kxvals[0], kxvals[127], kxvals[128], kxvals[-1]
kyminl, kymaxl, kyminr, kymaxr = kyvals[0], kyvals[-1], kyvals[0], kyvals[-1]

for i in range(8):
    
    band_gt = gtbands[i,:,:]*cmsk
    band_rec = brecons[i,:,:]*cmsk
    band_init = inits['pbe'][i,:,:]*cmsk
    band_diff = band_rec - band_gt
    vmin_gt, vmax_gt = gtbands[i,:,:].min(), gtbands[i,:,:].max()
    vmin_init, vmax_init = band_init.min(), band_init.max()
    vmin = min([vmin_gt, vmin_init]) - 0.1
    vmax = max([vmax_gt, vmax_init]) + 0.1
    
    ax[i, 0].imshow(band_gt[:, :127], cmap='Spectral_r', extent=[kxminl, kxmaxl, kyminl, kymaxl],
                    aspect=1, vmin=vmin, vmax=vmax)
    ax[i, 1].imshow(band_rec[:, 128:], cmap='Spectral_r', extent=[kxminr, kxmaxr, kyminr, kymaxr],
                    aspect=1, vmin=vmin, vmax=vmax)
    ax[i, 0].set_xticks(np.arange(-1.5, 0, 0.5))
    ax[i, 0].set_xticklabels(['', '-1', ''])
    ax[i, 1].set_xticks(np.arange(0.5, 1.6, 0.5))
    ax[i, 1].set_xticklabels(['', '1', ''])
    ax[i, 0].set_yticks(np.arange(-1, 1.1, 1))
    ax[i, 0].yaxis.set_minor_locator(AutoMinorLocator(2))
    ax[i, 1].set_yticks([])
    ax[i, 0].tick_params(axis='both', which='both', length=8, width=2, labelsize=18)
    ax[i, 1].tick_params(axis='both', length=8, width=2, labelsize=18)
    ax[i, 0].set_ylabel('$k_y$ $(\mathrm{\AA}^{-1})$', fontsize=18)
    ax[i, 0].text(0.1, 0.9, '#'+str(i+1), fontsize=15, transform=ax[i,0].transAxes)
    
    if i < 7:
        ax[i, 0].set_xticks([])
        ax[i, 1].set_xticks([])

ax[0, 0].set_title('Ground\ntruth (LDA)', fontsize=18)
ax[0, 1].set_title('Recon-\nstruction', fontsize=18)
ax[-1, 0].set_xlabel('$k_x$ $(\mathrm{\AA}^{-1})$', fontsize=18, x=1)
plt.subplots_adjust(wspace=0, hspace=0.1)
plt.savefig(r'../results/figures/sfig_9h1.png', bbox_inches='tight', transparent=True, dpi=300)

In [None]:
f, ax = plt.subplots(8, 2, figsize=(3, 25))
kxminl, kxmaxl, kxminr, kxmaxr = kxvals[0], kxvals[127], kxvals[128], kxvals[-1]
kyminl, kymaxl, kyminr, kymaxr = kyvals[0], kyvals[-1], kyvals[0], kyvals[-1]

for i in range(8):
    
    band_gt = gtbands[i,:,:]*cmsk
    band_rec = brecons[i,:,:]*cmsk
    band_init = inits['pbe'][i,:,:]*cmsk
    init_diff_all = (inits['pbe'] - gtbands)*cmsk
    init_diff = init_diff_all[i,...]
    band_diff_all = (brecons - gtbands)*cmsk
    band_diff = band_diff_all[i,...]
    
    vmin_gt, vmax_gt = gtbands[i,:,:].min(), gtbands[i,:,:].max()
    vmin_rec, vmax_rec = np.nanmin(band_diff_all), np.nanmax(band_diff_all)
    vmin_init, vmax_init = np.nanmin(init_diff_all), np.nanmax(init_diff_all)
    
    vmin = np.nanmin([vmin_init, vmin_rec]) - 0.01
    vmax = np.nanmax([vmax_init, vmax_rec]) + 0.01
    imaxl = ax[i, 0].imshow(init_diff[:, :127], cmap='terrain_r', extent=[kxminl, kxmaxl, kyminl, kymaxl],
                    aspect=1, vmin=vmin_init, vmax=vmax_init)
    imax = ax[i, 1].imshow(band_diff[:, 128:], cmap='RdBu_r', extent=[kxminr, kxmaxr, kyminr, kymaxr],
                    aspect=1, vmin=-0.2, vmax=0.2)

    ax[i, 0].set_xticks(np.arange(-1.5, 0, 0.5))
    ax[i, 0].set_xticklabels(['', '-1', ''])
    ax[i, 1].set_xticks(np.arange(0.5, 1.6, 0.5))
    ax[i, 1].set_xticklabels(['', '1', ''])
    ax[i, 0].set_yticks([])
    ax[i, 1].set_yticks([])
    ax[i, 0].tick_params(axis='both', length=8, width=2, labelsize=18)
    ax[i, 1].tick_params(axis='both', length=8, width=2, labelsize=18)
    ax[i, 0].text(0.1, 0.9, '#'+str(i+1), fontsize=15, transform=ax[i,0].transAxes)
    
    if i < 7:
        ax[i, 0].set_xticks([])
        ax[i, 1].set_xticks([])
        
ax[0, 0].set_title('\ninitial.-g.t.', fontsize=18)
ax[0, 1].set_title('\nrecon.-g.t.', fontsize=18)
plt.suptitle('Energy difference', y=0.902, fontsize=18)
ax[-1, 0].set_xlabel('$k_x$ $(\mathrm{\AA}^{-1})$', fontsize=18, x=1)

caxl = inset_axes(ax[6,0], width="3%", height="30%", bbox_to_anchor=(-130, 122, 350, 400))
cbl = plt.colorbar(imaxl, cax=caxl, ticks=np.arange(0.0, 0.61, 0.1))
cbl.ax.set_title('eV', fontsize=15)
cbl.ax.tick_params(axis='y', width=2, length=6, labelsize=15)

cax = inset_axes(ax[7,0], width="3%", height="30%", bbox_to_anchor=(-130, -50, 350, 400))
cb = plt.colorbar(imax, cax=cax, ticks=np.arange(-0.2, 0.21, 0.1))
# cb.ax.set_ylabel('Intensity', fontsize=15, rotation=-90, labelpad=17)
cb.ax.set_title('eV', fontsize=15)
cb.ax.tick_params(axis='y', width=2, length=6, labelsize=15)

plt.subplots_adjust(wspace=0, hspace=0.1)
plt.savefig(r'../results/figures/sfig_9h2.png', bbox_inches='tight', transparent=True, dpi=300)