# Classification with Comparator Bank Readout

## The Problem

A single thresholded SingleDendrite neuron is brittle:
- If threshold is too high: nothing fires
- If threshold is too low: everything fires
- Small calibration errors → large performance changes

## The Solution: Comparator Bank (Population Coding)

Use **K neurons per class** with **staggered phi_offsets** (effective thresholds):

```
           ┌─► SD(φ=0.15) ─► y₀ ─┐
           ├─► SD(φ=0.18) ─► y₁ ─┤
Hidden ────┼─► SD(φ=0.21) ─► y₂ ─┼──► S = Σyₖ (count)
           ├─► SD(φ=0.24) ─► y₃ ─┤
           └─► SD(φ=0.27) ─► y₄ ─┘
```

**Key Design: SHARED WEIGHTS**

All comparators receive the **same evidence** (shared weights from hidden layer) but have **different thresholds** (phi_offsets). This creates a true staircase approximation:

```
                 Independent Weights (BAD)          Shared Weights (GOOD)
                 
Hidden ──┬─► W₁ ──► Comparator(φ=0.15)    Hidden ──► W ──┬─► Comparator(φ=0.15)
         ├─► W₂ ──► Comparator(φ=0.20)              (shared) ├─► Comparator(φ=0.20)  
         └─► W₃ ──► Comparator(φ=0.25)                      └─► Comparator(φ=0.25)
         
Each learns different features!              All see same evidence, different thresholds!
```

## Hardware Mapping

| Component | Hardware Implementation |
|-----------|------------------------|
| Each comparator | SingleDendrite neuron |
| Shared weights | Single fan-out from hidden layer |
| Staggered thresholds | Different phi_offset per neuron |
| Spike counting | SFQ counter or optical pulse counter |
| Class decision | Compare counts: argmax(S_class) |

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from soen_toolkit.core import (
    ConnectionConfig,
    LayerConfig,
    SimulationConfig,
    SOENModelCore,
)

torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

## 1. Generate Circle-in-Ring Dataset

In [None]:
def generate_circle_ring_data(n_samples=500, inner_radius=0.3, outer_radius_min=0.5, 
                               outer_radius_max=0.8, noise=0.05):
    """
    Generate 2D classification data: circle inside a ring.
    """
    n_each = n_samples // 2
    
    # Class 0: Inner circle
    theta_inner = np.random.uniform(0, 2*np.pi, n_each)
    r_inner = np.random.uniform(0, inner_radius, n_each)
    x_inner = r_inner * np.cos(theta_inner) + np.random.normal(0, noise, n_each)
    y_inner = r_inner * np.sin(theta_inner) + np.random.normal(0, noise, n_each)
    
    # Class 1: Outer ring
    theta_outer = np.random.uniform(0, 2*np.pi, n_each)
    r_outer = np.random.uniform(outer_radius_min, outer_radius_max, n_each)
    x_outer = r_outer * np.cos(theta_outer) + np.random.normal(0, noise, n_each)
    y_outer = r_outer * np.sin(theta_outer) + np.random.normal(0, noise, n_each)
    
    X = np.vstack([
        np.column_stack([x_inner, y_inner]),
        np.column_stack([x_outer, y_outer])
    ])
    y = np.array([0] * n_each + [1] * n_each)
    
    idx = np.random.permutation(len(y))
    X, y = X[idx], y[idx]
    
    # Scale to SOEN operating range
    X = (X + 1) / 2 * 0.25 + 0.025
    
    return torch.FloatTensor(X), torch.FloatTensor(y)


N_SAMPLES = 500
X_data, y_data = generate_circle_ring_data(N_SAMPLES)

print(f"Dataset shape: X={X_data.shape}, y={y_data.shape}")
print(f"Class distribution: {(y_data == 0).sum().item()} inner, {(y_data == 1).sum().item()} outer")

# Visualize
plt.figure(figsize=(6, 6))
for c, color in enumerate(['blue', 'red']):
    mask = y_data == c
    plt.scatter(X_data[mask, 0], X_data[mask, 1], c=color, alpha=0.6, s=20)
plt.title('Circle vs Ring Dataset')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()

## 2. Prepare Data for SOEN

In [None]:
SEQ_LEN = 50

X_seq = X_data.unsqueeze(1).expand(-1, SEQ_LEN, -1).clone()
y_labels = y_data.long()  # Class indices for CrossEntropyLoss

