In [None]:
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment

In [None]:
def get_U1Gauge(states):
    (Nx, Ny, _) = states.shape

    U1_gauges = np.zeros(Nx*Ny*2).reshape((Nx,Ny,2)).astype(np.complex128)

    for ii in range(Nx):
        for jj in range(Ny):
            dotprod1 = np.dot(np.conj(states[(ii+1)%Nx,jj,:].T), states[ii,jj,:])
            dotprod2 = np.dot(np.conj(states[ii,(jj+1)%Ny,:].T), states[ii,jj,:])

            U1_gauges[ii,jj,0] = dotprod1/np.abs(dotprod1)
            U1_gauges[ii,jj,1] = dotprod2/np.abs(dotprod2)

    return U1_gauges

def get_numerical_BerryCurvature(states):
    (Nx, Ny, _) = states.shape

    U1_gauges = get_U1Gauge(states)
    
    BerryCurvature = np.zeros(Nx*Ny).reshape((Nx,Ny)).astype(np.complex128)

    for ii in range(Nx):
        for jj in range(Ny):
            BerryCurvature[ii,jj]  = np.log(U1_gauges[ii,jj,0] 
                                            * U1_gauges[(ii+1)%Nx,jj,1] 
                                            / U1_gauges[ii,(jj+1)%Ny,0] 
                                            / U1_gauges[ii,jj,1])
            
    return BerryCurvature

