## Tests of the Markov random field model for reconstructing 2D synthetic data

In [None]:
import warnings as wn
wn.filterwarnings("ignore")

import numpy as np
import fuller
import matplotlib.pyplot as plt
from mpes import analysis as aly
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import os
%matplotlib inline

In [None]:
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# Create plot folder if needed
if not os.path.exists('../results/figures'):
    os.mkdir('../results/figures')

## 2D single sinosodial band

### Data generation

In [None]:
# Generate synthetic data
kx = np.arange(-1, 1, 0.01)
band_sin2d = 3*np.sin(13*kx) + 2*np.cos(12*kx) - 4
b2d_min, b2d_max = band_sin2d.min(), band_sin2d.max()
# plt.plot(kx, band_sin2d)

Evals = np.arange(b2d_min-2, b2d_max+2, 0.01)
pes_data_2d = aly.voigt(feval=True, vardict={'amp':1, 'xvar':Evals[:,None],
                                             'ctr':band_sin2d, 'sig':1, 'gam':0.3})
plt.imshow(pes_data_2d[::-1,:], aspect=0.1, extent=[-1, 1, b2d_min-2, b2d_max+2], cmap='Blues')

In [None]:
# Construct initialization
init = np.zeros_like(kx)
plt.plot(kx, band_sin2d, c='r', label='ground truth')
plt.plot(kx, init, c='b', label='initialization')
plt.legend(loc='lower left', fontsize=12)

### Reconstruction

In [None]:
I = np.transpose(pes_data_2d)
I = I / I.max()
mrf = fuller.mrfRec.MrfRec(E=Evals, kx=kx, ky=np.array([0.]), I=np.reshape(I, (len(kx), 1, len(Evals))),
                           eta=1, E0=init[:, None])

In [None]:
mrf.iter_seq(200)
recon = mrf.getEb()[:, 0]

### Supplementary Figure 5a

In [None]:
# Summary plot
gs = gridspec.GridSpec(1, 2, width_ratios=[5,5])
fig = plt.figure()
axs = []
for i in range(2):
    axs.append(fig.add_subplot(gs[i]))
    
im = axs[0].imshow(pes_data_2d[::-1,:], aspect=0.15, extent=[-1, 1, b2d_min-2, b2d_max+2], cmap='Blues')
axs[0].plot(kx, band_sin2d, 'r')
axs[0].set_aspect(aspect=0.15)
axs[0].set_ylabel('Energy (a.u.)', fontsize=15)
cax = fig.add_axes([0.94, 0.54, 0.03, 0.2])
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[])
cb.ax.set_ylabel('Intensity', fontsize=15, rotation=-90, labelpad=18)

axs[1].plot(kx, band_sin2d, 'r', label='ground truth')
axs[1].plot(kx, init, 'b', label='initialization')
axs[1].plot(kx, recon, 'g', label='reconstruction')
axs[1].set_xlim([-1, 1])
axs[1].set_ylim([b2d_min-2, b2d_max+2])
axs[1].set_aspect(aspect=0.15)
axs[1].set_yticks([])
lg = axs[1].legend(fontsize=15, bbox_to_anchor=(1.04,0.5), frameon=False,
                   borderpad=0, labelspacing=0.8, handlelength=1.2, handletextpad=0.5)

for i in range(2):
    axs[i].set_xlabel('$k$ (a.u.)', fontsize=15)
    axs[i].tick_params(axis='both', length=8, width=2, labelsize=15)

plt.subplots_adjust(wspace=0.1)
plt.savefig('../results/figures/sfig_6a.png', dpi=300, bbox_inches='tight', transparent=True)

## 2D band crossing

### Data generation

In [None]:
# Generate synthetic data
kx = np.arange(-1, 1, 0.014)
band_pb2d_up = 5*kx**2 - kx/5 - 5
band_pb2d_down = -(0.2*kx**2 + kx/4 + 2.5)
plt.figure(figsize=(5,4))
b2d_min, b2d_max = band_pb2d_down.min(), band_pb2d_up.max()
# plt.plot(kx, band_pb2d_up)
# plt.plot(kx, band_pb2d_down)