print(f"SOEN input shape: {X_seq.shape}")
print(f"Labels shape: {y_labels.shape}")

## 3. Comparator Bank Model Builder

Architecture:
- Input: 2D (x, y)
- Hidden: SingleDendrite layer(s)
- Readout: **K neurons per class** with staggered phi_offsets
- **CRITICAL**: All comparators share the same weights (weight tying)
- Output: Sum of neuron outputs per class → class scores

In [None]:
def build_comparator_bank_classifier(hidden_dims, n_comparators_per_class=5,
                                      phi_offset_range=(0.15, 0.30),
                                      input_dim=2, dt=50.0, n_classes=2):
    """
    Build a SOEN classifier with COMPARATOR BANK readout.
    
    Architecture: input → [hidden] → K*n_classes comparator neurons
    
    Each class has K neurons with staggered phi_offsets.
    All K comparators SHARE WEIGHTS from the hidden layer (weight tying).
    Class score = sum of neuron outputs for that class.
    
    Args:
        hidden_dims: List of hidden layer dimensions
        n_comparators_per_class: K neurons per class
        phi_offset_range: (min, max) for staggered phi_offsets
        n_classes: Number of output classes
    """
    K = n_comparators_per_class
    
    # Generate staggered phi_offsets for each comparator
    phi_min, phi_max = phi_offset_range
    phi_offsets = np.linspace(phi_min, phi_max, K)
    
    print(f"Comparator bank: {K} neurons per class (SHARED WEIGHTS)")
    print(f"Staggered phi_offsets: {phi_offsets}")
    
    sim_cfg = SimulationConfig(
        dt=dt,
        input_type="state",
        track_phi=False,
        track_power=False,
    )
    
    layers = []
    connections = []
    
    # Input layer
    layers.append(LayerConfig(
        layer_id=0,
        layer_type="Input",
        params={"dim": input_dim},
    ))
    
    # Hidden layers
    for i, hidden_dim in enumerate(hidden_dims):
        layer_id = i + 1
        
        layers.append(LayerConfig(
            layer_id=layer_id,
            layer_type="SingleDendrite",
            params={
                "dim": hidden_dim,
                "solver": "FE",
                "source_func": "Heaviside_fit_state_dep",
                "phi_offset": 0.02,
                "bias_current": 1.98,
                "gamma_plus": 0.0005,
                "gamma_minus": 1e-6,
                "learnable_params": {
                    "phi_offset": False,
                    "bias_current": False,
                    "gamma_plus": False,
                    "gamma_minus": False,
                },
            },
        ))
        
        connections.append(ConnectionConfig(
            from_layer=layer_id - 1,
            to_layer=layer_id,
            connection_type="all_to_all",
            learnable=True,
            params={"init": "xavier_uniform"},
        ))
    
    # Output layers: K Comparator layers with different phi_offsets
    # We'll track the connection indices for weight tying
    output_layer_ids = []
    comparator_connection_indices = []  # Track which connections need weight tying
    last_hidden_id = len(hidden_dims)
    
    for k, phi in enumerate(phi_offsets):
        layer_id = len(hidden_dims) + 1 + k
        output_layer_ids.append(layer_id)
        
        layers.append(LayerConfig(
            layer_id=layer_id,
            layer_type="SingleDendrite",
            params={
                "dim": n_classes,  # One neuron per class at this threshold
                "solver": "FE",
                "source_func": "Heaviside_fit_state_dep",
                "phi_offset": float(phi),  # Staggered threshold!
                "bias_current": 1.98,
                "gamma_plus": 0.0005,
                "gamma_minus": 1e-6,
                "learnable_params": {
                    "phi_offset": False,
                    "bias_current": False,
                    "gamma_plus": False,
                    "gamma_minus": False,
                },
            },
        ))
        
        # Connect from last hidden layer to this comparator layer
        conn_idx = len(connections)
        comparator_connection_indices.append(conn_idx)
        
        connections.append(ConnectionConfig(
            from_layer=last_hidden_id,
            to_layer=layer_id,
            connection_type="all_to_all",
            learnable=True,
            params={"init": "xavier_uniform"},
        ))
    
    model = SOENModelCore(
        sim_config=sim_cfg,
        layers_config=layers,
        connections_config=connections,
    )
    
    # Store metadata for later use
    model.comparator_config = {
        'n_comparators_per_class': K,
        'n_classes': n_classes,
        'phi_offsets': phi_offsets,
        'output_layer_ids': output_layer_ids,
        'comparator_connection_indices': comparator_connection_indices,
    }
    
    return model


