## Digitization of reconstructed bands using hexagonal Zernike polynomials

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import os
import numpy as np
import fuller
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
import matplotlib.colors as cs
import itertools as it
from tqdm import tqdm_notebook as tqdm
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

colornames = ['#646464', '#666666', '#6a6a6a', '#6f6f6f', '#737373', '#787878', '#7d7d7d', '#828282', '#878787', '#8d8d8d', '#929292', '#989898', '#9e9e9e', '#a4a4a4', '#aaaaaa', '#b0b0b0', '#b6b6b6', '#bcbcbc', '#c2c2c2', '#c9c9c9', '#cfcfcf', '#d6d6d6', '#dcdcdc', '#e3e3e3', '#eaeaea', '#efefee', '#efeee5', '#efeddc', '#efecd3', '#eeebca', '#eeeac0', '#eee9b7', '#eee8ad', '#ede7a4', '#ede69a', '#ede590', '#ede487', '#ece37d', '#ece273', '#ece069', '#ecdf5f', '#ebde55', '#ebdd4b', '#ebdc41', '#ebdb37', '#ebd333', '#ebc933', '#ecbe32', '#ecb432', '#eda931', '#ee9e31', '#ee9330', '#ef8830', '#ef7d2f', '#f0722f', '#f0672e', '#f15c2e', '#f2512d', '#f2462d', '#f33b2c', '#f3302c', '#f4252b', '#f4192b', '#ef182f', '#e81834', '#e21939', '#db1a3e', '#d51a43', '#ce1b48', '#c71b4d', '#c11c52', '#ba1c58', '#b31d5d', '#ac1d62', '#a61e67', '#9f1e6c', '#981f72', '#911f77', '#8a207c', '#842182']
custom_cmap = mpl.colors.LinearSegmentedColormap.from_list('custom', colornames, N=256)

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

### Main Figure 3a

In [None]:
# Generate hexagonal Zernike basis
basis = fuller.generator.ppz.hexike_basis(nterms=100, npix=257, vertical=True, outside=0)

# Mask the region beyond the hexagonal boundary
bmask = fuller.generator.hexmask(hexdiag=257, imside=257, padded=False, margins=[1, 1, 1, 1])

In [None]:
# Plotting a selection of hexagonal Zernike polynomials
ff, axs = plt.subplots(5, 1, figsize=(3, 10))

for ind, ibs in enumerate([3, 10, 27, 41, 89]):
    im = axs[ind].imshow(basis[ibs,...]*bmask[...], cmap=custom_cmap, vmin=-1.8, vmax=1.8)
    axs[ind].axis('off')
    axs[ind].text(5, 5, str(ind+1), fontsize=15, fontname="Arial")

cax = ff.add_axes([0.36, 0.08, 0.3, 0.02])
cb = plt.colorbar(im, cax=cax, ticks=[-1.8, 1.8], orientation='horizontal')
cb.ax.tick_params(axis='both', length=0)
cb.ax.set_xticklabels(['low', 'high'], fontsize=15) #'{0}'.format(u'\u2014')
# cb.ax.set_ylabel('Height', rotation=-90, fontsize=15)
plt.subplots_adjust(hspace=0.1)
plt.savefig('../results/figures/fig_3a1.png', bbox_inches='tight', transparent=True, dpi=300)

In [None]:
# Calculate the decomposition coefficients for all bands
bandout = np.nan_to_num(np.load(r'../data/processed/wse2_recon_1BZ/postproc_bandcuts_lda.npz')['bandcuts'])
bases_recon = fuller.generator.ppz.hexike_basis(nterms=400, npix=175, vertical=True, outside=0)
cfs_rec_lda = []
for i in tqdm(range(14)):
    cfs_rec_lda.append(fuller.generator.decomposition_hex2d(bandout[i,...], bases=bases_recon, baxis=0, ret='coeffs'))
cfs_rec_lda = np.array(cfs_rec_lda)

In [None]:
# Collect the large coefficients needed for approximating each energy band
cfs_large = fuller.utils.binarize(cfs_rec_lda, threshold=1e-2, vals=[0, 1])

