## Construct similarity matrix between theoretical and reconstructed band structures

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import numpy as np
import fuller
from mpes import analysis as aly
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as cs
from mpl_toolkits.axes_grid1 import make_axes_locatable
import itertools as it
import scipy.spatial.distance as ssd
from numpy.linalg import norm
from tqdm import tqdm_notebook as tqdm
%matplotlib inline

In [None]:
# Create plot folder if needed
import os
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
bases = fuller.generator.ppz.hexike_basis(nterms=400, npix=207, vertical=True, outside=0)

In [None]:
bandout = np.nan_to_num(fuller.utils.loadHDF(r'../data/theory/bands_1BZ/wse2_pbesol_bandcuts.h5')['bands'])
cfs_pbesol = []
for i in tqdm(range(14)):
    cfs_pbesol.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases, baxis=0, ret='coeffs'))
cfs_pbesol = np.array(cfs_pbesol)

bandout = np.nan_to_num(fuller.utils.loadHDF(r'../data/theory/bands_1BZ/wse2_pbe_bandcuts.h5')['bands'])
cfs_pbe = []
for i in tqdm(range(14)):
    cfs_pbe.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases, baxis=0, ret='coeffs'))
cfs_pbe = np.array(cfs_pbe)

bandout = np.nan_to_num(fuller.utils.loadHDF(r'../data/theory/bands_1BZ/wse2_hse_bandcuts.h5')['bands'])
cfs_hse = []
for i in tqdm(range(14)):
    cfs_hse.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases, baxis=0, ret='coeffs'))
cfs_hse = np.array(cfs_hse)

bandout = np.nan_to_num(fuller.utils.loadHDF(r'../data/theory/bands_1BZ/wse2_lda_bandcuts.h5')['bands'])
cfs_lda = []
for i in tqdm(range(14)):
    cfs_lda.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases, baxis=0, ret='coeffs'))
cfs_lda = np.array(cfs_lda)

In [None]:
bases_recon = fuller.generator.ppz.hexike_basis(nterms=400, npix=175, vertical=True, outside=0)

In [None]:
bandout = np.nan_to_num(np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_pbe.npz')['bandcuts'])
cfs_rec_pbe_sym = []
for i in tqdm(range(14)):
    cfs_rec_pbe_sym.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases_recon, baxis=0, ret='coeffs'))
cfs_rec_pbe_sym = np.array(cfs_rec_pbe_sym)

bandout = np.nan_to_num(np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_pbesol.npz')['bandcuts'])
cfs_rec_pbesol_sym = []
for i in tqdm(range(14)):
    cfs_rec_pbesol_sym.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases_recon, baxis=0, ret='coeffs'))
cfs_rec_pbesol_sym = np.array(cfs_rec_pbesol_sym)

bandout = np.nan_to_num(np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_lda.npz')['bandcuts'])
cfs_rec_lda_sym = []
for i in tqdm(range(14)):
    cfs_rec_lda_sym.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases_recon, baxis=0, ret='coeffs'))
cfs_rec_lda_sym = np.array(cfs_rec_lda_sym)

bandout = np.nan_to_num(np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_hse.npz')['bandcuts'])
cfs_rec_hse_sym = []
for i in tqdm(range(14)):
    cfs_rec_hse_sym.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases_recon, baxis=0, ret='coeffs'))
cfs_rec_hse_sym = np.array(cfs_rec_hse_sym)

### Main Figure 3d

In [None]:
def demean(bscoefs):
    c = np.mean(bscoefs[:,0])
    bscoefsdm = bscoefs.copy()
    bscoefsdm[:, 0] -= c
    return bscoefsdm

In [None]:
# Calculate distance metrics with zeroed DFT band structure
# Per-band Euclidean distance between band structures
cfs = [cfs_lda, cfs_pbe, cfs_pbesol, cfs_hse, cfs_rec_lda_sym, cfs_rec_pbe_sym, cfs_rec_pbesol_sym, cfs_rec_hse_sym]
ncfs = len(cfs)
dcdcent = np.zeros((ncfs, ncfs))
ids = list(it.product(range(ncfs), repeat=2))
for ipair, pair in enumerate(ids):
    i, j = pair[0], pair[1]
    icfsdc, jcfsdc = demean(cfs[i]), demean(cfs[j])
    dnorm = 0
    for ii in range(14):
        dnorm += norm(icfsdc[ii,:] - jcfsdc[ii,:])
    dcdcent[i,j] = dnorm / 14

dcmstdsym = np.zeros((ncfs, ncfs))
ids = list(it.product(range(ncfs), repeat=2))
for ipair, pair in enumerate(ids):
    i, j = pair[0], pair[1]
    iest, jest = cfs[i].copy(), cfs[j].copy()
    icfsdc, jcfsdc = demean(iest), demean(jest)
    dnorms = []
    for ii in range(14):
        dnorms.append(norm(icfsdc[ii,:] - jcfsdc[ii,:]))
    dcmstdsym[i,j] = np.std(dnorms)/np.sqrt(14)