def tie_comparator_weights(model):
    """
    Enforce weight sharing across all comparator layers.
    
    Copies weights from the first comparator's connection to all others.
    Call this after each optimizer step during training.
    
    Note: model.connections is a ParameterDict with string keys.
    """
    config = model.comparator_config
    conn_indices = config['comparator_connection_indices']
    
    if len(conn_indices) <= 1:
        return  # No tying needed for K=1
    
    # Get the first comparator's weights as the "master"
    # ParameterDict requires string keys
    first_conn_key = str(conn_indices[0])
    master_weights = model.connections[first_conn_key].weight.data
    
    # Copy to all other comparators
    for conn_idx in conn_indices[1:]:
        conn_key = str(conn_idx)
        model.connections[conn_key].weight.data.copy_(master_weights)


def count_params(model):
    """Count learnable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_effective_params(model):
    """
    Count effective parameters (accounting for weight sharing).
    
    With weight tying, all K comparator connections share the same weights,
    so we only count them once.
    """
    total = 0
    config = model.comparator_config
    conn_indices = config['comparator_connection_indices']
    
    for i, p in enumerate(model.parameters()):
        if not p.requires_grad:
            continue
        # Check if this is a comparator connection (beyond the first)
        # If so, don't count it (shared weights)
        # This is a simplification - proper implementation would check parameter names
        total += p.numel()
    
    # Subtract the duplicated comparator weights
    if len(conn_indices) > 1:
        # ParameterDict requires string keys
        first_conn_key = str(conn_indices[0])
        first_conn = model.connections[first_conn_key]
        weight_size = first_conn.weight.numel()
        # We have K copies but only need 1, so subtract (K-1) * weight_size
        duplicates = (len(conn_indices) - 1) * weight_size
        total -= duplicates
    
    return total


# Test builder
print("Testing comparator bank model builder...")
test_model = build_comparator_bank_classifier([8], n_comparators_per_class=5)
print(f"Total parameters (before tying): {count_params(test_model)}")
print(f"Effective parameters (with tying): {count_effective_params(test_model)}")
print(f"Number of layers: {len(test_model.layers)}")
print(f"Layer dimensions: {[l.dim for l in test_model.layers]}")

## 4. Comparator Bank Forward Pass

We need a custom forward function that:
1. Runs the model to get all layer outputs
2. Sums outputs from all comparator layers per class
3. Returns class scores

In [None]:
def comparator_bank_forward(model, X):
    """
    Forward pass that sums comparator outputs per class.
    
    Returns:
        class_scores: [N, n_classes] - sum of comparator outputs per class
        all_outputs: List of [N, T, n_classes] for each comparator layer
    """
    config = model.comparator_config
    K = config['n_comparators_per_class']
    n_classes = config['n_classes']
    output_layer_ids = config['output_layer_ids']
    
    # Forward pass - get all layer states
    _, layer_states = model(X)
    
    # Collect outputs from all comparator layers
    # layer_states is a dict: {layer_id: [N, T, dim]}
    comparator_outputs = []
    for layer_id in output_layer_ids:
        # Get final timestep output: [N, n_classes]
        output = layer_states[layer_id][:, -1, :]
        comparator_outputs.append(output)
    
    # Stack and sum across comparators: [K, N, n_classes] -> [N, n_classes]
    stacked = torch.stack(comparator_outputs, dim=0)  # [K, N, n_classes]
    class_scores = stacked.sum(dim=0)  # [N, n_classes]
    
    return class_scores, comparator_outputs


# Test forward pass
test_model.eval()
with torch.no_grad():
    scores, outputs = comparator_bank_forward(test_model, X_seq[:5])
    print(f"Class scores shape: {scores.shape}")
    print(f"Number of comparator layers: {len(outputs)}")
    print(f"Sample class scores:\n{scores}")

## 5. Training with Comparator Bank

In [None]:
def train_comparator_bank(model, X_train, y_train, n_epochs=300, lr=0.02, verbose=False,
                          use_weight_tying=True):
    """
    Train comparator bank classifier.
    
    Uses CrossEntropyLoss on summed class scores.
    With weight tying, all comparators share the same weights.
    
    Args:
        use_weight_tying: If True, enforce shared weights across comparators after each step.
    """
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Initialize with tied weights
    if use_weight_tying:
        tie_comparator_weights(model)
    
    losses = []
    accuracies = []
    
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        
        # Forward with comparator bank summing
        class_scores, _ = comparator_bank_forward(model, X_train)
        
        # Loss on summed scores
        loss = criterion(class_scores, y_train)
        
        # Backward
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # CRITICAL: Tie weights after optimizer step
        if use_weight_tying:
            tie_comparator_weights(model)
        
        # Metrics
        with torch.no_grad():
            preds = class_scores.argmax(dim=1)
            acc = (preds == y_train).float().mean().item()
        
        losses.append(loss.item())
        accuracies.append(acc)
        
        if verbose and (epoch + 1) % 50 == 0:
            print(f"  Epoch {epoch+1}: Loss={loss.item():.4f}, Acc={acc:.4f}")
    
    return losses, accuracies


def evaluate_comparator_bank(model, X_test, y_test):
    """
    Evaluate comparator bank classifier.
    """
    model.eval()
    with torch.no_grad():
        class_scores, comparator_outputs = comparator_bank_forward(model, X_test)
        preds = class_scores.argmax(dim=1).numpy()
        probs = torch.softmax(class_scores, dim=1)[:, 1].numpy()
    
    y_true = y_test.numpy()
    accuracy = (preds == y_true).mean()
    
    return preds, probs, accuracy, class_scores.numpy(), comparator_outputs

## 6. Experiment: Vary Number of Comparators (WITH WEIGHT TYING)

Test K = 1, 3, 5, 7, 10 comparators per class.

**Key change**: All comparators now share the same weights!

In [None]:
K_VALUES = [1, 3, 5, 7, 10]
HIDDEN_DIM = [8]
N_EPOCHS = 400
LR = 0.02
PHI_RANGE = (0.10, 0.35)  # Range of phi_offsets

results_by_k = {}

print("Training comparator bank models with SHARED WEIGHTS...")
print(f"phi_offset range: {PHI_RANGE}")
print("=" * 60)

for K in K_VALUES:
    print(f"\nK = {K} comparators per class:")
    
    # Reset seed for fair comparison
    torch.manual_seed(42)
    
    model = build_comparator_bank_classifier(
        HIDDEN_DIM, 
        n_comparators_per_class=K,
        phi_offset_range=PHI_RANGE
    )
    n_params = count_params(model)
    n_effective = count_effective_params(model)
    
    # Train WITH weight tying
    losses, accs = train_comparator_bank(model, X_seq, y_labels, n_epochs=N_EPOCHS, lr=LR,
                                         use_weight_tying=True)
    _, _, final_acc, _, _ = evaluate_comparator_bank(model, X_seq, y_labels)
    
    results_by_k[K] = {
        'model': model,
        'losses': losses,
        'accuracies': accs,
        'final_acc': final_acc,
        'n_params': n_params,
        'n_effective_params': n_effective,
        'phi_offsets': model.comparator_config['phi_offsets'],
    }
    
    print(f"  Final Accuracy: {final_acc:.4f}")
    print(f"  Params (total/effective): {n_params}/{n_effective}")

print("\n" + "=" * 60)
print("Training complete!")

## 7. Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = plt.cm.viridis(np.linspace(0, 1, len(K_VALUES)))

# Loss curves
ax1 = axes[0]
for (K, res), color in zip(results_by_k.items(), colors):
    ax1.plot(res['losses'], label=f'K={K}', color=color, lw=1.5)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('CrossEntropy Loss')
ax1.set_title('Training Loss by Number of Comparators')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2 = axes[1]
for (K, res), color in zip(results_by_k.items(), colors):
    ax2.plot(res['accuracies'], label=f'K={K}', color=color, lw=1.5)
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy by Number of Comparators')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0.4, 1.05)

plt.tight_layout()
plt.show()

## 8. Accuracy vs Number of Comparators

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Accuracy vs K
ax1 = axes[0]
K_list = list(results_by_k.keys())
accs = [results_by_k[k]['final_acc'] for k in K_list]
params = [results_by_k[k]['n_params'] for k in K_list]

ax1.plot(K_list, accs, 'o-', markersize=10, lw=2, color='steelblue')
ax1.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random')
ax1.set_xlabel('K (Comparators per Class)')
ax1.set_ylabel('Final Accuracy')
ax1.set_title('Accuracy vs Number of Comparators')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0.4, 1.05)

for k, acc in zip(K_list, accs):
    ax1.annotate(f'{acc:.3f}', (k, acc), textcoords="offset points", 
                 xytext=(0, 10), ha='center')

# Parameters vs K
ax2 = axes[1]
ax2.bar(K_list, params, color='coral', alpha=0.7)
ax2.set_xlabel('K (Comparators per Class)')
ax2.set_ylabel('Number of Parameters')
ax2.set_title('Model Size vs Number of Comparators')
ax2.grid(True, alpha=0.3, axis='y')

for k, p in zip(K_list, params):
    ax2.annotate(f'{p}', (k, p), textcoords="offset points", 
                 xytext=(0, 5), ha='center')

plt.tight_layout()
plt.show()

## 9. Visualize Comparator Responses

For the best model, visualize how different comparators (different phi_offsets) respond.

In [None]:
# Select best K
best_K = max(results_by_k.keys(), key=lambda k: results_by_k[k]['final_acc'])
best_model = results_by_k[best_K]['model']
phi_offsets = results_by_k[best_K]['phi_offsets']

print(f"Best K = {best_K}, Accuracy = {results_by_k[best_K]['final_acc']:.4f}")
print(f"Phi offsets: {phi_offsets}")

# Get comparator outputs for all samples
best_model.eval()
with torch.no_grad():
    class_scores, comparator_outputs = comparator_bank_forward(best_model, X_seq)

# Visualize: For each comparator, show which samples activate each class neuron
fig, axes = plt.subplots(2, best_K, figsize=(3*best_K, 6))

for k in range(best_K):
    output = comparator_outputs[k].numpy()  # [N, 2]
    
    for c in range(2):
        ax = axes[c, k]
        
        # Color by true class
        class0_mask = y_labels.numpy() == 0
        class1_mask = y_labels.numpy() == 1
        
        ax.scatter(X_data[class0_mask, 0], X_data[class0_mask, 1], 
                   c=output[class0_mask, c], cmap='Blues', s=10, alpha=0.7,
                   vmin=0, vmax=output[:, c].max())
        ax.scatter(X_data[class1_mask, 0], X_data[class1_mask, 1], 
                   c=output[class1_mask, c], cmap='Reds', s=10, alpha=0.7,
                   vmin=0, vmax=output[:, c].max())
        
        class_name = "Class 0" if c == 0 else "Class 1"
        ax.set_title(f'φ={phi_offsets[k]:.2f}\n{class_name} neuron', fontsize=9)
        ax.set_aspect('equal')
        ax.set_xticks([])
        ax.set_yticks([])

axes[0, 0].set_ylabel('Class 0 Neuron Response')
axes[1, 0].set_ylabel('Class 1 Neuron Response')

plt.suptitle(f'Comparator Bank Responses (K={best_K})\nBrighter = Higher Activation', fontsize=12)
plt.tight_layout()
plt.show()

## 10. Decision Boundary

In [None]:
def plot_comparator_decision_boundary(model, X_data, y_data, ax, title, resolution=100):
    """Plot decision boundary for comparator bank classifier."""
    x_min, x_max = X_data[:, 0].min() - 0.02, X_data[:, 0].max() + 0.02
    y_min, y_max = X_data[:, 1].min() - 0.02, X_data[:, 1].max() + 0.02
    
    xx, yy = np.meshgrid(
        np.linspace(x_min, x_max, resolution),
        np.linspace(y_min, y_max, resolution)
    )
    
    grid_points = np.c_[xx.ravel(), yy.ravel()]
    grid_tensor = torch.FloatTensor(grid_points)
    grid_seq = grid_tensor.unsqueeze(1).expand(-1, SEQ_LEN, -1).clone()
    
    model.eval()
    with torch.no_grad():
        class_scores, _ = comparator_bank_forward(model, grid_seq)
        probs = torch.softmax(class_scores, dim=1)[:, 1].numpy()
    
    Z = probs.reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, levels=50, cmap='RdBu', alpha=0.7)
    ax.contour(xx, yy, Z, levels=[0.5], colors='black', linewidths=2)
    
    for c, color in enumerate(['blue', 'red']):
        mask = y_data == c
        ax.scatter(X_data[mask, 0], X_data[mask, 1], c=color, 
                   s=15, alpha=0.6, edgecolors='white', linewidths=0.3)
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(title)
    ax.set_aspect('equal')


# Plot decision boundaries for different K values
fig, axes = plt.subplots(1, len(K_VALUES), figsize=(4*len(K_VALUES), 4))

X_np = X_data.numpy()
y_np = y_labels.numpy()

for idx, K in enumerate(K_VALUES):
    model = results_by_k[K]['model']
    acc = results_by_k[K]['final_acc']
    
    plot_comparator_decision_boundary(
        model, X_np, y_np, axes[idx],
        f'K={K}\nAcc={acc:.3f}'
    )

plt.suptitle('Decision Boundaries: Comparator Bank Readout', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

## 11. Compare with Two-Neuron Approach

Build a simple two-neuron model (K=1 with single phi_offset) for comparison.

In [None]:
def build_two_neuron_classifier(hidden_dims, phi_offset=0.23, input_dim=2, dt=50.0):
    """Build simple two-neuron SingleDendrite readout for comparison."""
    sim_cfg = SimulationConfig(
        dt=dt,
        input_type="state",
        track_phi=False,
        track_power=False,
    )
    
    layers = []
    connections = []
    
    layers.append(LayerConfig(
        layer_id=0,
        layer_type="Input",
        params={"dim": input_dim},
    ))
    
    for i, hidden_dim in enumerate(hidden_dims):
        layer_id = i + 1
        layers.append(LayerConfig(
            layer_id=layer_id,
            layer_type="SingleDendrite",
            params={
                "dim": hidden_dim,
                "solver": "FE",
                "source_func": "Heaviside_fit_state_dep",
                "phi_offset": 0.02,
                "bias_current": 1.98,
                "gamma_plus": 0.0005,
                "gamma_minus": 1e-6,
            },
        ))
        connections.append(ConnectionConfig(
            from_layer=layer_id - 1,
            to_layer=layer_id,
            connection_type="all_to_all",
            learnable=True,
            params={"init": "xavier_uniform"},
        ))
    
    output_layer_id = len(hidden_dims) + 1
    layers.append(LayerConfig(
        layer_id=output_layer_id,
        layer_type="SingleDendrite",
        params={
            "dim": 2,
            "solver": "FE",
            "source_func": "Heaviside_fit_state_dep",
            "phi_offset": phi_offset,
            "bias_current": 1.98,
            "gamma_plus": 0.0005,
            "gamma_minus": 1e-6,
        },
    ))
    connections.append(ConnectionConfig(
        from_layer=output_layer_id - 1,
        to_layer=output_layer_id,
        connection_type="all_to_all",
        learnable=True,
        params={"init": "xavier_uniform"},
    ))
    
    return SOENModelCore(
        sim_config=sim_cfg,
        layers_config=layers,
        connections_config=connections,
    )


def train_two_neuron(model, X_train, y_train, n_epochs=300, lr=0.02):
    """Train two-neuron classifier with BCE on difference."""
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    
    y_target = y_train.float().unsqueeze(1)
    losses, accuracies = [], []
    
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        final_hist, _ = model(X_train)
        output = final_hist[:, -1, :]
        logits = (output[:, 1] - output[:, 0]).unsqueeze(1)
        
        loss = criterion(logits, y_target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        with torch.no_grad():
            preds = (torch.sigmoid(logits) > 0.5).float()
            acc = (preds == y_target).float().mean().item()
        
        losses.append(loss.item())
        accuracies.append(acc)
    
    return losses, accuracies


# Train two-neuron baseline
print("Training two-neuron baseline (BCE on s1-s0)...")
torch.manual_seed(42)
two_neuron_model = build_two_neuron_classifier(HIDDEN_DIM, phi_offset=0.23)
two_neuron_losses, two_neuron_accs = train_two_neuron(two_neuron_model, X_seq, y_labels, n_epochs=N_EPOCHS, lr=LR)
two_neuron_final_acc = two_neuron_accs[-1]
print(f"Two-neuron final accuracy: {two_neuron_final_acc:.4f}")

## 12. Summary Comparison

In [None]:
import pandas as pd

# Build comparison table
comparison_data = []

# Two-neuron baseline
comparison_data.append({
    'Method': 'Two-Neuron (BCE)',
    'K': 1,
    'Params (Total)': count_params(two_neuron_model),
    'Params (Effective)': count_params(two_neuron_model),
    'Final Acc': f'{two_neuron_final_acc:.4f}',
})

# Comparator bank results
for K, res in results_by_k.items():
    comparison_data.append({
        'Method': f'Comparator Bank (Shared W)',
        'K': K,
        'Params (Total)': res['n_params'],
        'Params (Effective)': res['n_effective_params'],
        'Final Acc': f"{res['final_acc']:.4f}",
    })

df = pd.DataFrame(comparison_data)

print("=" * 80)
print("COMPARISON: TWO-NEURON vs COMPARATOR BANK (SHARED WEIGHTS)")
print("=" * 80)
print(f"\nHidden architecture: {HIDDEN_DIM}")
print(f"Comparator phi_offset range: {PHI_RANGE}")
print(f"Training epochs: {N_EPOCHS}")
print(f"Weight tying: ENABLED (all comparators share same weights)")
print()
print(df.to_string(index=False))
print("=" * 80)

In [None]:
# Final visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Bar chart
ax1 = axes[0]
methods = ['2-Neuron'] + [f'CB K={k}' for k in K_VALUES]
accs = [two_neuron_final_acc] + [results_by_k[k]['final_acc'] for k in K_VALUES]
colors = ['steelblue'] + ['coral'] * len(K_VALUES)

bars = ax1.bar(methods, accs, color=colors, alpha=0.7)
ax1.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax1.set_ylabel('Accuracy')
ax1.set_title('Final Accuracy: Two-Neuron vs Comparator Bank')
ax1.set_ylim(0.4, 1.05)
ax1.tick_params(axis='x', rotation=30)

for bar, acc in zip(bars, accs):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{acc:.3f}', ha='center', va='bottom', fontsize=9)

# Training curves comparison
ax2 = axes[1]
ax2.plot(two_neuron_accs, label='Two-Neuron (BCE)', color='steelblue', lw=2)
best_K = max(results_by_k.keys(), key=lambda k: results_by_k[k]['final_acc'])
ax2.plot(results_by_k[best_K]['accuracies'], label=f'Comparator Bank K={best_K}', color='coral', lw=2)
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Curves: Best Methods')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0.4, 1.05)

plt.tight_layout()
plt.show()

## 13. Conclusions

In [None]:
print("=" * 70)
print("CONCLUSIONS")
print("=" * 70)

best_K = max(results_by_k.keys(), key=lambda k: results_by_k[k]['final_acc'])
best_acc = results_by_k[best_K]['final_acc']

print(f"\n1. BEST COMPARATOR BANK: K={best_K} with accuracy {best_acc:.4f}")
print(f"   Two-neuron baseline: {two_neuron_final_acc:.4f}")
print(f"   Improvement: {best_acc - two_neuron_final_acc:+.4f}")

print("\n2. EFFECT OF K (Comparators per Class) - WITH SHARED WEIGHTS:")
for K in K_VALUES:
    acc = results_by_k[K]['final_acc']
    n_eff = results_by_k[K]['n_effective_params']
    print(f"   K={K}: Acc={acc:.4f}, Effective Params={n_eff}")

print("\n3. WEIGHT SHARING BENEFITS:")
print("   - All comparators see the SAME evidence")
print("   - Different thresholds create true staircase approximation")
print("   - Effective parameters stay constant regardless of K")
print("   - More robust than independent weights")

print("\n4. HARDWARE IMPLICATIONS:")
print("   - Shared weights = single fan-out from hidden layer")
print("   - Different phi_offsets = different physical thresholds")
print("   - Spike counting gives graded output")
print("   - No need for precise analog readout")

print("\n5. RECOMMENDATION:")
if best_acc > two_neuron_final_acc:
    print(f"   ✓ Comparator bank (K={best_K}) outperforms two-neuron approach!")
    print(f"   ✓ Shared weights provide true staircase approximation")
else:
    print("   Two-neuron approach with BCE loss remains competitive")
    print("   Consider comparator bank for robustness to threshold calibration")

print("\n" + "=" * 70)