In [2]:
import sympy as sp

# =============================================================================
# Second-order Schrieffer–Wolff check (d-wave)
#
# Long-wavelength conventions used in the appendix:
# - dx,dy are commuting stand-ins for ∂x,∂y (no derivatives acting on A-fields)
# - We keep only what is needed for ΔH^(2) to O((A⊥)^2) and up to two derivatives
#   => we only match the O(A⊥) part of PVQ, since PVQ enters as (PVQ)(QVP).
# =============================================================================

I = sp.I
kron = sp.kronecker_product

# --- Helpers ---------------------------------------------------------------

def pauli():
    sx = sp.Matrix([[0, 1], [1, 0]])
    sy = sp.Matrix([[0, -I], [I, 0]])
    sz = sp.Matrix([[1, 0], [0, -1]])
    s0 = sp.eye(2)
    return s0, sx, sy, sz

def msimplify(M):
    return M.applyfunc(lambda e: sp.simplify(sp.together(e)))

def assert_mat_equal(A, B, label=""):
    D = msimplify(A - B)
    if any(el != 0 for el in D):
        raise AssertionError(f"{label} FAILED.\nDifference =\n{D}")

def select_total_degree(expr, vars_, degree):
    """
    Return the part of expr with total polynomial degree == degree
    in the variables vars_.
    """
    expr = sp.expand(expr)
    poly = sp.Poly(expr, *vars_, domain="EX")
    out = 0
    for mon, coeff in poly.terms():
        if sum(mon) == degree:
            term = coeff
            for v, p in zip(vars_, mon):
                if p:
                    term *= v**p
            out += term
    return sp.simplify(out)

def select_degree_matrix(M, vars_, degree):
    return M.applyfunc(lambda e: select_total_degree(e, vars_, degree))

def tau_decompose_4x4(M, tau0, taux, tauy, tauz):
    # Orthogonality in the rank-2 subspace: Tr(tau_i tau_j) = 2 δ_ij
    return {
        "a0": sp.simplify((tau0*M).trace()/2),
        "ax": sp.simplify((taux*M).trace()/2),
        "ay": sp.simplify((tauy*M).trace()/2),
        "az": sp.simplify((tauz*M).trace()/2),
    }

# --- Symbols ---------------------------------------------------------------

Cx, J = sp.symbols("C_x J", real=True)
m, Kzd = sp.symbols("m Kz_d", real=True)

# Commutative long-wavelength stand-ins
dx, dy = sp.symbols("d_x d_y", commutative=True)

# Longitudinal gauge
Aparx, Apary = sp.symbols("Aparx Apary", commutative=True)

# Transverse gauge (texture)
Apx1, Apx2, Apy1, Apy2 = sp.symbols(
    "Aperp_x1 Aperp_x2 Aperp_y1 Aperp_y2", commutative=True
)
Aperp_vars = [Apx1, Apx2, Apy1, Apy2]

# --- Pauli matrices (spin σ, sublattice η) ---------------------------------

s0, sx, sy, sz = pauli()
eta0, etax, etay, etaz = s0, sx, sy, sz
Id4 = sp.eye(4)

# --- Unperturbed 4×4 Hamiltonian and projector -----------------------------

H0 = Cx*kron(etax, s0) - J*kron(etaz, sz)
E0 = sp.sqrt(Cx**2 + J**2)

P = (Id4 - H0/E0)/2
Q = Id4 - P

# Sanity checks
assert_mat_equal(msimplify(P*P), P, "P^2 = P")
assert_mat_equal(msimplify(P*Q), sp.zeros(4), "P Q = 0")

sin_th = sp.simplify(Cx/E0)
cos_th = sp.simplify(J/E0)

# Embedded τ matrices (supported on P-subspace)
tau0 = msimplify(P*kron(eta0, s0)*P)
taux = msimplify(P*kron(etax, sx)*P)
tauy = msimplify(P*kron(etax, sy)*P)
tauz = msimplify(P*kron(eta0, sz)*P)

# --- Full spin covariant derivatives D = D|| + T ----------------------------

Dparx = dx*s0 + I*Aparx*sz/2
Dpary = dy*s0 + I*Apary*sz/2

Tx = I*(Apx1*sx + Apx2*sy)/2
Ty = I*(Apy1*sx + Apy2*sy)/2

Dx = Dparx + Tx
Dy = Dpary + Ty

D2 = sp.expand(Dx*Dx + Dy*Dy)
anti_xy = sp.expand(Dx*Dy + Dy*Dx)   # {Dx, Dy}