In [None]:
nr, nc = dcdcent.shape
dcm = dcdcent.copy()
dcmstd = dcmstdsym.copy()

# Combine two triangular matrix plots
matnan = np.ones((8, 8))*np.nan
ut = np.triu(dcm, k=0) + np.tril(matnan, k=-1)
lt = np.tril(dcmstd, k=-1) + np.triu(matnan, k=0)
f, ax = plt.subplots(figsize=(6, 6))
fup = ax.matshow(ut*1000, cmap='viridis', vmin=0, vmax=250)
flo = ax.matshow(lt*1000, cmap='viridis', vmin=0)

divider = make_axes_locatable(ax)
caxu = divider.append_axes("right", size="5%", pad=0.2)
caxu.tick_params(axis='y', size=8, length=8, width=2, labelsize=15)
caxl = divider.append_axes("bottom", size="5%", pad=0.2)
caxl.tick_params(axis='x', size=8, length=8, width=2, labelsize=15)
cbup = f.colorbar(fup, orientation='vertical', cax=caxu, ticks=np.arange(0, 351, 50))
cblo = f.colorbar(flo, orientation='horizontal', cax=caxl, ticks=np.arange(0, 31, 5))
cbup.ax.set_yticklabels(np.arange(0, 351, 50))
cbup.ax.set_ylabel('Band structure distance (meV/band)', fontsize=15, rotation=-90, labelpad=20)
# cbup.ax.set_ylim([0, 250])
cblo.ax.set_xlabel('Standard error (meV/band)', fontsize=15, rotation=0, labelpad=5)

meths = ['LDA', 'PBE', 'PBEsol', 'HSE06', 'LDA \nrecon.', 'PBE \nrecon.', 'PBEsol \nrecon.', 'HSE06 \nrecon.']
ax.set_xticklabels([''] + meths, fontsize=15, rotation=90)
ax.set_yticklabels([''] + meths, fontsize=15, rotation=0)
ax.tick_params(axis='both', size=8, width=2)
ax.tick_params(axis='x', bottom=False, pad=8)
ax.tick_params(axis='y', pad=4)

dcm_merged = np.zeros_like(dcm) + np.triu(dcm, k=1) + np.tril(dcmstd, k=-1)
dcm_merged = np.rint(dcm_merged*1000).astype('int')
for i in range(nr):
    for j in range(nc):
        if i == j:
            ax.text(j, i, 0, ha='center', va='center', color='w', fontsize=15, fontweight='bold')
        else:
            ax.text(j, i, dcm_merged[i, j], ha='center', va='center', color='#FF4500', fontsize=15, fontweight='bold')
            
plt.savefig('../results/figures/fig_3d.png', bbox_inches='tight', transparent=True, dpi=300)

### Supplementary Figure 13e-h

In [None]:
# Construct new colormap 'KRdBu' and 'KRdBu_r' (based on 'RdBu' with black blended into the very end of the red side)
cmap_rdbu = mpl.cm.get_cmap('RdBu')
cmap_gr = mpl.cm.get_cmap('Greys_r')
colors = [cmap_gr(0.1), cmap_rdbu(0.1)]
nk = 13

KRd = cs.LinearSegmentedColormap.from_list('KRdBu', colors, N=nk)
KRdvals = KRd(np.linspace(0, 1, nk))
RdBuvals = cmap_rdbu(np.linspace(0.1, 1, 256-nk))
KRdBu_vals = np.concatenate((KRdvals, RdBuvals))
KRdBu_r_vals = np.flipud(KRdBu_vals)
KRdBu = cs.ListedColormap(KRdBu_vals)
KRdBu_r = cs.ListedColormap(KRdBu_r_vals)

In [None]:
def similarity_matrix_plot(cmat, title=''):
    """ Plot similarity matrix in the manuscript.
    """
    
    f, ax = plt.subplots(figsize=(6, 6))
    im = ax.matshow(cmat, cmap=KRdBu_r, extent=[0, 14, 14, 0], origin='upper', vmin=-1, vmax=1)
    tks = list(np.arange(0.5, 14, 1))
    ax.set_xticks(tks)
    ax.set_yticks(tks)
    ax.set_xticklabels(['#' + str(int(i+0.5)) for i in tks], fontsize=15, rotation=90)
    ax.set_yticklabels(['#' + str(int(i+0.5)) for i in tks], fontsize=15, rotation=0)
    ax.tick_params(axis='both', size=8, width=2, labelsize=15)
    ax.tick_params(axis='x', bottom=False)
    ax.tick_params(axis='x', pad=8)
    ax.set_title(title, fontsize=15, y=1.15)
    # ax.set_title('HSE06', fontsize=15, x=0.5, y=1.15)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.2)
    cax.tick_params(axis='y', size=8)
    cb = plt.colorbar(im, cax=cax, ticks=np.arange(-1, 1.01, 0.2))
    cb.ax.set_ylabel('Cosine similarity', fontsize=15, rotation=-90, labelpad=20)
    cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)
    # plt.colorbar(im, cax=cax, ticks=[])
    ax.text(-0.18, 1.08, ' Band\n index', rotation=-45, transform=ax.transAxes, fontsize=15)
    
    return ax

