In [54]:
import numpy as np
from itertools import product
from typing import Tuple

# ----------------------------------------------------------------------
# 1. Spin configurations for a 3-site cell (global, read-only)
# ----------------------------------------------------------------------
ALL_SPINS = np.array(list(product([-1, 1], repeat=3)), dtype=np.float64)  # (8, 3)
PLUS_SPINS  = ALL_SPINS[np.sum(ALL_SPINS, axis=1) >= 1]   # (4, 3): cell spin +1
MINUS_SPINS = ALL_SPINS[np.sum(ALL_SPINS, axis=1) <= -1]  # (4, 3): cell spin -1


def intracell_energy(spins: np.ndarray, J2: float, J4: float) -> np.ndarray:
    """Energy inside one 3-site cell: J2*(s0*s1 + s1*s2) + J4*s0*s2"""
    s0, s1, s2 = spins[:, 0], spins[:, 1], spins[:, 2]
    return J2 * (s0*s1 + s1*s2) + J4 * (s0*s2)


def log_R(
    spinsL: np.ndarray, EL: np.ndarray,
    spinsR: np.ndarray, ER: np.ndarray,
    dists: np.ndarray, Jarr: np.ndarray, H: float = 0.0
) -> float:
    """
    log[ ∑_{σL,σR} exp(-βH) ] for fixed cell spins s'L, s'R
    Uses log-sum-exp for stability.
    """
    nL, nR = len(spinsL), len(spinsR)
    terms = np.zeros(nL * nR)

    idx = 0
    for i in range(nL):
        magL = np.sum(spinsL[i])
        for j in range(nR):
            magR = np.sum(spinsR[j])
            Eint = 0.0
            for k in range(len(dists)):
                iL, iR, d = dists[k]
                if d < len(Jarr):
                    Eint += Jarr[d] * spinsL[i, iL] * spinsR[j, iR]
            terms[idx] = EL[i] + ER[j] + Eint + H * (magL + magR)
            idx += 1

    if len(terms) == 0:
        return -np.inf
    mx = np.max(terms)
    s = np.sum(np.exp(terms - mx))
    return np.log(s) + mx if s > 0 else -np.inf


def J_prime(rp: int, Jarr: np.ndarray, max_dist: int) -> float:
    """
    Renormalized coupling J'(r') for cell separation r'
    Uses ferromagnetic symmetry: R(++,++) = R(--,--), R(+,-) = R(-,+)
    → J' = 0.5 * (log R++ - log R+-)
    """
    start = 3 * rp + 1
    left  = np.array([1, 3, 5])
    right = np.array([start, start + 2, start + 4])

    # Build list of inter-cell distances (only if <= max_dist)
    dists = []
    for iL in range(3):
        for iR in range(3):
            d = abs(right[iR] - left[iL])
            if d <= max_dist:
                dists.append((iL, iR, d))
    dists = np.array(dists, dtype=int)

    J2 = Jarr[2] if 2 < len(Jarr) else 0.0
    J4 = Jarr[4] if 4 < len(Jarr) else 0.0

    ELp = intracell_energy(PLUS_SPINS, J2, J4)
    ELm = intracell_energy(MINUS_SPINS, J2, J4)

    log_pp = log_R(PLUS_SPINS, ELp, PLUS_SPINS, ELp, dists, Jarr, H=0.0)
    log_pm = log_R(PLUS_SPINS, ELp, MINUS_SPINS, ELm, dists, Jarr, H=0.0)

    if np.isinf(log_pp) or np.isinf(log_pm):
        return np.nan
    return 0.5 * (log_pp - log_pm)


def H_prime(Jarr: np.ndarray, max_dist: int, H: float) -> float:
    """Renormalized field H' for nearest-neighbor cells (r'=1)"""
    rp = 1
    start = 3 * rp + 1
    left  = np.array([1, 3, 5])
    right = np.array([start, start + 2, start + 4])

    dists = [(iL, iR, abs(right[iR] - left[iL]))
             for iL in range(3) for iR in range(3)
             if abs(right[iR] - left[iL]) <= max_dist]
    dists = np.array(dists, dtype=int)

    J2 = Jarr[2] if 2 < len(Jarr) else 0.0
    J4 = Jarr[4] if 4 < len(Jarr) else 0.0

    ELp = intracell_energy(PLUS_SPINS, J2, J4)
    ELm = intracell_energy(MINUS_SPINS, J2, J4)

    log_pp = log_R(PLUS_SPINS, ELp, PLUS_SPINS, ELp, dists, Jarr, H)
    log_mm = log_R(MINUS_SPINS, ELm, MINUS_SPINS, ELm, dists, Jarr, H)

    if np.isinf(log_pp) or np.isinf(log_mm):
        return np.nan
    return 0.25 * (log_pp - log_mm)


