In [None]:
import sympy as sp
import numpy as np
from sympy.physics.quantum import TensorProduct
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize

In [None]:
def indf(j,k):
    return j * Nz * Nbands + k * Nbands

In [None]:
def get_wire_hamiltonian(h_symbolic, syms, params, corner_pert = None):
    Ny, Nz, Nbands = params["Ny"], params["Nz"], params["Nbands"]

    s0 = sp.eye(2)
    sx = sp.Matrix([[0, 1], [1, 0]])
    sy = sp.Matrix([[0, -sp.I], [sp.I, 0]])
    sz = sp.Matrix([[1, 0], [0, -1]])

    kx_sym, ky_sym, kz_sym = syms 
    
    Ly_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym * (-1.0)), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
    Ly_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * ky_sym), (ky_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

    Lz_nn_pos = sp.integrate(h_symbolic * sp.exp(sp.I * kz_sym * (-1.0)), (kz_sym, -sp.pi, sp.pi))/ (2 * sp.pi)
    Lz_nn_neg = sp.integrate(h_symbolic * sp.exp(sp.I * kz_sym), (kz_sym, -sp.pi, sp.pi))/ (2 * sp.pi)

    Ly_nn_pos = Ly_nn_pos.rewrite(sp.cos).simplify()
    Ly_nn_neg = Ly_nn_neg.rewrite(sp.cos).simplify()
    Lz_nn_pos = Lz_nn_pos.rewrite(sp.cos).simplify()
    Lz_nn_neg = Lz_nn_neg.rewrite(sp.cos).simplify()

    H_diag = h_symbolic
    H_diag -= (Ly_nn_pos * sp.exp(sp.I * ky_sym) + Ly_nn_neg * sp.exp(-sp.I * ky_sym))
    H_diag -= (Lz_nn_pos * sp.exp(sp.I * kz_sym) + Lz_nn_neg * sp.exp(-sp.I * kz_sym))
    H_diag = H_diag.rewrite(sp.cos).simplify()
    H_diag = sp.nsimplify(H_diag, tolerance = 1e-8)

    h = sp.zeros(Ny*Nz*Nbands, Ny*Nz*Nbands)

    for j in range(Ny):
        for k in range(Nz):
            h[indf(j,k):indf(j,k+1),indf(j,k):indf(j,k+1)] = H_diag

            if j > 0:
                h[indf(j-1,k):indf(j-1,k+1), indf(j,k):indf(j,k+1)] = Ly_nn_pos

            if j < Ny - 1:
                h[indf(j+1,k):indf(j+1,k+1), indf(j,k):indf(j,k+1)] = Ly_nn_neg

            if k > 0:
                h[indf(j,k-1):indf(j,k), indf(j,k):indf(j,k+1)] = Lz_nn_pos

            if k < Nz - 1:
                h[indf(j,k+1):indf(j,k+2), indf(j,k):indf(j,k+1)] = Lz_nn_neg

    if corner_pert == "detach":
        h[indf(Ny-1, 0) : indf(Ny-1, 1), indf(Ny-1, 0) : indf(Ny-1, 1)] *= 0.2

    if corner_pert == "gapout":
        mat = TensorProduct(s0, sy)
        pert1 = (1 - sp.cos(kx_sym)) * mat * 1
        pert2 = (1 + sp.cos(kx_sym)) * mat * 1

        h[indf(0,0):indf(0,1), indf(0,0):indf(0,1)] *= 0.2
        h[indf(Ny-1,0):indf(Ny-1,1), indf(Ny-1,0):indf(Ny-1,1)]*= 0.2
        h[indf(0,Nz-1):indf(0,Nz), indf(0,Nz-1):indf(0,Nz)] *= 0.2
        h[indf(Ny-1,Nz-1):indf(Ny-1,Nz), indf(Ny-1,Nz-1):indf(Ny-1,Nz)] *= 0.2

        h[indf(0,0):indf(0,1), indf(0,0):indf(0,1)] += pert1
        h[indf(Ny-1,0):indf(Ny-1,1), indf(Ny-1,0):indf(Ny-1,1)] += pert2
        h[indf(0,Nz-1):indf(0,Nz), indf(0,Nz-1):indf(0,Nz)] += pert2
        h[indf(Ny-1,Nz-1):indf(Ny-1,Nz), indf(Ny-1,Nz-1):indf(Ny-1,Nz)] += pert1

    wire_hfunc = sp.lambdify(kx_sym, h, modules="numpy")

    return wire_hfunc

In [None]:
kx_sym, ky_sym, kz_sym = sp.symbols('k_x k_y k_z', real = True)
ktemp = sp.symbols('ktemp', real = True)
ksymbols = [kx_sym, ky_sym, kz_sym]
alpha = sp.symbols('alpha', real = True, positive = True)
gamma_z, lambda_z = sp.symbols('gamma_z lambda_z', real = True)

In [None]:
s0 = sp.eye(2)
sx = sp.Matrix([[0, 1], [1, 0]])
sy = sp.Matrix([[0, -sp.I], [sp.I, 0]])
sz = sp.Matrix([[1, 0], [0, -1]])

In [None]:
hrtp = sp.sin(2*kx_sym) * sx
hrtp += sp.sin(kx_sym) * sp.sin(ky_sym) * sy
hrtp += - (alpha + sp.cos(2*kx_sym) + sp.cos(ky_sym)) * sz 

In [None]:
H_layered = sp.Matrix(np.zeros((4,4)))
H_layered += TensorProduct(sz,hrtp)
H_layered += TensorProduct(sy,s0) * lambda_z * sp.sin(kz_sym)
H_layered += TensorProduct(sx,s0) * (gamma_z + lambda_z * sp.cos(kz_sym))

In [None]:
H_layered

In [None]:
H_hodti_fixalpha = H_layered.subs({alpha : 1.0, gamma_z: 0.5, lambda_z: 1.0})

In [None]:
Mz = np.array(TensorProduct(s0, sz),dtype = np.complex128)

In [None]:
Nx = 81
Ny = 12
Nz = 12
Nbands = 4
Nocc = 2

params = {}
params["Nx"] = Nx
params["Ny"] = Ny
params["Nz"] = Nz
params["Nbands"] = Nbands
params["Nocc"] = Nocc

In [None]:
Kxs_wire = np.linspace(0, 2*np.pi, Nx, endpoint = False)

In [None]:
wire_hfunc = get_wire_hamiltonian(h_symbolic = H_hodti_fixalpha, syms = ksymbols, params = params, corner_pert = None)

In [None]:
wire_eigenvalues = np.zeros((Nx, Ny*Nz*Nbands))
wire_eigenstates = np.zeros((Nx, Ny*Nz*Nbands, Ny*Nz*Nbands)).astype(np.complex128)

for i, kx in enumerate(Kxs_wire):
    print("kx:", i, kx, end='\r')
    Ham = wire_hfunc(kx).astype(np.complex128)
    if np.allclose(Ham, Ham.T.conj()):
        pass
    else:
        print("Hamiltonian is not hermitian")
        break
    
    evals, evecs = np.linalg.eigh(Ham)
    sort_ind = np.argsort(evals)
    wire_eigenvalues[i,:] = evals[sort_ind]
    wire_eigenstates[i,:,:] = evecs[:, sort_ind]

In [None]:
for i in range(Nx-1):
    print("kx:", i, end='\r')
    valsX1 = wire_eigenvalues[i,:]
    valsX2 = wire_eigenvalues[i+1,:]
    vecsX1 = wire_eigenstates[i,:,:]
    vecsX2 = wire_eigenstates[i+1,:,:]

    Q = abs(vecsX1.conj().T @ vecsX2)   
    ind = linear_sum_assignment(-Q)[1]

    wire_eigenvalues[i+1,:] = valsX2[ind]
    wire_eigenstates[i+1,:,:] = vecsX2[:, ind]

In [None]:
fig = plt.figure(figsize=(4,3))
for i in range(Ny*Nz*Nbands):
    plt.plot(Kxs_wire, wire_eigenvalues[:,i], "-k", alpha = 0.15)

plt.xlabel(r'$k_x$', fontsize = 16, labelpad = -1)
plt.ylabel(r'$E$', fontsize = 16, labelpad = -3)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 14)
plt.yticks([-3, -2, -1, 0, 1, 2, 3], ["-3", "-2", "-1", "0", "1", "2", "3"], fontsize = 14)
plt.show()

