# SpikeCore Lab — PyTorch → TVM BYOC → Neuromorphic Target

This notebook demonstrates the full compilation pipeline from a trained PyTorch model
to a fictional neuromorphic hardware target ("SpikeCore") using TVM's BYOC framework.

**Pipeline stages:**
1. Train a 2-layer MLP on MNIST
2. Export to TVM Relay IR
3. Quantize float32 → int8
4. Register SpikeCore as BYOC target
5. Partition graph (host vs. SpikeCore subgraphs)
6. Code generation → SpikeCore assembly
7. Simulate on SpikeCore hardware model
8. Compare PyTorch vs. SpikeCore outputs

**SpikeCore hardware model:** 128 neuron cores, int8 weights, int32 accumulators,
3 primitive ops (ACC, FIRE, LEAK), event-driven execution.

## Cell 1 — Setup & Imports

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt

# Check for TVM availability
try:
    import tvm
    from tvm import relay
    HAS_TVM = True
    print(f"TVM version: {tvm.__version__}")
except ImportError:
    HAS_TVM = False
    print("TVM not available — using standalone compilation path")

import torch
import torch.nn as nn
import torch.nn.functional as F
print(f"PyTorch version: {torch.__version__}")

# SpikeCore modules
from spikecore.hardware_model import SpikeCoreCPU
from spikecore.assembly import assemble, disassemble, Opcode
from spikecore.byoc_codegen import compile_nn_to_spikecore
from spikecore.quantize import manual_quantize_weights, manual_quantize_activations
from spikecore.visualize import (
    plot_spike_raster, plot_compilation_graph,
    plot_weight_distribution, plot_comparison
)

print("All imports successful.")
print(f"SpikeCore simulator: {SpikeCoreCPU().__class__.__name__} with 128 cores")

## Cell 2 — PyTorch Model: 2-Layer MLP for MNIST

A minimal MLP: 784 → 64 (ReLU) → 10 (softmax). Trained for 2 epochs on MNIST.
This is intentionally simple — the focus is the compilation pipeline, not the model.

In [None]:
class MNISTNet(nn.Module):
    """2-layer MLP for MNIST digit classification."""
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )
    
    def forward(self, x):
        return self.layers(x.view(-1, 784))