In [None]:
# Plot the cosince similarity matrices for each DFT calculation (indicated in figure title)
dcm_lda = fuller.metrics.similarity_matrix(cfs_lda, fmetric=fuller.metrics.dcos)
similarity_matrix_plot(dcm_lda, title='LDA')
plt.savefig('../results/figures/sfig_13e.png', bbox_inches='tight', transparent=True, dpi=300)

dcm_pbe = fuller.metrics.similarity_matrix(cfs_pbe, fmetric=fuller.metrics.dcos)
similarity_matrix_plot(dcm_pbe, title='PBE')
plt.savefig('../results/figures/sfig_13f.png', bbox_inches='tight', transparent=True, dpi=300)

dcm_pbesol = fuller.metrics.similarity_matrix(cfs_pbesol, fmetric=fuller.metrics.dcos)
similarity_matrix_plot(dcm_pbesol, title='PBEsol')
plt.savefig('../results/figures/sfig_13g.png', bbox_inches='tight', transparent=True, dpi=300)

dcm_hse = fuller.metrics.similarity_matrix(cfs_hse, fmetric=fuller.metrics.dcos)
similarity_matrix_plot(dcm_hse, title='HSE06');
plt.savefig('../results/figures/sfig_13h.png', bbox_inches='tight', transparent=True, dpi=300)

### Supplementary Figure 13a-d

In [None]:
def decomposition_plot(coefs, coef_count, title):
    
    cl = plt.cm.tab20(np.linspace(0,1,14))
    f, ax = plt.subplots(figsize=(6, 10))
    xs = np.arange(1, 400)
    for i in range(14):
        ax.plot(xs, coefs[i,1:]-i/3, lw=1, color=cl[i])
        ax.axhline(y=-i/3, lw=1, color=cl[i])
        ax.text(320, -i/3+0.08, 'Band #'+str(i+1), fontsize=15, fontname="Arial")
    ax.bar(xs, coef_count, bottom=-4.9, width=1, color=(0.3, 0.3, 0.3))
    ax.axhline(y=-4.9, lw=1, color=(0.3, 0.3, 0.3))
    ax.set_title(title, fontsize=15, y=0.88)
    ax.text(320, -4.9+0.08, 'All bands', fontsize=15, fontname="Arial")
    # ax.text('Counts', transform=ax.transAxes)

    # ax.tick_params(axis='y', length=0)
    ax.set_xticks(list(range(0, 401, 50)))
    ax.set_yticks([])
    ax.set_ylim([-5, 1])
    ax.tick_params(axis='x', length=8, width=2, labelsize=15)
    ax.set_ylabel('Amplitude (a. u.)', fontsize=15)
    ax.set_xlim([0, 400])
    ax.set_xlabel('Coefficient index', fontsize=15)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    return ax

In [None]:
cfs_lda_large = fuller.utils.binarize(cfs_lda, threshold=1e-2)[:,1:].sum(axis=0)/40
decomposition_plot(cfs_lda, cfs_lda_large, title='LDA')
plt.savefig('../results/figures/sfig_13a.png', bbox_inches='tight', transparent=True, dpi=300)

cfs_pbe_large = fuller.utils.binarize(cfs_pbe, threshold=1e-2)[:,1:].sum(axis=0)/40
decomposition_plot(cfs_pbe, cfs_pbe_large, title='PBE')
plt.savefig('../results/figures/sfig_13b.png', bbox_inches='tight', transparent=True, dpi=300)

cfs_pbesol_large = fuller.utils.binarize(cfs_pbesol, threshold=1e-2)[:,1:].sum(axis=0)/40
decomposition_plot(cfs_pbesol, cfs_pbesol_large, title='PBEsol')
plt.savefig('../results/figures/sfig_13c.png', bbox_inches='tight', transparent=True, dpi=300)

cfs_hse_large = fuller.utils.binarize(cfs_hse, threshold=1e-2)[:,1:].sum(axis=0)/40
decomposition_plot(cfs_hse, cfs_hse_large, title='HSE06');
plt.savefig('../results/figures/sfig_13d.png', bbox_inches='tight', transparent=True, dpi=300)