## Generate synthetic multiband photoemission data using DFT calculations

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
import fuller
from mpes import analysis as aly
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from tqdm import tqdm_notebook as tqdm
import tifffile as ti
import matplotlib as mpl
from scipy import interpolate
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

In [None]:
ncfs = 400
bases = fuller.generator.ppz.hexike_basis(nterms=ncfs, npix=207, vertical=True, outside=0)

In [None]:
# Compute the polynomial decomposition coefficients
bandout = np.nan_to_num(fuller.utils.loadHDF('../data/theory/bands_1BZ/wse2_lda_bandcuts.h5')['bands'])
ldashift = 0.86813 # For zeroing the energy at K points
bcfs = []
for i in tqdm(range(14)):
    bcfs.append(fuller.generator.decomposition_hex2d(bandout[i,...] + ldashift, bases=bases, baxis=0, ret='coeffs'))
bcfs = np.array(bcfs)

In [None]:
# Generate Brillouin zone mask
bzmsk = fuller.generator.hexmask(hexdiag=207, imside=207, padded=False, margins=[1, 1, 1, 1])
bzmsk_tight = fuller.generator.hexmask(hexdiag=201, imside=207, padded=True, margins=[3, 3, 3, 3])

In [None]:
# Generate photoemission data without padding
nbands = 8
bshape = (207, 207)
amps = np.ones(bshape)
xs = np.linspace(-4.5, 0.5, 285, endpoint=True)
syndat = np.zeros((285, 207, 207))
gamss = []
for i in tqdm(range(nbands)):
    gams = 0.05
    syndat += aly.voigt(feval=True, vardict={'amp':amps, 'xvar':xs[:,None,None], 'ctr':(bandout[i,...] + 0.86813),
                                        'sig':0.1, 'gam':gams})

In [None]:
hwd = 103.5 # Half width of projected Brillouin zone in pixels

In [None]:
# Generate edge-padded bands
synfbands = []
padsize = ((24, 24), (24, 24))
for i in tqdm(range(nbands)): 
    impad = fuller.generator.hexpad(bandout[i,...] + 0.86813, cvd=hwd, mask=bzmsk, edgepad=padsize)
    synfbands.append(fuller.generator.restore(impad, method='cubic'))
synfbands = np.asarray(synfbands)

In [None]:
# Generate edge-padded photoemission data
bshape = (255, 255)
amps = np.ones(bshape)
xs = np.linspace(-4.5, 0.5, 285, endpoint=True)
synfdat = np.zeros((285, 255, 255))
gamss = []
for i in tqdm(range(nbands)):
#     btemp = np.nan_to_num(synbands[i,...])
#     gams = np.abs(synfbands[i,...] - np.nanmean(synfbands[i,...]))/3
    gams = 0.05
#     gamss.append(gams)
    synfdat += aly.voigt(feval=True, vardict={'amp':amps, 'xvar':xs[:,None,None], 'ctr':(synfbands[i,...]),
                                        'sig':0.1, 'gam':gams})
# gamss = np.asarray(gamss)

In [None]:
xss = np.linspace(-4.5, 0.5, 285, endpoint=True)
xss[1] - xss[0], xss.size

In [None]:
plt.imshow(synfdat[:,80,:], aspect=0.8, origin='lower', cmap='terrain_r')

In [None]:
# Generate mask for large coefficients
cfmask = fuller.utils.binarize(bcfs, threshold=1e-2, vals=[0, 1])
cfmask[:, 0] = 0 # No rigid shift modulation

In [None]:
# Generate coefficient-scaled data
synfscaled = {}
# errs = np.around(np.arange(0.3, 2.01, 0.05), 2)
errs = [0.8, 1.0, 1.2]
bscmod = bcfs.copy()

for err in tqdm(errs):
    
    synbands = []
    for i in range(nbands):
        
        bscmod[i, 1:] = err*bcfs[i, 1:] # Scale only the dispersion terms (leave out the first offset term)
        bandmod = fuller.generator.reconstruction_hex2d(bscmod[i, :], bases=bases)
        
        # Sixfold rotational symmetrization
        symmed = fuller.generator.rotosymmetrize(bandmod, center=(hwd, hwd), rotsym=6)[0]
        symmed = fuller.generator.reflectosymmetrize(symmed, center=(hwd, hwd), refangles=[0, 90])
        padded = fuller.generator.hexpad(symmed, cvd=103.5, mask=bzmsk_tight, edgepad=padsize)
        synbands.append(fuller.generator.restore(padded, method='nearest'))
    
    synfscaled[str(err)] = np.asarray(synbands)