model = MNISTNet()
print(model)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train on MNIST
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True)
test_data = datasets.MNIST('./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
for epoch in range(2):
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        loss = F.cross_entropy(model(batch_x), batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: avg loss = {total_loss / len(train_loader):.4f}")

# Evaluate
model.eval()
correct = 0
with torch.no_grad():
    for batch_x, batch_y in test_loader:
        preds = model(batch_x).argmax(dim=1)
        correct += (preds == batch_y).sum().item()
print(f"Test accuracy: {correct / len(test_data):.2%}")

## Cell 3 — Export to Relay IR

If TVM is available, we use `relay.frontend.from_pytorch()` to convert the model.
Otherwise, we extract weights directly from the PyTorch state dict.

The Relay IR is TVM's high-level graph representation — analogous to ONNX or
torch.fx but designed for compiler transformations.

In [None]:
# Extract PyTorch weights as numpy arrays
state = {k: v.detach().numpy() for k, v in model.state_dict().items()}
print("Extracted weights:")
for name, arr in state.items():
    print(f"  {name}: {arr.shape} ({arr.dtype})")

if HAS_TVM:
    # TVM path: trace model and convert to Relay
    dummy_input = torch.randn(1, 1, 28, 28)
    traced = torch.jit.trace(model, dummy_input)
    input_info = [("input0", (1, 1, 28, 28))]
    relay_mod, relay_params = relay.frontend.from_pytorch(traced, input_info)
    print("\n--- Relay IR ---")
    print(relay_mod["main"])
else:
    print("\n--- Equivalent Relay IR (for reference) ---")
    print("""
fn (%input0: Tensor[(1, 1, 28, 28), float32]) {
  %0 = nn.batch_flatten(%input0);
  %1 = nn.dense(%0, meta[relay.Constant][0] /* (64, 784) */);
  %2 = nn.bias_add(%1, meta[relay.Constant][1] /* (64,) */);
  %3 = nn.relu(%2);
  %4 = nn.dense(%3, meta[relay.Constant][2] /* (10, 64) */);
  %5 = nn.bias_add(%4, meta[relay.Constant][3] /* (10,) */);
  %5
}
""")

## Cell 4 — Quantize: float32 → int8

SpikeCore is integer-only: **int8 weights**, **int16 accumulators**.
We quantize the trained weights using symmetric quantization:

$$q = \text{round}\left(\frac{w}{\text{scale}}\right), \quad \text{scale} = \frac{\max|w|}{127}$$

This is the same quantization scheme used by TVM's `qconfig` with `global_scale`.

In [None]:
# Quantize each layer's weights
w1_fp = state['layers.0.weight']  # (64, 784)
b1_fp = state['layers.0.bias']    # (64,)
w2_fp = state['layers.2.weight']  # (10, 64)
b2_fp = state['layers.2.bias']    # (10,)

w1_int8, w1_scale, _ = manual_quantize_weights(w1_fp)
w2_int8, w2_scale, _ = manual_quantize_weights(w2_fp)
b1_int8, b1_scale, _ = manual_quantize_weights(b1_fp)
b2_int8, b2_scale, _ = manual_quantize_weights(b2_fp)

print("Quantization results:")
print(f"  Layer 1 weights: {w1_fp.shape} fp32 → int8 (scale={w1_scale:.6f})")
print(f"    Range: [{w1_int8.min()}, {w1_int8.max()}]")
print(f"  Layer 2 weights: {w2_fp.shape} fp32 → int8 (scale={w2_scale:.6f})")
print(f"    Range: [{w2_int8.min()}, {w2_int8.max()}]")

# Verify quantization error
w1_reconstructed = w1_int8.astype(np.float32) * w1_scale
quant_error = np.mean(np.abs(w1_fp - w1_reconstructed))
print(f"\n  Layer 1 mean quantization error: {quant_error:.6f}")
print(f"  Layer 1 mean weight magnitude:   {np.mean(np.abs(w1_fp)):.6f}")
print(f"  Relative error: {quant_error / np.mean(np.abs(w1_fp)):.2%}")

## Cell 5 — Register SpikeCore BYOC Target

In TVM's BYOC (Bring Your Own Codegen) framework, we register:
1. **Pattern table** — which Relay subgraphs map to SpikeCore ops
2. **Codegen callback** — how to emit SpikeCore assembly from matched patterns

Pattern matching rules:
- `nn.dense + bias_add + relu` → `ACC + FIRE` (hidden layer with activation)
- `nn.dense + bias_add` → `ACC` (output layer, no activation)

This is exactly how Intel would register Loihi as a BYOC target in TVM.

In [None]:
if HAS_TVM:
    from spikecore.byoc_codegen import register_spikecore_target, partition_for_spikecore
    from spikecore.relay_patterns import spikecore_pattern_table
    
    # Register the target
    register_spikecore_target()
    print("Registered 'spikecore' BYOC target in TVM")
    
    # Show pattern table
    patterns = spikecore_pattern_table()
    print(f"\nPattern table ({len(patterns)} patterns):")
    for name, pattern in patterns:
        print(f"  {name}")
else:
    print("BYOC target registration (reference — requires TVM):")
    print()
    print("  @tvm.register_func('relay.ext.spikecore')")
    print("  def spikecore_compiler(func):")
    print("      # Traverse Relay subgraph → emit SpikeCore assembly")
    print("      return codegen(func)")
    print()
    print("  Pattern table:")
    print("    spikecore.dense_relu  — nn.dense + bias_add + relu  → ACC + FIRE")
    print("    spikecore.dense_bias  — nn.dense + bias_add         → ACC")
    print("    spikecore.qnn_dense_clip — qnn.dense + add + clip   → ACC + FIRE (quantized)")

## Cell 6 — Partition Graph

TVM's partitioning pipeline splits the Relay graph into:
- **SpikeCore subgraphs** — offloaded to the neuromorphic accelerator
- **Host subgraphs** — remain on the CPU

Steps: `MergeComposite` → `AnnotateTarget` → `PartitionGraph`

In [None]:
if HAS_TVM:
    partitioned_mod = partition_for_spikecore(relay_mod, relay_params)
    print("--- Partitioned Relay IR ---")
    print(partitioned_mod["main"])
else:
    print("--- Partitioned Graph (reference) ---")
    print("""
fn (%input0: Tensor[(1, 1, 28, 28), float32]) {
  %0 = nn.batch_flatten(%input0);              // [host]
  %1 = @spikecore_dense_relu_0(%0);            // [spikecore] → ACC + FIRE
  %2 = @spikecore_dense_bias_1(%1);            // [spikecore] → ACC
  %2
}
""")

# Visualize the partitioned graph
layer_names = ["flatten", "dense_64\n+relu", "dense_10"]
layer_targets = ["host", "spikecore", "spikecore"]
layer_shapes = [(784, 784), (784, 64), (64, 10)]

fig = plot_compilation_graph(layer_names, layer_targets, layer_shapes)
plt.show()

## Cell 7 — Code Generation: Relay → SpikeCore Assembly

The BYOC codegen callback traverses each partitioned subgraph and emits
SpikeCore assembly instructions. Each output neuron becomes a core with:
- `ACC` — weighted accumulate from input spikes
- `FIRE` — threshold comparison + spike emission (for hidden layers)
- `LEAK` — membrane potential decay

In [None]:
# Compile using the standalone path (works with or without TVM)
weights = [w1_fp, w2_fp]
biases = [b1_fp, b2_fp]
activations = ["relu", None]

program, q_weights, q_biases = compile_nn_to_spikecore(weights, biases, activations)

print(f"Compiled program: {len(program)} instructions")
print(f"  ACC:  {sum(1 for i in program if i.opcode == Opcode.ACC)}")
print(f"  FIRE: {sum(1 for i in program if i.opcode == Opcode.FIRE)}")
print(f"  LEAK: {sum(1 for i in program if i.opcode == Opcode.LEAK)}")
print(f"  HALT: {sum(1 for i in program if i.opcode == Opcode.HALT)}")
print()

# Print first 20 instructions
listing = disassemble(program)
lines = listing.split('\n')
print("--- SpikeCore Assembly (first 20 instructions) ---")
for line in lines[:20]:
    print(line)
if len(lines) > 20:
    print(f"  ... ({len(lines) - 20} more instructions)")

## Cell 8 — Simulate on SpikeCore Hardware Model

The SpikeCore simulator executes the assembly on the virtual hardware:
- 128 neuron cores with local int8 weight memory
- Event-driven: spikes propagate between layers
- Membrane potential accumulates across timesteps

For this single-timestep rate-coded inference, the accumulator values
directly represent output logits.

In [None]:
# Pick a test sample
test_sample, test_label = test_data[0]
test_flat = test_sample.view(-1).numpy()  # (784,)

# Initialize SpikeCore CPU
cpu = SpikeCoreCPU(num_cores=128)

# Load quantized weights
for layer_id, (w_int8, scale) in enumerate(q_weights):
    cpu.load_weights(layer_id, w_int8)
    print(f"Loaded layer {layer_id}: {w_int8.shape} int8 weights onto cores")

# Quantize input
input_abs_max = np.max(np.abs(test_flat))
input_int8 = np.clip(
    np.round(test_flat / input_abs_max * 127), -128, 127
).astype(np.int8)
print(f"Input quantized: {test_flat.shape} fp32 → int8 (scale={input_abs_max/127:.6f})")

# Run simulation
cpu.load_program(program)
spike_counts = cpu.run(input_int8, timesteps=8)
spike_log = cpu.get_spike_log()

print(f"\nSimulation complete:")
print(f"  Timesteps: 8")
print(f"  Total spikes: {len(spike_log)}")
print(f"  Active cores: {len(set(s[1] for s in spike_log))}")

# Read output layer activations (cores 64-73 = layer 2)
output_cores = list(range(64, 74))
sc_outputs = cpu.get_output_activations(output_cores)
sc_pred = np.argmax(sc_outputs)
print(f"\nSpikeCore output scores: {sc_outputs}")
print(f"SpikeCore prediction: {sc_pred}")
print(f"True label: {test_label}")

## Cell 9 — Compare: PyTorch vs. SpikeCore

Run the same test samples through both the original PyTorch model and the
SpikeCore simulator, comparing top-1 predictions. Quantization introduces
some error, but the predictions should match on the vast majority of samples.

In [None]:
def spikecore_quantized_forward(q_weights, x_float):
    """Quantized forward pass matching SpikeCore's integer arithmetic.
    
    Per layer: quantize input → int matmul (ACC) → dequantize → activation (FIRE).
    This is what a properly calibrated neuromorphic compiler produces.
    """
    x = x_float.copy()
    for layer_idx, (w_int8, w_scale) in enumerate(q_weights):
        x_abs_max = np.max(np.abs(x))
        if x_abs_max < 1e-10:
            return np.zeros(w_int8.shape[0], dtype=np.float32)
        x_scale = x_abs_max / 127.0
        x_q = np.clip(np.round(x / x_scale), -128, 127).astype(np.int32)
        acc = x_q @ w_int8.astype(np.int32).T
        x = acc.astype(np.float64) * x_scale * w_scale
        if layer_idx < len(q_weights) - 1:
            x = np.maximum(x, 0)  # ReLU (FIRE equivalent)
    return x.astype(np.float32)

n_compare = 100
matches = 0
results = []

model.eval()
for i in range(n_compare):
    sample, label = test_data[i]
    flat = sample.view(-1).numpy()
    
    # PyTorch prediction
    with torch.no_grad():
        pt_logits = model(sample.unsqueeze(0)).numpy().flatten()
    pt_probs = np.exp(pt_logits) / np.exp(pt_logits).sum()
    pt_pred = np.argmax(pt_logits)
    
    # SpikeCore quantized prediction
    sc_logits = spikecore_quantized_forward(q_weights, flat)
    sc_pred = np.argmax(sc_logits)
    
    match = pt_pred == sc_pred
    if match:
        matches += 1
    results.append((label, pt_pred, sc_pred, match))

match_rate = matches / n_compare
print(f"PyTorch vs SpikeCore comparison ({n_compare} samples):")
print(f"  Matches: {matches}/{n_compare} ({match_rate:.1%})")
print(f"  Mismatches: {n_compare - matches}")

mismatches = [(l, pt, sc) for l, pt, sc, m in results if not m]
if mismatches:
    print(f"\n  First mismatch: label={mismatches[0][0]}, "
          f"PyTorch={mismatches[0][1]}, SpikeCore={mismatches[0][2]}")

# Side-by-side visualization for the first sample
sample_0, label_0 = test_data[0]
with torch.no_grad():
    pt_logits_0 = model(sample_0.unsqueeze(0)).numpy().flatten()
pt_probs_0 = np.exp(pt_logits_0) / np.exp(pt_logits_0).sum()

flat_0 = sample_0.view(-1).numpy()
sc_logits_0 = spikecore_quantized_forward(q_weights, flat_0)
sc_probs_0 = np.exp(sc_logits_0) / np.exp(sc_logits_0).sum()

fig = plot_comparison(pt_probs_0, sc_probs_0, title=f"Sample 0 (label={label_0})")
plt.show()

## Cell 10 — Visualize: Spike Raster & Weight Distribution

**Spike raster**: Shows which neuron cores fire at each timestep.
In a real neuromorphic chip, this temporal activity pattern encodes information
through spike timing — the basis of spike-timing-dependent plasticity (STDP).

**Weight distribution**: Shows the quantized int8 weight values.
A bell-shaped distribution centered near zero is typical for trained networks.

In [None]:
# Run multi-timestep simulation for richer spike activity
cpu = SpikeCoreCPU(num_cores=128)
for layer_id, (w_int8, scale) in enumerate(q_weights):
    cpu.load_weights(layer_id, w_int8)

cpu.load_program(program)
spike_counts = cpu.run(inp_q_0, timesteps=16)
spike_log = cpu.get_spike_log()

print(f"16-timestep simulation: {len(spike_log)} total spikes")

# Spike raster
fig1 = plot_spike_raster(spike_log, num_cores=74, num_timesteps=16,
                          title="SpikeCore — Spike Raster (MNIST sample)")
plt.show()

# Weight distributions
fig2, axes = plt.subplots(1, 2, figsize=(14, 4))
plot_weight_distribution(q_weights[0][0], "Layer 1 (784→64)", ax=axes[0])
plot_weight_distribution(q_weights[1][0], "Layer 2 (64→10)", ax=axes[1])
fig2.suptitle("Quantized Weight Distributions (int8)", fontsize=13, fontweight="bold")
fig2.tight_layout()
plt.show()

# Summary
print("\n" + "="*60)
print("PIPELINE SUMMARY")
print("="*60)
print(f"1. PyTorch model:     MLP 784→64→10 ({sum(p.numel() for p in model.parameters()):,} params)")
print(f"2. Relay IR:          nn.dense → bias_add → relu → nn.dense → bias_add")
print(f"3. Quantization:      float32 → int8 (symmetric, scale per tensor)")
print(f"4. BYOC partition:    2 SpikeCore subgraphs + 1 host (flatten)")
print(f"5. SpikeCore ASM:     {len(program)} instructions ({sum(1 for i in program if i.opcode == Opcode.ACC)} ACC, {sum(1 for i in program if i.opcode == Opcode.FIRE)} FIRE, {sum(1 for i in program if i.opcode == Opcode.LEAK)} LEAK)")
print(f"6. Simulation:        {len(spike_log)} spikes over 16 timesteps")
print(f"7. Accuracy match:    {match_rate:.1%} top-1 agreement on {n_compare} test samples")
print("="*60)
print()
print("Real-world mapping:")
print("  SpikeCore ACC   → Loihi dendritic accumulation")
print("  SpikeCore FIRE  → Loihi axon/spike generation")
print("  SpikeCore LEAK  → Loihi compartment leak current")
print("  BYOC partition  → Lava compiler's graph splitting")
print("  int8 quant      → Loihi's native 1-9 bit weight precision")