In [2]:
import math
from typing import List, Tuple, Dict, Any

# --------- helpers: ages and virtual power fences on [0, t] ----------
def compute_ages(t: int, last_sync_at: List[int]) -> List[int]:
    """Age at time t is (t - last_sync_at[i]). Assumes t >= last_sync."""
    return [t - s for s in last_sync_at]

def virt_power_edges(t: int, k: int, p: float, g: int = 1) -> List[float]:
    """
    Virtual edges with m = k + g bins on [0, t]:
      x_j = t * (j/m)^p,  j = 0..m
    Only the first k fences (x_1..x_k) are enforced; last g are ghost headroom.
    """
    assert p > 0 and g >= 1 and t >= 0
    m = k + g
    return [t * ((j / m) ** p) for j in range(0, m + 1)]

def virt_power_cutoffs(t: int, k: int, p: float, g: int = 1) -> List[float]:
    """Fences to enforce for ranks j = 1..k: x_j = t * (j/(k+g))^p."""
    m = k + g
    return [t * ((j / m) ** p) for j in range(1, k + 1)]

# --------- main step: arbitrary-t controller (age 0 after sync) ----------
def step_power_ghosts_zero_age(
    t: int,
    last_sync_at: List[int],
    p: float = 2.0,
    g: int = 1,
    always_fresh: bool = True,
) -> Tuple[List[int], Dict[str, Any]]:
    """
    Enforce power-spaced fences with g ghost bins at *arbitrary* timestep t.
    Ages are computed as age[i] = t - last_sync_at[i]. A sync at time t sets age to 0.

    Inputs
    ------
    t : current timestep (integer, must be >= every last_sync_at[i])
    last_sync_at : list length k; last_sync_at[i] is the timestep when worker i was last synced
    p : power parameter (>0). p=1 uniform; p>1 favors older ages; p<1 favors younger
    g : number of ghost bins on the right (>=1). Default 1 (natural drop of oldest)
    always_fresh : if True and no fence sync happens, reset the *youngest* (age>0) to ensure a fresh worker

    Returns
    -------
    new_last_sync_at : updated list (synced workers get timestamp = t)
    info : dict with telemetry:
        't', 'ages_before', 'ages_after', 'synced_indices',
        'fences' (x_1..x_k), 'edges' (x_0..x_{k+g}), 'sorted_order'
    """
    k = len(last_sync_at)
    assert all(isinstance(s, int) for s in last_sync_at), "last_sync_at must be integers"
    assert all(t >= s for s in last_sync_at), "t must be >= every last_sync_at[i]"
    assert p > 0 and g >= 1 and t >= 0

    ages = compute_ages(t, last_sync_at)
    ages_before = ages[:]

    fences = virt_power_cutoffs(t, k, p, g)   # x_1..x_k
    edges  = virt_power_edges(t, k, p, g)     # x_0..x_{k+g}

    synced: List[int] = []

    # Enforce fences by syncing the *violating rank* (j-th smallest)
    while True:
        order = sorted(range(k), key=lambda i: ages[i])   # ascending ages
        sorted_ages = [ages[i] for i in order]
        viol_j = next((j for j in range(1, k + 1) if sorted_ages[j - 1] > fences[j - 1]), None)
        if viol_j is None:
            break

        i_sync = order[viol_j - 1]
        # Sync at time t -> last_sync_at[i] = t; age becomes 0
        last_sync_at[i_sync] = t
        ages[i_sync] = 0
        synced.append(i_sync)

    # Guarantee at least one fresh worker (only if no fence sync occurred)
    if always_fresh and not synced:
        i_youngest = min(range(k), key=lambda i: ages[i])
        if ages[i_youngest] > 0:  # only if not already 0-age
            last_sync_at[i_youngest] = t
            ages[i_youngest] = 0
            synced.append(i_youngest)

    info: Dict[str, Any] = {
        "t": t,
        "ages_before": ages_before,
        "ages_after": ages[:],
        "synced_indices": synced,
        "fences": fences,
        "edges": edges,
        "sorted_order": sorted(range(k), key=lambda i: ages[i]),
    }
    return last_sync_at[:], info

# --------- example usage (non-consecutive timesteps, starting at t=0) ----------
if __name__ == "__main__":
    k = 6
    p = 2.0
    g = 1
    # Initialize: everyone last synced at t=0 (so age==0 at t=0)
    last_sync = [0] * k

    for t in range(0, 100, 1):
        last_sync, info = step_power_ghosts_zero_age(t, last_sync, p=p, g=g, always_fresh=True)
        ages = info["ages_after"]
        # print(f"\n=== t = {t} ===")
        # print(f"Ages before: {info['ages_before']}")
        # print(f"Synced idx : {info['synced_indices']}")
        # print(f"Ages after : {ages}")
        # print(f"Fences (x1..xk): {[f'{x:.1f}' for x in info['fences']]}")
        # print(f"Virtual right edge (x_{k+g}=t): {info['edges'][-1]:.1f}")
        print(f"t={t}, sorted_ages={sorted(ages)}")





t=0, sorted_ages=[0, 0, 0, 0, 0, 0]
t=1, sorted_ages=[0, 0, 0, 0, 0, 0]
t=2, sorted_ages=[0, 0, 0, 0, 1, 1]
t=3, sorted_ages=[0, 0, 0, 0, 0, 2]
t=4, sorted_ages=[0, 0, 0, 0, 1, 1]
t=5, sorted_ages=[0, 0, 0, 1, 2, 2]
t=6, sorted_ages=[0, 0, 0, 1, 3, 3]
t=7, sorted_ages=[0, 0, 0, 1, 2, 4]
t=8, sorted_ages=[0, 0, 1, 2, 3, 5]
t=9, sorted_ages=[0, 0, 0, 0, 4, 6]
t=10, sorted_ages=[0, 0, 1, 1, 5, 7]
t=11, sorted_ages=[0, 0, 0, 2, 2, 8]
t=12, sorted_ages=[0, 0, 0, 1, 3, 3]
t=13, sorted_ages=[0, 1, 1, 2, 4, 4]
t=14, sorted_ages=[0, 0, 2, 3, 5, 5]
t=15, sorted_ages=[0, 0, 1, 4, 6, 6]
t=16, sorted_ages=[0, 1, 2, 5, 7, 7]
t=17, sorted_ages=[0, 0, 0, 3, 8, 8]
t=18, sorted_ages=[0, 1, 1, 4, 9, 9]
t=19, sorted_ages=[0, 0, 0, 2, 5, 10]
t=20, sorted_ages=[0, 1, 1, 3, 6, 11]
t=21, sorted_ages=[0, 0, 2, 4, 7, 12]
t=22, sorted_ages=[0, 1, 3, 5, 8, 13]
t=23, sorted_ages=[0, 0, 4, 6, 9, 14]
t=24, sorted_ages=[0, 0, 1, 7, 10, 15]
t=25, sorted_ages=[0, 1, 2, 8, 11, 16]
t=26, sorted_ages=[0, 0, 2, 3, 12, 17]
