In [1]:
import numpy as np
from itertools import product
from numba import njit

# ============================================================
# Numba RNG: xorshift64*
# ============================================================

@njit(cache=True)
def _xorshift64star_next(state):
    x = state
    x ^= (x >> np.uint64(12))
    x ^= (x << np.uint64(25))
    x ^= (x >> np.uint64(27))
    state = x
    out = x * np.uint64(2685821657736338717)
    return state, out

@njit(cache=True)
def _u01_from_uint64(u):
    # top 53 bits -> [0,1)
    return ((u >> np.uint64(11)) & np.uint64((1 << 53) - 1)) * (1.0 / float(1 << 53))

@njit(cache=True)
def _rand_idx(state, N):
    state, rnd = _xorshift64star_next(state)
    return state, int(rnd % np.uint64(N))

@njit(cache=True)
def _rand_u01(state):
    state, rnd = _xorshift64star_next(state)
    return state, _u01_from_uint64(rnd)

# ============================================================
# Majority-rule configs for 3-spin cells
# ============================================================

_all_spins = np.array(list(product([-1, 1], repeat=3)), dtype=np.int8)
plus_configs  = _all_spins[np.sum(_all_spins, axis=1) >=  1]   # 4x3
minus_configs = _all_spins[np.sum(_all_spins, axis=1) <= -1]   # 4x3

@njit(cache=True)
def logsumexp(values):
    m = np.max(values)
    return m + np.log(np.sum(np.exp(values - m)))

# ============================================================
# Staggered geometry (as you used)
# ============================================================

@njit(cache=True)
def right_pos_staggered(r):
    if r % 2 == 1:
        k = (r - 1) // 2
        m = 2 + 6 * k
    else:
        k = (r - 2) // 2
        m = 7 + 6 * k
    return np.array([m, m + 2, m + 4], dtype=np.int64)

@njit(cache=True)
def r_max(D):
    r = 1
    while True:
        rp = right_pos_staggered(r)
        dmax = rp[2] - 1
        if dmax > D:
            return r - 1
        r += 1

# ============================================================
# Pool initialization: pools[d, n] = bond value at physical distance d
# ============================================================

@njit(cache=True)
def init_pools(J0, a, D, p, N, seed=12345):
    """
    Returns:
      pools: float64 array (D+1, N), index 0 unused
      rng_state: uint64 state (carry this through for reproducibility w/o repetition)
    """
    pools = np.zeros((D + 1, N), dtype=np.float64)

    state = np.uint64(seed) ^ np.uint64(0x9E3779B97F4A7C15)
    if state == np.uint64(0):
        state = np.uint64(0xD1B54A32D192ED03)

    for d in range(1, D + 1):
        mag = J0 / (d ** a)
        for n in range(N):
            state, u = _rand_u01(state)
            sgn = -1.0 if u < p else 1.0
            pools[d, n] = sgn * mag

    return pools, state

# ============================================================
# One quenched-disorder decimation sample for a given block separation r
# ============================================================

@njit(cache=True)
def _intracell_energy(s, J2, J4):
    # distances inside staggered cell: 2,2,4
    s0, s1, s2 = s[0], s[1], s[2]
    return J2 * (s0 * s1 + s1 * s2) + J4 * (s0 * s2)

