In [1]:
# Pure-Python Spiking LCA (S-LCA) demo
# - Minimal, dependency-free (only numpy)
# - Implements the paper's S-LCA with unit-area exponential synapses
# - Includes a tiny FISTA solver for CLASSO (nonnegative) to verify the target solution
#
# You can copy this whole cell into a .py file and run it as a script.

import numpy as np

def normalize_columns(Phi: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(Phi, axis=0, keepdims=True) + 1e-12
    return Phi / norms

def classo_fista_nonneg(Phi, s, lam, max_iter=5000, tol=1e-9):
    """FISTA for: min_{a >= 0} 0.5||s - Phi a||^2 + lam*||a||_1"""
    Phi = normalize_columns(Phi)
    PhiT = Phi.T
    PhiTPhi = PhiT @ Phi
    PhiTs = PhiT @ s
    L = np.linalg.eigvalsh(PhiTPhi).max() + 1e-12  # Lipschitz constant
    tstep = 1.0 / L

    def prox_nonneg_l1(x, thr):
        return np.maximum(0.0, x - thr)

    N = Phi.shape[1]
    a = np.zeros(N)
    y = a.copy()
    theta = 1.0

    for k in range(max_iter):
        grad = PhiTPhi @ y - PhiTs
        a_next = prox_nonneg_l1(y - tstep * grad, lam * tstep)
        theta_next = 0.5 * (1 + np.sqrt(1 + 4 * theta * theta))
        y = a_next + (theta - 1) / theta_next * (a_next - a)
        if np.linalg.norm(a_next - a) < tol * (np.linalg.norm(a) + 1e-12):
            a = a_next
            break
        a, theta = a_next, theta_next
    return a

def spiking_lca(Phi, s, lam=0.1, dt=1e-4, tau_syn=1e-2, T_sec=1000.0,
                v_th=1.0, v_reset=0.0, seed=None, return_logs=True,
                t0=0.5):
    """
    Spiking LCA with exponential synapses and unit-area scaling.
      r[n]   = decay * r[n-1] + sigma_prev
      mu     = b - W @ (r / tau_syn)
      v     += dt * (mu - lam)
      if v >= v_th: spike and reset

    Returns:
      rates (Hz) from spike counts / T_sec,
      spike_counts,
      meta dict including (b, W, Phi) and the CLASSO-aligned readout:
        - 'Tlam_u': thresholded average current over tail window [t0, T_sec]
        - 'u_tail': average soma currents over tail window
    """
    rng = np.random.default_rng(seed)
    Phi = normalize_columns(Phi)
    b   = Phi.T @ s
    W   = Phi.T @ Phi
    np.fill_diagonal(W, 0.0)
    print(W)
    decay = float(np.exp(-dt / tau_syn))
    steps = int(T_sec / dt)
    N     = Phi.shape[1]

    # States
    v = np.zeros(N)
    r = np.zeros(N)
    sigma_prev = np.zeros(N)
    spike_counts = np.zeros(N, dtype=int)

    # ----- Added: accumulate ∫ μ dt and snapshot at t0 -----
    mu_int = np.zeros(N)        # integral of μ from 0 to current time
    mu_int_at_t0 = None         # snapshot of integral at time t0

    # Optional logs
    if return_logs:
        log_every = max(1, int(1e-3 / dt)) # every 1 ms
        ts, Vlog, Mlog, Rlog = [], [], [], []

    for n in range(steps):
        # Exponential filter
        r = decay * r + sigma_prev
        # Soma current with unit-area scaling
        mu = b - (W @ (r / tau_syn))

        # ----- Added: accumulate integral and snapshot at t0 -----
        mu_int += mu * dt
        if (mu_int_at_t0 is None) and ((n + 1) * dt >= t0):
            mu_int_at_t0 = mu_int.copy()

        # Voltage integration with bias
        v += dt * (mu - lam)
        # Spikes
        sigma = (v >= v_th).astype(float)
        if sigma.any():
            spike_counts += sigma.astype(int)
            v[sigma > 0] = v_reset
        sigma_prev = sigma

        if return_logs and (n % log_every == 0):
            ts.append(n * dt)
            Vlog.append(v.copy())
            Mlog.append(mu.copy())
            Rlog.append(r.copy())

    # Default snapshot if t0 is 0 or sim is too short
    if mu_int_at_t0 is None:
        mu_int_at_t0 = np.zeros_like(mu_int)

    # ----- Added: compute tail-window averages and Tλ(u) readout -----
    denom = max(T_sec - t0, 1e-12)
    u_tail = (mu_int - mu_int_at_t0) / denom
    Tlam_u = np.maximum(0.0, u_tail - lam)   # CLASSO-aligned readout

    rates = spike_counts / T_sec
    meta = {
        "t": np.array(ts) if return_logs else None,
        "v": np.array(Vlog) if return_logs else None,
        "mu": np.array(Mlog) if return_logs else None,
        "r": np.array(Rlog) if return_logs else None,
        "b": b, "W": W, "Phi": Phi,
        "u_tail": u_tail, "Tlam_u": Tlam_u, "t0": t0
    }
    return rates, spike_counts, meta

# ------------------- Demo with the paper's 3-neuron example -------------------
Phi_demo = np.array([
    [0.3313, 0.8148, 0.4364],
    [0.8835, 0.3621, 0.2182],
    [0.3313, 0.4527, 0.8729],
], dtype=float)
s_demo = np.array([0.5, 1.0, 1.5], dtype=float)
lam_demo = 0.1

# 1) Closed-form target via FISTA (nonnegative CLASSO)
a_star = classo_fista_nonneg(Phi_demo, s_demo, lam_demo)
print("CLASSO (FISTA) solution a* ≈", a_star)

# 2) Spiking LCA run
rates, counts, meta = spiking_lca(
    Phi_demo, s_demo, lam=lam_demo,
    dt=1e-3, tau_syn=1, T_sec=100,
    t0=1,  # tail window start (seconds); tune 5–20% of T_sec
    return_logs=True
)

print("Spiking LCA spike counts over T=3.0s:", counts)
print("Spiking LCA rates (Hz):", rates)
print("Tail-window Tλ(u) readout:", meta["Tlam_u"])
print("Active (by Tλ(u) > 0):", np.where(meta["Tlam_u"] > 0)[0])


CLASSO (FISTA) solution a* ≈ [0.68306058 0.         1.21779048]
[[0.         0.73982169 0.62651876]
 [0.73982169 0.         0.82976598]
 [0.62651876 0.82976598 0.        ]]
Spiking LCA spike counts over T=3.0s: [ 68   1 121]
Spiking LCA rates (Hz): [0.68 0.01 1.21]
Tail-window Tλ(u) readout: [0.68403301 0.         1.21031113]
Active (by Tλ(u) > 0): [0 2]


# Systematic Comparison: Ground Truth vs Fugu Implementation

## Key Differences to Investigate:
1. **Normalization**: Are we double-normalizing Phi?
2. **Edge weights**: Are lateral inhibition weights correct?
3. **Integration method**: Different membrane potential updates?
4. **Readout method**: Different ways to compute final sparse codes?
5. **Parameters**: Different dt, tau, lambda values?

In [2]:
# 1. NORMALIZATION CHECK
print("=== NORMALIZATION ANALYSIS ===")

# Ground truth normalization
Phi_raw = np.array([
    [0.3313, 0.8148, 0.4364],
    [0.8835, 0.3621, 0.2182],
    [0.3313, 0.4527, 0.8729],
], dtype=float)

def normalize_columns_gt(Phi):
    norms = np.linalg.norm(Phi, axis=0, keepdims=True) + 1e-12
    return Phi / norms

Phi_gt_normalized = normalize_columns_gt(Phi_raw)
print("Ground truth normalized Phi:")
print(Phi_gt_normalized)
print("Column norms:", np.linalg.norm(Phi_gt_normalized, axis=0))

# Check what the Fugu backend received
print("\nFugu backend Phi (after compile):")
print("(This will show if there's double normalization)")

# We need to access the actual compiled Phi from the ground truth demo first
print("\nGround truth b = Phi^T @ s:")
b_gt = Phi_gt_normalized.T @ np.array([0.5, 1.0, 1.5])
print(b_gt)

=== NORMALIZATION ANALYSIS ===
Ground truth normalized Phi:
[[0.33128482 0.81481925 0.43639768]
 [0.88345953 0.36210856 0.21819884]
 [0.33128482 0.4527107  0.87289537]]
Column norms: [1. 1. 1.]

Fugu backend Phi (after compile):
(This will show if there's double normalization)

Ground truth b = Phi^T @ s:
[1.54602917 1.44858423 1.74574074]


In [3]:
# 2. COMPARE FUGU BACKEND PARAMETERS
print("\\n=== FUGU BACKEND ANALYSIS ===")

# Navigate to lca_dev notebook and get backend state
import sys
sys.path.append('/Users/kamerongano/Documents/GitHub/Fugu_dev')

# Import the results from the lca_dev notebook (we'll reconstruct)
# Using the exact same parameters
from fugu.backends import slca_Backend
from fugu.bricks import LCABrick 
from fugu import Scaffold

Phi_demo = np.array([
    [0.3313, 0.8148, 0.4364],
    [0.8835, 0.3621, 0.2182],
    [0.3313, 0.4527, 0.8729],
], dtype=float)
s_demo = np.array([0.5, 1.0, 1.5], dtype=float)
lam_demo = 0.1

scaffold = Scaffold()
scaffold.add_brick(LCABrick(Phi=Phi_demo), output=True)
scaffold.lay_bricks()

backend = slca_Backend()
backend.compile(
    scaffold=scaffold,
    compile_args={
        'y': s_demo,
        'Phi': Phi_demo,
        'lam': lam_demo,
        'dt': 1e-3,        # Ground truth uses dt=1e-3
        'tau_syn': 1,      # Ground truth uses tau=1  
        'T_steps': 1000,
    }
)

print("Fugu backend.Phi shape:", backend.Phi.shape)
print("Fugu backend.Phi column norms:", np.linalg.norm(backend.Phi, axis=0))
print("Fugu backend.b:", backend.b)
print("Fugu backend.W shape:", backend.W.shape)
print("Fugu backend.W[0,1]:", backend.W[0,1])
print("Fugu backend.tau:", backend.tau)
print("Fugu backend.dt:", backend.dt)

\n=== FUGU BACKEND ANALYSIS ===
Fugu backend.Phi shape: (3, 3)
Fugu backend.Phi column norms: [1. 1. 1.]
Fugu backend.b: [1.54602917 1.44858423 1.74574074]
Fugu backend.W shape: (3, 3)
Fugu backend.W[0,1]: 0.7398216888397283
Fugu backend.tau: 1.0
Fugu backend.dt: 0.001


In [4]:
# 3. SIDE-BY-SIDE INTEGRATION COMPARISON
print("\\n=== INTEGRATION METHOD COMPARISON ===")

# Ground truth method (from the notebook above)
print("Ground truth method:")
print("  r[n] = decay * r[n-1] + sigma_prev")
print("  mu = b - W @ (r / tau_syn)  # Note: r/tau scaling")
print("  v += dt * (mu - lam)")

print("\\nFugu backend method (let's check the code):")

# Let's run a few steps manually to compare
print("\\nManual step-by-step comparison:")

# Initialize states for both methods
# Ground truth states
decay_gt = np.exp(-1e-3 / 1.0)  # exp(-dt/tau)
r_gt = np.zeros(3)
v_gt = np.zeros(3) 
sigma_prev_gt = np.zeros(3)

# Fugu states (from backend)
backend.inhibition[:] = 0.0
backend.soma_current[:] = 0.0  
backend.spikes_prev[:] = 0.0
# Initialize neurons to 0
for name, n in backend.nn.nrns.items():
    if "neuron_" in name:
        n.v = 0.0

print(f"Initial states - GT r: {r_gt}, Fugu inhibition: {backend.inhibition}")

# Step 1: No spikes initially, so mu = b in both cases
mu_gt_step1 = backend.b - backend.W @ (r_gt / backend.tau)  # GT method
backend.soma_current = backend.b - backend.W @ backend.inhibition  # Fugu method

print(f"Step 1 - GT mu: {mu_gt_step1}")
print(f"Step 1 - Fugu mu: {backend.soma_current}")
print(f"Should be identical: {np.allclose(mu_gt_step1, backend.soma_current)}")

# Check voltage updates
v_update_gt = backend.dt * (mu_gt_step1 - backend.lam)
print(f"GT voltage update: dv = dt*(mu-lam) = {v_update_gt}")
print(f"Expected new voltages: {v_gt + v_update_gt}")

# The issue might be in how Fugu applies the voltage update!
print("\\n=== CHECKING FUGU VOLTAGE UPDATE METHOD ===")
print("Let's examine what the Fugu slca_step actually does to voltages...")

\n=== INTEGRATION METHOD COMPARISON ===
Ground truth method:
  r[n] = decay * r[n-1] + sigma_prev
  mu = b - W @ (r / tau_syn)  # Note: r/tau scaling
  v += dt * (mu - lam)
\nFugu backend method (let's check the code):
\nManual step-by-step comparison:
Initial states - GT r: [0. 0. 0.], Fugu inhibition: [0. 0. 0.]
Step 1 - GT mu: [1.54602917 1.44858423 1.74574074]
Step 1 - Fugu mu: [1.54602917 1.44858423 1.74574074]
Should be identical: True
GT voltage update: dv = dt*(mu-lam) = [0.00144603 0.00134858 0.00164574]
Expected new voltages: [0.00144603 0.00134858 0.00164574]
\n=== CHECKING FUGU VOLTAGE UPDATE METHOD ===
Let's examine what the Fugu slca_step actually does to voltages...


In [5]:
# 4. TEST CORRECTED FUGU IMPLEMENTATION
print("\\n=== TESTING CORRECTED FUGU IMPLEMENTATION ===")

# Force reload the fixed backend
import importlib
if 'fugu.backends.slca_backend' in sys.modules:
    importlib.reload(sys.modules['fugu.backends.slca_backend'])
from fugu.backends import slca_Backend

# Create corrected backend with EXACT same parameters as ground truth
backend_corrected = slca_Backend()
backend_corrected.compile(
    scaffold=scaffold,
    compile_args={
        'y': s_demo,
        'Phi': Phi_demo,
        'lam': 0.1,
        'dt': 1e-3,      # Same as ground truth
        'tau_syn': 1.0,  # Same as ground truth  
        'T_steps': 100000,  # Same duration as ground truth
        't0_steps': 1000,   # Tail window start
    }
)

result_corrected = backend_corrected.run()

print("\\n=== COMPARISON RESULTS ===")
print(f"Ground Truth CLASSO: [0.683, 0.0, 1.218]")
print(f"Ground Truth S-LCA:  [0.684, 0.0, 1.210]") 
print(f"Fugu Corrected:      {result_corrected['a_tail']}")
print(f"")
print(f"Ground Truth active: [0, 2]")
print(f"Fugu active:         {np.where(result_corrected['a_tail'] > 0.01)[0]}")
print(f"")
print(f"Reconstruction error:")
print(f"  Ground Truth: {np.linalg.norm(Phi_gt_normalized @ np.array([0.684, 0.0, 1.210]) - s_demo):.6f}")
print(f"  Fugu:         {np.linalg.norm(result_corrected['x_hat'] - s_demo):.6f}")
print(f"")
print(f"Success? {np.allclose(result_corrected['a_tail'], [0.684, 0.0, 1.210], atol=0.1)}")

\n=== TESTING CORRECTED FUGU IMPLEMENTATION ===
\n=== COMPARISON RESULTS ===
Ground Truth CLASSO: [0.683, 0.0, 1.218]
Ground Truth S-LCA:  [0.684, 0.0, 1.210]
Fugu Corrected:      [0.68402167 0.         1.21029842]

Ground Truth active: [0, 2]
Fugu active:         [0 2]

Reconstruction error:
  Ground Truth: 0.359666
  Fugu:         0.359571

Success? True