def dH_dH_at_point(Jarr: np.ndarray, max_dist: int, eps: float = 1e-8) -> float:
    """∂H'/∂H at H=0 via finite difference"""
    H0 = H_prime(Jarr, max_dist, 0.0)
    Hp = H_prime(Jarr, max_dist, eps)
    if np.isnan(H0) or np.isnan(Hp):
        return np.nan
    return (Hp - H0) / eps


# ----------------------------------------------------------------------
# 2. MAIN: Build recursion matrix at user-specified (a, J0)
# ----------------------------------------------------------------------
def build_recursion_matrix_at_point(
    a: float,
    J0: float,
    max_dist: int,
    eps: float = 1e-8
) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Build the RG recursion matrix M_ij = ∂J'_i / ∂J_j at given (a, J0).

    Parameters
    ----------
    a : float
        Power-law exponent (e.g. 1.5)
    J0 : float
        Overall coupling strength: J(r) = J0 / r^a
    max_dist : int
        Truncate interactions beyond distance max_dist
    eps : float
        Finite-difference step for Jacobian

    Returns
    -------
    J_vec : np.ndarray
        Current interaction vector J[1], J[2], ..., J[max_dist]
    recursion_matrix : np.ndarray
        (max_dist × max_dist) matrix M_ij
    dHdH : float
        Magnetic scaling: ∂H'/∂H at H=0
    """
    if max_dist < 1:
        raise ValueError("max_dist must be >= 1")
    if a <= 0:
        raise ValueError("a must be > 0")

    # Build current interaction array
    r = np.arange(1, max_dist + 1)
    J_vec = J0 * r**(-a)                    # shape: (max_dist,)
    Jarr = np.zeros(max_dist + 1)
    Jarr[1:] = J_vec                        # Jarr[0] unused

    n = max_dist
    Jac = np.zeros((n, n))                  # recursion matrix

    # Baseline renormalized couplings
    Jp_base = np.array([J_prime(rp, Jarr, max_dist) for rp in range(1, n + 1)])

    # Finite difference for each input J_k
    for j in range(n):
        dJ = eps * max(1.0, abs(Jarr[j + 1]))
        if dJ == 0:
            dJ = eps

        J_plus = Jarr.copy()
        J_minus = Jarr.copy()
        J_plus[j + 1] += dJ
        J_minus[j + 1] -= dJ

        Jp_plus  = np.array([J_prime(rp, J_plus, max_dist)  for rp in range(1, n + 1)])
        Jp_minus = np.array([J_prime(rp, J_minus, max_dist) for rp in range(1, n + 1)])

        # Central difference
        Jac[:, j] = (Jp_plus - Jp_minus) / (2 * dJ)

    # Magnetic field scaling
    dHdH = dH_dH_at_point(Jarr, max_dist, eps=eps)

    return J_vec, Jac, dHdH

def invert_recursion_matrix(M: np.ndarray, rcond: float = 1e-12) -> np.ndarray:
    """
    Safely invert a square recursion matrix M.
    
    Parameters
    ----------
    M : np.ndarray
        Square (n × n) recursion matrix from RG transformation.
    rcond : float
        Cut-off ratio for small singular values in pseudo-inverse.
    
    Returns
    -------
    np.ndarray
        Inverse (or Moore-Penrose pseudo-inverse) of M.
    """
    if M.shape[0] != M.shape[1]:
        raise ValueError("Recursion matrix must be square for inversion.")
    
    try:
        # Try direct inversion
        return np.linalg.inv(M)
    except np.linalg.LinAlgError:
        # Fall back to pseudo-inverse if singular
        print("Warning: Recursion matrix is singular. Using pseudo-inverse.")
        return np.linalg.pinv(M, rcond=rcond)

def propagate_perturbation_backwards(recursion_matrix: np.ndarray,
                                     delta_J_prime: np.ndarray,
                                     rcond: float = 1e-12) -> np.ndarray:
    """
    Given a desired change δJ' after one RG step, compute the required
    microscopic perturbation δJ in the original system.
    
    δJ = M⁻¹ · δJ'
    
    Parameters
    ----------
    recursion_matrix : np.ndarray
        (n × n) matrix M_ij = ∂J'_i / ∂J_j
    delta_J_prime : np.ndarray
        (n,) vector of desired changes in renormalized couplings
    rcond : float
        Tolerance for pseudo-inverse
    
    Returns
    -------
    np.ndarray
        Required change δJ in original couplings
    """
    Minv = invert_recursion_matrix(recursion_matrix, rcond=rcond)
    return Minv @ delta_J_prime