In [None]:
H_hodti_fixalpha = H_layered.subs({alpha : 1, gamma_z: 0.5, lambda_z: 1.0})

In [None]:
wire_hfunc = get_wire_hamiltonian(h_symbolic = H_hodti_fixalpha, syms = ksymbols, params = params, corner_pert = "detach")

In [None]:
wire_eigenvalues = np.zeros((Nx, Ny*Nz*Nbands))
wire_eigenstates = np.zeros((Nx, Ny*Nz*Nbands, Ny*Nz*Nbands)).astype(np.complex128)

for i, kx in enumerate(Kxs_wire):
    print("kx:", i, kx, end='\r')
    Ham = wire_hfunc(kx).astype(np.complex128)
    if np.allclose(Ham, Ham.T.conj()):
        pass
    else:
        print("Hamiltonian is not hermitian")
        break
    
    evals, evecs = np.linalg.eigh(Ham)
    sort_ind = np.argsort(evals)
    wire_eigenvalues[i,:] = evals[sort_ind]
    wire_eigenstates[i,:,:] = evecs[:, sort_ind]

In [None]:
for i in range(Nx-1):
    print("kx:", i, end='\r')
    valsX1 = wire_eigenvalues[i,:]
    valsX2 = wire_eigenvalues[i+1,:]
    vecsX1 = wire_eigenstates[i,:,:]
    vecsX2 = wire_eigenstates[i+1,:,:]

    Q = abs(vecsX1.conj().T @ vecsX2)   
    ind = linear_sum_assignment(-Q)[1]

    wire_eigenvalues[i+1,:] = valsX2[ind]
    wire_eigenstates[i+1,:,:] = vecsX2[:, ind]