In [None]:
cl = plt.cm.tab20(np.linspace(0,1,14))
f, ax = plt.subplots(figsize=(6, 10))
xs = np.arange(1, 400)
for i in range(14):
    ax.plot(xs, cfs_rec_lda[i,1:]-i/3, lw=1, color=cl[i])
    ax.axhline(y=-i/3, lw=1, color=cl[i])
    ax.text(320, -i/3+0.08, 'Band #'+str(i+1), fontsize=15, fontname="Arial")

ax.bar(xs, (np.abs(cfs_large[:,1:])).sum(axis=0)/40, bottom=-4.9, width=1, color=(0.3, 0.3, 0.3))
ax.axhline(y=-4.9, lw=1, color=(0.3, 0.3, 0.3))
# ax.set_title('HSE06', fontsize=15, y=0.88)
ax.text(320, -4.9+0.08, 'All bands', fontsize=15, fontname="Arial")
# ax.text('Counts', transform=ax.transAxes)

ax.set_xticks(list(range(0, 401, 50)))
ax.set_yticks([])
ax.set_ylim([-5, 1])
ax.tick_params(axis='x', length=8, width=2, labelsize=15)
ax.set_ylabel('Amplitude (a. u.)', fontsize=15)
ax.set_xlim([0, 400])
ax.set_xlabel('Coefficient index', fontsize=15)
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.text(-22, -4.5, 'Counts', rotation=90, fontsize=15);
plt.savefig('../results/figures/fig_3a2.png', bbox_inches='tight', transparent=True, dpi=300)

### Main Figure 3c

In [None]:
# Calculate the pairwise correlation matrix
ncfs = 14
dcm = np.zeros((ncfs, ncfs))
ids = list(it.product(range(ncfs), repeat=2))
for ipair, pair in enumerate(ids):
    i, j = pair[0], pair[1]
    dcm[i,j] = fuller.metrics.dcos(cfs_rec_lda[i,1:], cfs_rec_lda[j,1:])

In [None]:
# Construct new colormap 'KRdBu' and 'KRdBu_r' (based on 'RdBu' with black blended into the very end of the red side)
cmap_rdbu = mpl.cm.get_cmap('RdBu')
cmap_gr = mpl.cm.get_cmap('Greys_r')
colors = [cmap_gr(0.1), cmap_rdbu(0.1)]
nk = 13

KRd = cs.LinearSegmentedColormap.from_list('KRdBu', colors, N=nk)
KRdvals = KRd(np.linspace(0, 1, nk))
RdBuvals = cmap_rdbu(np.linspace(0.1, 1, 256-nk))
KRdBu_vals = np.concatenate((KRdvals, RdBuvals))
KRdBu_r_vals = np.flipud(KRdBu_vals)
KRdBu = cs.ListedColormap(KRdBu_vals)
KRdBu_r = cs.ListedColormap(KRdBu_r_vals)

In [None]:
# Plot the pairwise correlation matrix for reconstructed bands
f, ax = plt.subplots(figsize=(6, 6))
im = ax.matshow(dcm, cmap=KRdBu_r, extent=[0, 14, 14, 0], origin='upper', vmin=-1, vmax=1)
tks = list(np.arange(0.5, 14, 1))
ax.set_xticks(tks)
ax.set_yticks(tks)
ax.set_xticklabels(['#' + str(int(i+0.5)) for i in tks], fontsize=15, rotation=90)
ax.set_yticklabels(['#' + str(int(i+0.5)) for i in tks], fontsize=15, rotation=0)
ax.tick_params(axis='both', size=8, width=2, labelsize=15)
ax.tick_params(axis='x', bottom=False)
ax.tick_params(axis='x', pad=8)
# ax.set_title('HSE06', fontsize=15, x=0.5, y=1.15)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.2)
cax.tick_params(axis='y', size=8)
cb = plt.colorbar(im, cax=cax, ticks=np.arange(-1, 1.01, 0.2))
cb.ax.set_ylabel('Cosine similarity', fontsize=15, rotation=-90, labelpad=20)
cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)
# plt.colorbar(im, cax=cax, ticks=[])
ax.text(-0.18, 1.08, ' Band\n index', rotation=-45, transform=ax.transAxes, fontsize=15)
plt.savefig('../results/figures/fig_3c.png', bbox_inches='tight', transparent=True, dpi=300)