In [None]:
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize

In [None]:
def get_surface_hamiltonian(h_symbolic, ksymbols, params, direction = "x", pert = None):
    Nx, Ny, Nbands = params['Nx'], params['Ny'], params['Nbands']
    kx_sym, ky_sym = ksymbols

    if direction == "x":
        Lx_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * kx_sym * (-1.0)), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * kx_sym), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nnn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * kx_sym * (-1.0)), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nnn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * kx_sym), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

        Lx_nn_pos = Lx_nn_pos.rewrite(sp.cos).simplify()
        Lx_nn_neg = Lx_nn_neg.rewrite(sp.cos).simplify()
        Lx_nnn_pos = Lx_nnn_pos.rewrite(sp.cos).simplify()
        Lx_nnn_neg = Lx_nnn_neg.rewrite(sp.cos).simplify()
        
        H_diag = h_symbolic - (Lx_nn_pos * sp.exp(sp.I * kx_sym) + Lx_nn_neg * sp.exp(-sp.I * kx_sym) + Lx_nnn_pos * sp.exp(sp.I * 2 * kx_sym) + Lx_nnn_neg * sp.exp(-sp.I * 2 * kx_sym))
        H_diag = H_diag.rewrite(sp.cos).simplify()
        H_diag = sp.nsimplify(H_diag, tolerance = 1e-8)
        
        h = sp.zeros(Nx*Nbands, Nx*Nbands)

        for i in range(Nx):
            h[i*Nbands:(i+1)*Nbands, i*Nbands:(i+1)*Nbands] = H_diag[:,:]

            if i > 0: 
                h[(i-1)*Nbands:i*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nn_pos[:,:]

                if i > 1:
                    h[(i-2)*Nbands:(i-1)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nnn_pos[:,:]

            if i < Nx - 1:
                h[(i+1)*Nbands:(i+2)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nn_neg[:,:]

                if i < Nx - 2:
                    h[(i+2)*Nbands:(i+3)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nnn_neg[:,:]

        slab_hfunc = sp.lambdify(ky_sym, h, modules = "numpy")

    elif direction == "y": 
        Ly_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym * (-1.0)), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nnn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * ky_sym * (-1.0)), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nnn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * ky_sym), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

        Ly_nn_pos = Ly_nn_pos.rewrite(sp.cos).simplify()
        Ly_nn_neg = Ly_nn_neg.rewrite(sp.cos).simplify()
        Ly_nnn_pos = Ly_nnn_pos.rewrite(sp.cos).simplify()
        Ly_nnn_neg = Ly_nnn_neg.rewrite(sp.cos).simplify()

        H_diag = h_symbolic - (Ly_nn_pos * sp.exp(sp.I * ky_sym) + Ly_nn_neg * sp.exp(-sp.I * ky_sym) + Ly_nnn_pos * sp.exp(sp.I * 2 * ky_sym) + Ly_nnn_neg * sp.exp(-sp.I * 2 * ky_sym))
        H_diag = H_diag.rewrite(sp.cos).simplify()
        H_diag = sp.nsimplify(H_diag, tolerance = 1e-8)
        
        h = sp.zeros(Ny*Nbands, Ny*Nbands)

        for i in range(Ny):
            h[i*Nbands:(i+1)*Nbands, i*Nbands:(i+1)*Nbands] = H_diag[:,:]

            if i > 0: 
                h[(i-1)*Nbands:i*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nn_pos[:,:]

            if i > 1:
                h[(i-2)*Nbands:(i-1)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nnn_pos[:,:]

            if i < Ny - 1:
                h[(i+1)*Nbands:(i+2)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nn_neg[:,:]

            if i < Ny - 2:
                h[(i+2)*Nbands:(i+3)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nnn_neg[:,:]

        if pert: 
            sz = sp.Matrix([[1, 0], [0, -1]])
            h[0:1*Nbands, 0:1*Nbands] += +(1-sp.cos(kx_sym)) * (sp.eye(Nbands)+sz)
            h[(Ny-1)*Nbands:Ny*Nbands, (Ny-1)*Nbands:Ny*Nbands] += -(1-sp.cos(kx_sym)) * (sp.eye(Nbands)-sz)
            
        slab_hfunc = sp.lambdify(kx_sym, h, modules = "numpy")

    else:
        raise ValueError("Invalid direction")
    
    return slab_hfunc

In [None]:
s0 = sp.eye(2)
sx = sp.Matrix([[0, 1], [1, 0]])
sy = sp.Matrix([[0, -sp.I], [sp.I, 0]])
sz = sp.Matrix([[1, 0], [0, -1]])

In [None]:
kx_sym, ky_sym = sp.symbols('k_x k_y', real = True)
ksymbols = [kx_sym, ky_sym]
alpha = sp.symbols('alpha', real = True, positive = True)
gamma_z, lambda_z = sp.symbols('gamma_z lambda_z', real = True)

In [None]:
hrtp = sp.sin(2*kx_sym) * sx
hrtp += sp.sin(kx_sym) * sp.sin(ky_sym) * sy
hrtp += - (alpha + sp.cos(2*kx_sym) + sp.cos(ky_sym)) * sz 

In [None]:
Nx = 100
Ny = 21
Nbands = 2
Nocc = 1

params = {}
params["Nx"] = Nx
params["Ny"] = Ny
params["Nz"] = 0
params["Nbands"] = Nbands
params["Nocc"] = Nocc

In [None]:
Kxs = np.linspace(0, 2*np.pi, Nx)
Kys = np.linspace(0, 2*np.pi, Ny)

In [None]:
H_fixparam = hrtp.subs({alpha : 1})
hfunc = sp.lambdify((kx_sym,ky_sym), H_fixparam, modules = "numpy")

In [None]:
slab_hfunc_Y = get_surface_hamiltonian(h_symbolic = H_fixparam, ksymbols = ksymbols, params = params, direction = "y", pert = None)

In [None]:
eigenvalues_y = np.zeros((Nx,Ny*Nbands))
eigenstates_y = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)

for idx, kx in enumerate(Kxs):
    vals, vecs = np.linalg.eigh(slab_hfunc_Y(kx))

    ind = np.argsort(vals)
    eigenvalues_y[idx,:] = vals[ind]
    eigenstates_y[idx,:,:] = vecs[:,ind]

In [None]:
#carry out linear sum assignment
for i in range(Nx-1):
    v0 = eigenstates_y[i,:,:]
    v1 = eigenstates_y[i+1,:,:]
    vals1 = eigenvalues_y[i+1,:]

    Q = abs(v0.conj().T @ v1)
    ind = linear_sum_assignment(-Q)[1]

    eigenvalues_y[i+1,:] = vals1[ind]
    eigenstates_y[i+1,:,:] = v1[:, ind]

In [None]:
fig = plt.figure(figsize = (4,3))
for i in range(Ny*Nbands):
    plt.plot(Kxs, eigenvalues_y[:,i], color = "black", alpha = 0.2)

plt.xlabel(r"$k_x$", fontsize = 16, labelpad = 0)
plt.ylabel(r"$E$", fontsize = 16, labelpad = -2)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 14)
plt.yticks([-2, 0, 2], [-2, 0, 2], fontsize = 14)
plt.tight_layout()
plt.show()

In [None]:
slab_hfunc_Y = get_surface_hamiltonian(h_symbolic = H_fixparam, ksymbols = ksymbols, params = params, direction = "y", pert = True)

In [None]:
eigenvalues_y = np.zeros((Nx,Ny*Nbands))
eigenstates_y = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)

for idx, kx in enumerate(Kxs):
    vals, vecs = np.linalg.eigh(slab_hfunc_Y(kx))

    ind = np.argsort(vals)
    eigenvalues_y[idx,:] = vals[ind]
    eigenstates_y[idx,:,:] = vecs[:,ind]

In [None]:
for i in range(Nx-1):
    v0 = eigenstates_y[i,:,:]
    v1 = eigenstates_y[i+1,:,:]
    vals1 = eigenvalues_y[i+1,:]

    Q = abs(v0.conj().T @ v1)
    ind = linear_sum_assignment(-Q)[1]

    eigenvalues_y[i+1,:] = vals1[ind]
    eigenstates_y[i+1,:,:] = v1[:, ind]

In [None]:
fig = plt.figure(figsize = (4,3))
for i in range(Ny*Nbands):
    plt.plot(Kxs, eigenvalues_y[:,i], color = "black", alpha = 0.2)

plt.xlabel(r"$k_x$", fontsize = 14)
plt.ylabel(r"$E$", fontsize = 14)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 10)
plt.yticks([-2, 0, 2], [-2, 0, 2], fontsize = 10)
plt.tight_layout()
#plt.show()

In [None]:
Ps = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128) #occupied subspace
Qs = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128) #unoccupied subspace