In [None]:
def get_surface_hamiltonian(h_symbolic, ksymbols, params, direction = "x"):
    Nx, Ny, Nbands = params['Nx'], params['Ny'], params['Nbands']

    kx_sym, ky_sym = ksymbols

    if direction == "x":
        Lx_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * kx_sym * (-1.0)), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * kx_sym), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nnn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * kx_sym * (-1.0)), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Lx_nnn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * kx_sym), (kx_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

        Lx_nn_pos = Lx_nn_pos.rewrite(sp.cos).simplify()
        Lx_nn_neg = Lx_nn_neg.rewrite(sp.cos).simplify()
        Lx_nnn_pos = Lx_nnn_pos.rewrite(sp.cos).simplify()
        Lx_nnn_neg = Lx_nnn_neg.rewrite(sp.cos).simplify()
        
        H_diag = h_symbolic - (Lx_nn_pos * sp.exp(sp.I * kx_sym) + Lx_nn_neg * sp.exp(-sp.I * kx_sym) + Lx_nnn_pos * sp.exp(sp.I * 2 * kx_sym) + Lx_nnn_neg * sp.exp(-sp.I * 2 * kx_sym))
        H_diag = H_diag.rewrite(sp.cos).simplify()
        H_diag = sp.nsimplify(H_diag, tolerance = 1e-8)
        
        h = sp.zeros(Nx*Nbands, Nx*Nbands)

        for i in range(Nx):
            h[i*Nbands:(i+1)*Nbands, i*Nbands:(i+1)*Nbands] = H_diag[:,:]

            if i > 0: 
                h[(i-1)*Nbands:i*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nn_pos[:,:]

            if i > 1:
                h[(i-2)*Nbands:(i-1)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nnn_pos[:,:]

            if i < Nx - 1:
                h[(i+1)*Nbands:(i+2)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nn_neg[:,:]

            if i < Nx - 2:
                h[(i+2)*Nbands:(i+3)*Nbands, i*Nbands:(i+1)*Nbands] = Lx_nnn_neg[:,:]

        slab_hfunc = sp.lambdify(ky_sym, h, modules = "numpy")

    elif direction == "y":
        Ly_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym * (-1.0)), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nnn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * ky_sym * (-1.0)), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
        Ly_nnn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * 2 * ky_sym), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

        Ly_nn_pos = Ly_nn_pos.rewrite(sp.cos).simplify()
        Ly_nn_neg = Ly_nn_neg.rewrite(sp.cos).simplify()
        Ly_nnn_pos = Ly_nnn_pos.rewrite(sp.cos).simplify()
        Ly_nnn_neg = Ly_nnn_neg.rewrite(sp.cos).simplify()

        H_diag = h_symbolic - (Ly_nn_pos * sp.exp(sp.I * ky_sym) + Ly_nn_neg * sp.exp(-sp.I * ky_sym) + Ly_nnn_pos * sp.exp(sp.I * 2 * ky_sym) + Ly_nnn_neg * sp.exp(-sp.I * 2 * ky_sym))
        H_diag = H_diag.rewrite(sp.cos).simplify()
        H_diag = sp.nsimplify(H_diag, tolerance = 1e-8)
        
        h = sp.zeros(Ny*Nbands, Ny*Nbands)

        for i in range(Ny):
            h[i*Nbands:(i+1)*Nbands, i*Nbands:(i+1)*Nbands] = H_diag[:,:]

            if i > 0: 
                h[(i-1)*Nbands:i*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nn_pos[:,:]

            if i > 1:
                h[(i-2)*Nbands:(i-1)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nnn_pos[:,:]

            if i < Ny - 1:
                h[(i+1)*Nbands:(i+2)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nn_neg[:,:]

            if i < Ny - 2:
                h[(i+2)*Nbands:(i+3)*Nbands, i*Nbands:(i+1)*Nbands] = Ly_nnn_neg[:,:]
            
        slab_hfunc = sp.lambdify(kx_sym, h, modules = "numpy")

    else:
        raise ValueError("Invalid direction")
    
    return slab_hfunc

In [None]:
s0 = sp.eye(2)
sx = sp.Matrix([[0, 1], [1, 0]])
sy = sp.Matrix([[0, -sp.I], [sp.I, 0]])
sz = sp.Matrix([[1, 0], [0, -1]])

In [None]:
kx_sym, ky_sym = sp.symbols('k_x k_y', real = True)
ksymbols = [kx_sym, ky_sym]
alpha = sp.symbols('alpha', real = True, positive = True)

In [None]:
H_dartboard = - sp.sin(kx_sym) * sp.sin(2*ky_sym) * sx
H_dartboard += sp.sin(2*kx_sym) * sp.sin(ky_sym) * sy 
H_dartboard += (alpha + sp.cos(2*kx_sym) + sp.cos(2*ky_sym)) * sz

In [None]:
Nx = 100
Ny = 30
Nbands = 2
Nocc = 1

params = {}
params["Nx"] = Nx
params["Ny"] = Ny
params["Nz"] = 0
params["Nbands"] = Nbands
params["Nocc"] = Nocc

In [None]:
Kxs = np.linspace(0, 2*np.pi, Nx, endpoint = False)
Kys = np.linspace(0, 2*np.pi, Ny, endpoint = False)

In [None]:
H_fixparam = H_dartboard.subs({alpha : 1})
hfunc = sp.lambdify((kx_sym,ky_sym), H_fixparam, modules = "numpy")

In [None]:
eigenvalues = np.zeros((Nx,Ny,Nbands))
eigenstates = np.zeros((Nx,Ny,Nbands,Nbands), dtype = np.complex128)

for i in range(Nx):
    for j in range(Ny):
        vals, vecs = np.linalg.eigh(hfunc(Kxs[i],Kys[j]))

        ind = np.argsort(vals)

        eigenvalues[i,j] = vals[ind]
        eigenstates[i,j] = vecs[:,ind]

In [None]:
slab_hfunc_Y = get_surface_hamiltonian(h_symbolic = H_fixparam, ksymbols = ksymbols, params = params, direction = "y")

In [None]:
eigenvalues_y = np.zeros((Nx,Ny*Nbands))
eigenstates_y = np.zeros((Nx,Ny*Nbands,Ny*Nbands), dtype = np.complex128)

for idx, kx in enumerate(Kxs):
    vals, vecs = np.linalg.eigh(slab_hfunc_Y(kx))

    ind = np.argsort(vals)
    eigenvalues_y[idx,:] = vals[ind]
    eigenstates_y[idx,:,:] = vecs[:,ind]

In [None]:
for i in range(Nx-1):
    v0 = eigenstates_y[i,:,:]
    v1 = eigenstates_y[i+1,:,:]
    vals1 = eigenvalues_y[i+1,:]

    Q = abs(v0.conj().T @ v1)
    ind = linear_sum_assignment(-Q)[1]

    eigenvalues_y[i+1,:] = vals1[ind]
    eigenstates_y[i+1,:,:] = v1[:, ind]

In [None]:
fig = plt.figure(figsize = (4,3))
for i in range(Ny*Nbands):
    plt.plot(Kxs, eigenvalues_y[:,i], color = "black", alpha = 0.3)

plt.xlabel(r"$k_x$", fontsize = 16, labelpad = -3)
plt.ylabel(r"$E$", fontsize = 16, labelpad = -7) 
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 14)
plt.yticks([-3, 0, 3], fontsize = 14)
plt.show()

In [None]:
Nx2 = int(Nx/2)
Ny2 = int(Ny/2)

cornerBZ_states = eigenstates[:Nx2,:Ny2,:,0]

In [None]:
bc = get_numerical_BerryCurvature(cornerBZ_states)
np.sum(bc / (2*np.pi*1.j))