In [None]:
import sympy as sp
import numpy as np
from sympy.physics.quantum import TensorProduct
import matplotlib.pyplot as plt
from matplotlib import cm
import scipy as scp

In [None]:
def make_values_continuous(nnu_vals):
    if len(nnu_vals.shape) == 2:
        (N1, N2) = nnu_vals.shape

        for k in range(N2):
            for i in range(1, N1):
                diff = nnu_vals[i, k] - nnu_vals[i - 1, k]
                if abs(diff) > 0.5:
                    nnu_vals[i, k] -= 1 * np.sign(diff)
    else:
        N1 = len(nnu_vals)
        for i in range(1, N1):
            diff = nnu_vals[i] - nnu_vals[i - 1]
            if abs(diff) > 0.5:
                nnu_vals[i] -= 1 * np.sign(diff)

    return nnu_vals

In [None]:
kx_sym, ky_sym, kz_sym = sp.symbols('k_x k_y k_z', real = True)
ktemp = sp.symbols('ktemp', real = True)
ksymbols = [kx_sym, ky_sym, kz_sym]
alpha = sp.symbols('alpha', real = True, positive = True)
gamma_z, lambda_z = sp.symbols('gamma_z lambda_z', real = True)

In [None]:
s0 = sp.eye(2)
sx = sp.Matrix([[0, 1], [1, 0]])
sy = sp.Matrix([[0, -sp.I], [sp.I, 0]])
sz = sp.Matrix([[1, 0], [0, -1]])

In [None]:
hrtp = sp.sin(2*kx_sym) * sx
hrtp += sp.sin(kx_sym) * sp.sin(ky_sym) * sy
hrtp += - (alpha + sp.cos(2*kx_sym) + sp.cos(ky_sym)) * sz 

In [None]:
H_layered = sp.Matrix(np.zeros((4,4)))
H_layered += TensorProduct(sz,hrtp)
H_layered += TensorProduct(sy,s0) * lambda_z * sp.sin(kz_sym)
H_layered += TensorProduct(sx,s0) * (gamma_z + lambda_z * sp.cos(kz_sym))

In [None]:
H_layered

In [None]:
Nx = 41
Ny = 41
Nz = 41
Nbands = 4
Nocc = 2

params = {}
params["Nx"] = Nx
params["Ny"] = Ny
params["Nz"] = Nz
params["Nbands"] = Nbands
params["Nocc"] = Nocc

In [None]:
Kxs = np.linspace(0, 2 * np.pi, Nx, endpoint=False) + 1e-10
Kys = np.linspace(0, 2 * np.pi, Ny, endpoint=False) + 1e-10
Kzs = np.linspace(0, 2 * np.pi, Nz, endpoint=False) + 1e-10

In [None]:
H_hodti_fixalpha = H_layered.subs({alpha : 1, gamma_z: 0.5, lambda_z: 1.0})

In [None]:
hfunc = sp.lambdify((kx_sym,ky_sym,kz_sym), H_hodti_fixalpha, modules = "numpy")

In [None]:
eigenvalues = np.zeros((Nx, Ny, Nz, Nbands))
eigenvectors = np.zeros((Nx, Ny, Nz, Nbands, Nbands)).astype(np.complex128)

for i, kx in enumerate(Kxs):
    for j, ky in enumerate(Kys):
        for k, kz in enumerate(Kzs):
            evals, evecs = np.linalg.eigh(hfunc(kx, ky, kz))

            idx = np.argsort(evals)
            eigenvalues[i, j, k, :] = evals[idx]
            eigenvectors[i, j, k, :, :] = evecs[:, idx]

In [None]:
links_occ_z = np.zeros((Nx, Ny, Nz, Nocc, Nocc)).astype(np.complex128)

for i in range(Nx):
    for j in range(Ny):
        for k in range(Nz):
            overlap_occ = eigenvectors[i, j, (k + 1) % Nz, :, :Nocc].conj().T @ eigenvectors[i, j, k, :, :Nocc]

            U_occ, _, Vh_occ = np.linalg.svd(overlap_occ)

            links_occ_z[i, j, k, :, :] = U_occ @ Vh_occ

