In [1]:
pip install jax diffrax

Collecting diffrax
  Downloading diffrax-0.7.0-py3-none-any.whl.metadata (17 kB)
Collecting equinox>=0.11.10 (from diffrax)
  Downloading equinox-0.13.0-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.24 (from diffrax)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting lineax>=0.0.5 (from diffrax)
  Downloading lineax-0.0.8-py3-none-any.whl.metadata (18 kB)
Collecting optimistix>=0.0.10 (from diffrax)
  Downloading optimistix-0.0.10-py3-none-any.whl.metadata (17 kB)
Collecting wadler-lindig>=0.1.1 (from diffrax)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading diffrax-0.7.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.2/193.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading equinox-0.13.0-py3-none-any.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.7/177.7 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtypin

In [2]:
#streamlined and faster version

# Core deps
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, grad
import diffrax
import lineax as lx

tau = 2.3    # data-fit scale
lam = 0.5    # constant-modulus weight
kap = 0.5    # QPSK phase bias weight

# Integration defaults
DT_DEFAULT       = 0.09   # None => auto dt = horizon / max_steps
MAX_STEPS_DEF    = 20
HORIZON_DEFAULT  = 1.0
INJ_NOISE_STD    = 0.0    # set >0 only if you truly want SDE noise

# Robust JIT decorator across JAX versions
from functools import partial
try:
    JIT = partial(jax.jit, static_argnames=("inj_noise_std","dt","max_steps","horizon","lock"))
    _ = JIT(lambda H,y,key,inj_noise_std,dt,max_steps,horizon,lock: (H,y))
except TypeError:
    # fall back to positional static args: (H,y,key,inj_noise_std,dt,max_steps,horizon,lock)
    JIT = partial(jax.jit, static_argnums=(3,4,5,6,7))

# ------------------ QPSK helpers ------------------
def qpsk_constellation(axis_aligned: bool = True):
    # Axis-aligned: {1, j, -1, -j}; set False for π/4 rotation if needed
    if axis_aligned:
        return jnp.asarray([1+0j, 1j, -1+0j, -1j], dtype=jnp.complex64)
    base = jnp.exp(1j * (jnp.pi/4 + jnp.arange(4)*jnp.pi/2))
    return jnp.asarray(base, dtype=jnp.complex64)

def hard_decide_qpsk(x: jnp.ndarray, axis_aligned: bool = True):
    const = qpsk_constellation(axis_aligned=axis_aligned)
    d = jnp.abs(const[:, None] - x[None, :])
    return jnp.argmin(d, axis=0)  # indices 0..3

def qpsk_symbols(N: int, key, axis_aligned: bool = True):
    const = qpsk_constellation(axis_aligned=axis_aligned)
    key, ks = random.split(key)
    idx = random.randint(ks, (N,), 0, 4)
    x = const[idx]
    return x.astype(jnp.complex64), idx, key

# ------------------ Channel + noise ------------------
def make_synthetic_H(N: int, key):
    k1, k2 = random.split(key)
    H = (random.normal(k1, (N, N), dtype=jnp.float32)
         + 1j * random.normal(k2, (N, N), dtype=jnp.float32)) / jnp.sqrt(2.0)
    row_norm = jnp.sqrt(jnp.mean(jnp.sum(jnp.abs(H) ** 2, axis=1)))
    return (H / row_norm).astype(jnp.complex64)

def sigma_from_snr_db(H: jnp.ndarray, x_true: jnp.ndarray, snr_db: float) -> float:
    y_clean = H @ x_true
    Psig = jnp.mean(jnp.abs(y_clean) ** 2)
    return jnp.sqrt(Psig * 10.0 ** (-snr_db / 10.0))

def make_y(H, x_true, snr_db, key, compat_noise: bool = False):
    key, k1, k2 = random.split(key, 3)
    if compat_noise:
        sigma = jnp.sqrt(10.0 ** (-snr_db / 10.0))
        n = sigma * (random.normal(k1, x_true.shape) + 1j * random.normal(k1, x_true.shape))
    else:
        sigma = sigma_from_snr_db(H, x_true, snr_db)
        n = sigma * (random.normal(k1, x_true.shape) + 1j * random.normal(k2, x_true.shape)) / jnp.sqrt(2.0)
    return (H @ x_true + n).astype(jnp.complex64), key

def sample_frame(H: jnp.ndarray, snr_db: float, key, compat_noise: bool = False, axis_aligned_qpsk: bool = True):
    x_true, idx_true, key = qpsk_symbols(H.shape[1], key, axis_aligned=axis_aligned_qpsk)
    y, key = make_y(H, x_true, snr_db, key, compat_noise=compat_noise)
    return x_true, idx_true, y, key