for i in range(Nx):
    for j in range(Ny):
        Ps[i,:,:] += np.outer(eigenstates_y[i,:,j], eigenstates_y[i,:,j].conj())
        Qs[i,:,:] += np.outer(eigenstates_y[i,:,Ny+j], eigenstates_y[i,:,Ny+j].conj())

In [None]:
yops = np.zeros((Nx, Ny*Nbands,Ny*Nbands))
for kx in range(Nx):
    for i in range(Ny):
        yops[kx,i*Nbands:(i+1)*Nbands,i*Nbands:(i+1)*Nbands] = (i+1) * np.eye(Nbands)
    yops[kx,(Ny-1)*Nbands:Ny*Nbands,(Ny-1)*Nbands:Ny*Nbands] -= 0.02 * sz * np.sin(Kxs[kx]/2)

In [None]:
Projpos = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)
Qrojpos = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)
for i in range(Nx):
    Projpos[i,:,:] = Ps[i,:,:] @ yops[i] @ Ps[i,:,:]
    Qrojpos[i,:,:] = Qs[i,:,:] @ yops[i] @ Qs[i,:,:]

In [None]:
Projpos_eigenvalues = np.zeros((Nx,Ny*Nbands))
Projpos_eigenstates = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)

for i in range(Nx):
    vals, vecs = np.linalg.eigh(Projpos[i,:,:])

    ind = np.argsort(vals)
    Projpos_eigenvalues[i,:] = vals[ind]
    Projpos_eigenstates[i,:,:] = vecs[:,ind]

