## Illustration of approximations to a reconstructed band

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
import fuller
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm_notebook as tqdm
from matplotlib.ticker import AutoMinorLocator
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

In [None]:
bandcuts = np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_lda.npz')['bandcuts']
plt.imshow(bandcuts[3,...])

In [None]:
# Approximations using different numbers of basis terms and in different orders
idx = 3
recon = np.zeros_like(bandcuts[0,...])
indarr = list(range(5, 400, 1))
# errors for summation in default polynomial order (errseq) and in coefficient-ranked order (errmaj)
errseq, errmaj = [], []
# The pixel-averaged versions of errseq and errmaj
errseqavg, errmajavg = [], []
reconms = []
bandref = np.nan_to_num(bandcuts[idx,...])
bcf, bss0 = fuller.generator.decomposition_hex2d(bandref, nterms=400, ret='all')
npixbz = np.sum(bss0[0,...] == 1) # Number of pixels within the first Brillouin zone
magind = np.argsort(np.abs(bcf))[::-1]

for nt in tqdm(indarr):
    # Generate mask
    currcf = np.zeros_like(bcf)
    currcfm = np.zeros_like(bcf)
    currcf[:nt] = bcf[:nt]
    currcfm[magind[:nt]] = bcf[magind[:nt]]
    recon = fuller.generator.reconstruction_hex2d(currcf, bss0)
    reconm = fuller.generator.reconstruction_hex2d(currcfm, bss0)

    reconms.append(reconm)
    errseq.append(np.linalg.norm(recon - bandref)/np.linalg.norm(bandref))
    errmaj.append(np.linalg.norm(reconm - bandref)/np.linalg.norm(bandref))
    errseqavg.append(np.linalg.norm(recon - bandref)/np.sqrt(npixbz))
    errmajavg.append(np.linalg.norm(reconm - bandref)/np.sqrt(npixbz))

errseq, errmaj, errseqavg, errmajavg = list(map(np.asarray, [errseq, errmaj, errseqavg, errmajavg]))
reconms = np.asarray(reconms)

### Main Figure 3b

In [None]:
f, ax = plt.subplots(figsize=(5, 3.5))

ax.plot(indarr, errmajavg*1000, '-', lw=2, c='#0000FF')
ax.plot(indarr, errseqavg*1000, '-', lw=2, c='#CC6600')

ax.set_xlabel('Number of terms', fontsize=18)
ax.set_ylabel('Avg. approx. error (meV)', fontsize=18)
ax.set_xticks(range(0, 181, 20))
ax.set_ylim([0, 200])
ax.set_xlim([0, 100])

ax.tick_params(which='major', axis='both', length=8, width=2, labelsize=18)
ax.tick_params(which='minor', axis='both', length=8, width=1, labelsize=18)

ax.axvline(x=5, ls='--', c='k', dashes=(5, 3))
ax.axvline(x=15, ls='--', c='k', dashes=(5, 3))
ax.axvline(x=45, ls='--', c='k', dashes=(5, 3))

ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
ax.set_title('Polynomial\n approximation\n to band #4', fontsize=18, x=0.7, y=0.68, transform=ax.transAxes)

ax2 = ax.twinx()
ax2.set_yticks(np.arange(0, 0.11, 0.02))
ax2.set_ylim([0, 200*errmaj[0]/(errmajavg[0]*1000)])
ax2.set_ylabel('Rel. approx. error', fontsize=18, rotation=-90, labelpad=25)
ax2.tick_params(which='major', axis='both', length=8, width=2, labelsize=18)
ax2.tick_params(which='minor', axis='both', length=8, width=1, labelsize=18)
ax2.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.savefig('../results/figures/fig_3b.png', bbox_inches='tight', transparent=True, dpi=300)