# ----------------------------------------------------------------------------
# 4×4 perturbation entering SW at this order:
#
# - isotropic:  k^2/(2m) -> -(D^2)/(2m)
# - d-wave:     (Kz/2!) g2(k) ηz with g2 = kx ky
#               kx ky -> (-iDx)(-iDy) = -Dx Dy
#               Weyl symm: W[DxDy] = 1/2 {Dx, Dy}
#               => term = (Kz/2) * ( - 1/2 {Dx,Dy} ) ηz = -(Kz/4) ηz {Dx,Dy}
# ----------------------------------------------------------------------------
V_quad = -(1/(2*m))*kron(eta0, D2) - (Kzd/4)*kron(etaz, anti_xy)

# --- Analytic flip sector V_f (as in the appendix) --------------------------

# B = A⊥_x ∂x + A⊥_y ∂y,   C = A⊥_x ∂y + A⊥_y ∂x
B1 = Apx1*dx + Apy1*dy
B2 = Apx2*dx + Apy2*dy
C1 = Apx1*dy + Apy1*dx
C2 = Apx2*dy + Apy2*dx

Bdot_sigma = B1*sx + B2*sy
Cdot_sigma = C1*sx + C2*sy

# V0,f from isotropic k^2/(2m)
V0f = -(I/(2*m))*kron(eta0, Bdot_sigma)

# Vz,f from d-wave Kz term (note the 1/4 from (Kz/2!) + Weyl symm)
Vzf = -(I*Kzd/4)*kron(etaz, Cdot_sigma)

Vf = msimplify(V0f + Vzf)

# ----------------------------------------------------------------------------
# Check 1: flip extraction at the order we actually use
# We only match the O(A⊥) part of PVQ, since ΔH^(2) keeps O((A⊥)^2).
# ----------------------------------------------------------------------------
PVQ_full = msimplify(P*V_quad*Q)
PVQ_full_OA = select_degree_matrix(PVQ_full, Aperp_vars, degree=1)

PVQ_flip = msimplify(P*Vf*Q)  # already O(A⊥)
assert_mat_equal(PVQ_full_OA, PVQ_flip, "Flip extraction: (PVQ)|_{O(A⊥)}")

QVP_full_OA = select_degree_matrix(msimplify(Q*V_quad*P), Aperp_vars, degree=1)
QVP_flip = msimplify(Q*Vf*P)
assert_mat_equal(QVP_full_OA, QVP_flip, "Flip extraction: (QVP)|_{O(A⊥)}")

# ----------------------------------------------------------------------------
# Second-order SW correction:
# ΔH^(2) = -(1/(2E0)) P Vf Q Vf P
# ----------------------------------------------------------------------------
dH2 = msimplify(-(1/(2*E0)) * P*Vf*Q*Vf*P)

# Decompose into τ basis and compare to closed form (B32)
coeffs = tau_decompose_4x4(dH2, tau0, taux, tauy, tauz)

BdotB = sp.expand(B1**2 + B2**2)
CdotC = sp.expand(C1**2 + C2**2)
BdotC = sp.expand(B1*C1 + B2*C2)

a0_expected = sp.simplify(cos_th**2/(8*m**2*E0)*BdotB + Kzd**2/(32*E0)*CdotC)
az_expected = sp.simplify(Kzd*cos_th/(8*m*E0)*BdotC)

assert sp.simplify(coeffs["ax"]) == 0
assert sp.simplify(coeffs["ay"]) == 0
assert sp.simplify(coeffs["a0"] - a0_expected) == 0
assert sp.simplify(coeffs["az"] - az_expected) == 0

# ----------------------------------------------------------------------------
# Dot product identities (B33)–(B35) in terms of g_jk = A⊥_j · A⊥_k
# ----------------------------------------------------------------------------
gxx = Apx1**2 + Apx2**2
gyy = Apy1**2 + Apy2**2
gxy = Apx1*Apy1 + Apx2*Apy2

BdotB_g = sp.expand(gxx*dx**2 + gyy*dy**2 + gxy*(dx*dy + dy*dx))
CdotC_g = sp.expand(gxx*dy**2 + gyy*dx**2 + gxy*(dx*dy + dy*dx))
BdotC_g = sp.expand(gxy*(dx**2 + dy**2) + sp.Rational(1,2)*(gxx + gyy)*(dx*dy + dy*dx))

assert sp.simplify(BdotB - BdotB_g) == 0
assert sp.simplify(CdotC - CdotC_g) == 0
assert sp.simplify(BdotC - BdotC_g) == 0

print("OK: (PVQ)|_{O(A⊥)} matches analytic Vf (d-wave).")
print("OK: ΔH^(2) matches Eq. (B32) with only τ0 and τz components.")
print("OK: dot identities (B33)–(B35) verified.")


OK: (PVQ)|_{O(A⊥)} matches analytic Vf (d-wave).
OK: ΔH^(2) matches Eq. (B32) with only τ0 and τz components.
OK: dot identities (B33)–(B35) verified.