In [None]:
Qrojpos_eigenvalues = np.zeros((Nx,Ny*Nbands))
Qrojpos_eigenstates = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)

for i in range(Nx):
    vals, vecs = np.linalg.eigh(Qrojpos[i,:,:])

    ind = np.argsort(vals)
    Qrojpos_eigenvalues[i,:] = vals[ind]
    Qrojpos_eigenstates[i,:,:] = vecs[:,ind]

In [None]:
slabMz = np.zeros((Ny*Nbands,Ny*Nbands))
for i in range(Ny):
    slabMz[i*Nbands:(i+1)*Nbands,i*Nbands:(i+1)*Nbands] = sz

In [None]:
Mz_vals_P = np.zeros((Nx,Ny))
Mz_vals_Q = np.zeros((Nx,Ny))

for i in range(Nx):
    for j in range(Ny):
        v = Projpos_eigenstates[i,:,Ny+j]
        Mz_vals_P[i,j] = (np.dot(v.conj(), slabMz @ v)).real 

        v = Qrojpos_eigenstates[i,:,Ny+j] 
        Mz_vals_Q[i,j] = (np.dot(v.conj(), slabMz @ v)).real

In [None]:
fig, axs = plt.subplots(figsize=(3.5,3.5))

skip = 2
x = Kxs[::skip]

for j in range(Ny-4,Ny):
    y = Projpos_eigenvalues[::skip,Ny+j]
    axs.plot(x,y, alpha = 0)
    color_values = Mz_vals_P[::skip,j]
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    norm = Normalize(vmin=-1, vmax=1)
    cmap = plt.get_cmap('viridis')
    lc = LineCollection(segments, cmap=cmap, norm=norm, alpha = 1, linewidth = 2)
    lc.set_array(color_values)
    axs.add_collection(lc)

for j in range(Ny-3,Ny):
    y = Qrojpos_eigenvalues[::skip,Ny+j]
    axs.plot(x,y, alpha = 0)
    color_values = Mz_vals_Q[::skip,j]
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    norm = Normalize(vmin=-1, vmax=1)
    cmap = plt.get_cmap('viridis')
    lc = LineCollection(segments[::2], cmap=cmap, norm=norm, alpha = 1, linewidth = 2)
    lc.set_array(color_values[::2])
    axs.add_collection(lc)

axs.set_xlabel(r"$k_x$", fontsize = 16, labelpad = 0)
axs.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 14)
axs.set_ylabel(r"$y$", fontsize = 16, labelpad = -7)
axs.set_yticks([Ny,Ny-1],[Ny,Ny-1], fontsize = 14)
plt.tight_layout()
plt.show()