In [None]:
W_occ_z = np.zeros((Nx, Ny, Nz, Nocc, Nocc)).astype(np.complex128)

for i in range(Nx):
    for j in range(Ny):
        for k in range(Nz):

            W = np.eye(Nocc).astype(np.complex128)

            for kp in range(Nz):
                W = links_occ_z[i, j, (k + kp) % Nz, :, :] @ W

            W_occ_z[i, j, k, :, :] = W

In [None]:
nu_vals_occ_z = np.zeros((Nx, Ny, Nz, Nocc))
nu_vecs_occ_z = np.zeros((Nx, Ny, Nz, Nocc, Nocc)).astype(np.complex128)

for i in range(Nx):
    for j in range(Ny):
        for k in range(Nz):
            T, Z = scp.linalg.schur(W_occ_z[i, j, k, :, :])
            evals = scp.linalg.eigvals(T)
            angles = np.angle(evals) / (2 * np.pi)

            idx = np.argsort(angles)
            nu_vals_occ_z[i, j, k, :] = angles[idx]
            nu_vecs_occ_z[i, j, k, :, :] = Z[:, idx]

In [None]:
fig, ax = plt.subplots(1,1, subplot_kw={"projection": "3d"}, figsize = (4,3))
plt.subplots_adjust(bottom=0.1, right=1, top=2)

Xx = np.linspace(0, 2*np.pi, Nx, endpoint=False)
Yy = np.linspace(0, 2*np.pi, Ny, endpoint=False)
X, Y = np.meshgrid(Xx, Yy)
index = 1

vmin = np.min([nu_vals_occ_z[:,:,0,0], nu_vals_occ_z[:,:,0,1]])
vmax = np.max([nu_vals_occ_z[:,:,0,0], nu_vals_occ_z[:,:,0,1]])

surf = ax.plot_surface(X.T, Y.T, nu_vals_occ_z[:,:,index,0], cmap=cm.viridis,  vmin = vmin, vmax = vmax,
                          linewidth=0, antialiased=False)
surf = ax.plot_surface(X.T, Y.T, nu_vals_occ_z[:,:,index,1], cmap=cm.viridis,  vmin = vmin, vmax = vmax,
                          linewidth=0, antialiased=False)

ax.set_xlabel(r"$k_x$", fontsize = 16, labelpad = 7)
ax.set_ylabel(r"$k_y$", fontsize = 16, labelpad = 7)

ax.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], labels=["0", r"$\pi/2$", r"$\pi$", r"$3 \pi /2$", r"$2 \pi$"], fontsize=14)
ax.set_yticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], labels=["0", r"$\pi/2$", r"$\pi$", r"$3 \pi /2$", r"$2 \pi$"], fontsize=14)
ax.set_zticks([-0.1, -0.05, 0, 0.05, 0.1])
ax.set_zticklabels(["-0.10", "-0.05", "0.00", "0.05", "0.10"], fontsize=14)

ax.view_init(elev=10., azim=46)
ax.set_title(r"$\nu^{occ}_z(k_x, k_y)$", fontsize = 16, y = 0.9)
plt.show()

In [None]:
wb_occ_z_lower = np.zeros((Nx, Ny, Nz, Nbands)).astype(np.complex128)
wb_occ_z_upper = np.zeros((Nx, Ny, Nz, Nbands)).astype(np.complex128)

E_ind = 0
for i in range(Nx):
    for j in range(Ny):
        for k in range(Nz):
            wb_occ_z_lower[i, j, k, :] = eigenvectors[i, j, k, :, E_ind + 0] * nu_vecs_occ_z[i, j, k, 0, 0] + eigenvectors[i, j, k, :, E_ind + 1] * nu_vecs_occ_z[i, j, k, 1, 0]
            wb_occ_z_upper[i, j, k, :] = eigenvectors[i, j, k, :, E_ind + 0] * nu_vecs_occ_z[i, j, k, 0, 1] + eigenvectors[i, j, k, :, E_ind + 1] * nu_vecs_occ_z[i, j, k, 1, 1]


In [None]:
Mz = np.array(TensorProduct(s0, sz),dtype = np.complex128)

