## Tests of the Markov random field model for reconstructing 3D synthetic data

In [None]:
import os
import warnings as wn
wn.filterwarnings("ignore")

import numpy as np
import fuller
import matplotlib.pyplot as plt
from mpes import analysis as aly
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.tri as mtri
import tifffile as ti
from scipy import io
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

## 3D single band: Second-order Griewank function

### Data generation

In [None]:
kx = np.arange(-6, 6, 0.04)
ky = np.arange(-6, 6, 0.04)
kyy, kxx = np.meshgrid(kx, ky)

band_griewank = 1/4000*sum((kxx/2)**2 + (kyy/2)**2) - np.cos(2*kxx)*np.cos(2*(np.sqrt(2)/2)*kyy) - 1.5
b3d_min, b3d_max = band_griewank.min(), band_griewank.max()

Evals = np.linspace(b3d_min-2, b3d_max+2, 400)
pes_data_3d = aly.voigt(feval=True, vardict={'amp':1, 'xvar':Evals[:, None, None],
                                             'ctr':band_griewank, 'sig':1, 'gam':0.3})

plt.imshow(pes_data_3d[:, :, 150], cmap='Blues')

### Reconstruction

In [None]:
# Create model
I = np.transpose(pes_data_3d, (1, 2, 0))
I = I / I.max()
E0 = np.zeros_like(band_griewank)

mrf = fuller.mrfRec.MrfRec(E=Evals, kx=kx, ky=ky, I=I, E0=E0, eta=1)

In [None]:
# Perform reconstruction
mrf.iter_para(200)

### Supplementary Figure 5d

In [None]:
# Summary plot

recon3d = mrf.getEb()

gs = gridspec.GridSpec(1, 3, width_ratios=[5,5,5])
fig = plt.figure(figsize=(10, 4))
axs = []
for i in range(3):
    axs.append(fig.add_subplot(gs[i]))
    
im = axs[0].imshow(band_griewank, aspect=1, extent=[-6, 6, -6, 6], vmin=-2.2, vmax=0, cmap='Spectral_r')
axs[0].set_yticks(range(-6, 7, 2))
axs[0].set_ylabel('$k_y$ (a.u.)', fontsize=15)
axs[0].set_title('Ground truth', fontsize=15)
cax = fig.add_axes([0.93, 0.2, 0.02, 0.2])
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=np.arange(-2, 0.1, 1))
cb.ax.set_title('Energy\n(a.u.)', fontsize=15, pad=10)
cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)

axs[1].imshow(E0, cmap='Spectral_r', extent=[-6, 6, -6, 6], vmin=-2.2, vmax=0)
axs[1].set_title('Initialization', fontsize=15)
axs[1].tick_params(axis='y', length=0)
axs[1].set_yticks([])

axs[2].imshow(recon3d, aspect=1, extent=[-6, 6, -6, 6], vmin=-2.2, vmax=0, cmap='Spectral_r')
axs[2].set_yticks([])
axs[2].set_title('Reconstruction', fontsize=15)

for i in [0,1,2]:
    axs[i].set_xticks(range(-6, 7, 2))
    axs[i].set_xlabel('$k_x$ (a.u.)', fontsize=15)
    axs[i].tick_params(axis='both', length=8, width=2, labelsize=15)

plt.subplots_adjust(wspace=0.15)
plt.savefig('../results/figures/sfig_6d1.png', dpi=300, bbox_inches='tight', transparent=True)

In [None]:
f, ax = plt.subplots(figsize=(4, 3))
im = ax.imshow(recon3d - band_griewank, cmap='RdBu_r', vmax=0.1, vmin=-0.1, extent=[-6, 6, -6, 6])
ax.tick_params(axis='both', length=8, width=2, labelsize=15)
ax.set_xticks(range(-6, 7, 2))
ax.set_xlabel('$k_x$ (a.u.)', fontsize=15)
ax.set_yticks(range(-6, 7, 2))
ax.set_ylabel('$k_y$ (a.u.)', fontsize=15, rotation=-90, labelpad=20)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_title('Difference', fontsize=15)
cax = f.add_axes([-0.02, 0.53, 0.05, 0.25])
cb = plt.colorbar(im, cax=cax, orientation='vertical')
cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)
cb.ax.set_title('Energy\n(a.u.)', fontsize=15, pad=10)
plt.savefig('../results/figures/sfig_6d2.png', dpi=300, bbox_inches='tight', transparent=True)

## 3D band near-crossing: graphene band struction nearby Fermi level

### Data generation