In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(synbands[0])

In [None]:
# Calibrate momentum axes
mc = aly.MomentumCorrector(np.asarray(synbands))
mc.selectSlice2D(selector=slice(0,1), axis=0)

In [None]:
mc.featureExtract(mc.slice, method='daofind', fwhm=30, sigma=20)
#mc.view(mc.slice, annotated=True, points=mc.features)

In [None]:
# Calculate distances
dg = 1.27/np.cos(np.radians(30))
axes = mc.calibrate(mc.slice, mc.pouter_ord[0,:], mc.pcent, dist=dg, equiscale=True, ret='axes')
dg, axes['axes'][0][0], axes['axes'][0][-1]

### Supplementary Figure 9c

In [None]:
kx, ky = axes['axes'][0], axes['axes'][1]
emin, emax = xs.min(), xs.max()
kxtight = kx[24:-24]
kytight = ky[24:-24]
kxmin, kxmax = kxtight.min(), kxtight.max()
kymin, kymax = kytight.min(), kytight.max()

In [None]:
islc = 90 # slice index
f, axs = plt.subplots(1, 2, figsize=(10, 6))
bands_tight = bandout + ldashift

cs = ['r']
labels = [0.8]
for ni, i in enumerate(labels):
    lbl = str(i)
    reconbands = bzmsk_tight*(synfscaled[lbl][:,24:-24,24:-24])
    for j in range(8):
        axs[0].plot(kxtight, reconbands[j, islc, :], c=cs[ni])
        axs[1].plot(kytight, reconbands[j, :, islc], c=cs[ni])
        if j == 7:
            axs[0].plot(kxtight, reconbands[j, islc, :], c=cs[ni], label='Scaled LDA ('+lbl+r'$\times$)')
            axs[1].plot(kytight, reconbands[j, :, islc], c=cs[ni], label='Scaled LDA ('+lbl+r'$\times$)')

gtband = bzmsk_tight*bands_tight
for j in range(8):
    axs[0].plot(kxtight, gtband[j, islc, :], c='k', lw=2)
    axs[1].plot(kytight, gtband[j, :, islc], c='k', lw=2)
    if j == 7:
        axs[0].plot(kxtight, gtband[j, islc, :], c='k', lw=2, label=r'LDA calc. (1.0$\times$)')
        axs[1].plot(kytight, gtband[j, :, islc], c='k', lw=2, label=r'LDA calc. (1.0$\times$)')

cs = ['g']
labels = [1.2]
for ni, i in enumerate(labels):
    lbl = str(i)
    reconbands = bzmsk_tight*(synfscaled[lbl][:,24:-24,24:-24])
    for j in range(8):
        axs[0].plot(kxtight, reconbands[j, islc, :], c=cs[ni])
        axs[1].plot(kytight, reconbands[j, :, islc], c=cs[ni])
        if j == 7:
            axs[0].plot(kxtight, reconbands[j, islc, :], c=cs[ni], label='Scaled LDA ('+lbl+r'$\times$)')
            axs[1].plot(kytight, reconbands[j, :, islc], c=cs[ni], label='Scaled LDA ('+lbl+r'$\times$)')

for i in range(2):
    axs[i].tick_params(which='major', axis='both', length=8, width=2, labelsize=15)
    axs[i].tick_params(which='minor', axis='both', length=8, width=1)
    axs[i].set_xticks(np.arange(-1., 1.1, 1))
    axs[i].xaxis.set_minor_locator(AutoMinorLocator(2))
    axs[i].legend(loc='upper left', frameon=False, fontsize=15, ncol=1, labelspacing=0.1, borderpad=0, columnspacing=1)
    axs[i].set_yticks(np.arange(-4, 2, 1))
    axs[i].set_ylim([-4.2, 1.3])
    
axs[0].yaxis.set_minor_locator(AutoMinorLocator(2))
axs[0].set_xlabel('$k_x$ $(\mathrm{\AA}^{-1})$', fontsize=18)
axs[0].set_ylabel('Energy (eV)', fontsize=18)
axs[1].set_xlabel('$k_y$ $(\mathrm{\AA}^{-1})$', fontsize=18)
axs[1].set_yticks([])
plt.subplots_adjust(wspace=0.1)
plt.savefig('../results/figures/sfig_9c.png', bbox_inches='tight', transparent=True, dpi=300)