In [None]:
fig = plt.figure(figsize=(4,3))
for i in range(Ny*Nz*Nbands):
    plt.plot(Kxs_wire, wire_eigenvalues[:,i], "-k", alpha = 0.15)

plt.plot(Kxs_wire, wire_eigenvalues[:,Ny*Nz*2-1], "-", alpha = 1, color = "red")

plt.xlabel(r'$k_x$', fontsize = 16, labelpad = -5)
plt.ylabel(r'$E$', fontsize = 16, labelpad = -10)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 10)
plt.show()

In [None]:
Mz = np.array(TensorProduct(s0, sz))

Mz_wire = np.zeros((Ny*Nz*Nbands, Ny*Nz*Nbands)).astype(np.complex128)

for j in range(Ny): 
    for k in range(Nz):
        Mz_wire[indf(j,k):indf(j,k+1), indf(j,k):indf(j,k+1)] = Mz

In [None]:
Mvals = np.zeros(Nx)
for i in range(Nx):
    v = wire_eigenstates[i,:,Ny*Nz*2-1]
    Mvals[i] = (v.conj().T @ Mz_wire @ v).real

fig = plt.figure(figsize=(3,2))
plt.plot(Kxs_wire, Mvals)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,3))

for i in range(Ny*Nz*Nbands):
    plt.plot(Kxs_wire, wire_eigenvalues[:,i], "-k", alpha = 0.15)

x = Kxs_wire[::2]
y = wire_eigenvalues[:,Ny*Nz*2-1][::2]
color_values = Mvals[::2]
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
norm = Normalize(vmin=-1, vmax=1)
cmap = plt.get_cmap('viridis')
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(color_values)
lc.set_linewidth(2)
lc.set_alpha(1)
ax.add_collection(lc)

plt.xlabel(r'$k_x$', fontsize = 16, labelpad = -1)
plt.ylabel(r'$E$', fontsize = 16, labelpad = -3)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], [r"$0$", r"$\pi/2$", r"$\pi$", r"$3\pi/2$", r"$2\pi$"], fontsize = 14)
plt.yticks([-3, -2, -1, 0, 1, 2, 3], ["-3", "-2", "-1", "0", "1", "2", "3"], fontsize = 14)

cbar = plt.colorbar(lc, ax=ax)
cbar.set_ticks([-1, 1])
cbar.set_ticklabels(['-1', '1']) 
cbar.ax.tick_params(labelsize=14)  
cbar.set_label(r'$\left< M_x \right> $', rotation=90, labelpad=-12, fontsize = 16)
plt.show()