In [None]:
kx = np.arange(-1, 1, 0.01)
ky = np.arange(-1, 1, 0.01)
kyy, kxx = np.meshgrid(kx, ky)

sq3 = np.sqrt(3)
t, a = 1, 2*np.pi / (sq3)
band_graphene = 1 + 4 * (np.cos(sq3 * kyy * a / 2) ** 2) + 4 * np.cos(sq3 * kyy * a / 2) * np.cos(3 * kxx * a / 2)
band_graphene[band_graphene < 0] = 1.e-10
band_graphene_upper = t*np.sqrt(band_graphene)
band_graphene_lower = - t*np.sqrt(band_graphene)
b3d_max, b3d_min = band_graphene_upper.max(), band_graphene_lower.min()

In [None]:
f = plt.figure(figsize=(6, 5))
ax = f.add_subplot(111, projection='3d')

tri = mtri.Triangulation(kyy.flatten(), kxx.flatten())
ax.plot_trisurf(kxx.flatten(), kyy.flatten(), band_graphene_upper.flatten(),
                triangles=tri.triangles, cmap='Spectral_r', antialiased=False)
ax.plot_trisurf(kxx.flatten(), kyy.flatten(), band_graphene_lower.flatten(),
                triangles=tri.triangles, cmap='Spectral', antialiased=False)

ax.set_xlabel('$k_x$', labelpad=15)
ax.set_ylabel('$k_y$', labelpad=15)
ax.set_zlabel('Energy', labelpad=15);

In [None]:
Evals = np.linspace(b3d_min-2, b3d_max+2, 400)
pes_data_3d_upper = aly.voigt(feval=True, vardict={'amp':1, 'xvar':Evals[:, None, None],
                                             'ctr':band_graphene_upper, 'sig':0.2, 'gam':0.3})
pes_data_3d_lower = aly.voigt(feval=True, vardict={'amp':1, 'xvar':Evals[:, None, None],
                                             'ctr':band_graphene_lower, 'sig':0.2, 'gam':0.3})
pes_data_3d = pes_data_3d_upper + pes_data_3d_lower

plt.imshow(pes_data_3d[:, 90, :], aspect=0.15, extent=[-1, 1, b3d_min-2, b3d_max+2], cmap='Blues')
plt.xlabel('$k_x$', fontsize=15)
plt.ylabel('Energy', fontsize=15)
plt.tick_params(axis='both', length=8, width=2, labelsize=15)

### Reconstruction

In [None]:
I = np.transpose(pes_data_3d, (1, 2, 0))
I = I / I.max()

results = np.zeros((2,) + band_graphene.shape)
E0 = np.ones((2,) + band_graphene.shape) * 4
E0[1, :, :] *= -1

for i in range(2):
    mrf = fuller.mrfRec.MrfRec(E=Evals, kx=kx, ky=ky, I=I, E0=E0[i,...], eta=0.3)
    mrf.iter_para(200)
    results[i,...] = mrf.getEb()

### Supplementary Figure 5f

In [None]:
# Summary plot
recon3d_upper = results[0,...]
recon3d_lower = results[1,...]

init_upper = E0[0,...]
init_lower = E0[1,...]

gs = gridspec.GridSpec(2, 3)
fig = plt.figure(figsize=(9.8, 6.5))
axs = []
for i in range(6):
    axs.append(fig.add_subplot(gs[i]))
    
axs[0].imshow(band_graphene_upper, aspect=1, extent=[-1, 1, -1, 1], vmin=0, vmax=3, cmap='Spectral_r')
axs[0].set_yticks(np.arange(-1, 1.1, 0.5))
axs[0].set_ylabel('$k_y$ $(\mathrm{\AA^{-1}})$', fontsize=15)
axs[0].set_title('Ground truth', fontsize=15)
axs[0].tick_params(axis='both', length=8, width=2, labelsize=15)
axs[0].text(0.15, 0.9, 'Conduction Band', fontsize=15, transform=axs[0].transAxes)

axs[1].imshow(init_upper, cmap='Spectral_r', aspect=1, extent=[-1, 1, -1, 1], vmin=0, vmax=3)
axs[1].set_title('Initialization', fontsize=15)
axs[1].tick_params(axis='y', length=0)
axs[1].set_yticks([])

imu = axs[2].imshow(recon3d_upper, aspect=1, extent=[-1, 1, -1, 1], vmin=0, vmax=3, cmap='Spectral_r')
axs[2].set_yticks([])
axs[2].set_title('Reconstruction', fontsize=15)
axs[2].yaxis.set_label_position("right")