@njit(cache=True)
def log_Rpp_Rpm_sample_from_pools(r, pools, state):
    """
    Draws required intracell + intercell bonds from pools (with replacement),
    computes one sample of (log R_pp, log R_pm) for this r.

    pools indexed by PHYSICAL distance d.
    """
    D = pools.shape[0] - 1
    N = pools.shape[1]

    # --- draw intracell couplings for left and right blocks ---
    # (if D < 2 or 4, treat missing as 0)
    JL2 = 0.0
    JL4 = 0.0
    JR2 = 0.0
    JR4 = 0.0

    if D >= 2:
        state, idx = _rand_idx(state, N)
        JL2 = pools[2, idx]
        state, idx = _rand_idx(state, N)
        JR2 = pools[2, idx]
    if D >= 4:
        state, idx = _rand_idx(state, N)
        JL4 = pools[4, idx]
        state, idx = _rand_idx(state, N)
        JR4 = pools[4, idx]

    # --- geometry distances between 3x3 spins ---
    left_pos  = np.array([1, 3, 5], dtype=np.int64)
    right_pos = right_pos_staggered(r)

    dist = np.empty((3, 3), dtype=np.int64)
    for a in range(3):
        for b in range(3):
            dist[a, b] = abs(right_pos[b] - left_pos[a])

    # --- draw intercell bond couplings for each of 9 pairs ---
    Jint = np.zeros((3, 3), dtype=np.float64)
    for a in range(3):
        for b in range(3):
            d = dist[a, b]
            if d <= D:
                state, idx = _rand_idx(state, N)
                Jint[a, b] = pools[d, idx]
            else:
                Jint[a, b] = 0.0

    # --- enumerate block microstates (same as your exact enumeration) ---
    n_plus = plus_configs.shape[0]   # 4
    n_minus = minus_configs.shape[0] # 4

    totals_pp = np.empty(n_plus * n_plus, dtype=np.float64)
    totals_pm = np.empty(n_plus * n_minus, dtype=np.float64)

    # R(++)
    k = 0
    for iL in range(n_plus):
        sL = plus_configs[iL]
        EL = _intracell_energy(sL, JL2, JL4)
        for iR in range(n_plus):
            sR = plus_configs[iR]
            ER = _intracell_energy(sR, JR2, JR4)
            E_int = 0.0
            for a in range(3):
                for b in range(3):
                    E_int += Jint[a, b] * sL[a] * sR[b]
            totals_pp[k] = EL + ER + E_int
            k += 1

    # R(+-)
    k = 0
    for iL in range(n_plus):
        sL = plus_configs[iL]
        EL = _intracell_energy(sL, JL2, JL4)
        for iR in range(n_minus):
            sR = minus_configs[iR]
            ER = _intracell_energy(sR, JR2, JR4)
            E_int = 0.0
            for a in range(3):
                for b in range(3):
                    E_int += Jint[a, b] * sL[a] * sR[b]
            totals_pm[k] = EL + ER + E_int
            k += 1

    log_pp = logsumexp(totals_pp)
    log_pm = logsumexp(totals_pm)
    return log_pp, log_pm, state

# ============================================================
# RG step on pools (distributional RG / population dynamics)
# ============================================================

@njit(cache=True)
def rg_step_pools(pools, state):
    """
    Input:
      pools: (D+1, N) pools[d, n] over physical distances d=1..D
      state: uint64 rng state (carried and advanced)

    Output:
      pools_new: (Dnew+1, N) where Dnew = r_max(D)  (reduced number of distances)
                and each pools_new[r, :] is filled with N renormalized samples
      state: updated rng state
    """
    D = pools.shape[0] - 1
    N = pools.shape[1]

    Dnew = r_max(D)
    if Dnew < 1:
        # no representable intercell separations
        return np.zeros((1, N), dtype=np.float64), state

    pools_new = np.zeros((Dnew + 1, N), dtype=np.float64)

    for r in range(1, Dnew + 1):
        for n in range(N):
            log_pp, log_pm, state = log_Rpp_Rpm_sample_from_pools(r, pools, state)
            pools_new[r, n] = 0.5 * (log_pp - log_pm)

    return pools_new, state

# ============================================================
# RG flow (multiple steps), reproducible w/o repetition
# ============================================================

def rg_flow_pools(J0, a, D, p, N, n_steps, seed=12345, store=True):
    """
    Returns:
      flow: list of pool matrices at each step (if store=True), otherwise final pools
    """
    pools, state = init_pools(J0, a, D, p, N, seed)
    if not store:
        for _ in range(n_steps):
            pools, state = rg_step_pools(pools, state)
        return pools

    flow = [pools.copy()]
    for _ in range(n_steps):
        pools, state = rg_step_pools(pools, state)
        flow.append(pools.copy())
    return flow

In [None]:

# ============================================================
# Example use
# ============================================================
J0   = 1.0
a    = 1.5
D    = 200
p    = 0.5
N    = 2000
steps = 6
seed  = 12345

flow = rg_flow_pools(J0, a, D, p, N, steps, seed=seed, store=True)
print("Step 0 pools shape:", flow[0].shape)
for k in range(1, len(flow)):
    print(f"Step {k} pools shape:", flow[k].shape)