Evals = np.arange(b2d_min-4, b2d_max+2, 0.012)
pes_data_2d_up = aly.voigt(feval=True, vardict={'amp':1.6, 'xvar':Evals[:,None],
                                             'ctr':band_pb2d_up, 'sig':0.07, 'gam':0.15})
pes_data_2d_down = aly.voigt(feval=True, vardict={'amp':1, 'xvar':Evals[:,None],
                                             'ctr':band_pb2d_down, 'sig':0.07, 'gam':0.1})
pes_data_2d = pes_data_2d_up + pes_data_2d_down
plt.imshow(pes_data_2d[::-1,:], aspect=0.2, extent=[-1, 1, b2d_min-4, b2d_max+2], cmap='Blues')

In [None]:
# Construct initialization
band_init2d_up = 3.5 * kx ** 2 - kx / 20 - 4
band_init2d_down = -3 * np.ones_like(kx)

plt.plot(kx, band_pb2d_up, c='r')
plt.plot(kx, band_pb2d_down, c='r', label='ground truth')
plt.plot(kx, band_init2d_up, c='b')
plt.plot(kx, band_init2d_down, c='b', label='initialization')
plt.legend(loc='upper center', fontsize=12)

### Reconstruction

In [None]:
# Reconstruct first band
I = np.transpose(pes_data_2d)
I = I / I.max()
mrf = fuller.mrfRec.MrfRec(E=Evals, kx=kx, ky=np.array([0.]), I=np.reshape(I, (len(kx), 1, len(Evals))),
                           eta=0.085, E0=band_init2d_down[:, None])
mrf.iter_seq(500)
recon_down = mrf.getEb()[:, 0]

In [None]:
# Reconstruct second band
mrf = fuller.mrfRec.MrfRec(E=Evals, kx=kx, ky=np.array([0.]), I=np.reshape(I, (len(kx), 1, len(Evals))),
                           eta=0.2, E0=band_init2d_up[:, None])
mrf.iter_seq(500)
recon_up = mrf.getEb()[:, 0]

### Supplementary Figure 6b

In [None]:
# Summary plot
emin, emax = Evals.min(), Evals.max()

gs = gridspec.GridSpec(1, 2, width_ratios=[8,8])
fig = plt.figure()
axs = []
for i in range(2):
    axs.append(fig.add_subplot(gs[i]))
    
im = axs[0].imshow(pes_data_2d, aspect=0.2, extent=[-1, 1, emin, emax], cmap='Blues', origin='lower', vmax=2)
axs[0].plot(kx, band_pb2d_up, 'r')
axs[0].plot(kx, band_pb2d_down, 'r')
axs[0].set_aspect(aspect=0.2)
axs[0].set_ylabel('Energy (a.u.)', fontsize=15)
cax = fig.add_axes([0.94, 0.54, 0.03, 0.2])
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[])
cb.ax.set_ylabel('Intensity', fontsize=15, rotation=-90, labelpad=18)

axs[1].plot(kx, band_pb2d_up, 'r')
axs[1].plot(kx, band_pb2d_down, 'r', label='ground truth')
axs[1].plot(kx, band_init2d_up, 'b')
axs[1].plot(kx, band_init2d_down, 'b', label='initialization')
axs[1].plot(kx, recon_up, 'g')
axs[1].plot(kx, recon_down, 'g', label='reconstruction')
axs[1].set_xlim([-1, 1])
axs[1].set_ylim([emin, emax])
axs[1].set_aspect(aspect=0.2)
axs[1].set_yticks([])
lg = axs[1].legend(fontsize=15, bbox_to_anchor=(1.04,0.5), frameon=False,
                   borderpad=0, labelspacing=0.8, handlelength=1.2, handletextpad=0.5)

for i in range(2):
    axs[i].set_xlabel('$k$ (a.u.)', fontsize=15)
    axs[i].tick_params(axis='both', length=8, width=2, labelsize=15)

plt.subplots_adjust(wspace=0.1)
plt.savefig('../results/figures/sfig_6b.png', dpi=300, bbox_inches='tight', transparent=True)