# ------------------ Baselines (NumPy) ------------------
def zf_indices_np(H_jnp, y_jnp):
    H = np.asarray(H_jnp); y = np.asarray(y_jnp)
    x_hat, *_ = np.linalg.lstsq(H, y, rcond=None)
    return hard_decide_qpsk(jnp.asarray(x_hat))

def lmmse_indices_np(H_jnp, y_jnp, snr_db: float, x_true_jnp):
    H = np.asarray(H_jnp); y = np.asarray(y_jnp)
    sigma2 = float(sigma_from_snr_db(H_jnp, x_true_jnp, snr_db) ** 2)
    HH = H.conj().T @ H
    A = HH + sigma2 * np.eye(H.shape[1], dtype=H.dtype)
    x_hat = np.linalg.solve(A, H.conj().T @ y)
    return hard_decide_qpsk(jnp.asarray(x_hat))

# ------------------ Flow pieces ------------------
def pack(z):   return jnp.stack([jnp.real(z), jnp.imag(z)], axis=0)  # [2,N]
def unpack(z): return z[0] + 1j * z[1]

@jit
def loss_datafit(x: jnp.ndarray, y: jnp.ndarray, H: jnp.ndarray) -> jnp.ndarray:
    r = H @ x - y
    return jnp.sum(jnp.abs(r) ** 2)