# Upper band colorbar
caxu = fig.add_axes([0.94, 0.5, 0.02, 0.12])
cbu = fig.colorbar(imu, cax=caxu, orientation='vertical', ticks=np.arange(0, 3.1, 1))
cbu.ax.set_title('Energy\n(eV)', fontsize=15, pad=10)
cbu.ax.tick_params(axis='both', length=8, width=2, labelsize=15)

iml = axs[3].imshow(band_graphene_lower, aspect=1, extent=[-1, 1, -1, 1], vmin=-3, vmax=0, cmap='Spectral_r')
axs[3].set_yticks(np.arange(-1, 1.1, 0.5))
axs[3].set_ylabel('$k_y$ $(\mathrm{\AA^{-1}})$', fontsize=15)
axs[3].text(0.3, 0.9, 'Valence Band', fontsize=15, transform=axs[3].transAxes)

axs[4].imshow(init_lower, cmap='Spectral_r', aspect=1, extent=[-1, 1, -1, 1], vmin=-3, vmax=0)
axs[4].tick_params(axis='y', length=0)
axs[4].set_yticks([])

axs[5].imshow(recon3d_lower, aspect=1, extent=[-1, 1, -1, 1], vmin=-3, vmax=0, cmap='Spectral_r')
axs[5].set_yticks([])
axs[5].yaxis.set_label_position("right")

# Lower band colorbar
caxl = fig.add_axes([0.94, 0.03, 0.02, 0.12])
cbl = fig.colorbar(iml, cax=caxl, orientation='vertical', ticks=np.arange(-3, 0.1, 1))
cbl.ax.set_title('Energy\n(eV)', fontsize=15, pad=10)
cbl.ax.tick_params(axis='both', length=8, width=2, labelsize=15)

for i in [0, 1, 2]:
    axs[i].set_xticks([])
    
for i in [3, 4, 5]:
    axs[i].set_xticks(np.arange(-1, 1.1, 0.5))
    axs[i].set_xlabel('$k_x$ $(\mathrm{\AA^{-1}})$', fontsize=15)
    axs[i].tick_params(axis='both', length=8, width=2, labelsize=15)
    
plt.subplots_adjust(hspace=0.18, wspace=0.1)
plt.savefig('../results/figures/sfig_6f1.png', dpi=300, bbox_inches='tight', transparent=True)

In [None]:
f, ax = plt.subplots(figsize=(4, 3))
im = ax.imshow(recon3d_upper - band_graphene_upper, cmap='RdBu_r', vmax=0.1, vmin=-0.1, extent=[-1, 1, -1, 1])
ax.tick_params(axis='both', length=8, width=2, labelsize=15)
ax.set_xticks(np.arange(-1, 1.1, 0.5))
ax.set_xlabel('$k_x$ $(\mathrm{\AA^{-1}})$', fontsize=15)
ax.set_yticks(np.arange(-1, 1.1, 0.5))
ax.set_ylabel('$k_y$ $(\mathrm{\AA^{-1}})$', fontsize=15, rotation=-90, labelpad=20)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_title('Difference', fontsize=15)
cax = f.add_axes([-0.02, 0.53, 0.05, 0.25])
cb = plt.colorbar(im, cax=cax, orientation='vertical')
cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)
cb.ax.set_title('Energy\n(eV)', fontsize=15, pad=10)
plt.savefig('../results/figures/sfig_6f2.png', dpi=300, bbox_inches='tight', transparent=True)

In [None]:
f, ax = plt.subplots(figsize=(4, 3))
im = ax.imshow(recon3d_lower - band_graphene_lower, cmap='RdBu_r', vmax=0.1, vmin=-0.1, extent=[-1, 1, -1, 1])
ax.tick_params(axis='both', length=8, width=2, labelsize=15)
ax.set_xticks(np.arange(-1, 1.1, 0.5))
ax.set_xlabel('$k_x$ $(\mathrm{\AA^{-1}})$', fontsize=15)
ax.set_yticks(np.arange(-1, 1.1, 0.5))
ax.set_ylabel('$k_y$ $(\mathrm{\AA^{-1}})$', fontsize=15, rotation=-90, labelpad=20)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.set_title('Difference', fontsize=15)
cax = f.add_axes([-0.02, 0.53, 0.05, 0.25])
cb = plt.colorbar(im, cax=cax, orientation='vertical')
cb.ax.tick_params(axis='both', length=8, width=2, labelsize=15)
cb.ax.set_title('Energy\n(eV)', fontsize=15, pad=10)
plt.savefig('../results/figures/sfig_6f3.png', dpi=300, bbox_inches='tight', transparent=True)