In [None]:
mirror_vals_z = np.zeros((Nx,Ny,2))
for i in range(Nx):
    for j in range(Ny):
        mirror_vals_z[i,j,0] = (wb_occ_z_lower[i,j,0,:].conj().T @ Mz @ wb_occ_z_lower[i,j,0,:]).real
        mirror_vals_z[i,j,1] = (wb_occ_z_upper[i,j,0,:].conj().T @ Mz @ wb_occ_z_upper[i,j,0,:]).real

In [None]:
fig = plt.figure(figsize = (4,3))
plt.imshow(mirror_vals_z[:,:,0].T, origin = "lower", cmap = "viridis", vmin = -1, vmax = 1, extent = [0, 2*np.pi, 0, 2*np.pi])
cbar = plt.colorbar()
cbar.set_ticks([-1, 1])
cbar.set_ticklabels(['-1', '1'], fontsize = 12) 
cbar.set_label(r'$\left< M_x \right> $', rotation=90, labelpad=-10, fontsize = 16)
plt.xlabel(r"$k_x$", fontsize = 16, labelpad = -2)
plt.ylabel(r"$k_y$", fontsize = 16, labelpad = -12)
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi],["0", r"$\pi/2$" ,r"$\pi$", r"$3 \pi /2 $", r"$2 \pi$"], fontsize = 14)
plt.yticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi],["0", r"$\pi/2$" ,r"$\pi$", r"$3 \pi /2 $", r"$2 \pi$"], fontsize = 14)
plt.show()

In [None]:
nnuvals_occ_lower_zy = np.zeros((Nx, Nz))
nnuvals_occ_upper_zy = np.zeros((Nx, Nz))

for i in range(Nx):
    for k in range(Nz):
        W_lower = 1.0
        W_upper = 1.0

        for j in range(Ny):
            overlap_lower = wb_occ_z_lower[i, (j + 1) % Ny, k, :].conj().T  @ wb_occ_z_lower[i, j, k, :]
            overlap_upper = wb_occ_z_upper[i, (j + 1) % Ny, k, :].conj().T  @ wb_occ_z_upper[i, j, k, :]

            W_lower = overlap_lower * W_lower / np.abs(overlap_lower)
            W_upper = overlap_upper * W_upper / np.abs(overlap_upper)

        nnuvals_occ_lower_zy[i, k] = np.angle(W_lower) / (2 * np.pi)
        nnuvals_occ_upper_zy[i, k] = np.angle(W_upper) / (2 * np.pi)

In [None]:
for k in range(Nz):
    nnuvals_occ_lower_zy[:,k] = make_values_continuous(nnuvals_occ_lower_zy[:,k])
    nnuvals_occ_upper_zy[:,k] = make_values_continuous(nnuvals_occ_upper_zy[:,k])

In [None]:
Kxs = np.linspace(0, 2*np.pi, Nx, endpoint = False)

fig, axs = plt.subplots(1,1, figsize = (4,3))
plt.subplots_adjust(bottom=0.1, right=1)
plt.subplots_adjust(wspace=0.5)

pd_lower_zy = np.array([np.sum(nnuvals_occ_lower_zy[i,:])/Nz for i in range(Nx)])
pd_upper_zy = np.array([np.sum(nnuvals_occ_upper_zy[i,:])/Nz for i in range(Nx)])

axs.plot(Kxs, pd_lower_zy, "-", label = r"$\nu^{\nu^{-}_z}_y$",  lw = 4)
axs.plot(Kxs, pd_upper_zy, "-", label = r"$\nu^{\nu^{+}_z}_y$", lw = 4)

axs.set_xlabel(r"$k_x$", fontsize = 16, labelpad = -3)
axs.set_ylabel(r"$\nu^{\nu_z}_{y}$", fontsize = 16, labelpad = -5)
axs.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], ["0", r"$\pi/2$", r"$\pi$", r"$3 \pi/2$", r"$2\pi$"], fontsize = 14)
axs.set_yticks([-1, 0, 1], ["-1", "0", "1"], fontsize = 14)

plt.legend(fontsize = 14, loc = "upper right")
plt.show()