@jit
def loss_const_mod(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum((jnp.abs(x) ** 2 - 1.0) ** 2)

# Complex grads (Wirtinger). Conjugate ONCE in drift (your pattern).
dLdx_datafit = jit(grad(loss_datafit, argnums=0, holomorphic=False))
dLdx_const   = jit(grad(loss_const_mod,  argnums=0, holomorphic=False))

@jit
def phase_bias(x: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray:
    # Your original: unit phasor ^3 (four-fold angular pull)
    ph = x / (jnp.abs(x) + eps)
    return ph ** 3

def make_drift_matching(H: jnp.ndarray, y: jnp.ndarray):
    """Exact same drift as your reference: conjugate once in the drift."""
    H = jnp.asarray(H); y = jnp.asarray(y)
    def drift(t, zpack, args):
        x = unpack(zpack)
        dx = (-tau * dLdx_datafit(x, y, H).conj()
              -lam * dLdx_const(x).conj()
              + kap * phase_bias(x).conj())
        return pack(dx)
    return drift

def make_diffusion(inj_noise_std: float):
    if inj_noise_std == 0.0:
        return None
    def diffusion(t, zpack, args):
        diag = inj_noise_std * jnp.ones_like(zpack)  # [2,N]
        return lx.DiagonalLinearOperator(diag)
    return diffusion

# --- helper: choose t1 and dt based on lock mode ---
def _calc_t1_dt(dt, max_steps, horizon=1.0, lock="steps"):
    """
    lock='steps'   -> enforce exactly max_steps; t1 = dt * max_steps (if dt is None, dt=horizon/max_steps and t1=horizon)
    lock='horizon' -> keep t1=horizon; steps ~= horizon/dt (if dt is None, dt=horizon/max_steps)
    """
    if dt is None:
        dt = horizon / max_steps
    if lock == "steps":
        t1 = dt * max_steps
    else:
        t1 = horizon
    return float(t1), float(dt)

@JIT
def run_flow_fast(H, y, key,
                  inj_noise_std=INJ_NOISE_STD,
                  dt=DT_DEFAULT,
                  max_steps=MAX_STEPS_DEF,
                  horizon=HORIZON_DEFAULT,
                  lock="steps"):
    """
    Same dynamics as your original (tau/lam/kap; single *.conj()).
    Overhead cuts only: final-state only, constant step, JIT-friendly loop.

    lock:
      - 'steps'   => keep exactly max_steps (t1 = dt*max_steps; if dt None, auto dt = horizon/max_steps and t1=horizon)
      - 'horizon' => keep horizon fixed (t1=horizon); steps vary with dt
    """
    N = H.shape[1]
    drift = make_drift_matching(H, y)
    diffusion = make_diffusion(inj_noise_std)

    # Small random complex init (matches your reference)
    key, k1, k2 = random.split(key, 3)
    x0 = 0.1 * (random.normal(k1, (N,)) + 1j * random.normal(k2, (N,)))
    z0 = pack(x0)

    controller = diffrax.ConstantStepSize()
    saveat     = diffrax.SaveAt(t1=True)

    # Solver selection
    if diffusion is None:
        term   = diffrax.ODETerm(drift)
        solver = diffrax.Heun()              # improved Euler (like your ODE path)
    else:
        # build brownian matching the actual integration horizon
        t1_tmp, dt_tmp = _calc_t1_dt(dt, max_steps, horizon=horizon, lock=lock)
        brown  = diffrax.VirtualBrownianTree(t0=0.0, t1=t1_tmp, tol=1e-3, shape=z0.shape, key=key)
        term   = diffrax.MultiTerm(diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, brown))
        solver = diffrax.EulerHeun()

    t1, dt_eff = _calc_t1_dt(dt, max_steps, horizon=horizon, lock=lock)

    sol = diffrax.diffeqsolve(
        term, solver, t0=0.0, t1=t1, dt0=dt_eff,
        y0=z0, saveat=saveat,
        stepsize_controller=controller,      # constant-dt (no adaptive overhead)
        max_steps=max_steps
    )
    return sol


In [3]:
# ==== Cell B (fast-compat) — flow readout (indices) ====


def _extract_final_pack(sol):
    """Handle sol.ys across diffrax versions; return (2,N) or complex (N,) consistently."""
    y = sol.ys
    if isinstance(y, dict):
        y = y.get('y', y.get('x', next(iter(y.values()))))
    elif isinstance(y, (tuple, list)):
        y = y[0]
    y = jnp.asarray(y)
    if y.ndim == 3 and y.shape[0] == 1:
        y = y[0]
    return y

def flow_indices_fast(H, y, key,
                      inj_noise_std=INJ_NOISE_STD,
                      dt=DT_DEFAULT,
                      max_steps=MAX_STEPS_DEF,
                      horizon=HORIZON_DEFAULT,
                      lock="steps"):
    """Run fast flow and return (qpsk_indices, accepted_steps)."""
    sol = run_flow_fast(H, y, key, inj_noise_std=inj_noise_std,
                        dt=dt, max_steps=max_steps, horizon=horizon, lock=lock)
    pack = _extract_final_pack(sol)
    xT = pack if (jnp.iscomplexobj(pack) and pack.ndim == 1) else (pack[0] + 1j * pack[1])
    idx_hat = hard_decide_qpsk(xT)

    if lock == "steps":
        steps = int(max_steps)  # deterministic
    else:
        steps_attr = getattr(sol.stats, "num_accepted_steps", None)
        steps = int(steps_attr) if steps_attr is not None else int(np.ceil(horizon / (dt or (horizon/max_steps))))
    return idx_hat, steps


In [6]:
# ==== Benchmarking SER, BER & latency (p50/p95) for ZF / LMMSE / Flow ====
import time
import numpy as np
import jax.numpy as jnp
from jax import random
import pandas as pd

# --------- knobs ---------
N_LIST     = [180, 256, 512, 1000]
SNR_LIST   = [2.0, 4.0, 8.0, 12.0, 16.0]
NUM_H      = 12
FRAMES_PER = 1

DT        = DT_DEFAULT        # from Cell A
MAX_STEPS = MAX_STEPS_DEF
HORIZON   = HORIZON_DEFAULT
LOCK_MODE = "steps"           # keep exactly MAX_STEPS by default
INJ_STD   = INJ_NOISE_STD

SEED = 20256969

# ---- bit mapping helpers (axis-aligned QPSK: [1, j, -1, -j]) ----
# Choose a Gray code mapping for indices -> bits
_IDX2BITS = jnp.array([[0,0],   # 1+0j  -> 00
                       [0,1],   # +j    -> 01
                       [1,1],   # -1+0j -> 11
                       [1,0]],  # -j    -> 10
                      dtype=jnp.int32)

def idx_to_bits(idx: jnp.ndarray) -> jnp.ndarray:
    """idx [N] -> bits [N,2] using the fixed Gray map above."""
    return _IDX2BITS[idx]

def ber_from_indices(idx_true: jnp.ndarray, idx_hat: jnp.ndarray) -> float:
    """Compute bit error rate given symbol indices under our fixed map."""
    bt = idx_to_bits(idx_true)  # [N,2]
    bh = idx_to_bits(idx_hat)   # [N,2]
    bit_errs = jnp.sum(jnp.not_equal(bt, bh))
    total_bits = bt.size
    return float(bit_errs) / float(total_bits)

# --------- run ---------
rng = random.PRNGKey(SEED)
records = []

for N in N_LIST:
    rng, kN = random.split(rng)
    H_bank = []
    for _ in range(NUM_H):
        kN, kH = random.split(kN)
        H = make_synthetic_H(N, kH)
        H_bank.append(H)

    for H in H_bank:
        for snr_db in SNR_LIST:
            rng, kpair = random.split(rng)
            for _ in range(FRAMES_PER):
                x_true, idx_true, y, kpair = sample_frame(H, snr_db, kpair, compat_noise=False, axis_aligned_qpsk=True)

                # ZF
                t0 = time.perf_counter()
                idx_hat = zf_indices_np(H, y)
                t1 = time.perf_counter()
                ser = 1.0 - float(jnp.mean((idx_hat == idx_true).astype(jnp.float32)))
                ber = ber_from_indices(idx_true, idx_hat)
                records.append({"N": N, "snr_db": float(snr_db), "method": "ZF", "ser": ser, "ber": ber, "time_s": (t1 - t0)})

                # LMMSE
                t0 = time.perf_counter()
                idx_hat = lmmse_indices_np(H, y, snr_db, x_true)
                t1 = time.perf_counter()
                ser = 1.0 - float(jnp.mean((idx_hat == idx_true).astype(jnp.float32)))
                ber = ber_from_indices(idx_true, idx_hat)
                records.append({"N": N, "snr_db": float(snr_db), "method": "LMMSE", "ser": ser, "ber": ber, "time_s": (t1 - t0)})

                # Flow (fast-compat)
                rng, kf = random.split(kpair)
                t0 = time.perf_counter()
                idx_hat, steps = flow_indices_fast(H, y, kf,
                                                   inj_noise_std=INJ_STD,
                                                   dt=DT, max_steps=MAX_STEPS,
                                                   horizon=HORIZON, lock=LOCK_MODE)
                t1 = time.perf_counter()
                ser = 1.0 - float(jnp.mean((idx_hat == idx_true).astype(jnp.float32)))
                ber = ber_from_indices(idx_true, idx_hat)
                records.append({"N": N, "snr_db": float(snr_db), "method": "Flow(fast)", "ser": ser, "ber": ber, "time_s": (t1 - t0), "steps": steps})

# --------- summarize ---------
df = pd.DataFrame.from_records(records)

def p50(x): return float(np.percentile(x, 50))
def p95(x): return float(np.percentile(x, 95))
def nanmean(x):
    x = np.asarray(list(x))
    return float(np.nanmean(x)) if x.size else float("nan")

summary = (df
    .groupby(["N","snr_db","method"], as_index=False)
    .agg(
        ser_mean=("ser","mean"),
        ber_mean=("ber","mean"),
        time_p50=("time_s", p50),
        time_p95=("time_s", p95),
        steps_mean=("steps", nanmean),
        steps_p95=("steps", lambda x: float(np.nanpercentile(list(x),95)) if len(list(x)) else np.nan),
    )
    .sort_values(["N","snr_db","method"])
)

print("\n=== SER/BER (mean) & latency p50/p95 by N, SNR, method ===")
for (N, snr), grp in summary.groupby(["N","snr_db"]):
    print(f"\nN={N}, SNR={snr:.1f} dB")
    for _, row in grp.iterrows():
        m = row["method"]
        ser = row["ser_mean"]; ber = row["ber_mean"]
        t50 = row["time_p50"]; t95 = row["time_p95"]
        if "Flow" in m:
            smean = row["steps_mean"]; sp95 = row["steps_p95"]
            print(f"  {m:10s}  SER={ser:.4e}  BER={ber:.4e}  t50={t50:.3f}s  t95={t95:.3f}s  steps≈{smean:.1f}/p95={sp95 if not np.isnan(sp95) else '—'}")
        else:
            print(f"  {m:10s}  SER={ser:.4e}  BER={ber:.4e}  t50={t50:.3f}s  t95={t95:.3f}s")



=== SER/BER (mean) & latency p50/p95 by N, SNR, method ===

N=180, SNR=2.0 dB
  Flow(fast)  SER=3.3426e-01  BER=1.8333e-01  t50=0.011s  t95=0.015s  steps≈20.0/p95=20.0
  LMMSE       SER=3.3380e-01  BER=1.8287e-01  t50=0.005s  t95=0.011s
  ZF          SER=7.0972e-01  BER=4.6528e-01  t50=0.016s  t95=0.027s

N=180, SNR=4.0 dB
  Flow(fast)  SER=2.5741e-01  BER=1.3727e-01  t50=0.011s  t95=0.021s  steps≈20.0/p95=20.0
  LMMSE       SER=2.6019e-01  BER=1.3727e-01  t50=0.006s  t95=0.020s
  ZF          SER=7.1389e-01  BER=4.5972e-01  t50=0.017s  t95=0.038s

N=180, SNR=8.0 dB
  Flow(fast)  SER=1.0926e-01  BER=5.6481e-02  t50=0.011s  t95=0.017s  steps≈20.0/p95=20.0
  LMMSE       SER=1.5880e-01  BER=8.1944e-02  t50=0.005s  t95=0.012s
  ZF          SER=6.9954e-01  BER=4.4653e-01  t50=0.016s  t95=0.081s

N=180, SNR=12.0 dB
  Flow(fast)  SER=1.0648e-02  BER=5.3241e-03  t50=0.011s  t95=0.029s  steps≈20.0/p95=20.0
  LMMSE       SER=5.8333e-02  BER=2.9398e-02  t50=0.006s  t95=0.006s
  ZF          SER=6.

  return float(np.nanmean(x)) if x.size else float("nan")
  return _nanquantile_unchecked(
