In [None]:
from pathlib import Path
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchaudio
import torchaudio.transforms as T
import soundfile as sf

import nengo
import nengo_loihi
import numpy as np
from torch.utils.data import Dataset, DataLoader


PROJECT_ROOT = Path.cwd().resolve().parent   # parent of loihi_emulator folder
DATA_ROOT = Path("data") / "SpeechCommands" / "speech_commands_v0.02"
MODEL_DIR = PROJECT_ROOT / "saved_models"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATA_ROOT exists:", DATA_ROOT.exists())
print("MODEL_DIR exists:", MODEL_DIR.exists())

print(np.__version__)

device = torch.device("cpu")  
print("Using device:", device)

# Same 6 classes
CLASSES = ["yes", "no", "go", "stop", "down", "up"]
NUM_CLASSES = len(CLASSES)


PROJECT_ROOT: /home/jimson/CS-576-Final-Project
DATA_ROOT exists: True
MODEL_DIR exists: True
2.1.2
Using device: cpu


In [57]:
import snntorch as snn
from snntorch import surrogate

class SNNKeywordNet(nn.Module):
    def __init__(self, num_classes, beta=0.9, num_steps=8):
        super().__init__()
        self.num_steps = num_steps

        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2)

        spike_grad = surrogate.fast_sigmoid()
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta=0.95, spike_grad=spike_grad)
        self.lif3 = snn.Leaky(beta=0.98, spike_grad=spike_grad)

        self.fc1 = None
        self.fc2 = None
        self.num_classes = num_classes

    def _build_heads_if_needed(self, x):
        with torch.no_grad():
            z1 = self.pool1(self.conv1(x))
            z2 = self.pool2(self.conv2(z1))
            flat_dim = z2.flatten(start_dim=1).shape[1]
        if self.fc1 is None:
            self.fc1 = nn.Linear(flat_dim, 64).to(x.device)
            self.fc2 = nn.Linear(64, self.num_classes).to(x.device)

    def forward(self, x, return_spikes: bool = False):
        x = x.unsqueeze(1)          # [B,1,40,T]
        self._build_heads_if_needed(x)

        mem1 = mem2 = mem3 = None
        logits_sum = 0.0

        if return_spikes:
            spikes_l1 = spikes_l2 = spikes_l3 = 0.0

        for _ in range(self.num_steps):
            z1 = self.conv1(x)
            spk1, mem1 = self.lif1(z1, mem1) if mem1 is not None else self.lif1(z1)
            p1 = self.pool1(spk1)

            z2 = self.conv2(p1)
            spk2, mem2 = self.lif2(z2, mem2) if mem2 is not None else self.lif2(z2)
            p2 = self.pool2(spk2)

            flat = p2.flatten(start_dim=1)
            z3 = self.fc1(flat)
            spk3, mem3 = self.lif3(z3, mem3) if mem3 is not None else self.lif3(z3)

            logits = self.fc2(spk3)
            logits_sum = logits_sum + logits

            if return_spikes:
                spikes_l1 += spk1.sum().item()
                spikes_l2 += spk2.sum().item()
                spikes_l3 += spk3.sum().item()

        logits_mean = logits_sum / self.num_steps

        if return_spikes:
            spike_dict = {
                "layer1": spikes_l1,
                "layer2": spikes_l2,
                "layer3": spikes_l3,
                "total": spikes_l1 + spikes_l2 + spikes_l3,
            }
            return logits_mean, spike_dict

        return logits_mean


In [None]:
import torch
import torch.nn as nn

snn_ckpt_path = MODEL_DIR / "baseline_snn_kws_vfinal.pt"
print("SNN checkpoint exists:", snn_ckpt_path.exists())

# load full checkpoint
state_dict_snn = torch.load(snn_ckpt_path, map_location=device)

# Build a fresh SNN model
snn_model = SNNKeywordNet(num_classes=NUM_CLASSES, beta=0.9, num_steps=8).to(device)

state_no_fc = {k: v for k, v in state_dict_snn.items() if not k.startswith("fc")}
missing, unexpected = snn_model.load_state_dict(state_no_fc, strict=False)
print("Loaded SNN conv + LIF weights.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

snn_model.eval()

W_snn = state_dict_snn["fc2.weight"].cpu().numpy()  # [6, 64]
b_snn = state_dict_snn["fc2.bias"].cpu().numpy()
print("SNN fc2 weight shape:", W_snn.shape)


SNN checkpoint exists: True
Loaded SNN conv + LIF weights.
Missing keys: []
Unexpected keys: []
SNN fc2 weight shape: (6, 64)


  state_dict_snn = torch.load(snn_ckpt_path, map_location=device)


In [None]:
SAMPLE_RATE = 16000
N_MFCC = 40

mfcc_transform = nn.Sequential(
    T.MFCC(
        sample_rate=SAMPLE_RATE,
        n_mfcc=N_MFCC,
        melkwargs={
            "n_fft": 400,
            "hop_length": 160,
            "n_mels": 40,
            "center": False,
        },
    ),
    T.AmplitudeToDB(),
)

def wav_to_mfcc(path: Path) -> torch.Tensor:
    waveform, sr = sf.read(str(path))
    waveform = torch.tensor(waveform).float()

    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    elif waveform.ndim == 2 and waveform.shape[1] > waveform.shape[0]:
        waveform = waveform.T  # [C,N]

    if sr != SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)

    mfcc = mfcc_transform(waveform).squeeze(0)  # [40,T]
    mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6)
    mfcc = torch.clamp(mfcc, -2.0, 2.0)
    return mfcc

class KWS_Dataset(Dataset):
    def __init__(self, files, classes):
        self.files = files
        self.classes = classes

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        mfcc = wav_to_mfcc(path)    # [40,T]
        label = path.parent.name
        y = self.classes.index(label)
        return mfcc, y

def pad_collate(batch):
    xs, ys = zip(*batch)
    max_t = max(x.shape[1] for x in xs)
    xs = [F.pad(x, (0, max_t - x.shape[1])) for x in xs]
    xs = torch.stack(xs)   # [B,40,T]
    ys = torch.tensor(ys)
    return xs, ys

file_list = []
for c in CLASSES:
    class_dir = DATA_ROOT / c
    file_list += sorted(class_dir.glob("*.wav"))

print("Total WAV files in 6 classes:", len(file_list))

random.seed(0)
random.shuffle(file_list)

test_files = file_list  

test_dataset = KWS_Dataset(test_files, CLASSES)
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=pad_collate,
)

print("Loaded test samples:", len(test_dataset))


Total WAV files in 6 classes: 23377
Loaded test samples: 23377


In [None]:
def extract_snn_features(x: torch.Tensor, model: SNNKeywordNet) -> torch.Tensor:
   
    model.eval()
    with torch.no_grad():
        x_ = x.unsqueeze(1)   # [B,1,40,T]
        model._build_heads_if_needed(x_)

        mem1 = mem2 = mem3 = None
        spk3_sum = None

        for _ in range(model.num_steps):
            z1 = model.conv1(x_)
            spk1, mem1 = model.lif1(z1, mem1) if mem1 is not None else model.lif1(z1)
            p1 = model.pool1(spk1)

            z2 = model.conv2(p1)
            spk2, mem2 = model.lif2(z2, mem2) if mem2 is not None else model.lif2(z2)
            p2 = model.pool2(spk2)

            flat = p2.flatten(start_dim=1)
            z3 = model.fc1(flat)
            spk3, mem3 = model.lif3(z3, mem3) if mem3 is not None else model.lif3(z3)

            if spk3_sum is None:
                spk3_sum = spk3
            else:
                spk3_sum = spk3_sum + spk3

        feats = spk3_sum / model.num_steps   # [B,64], average spikes/time
    return feats

example_path = next(iter(test_files))
print("Example file:", example_path)

mfcc_ex = wav_to_mfcc(example_path).unsqueeze(0).to(device)  # [1,40,T]
with torch.no_grad():
    feats_ex = extract_snn_features(mfcc_ex, snn_model)
    logits_ex, _ = snn_model(mfcc_ex, return_spikes=True)
    pred_idx = logits_ex.argmax(dim=1).item()

print("SNN feature shape:", feats_ex.shape)
print("SNN predicts:", CLASSES[pred_idx])


Example file: data/SpeechCommands/speech_commands_v0.02/stop/b84f83d2_nohash_0.wav
SNN feature shape: torch.Size([1, 64])
SNN predicts: yes


In [None]:
def run_loihi_for_feature(
    feat_vec: np.ndarray,
    W: np.ndarray,
    sim_time: float = 0.1
) -> np.ndarray:
    
    assert feat_vec.shape == (64,), f"Expected (64,), got {feat_vec.shape}"
    assert W.shape == (NUM_CLASSES, 64), f"W shape {W.shape} unexpected"

    with nengo.Network(seed=0) as net:
        inp = nengo.Node(output=lambda t: feat_vec)

        ens = nengo.Ensemble(
            n_neurons=64,
            dimensions=64,
            neuron_type=nengo.LIF(),
        )

        out = nengo.Node(size_in=NUM_CLASSES)

        nengo.Connection(inp, ens, synapse=None)
        nengo.Connection(
            ens.neurons,
            out,
            transform=W,
            synapse=0.01,
        )

        p_out = nengo.Probe(out, synapse=0.01)


    with nengo.Simulator(net) as sim:
        sim.run(sim_time)
        logits_loihi = sim.data[p_out][-1]

    logits_loihi = np.nan_to_num(logits_loihi)
    return logits_loihi


In [62]:
from typing import Tuple

def eval_loihi_classifier_snn(
    loader,
    snn_model: SNNKeywordNet,
    W: np.ndarray,
    device: torch.device,
    max_samples: int = 100,
    sim_time: float = 0.1,
) -> Tuple[float, float, int]:

    snn_model.eval()
    total = 0
    correct_snn = 0
    correct_loihi = 0

    for mfcc_batch, y_batch in loader:
        mfcc_batch = mfcc_batch.to(device)
        y_batch_np = y_batch.numpy()

        with torch.no_grad():
            feats = extract_snn_features(mfcc_batch, snn_model)   # [B,64]
            logits_snn, _ = snn_model(mfcc_batch, return_spikes=True)
            preds_snn = logits_snn.argmax(dim=1).cpu().numpy()

        batch_size = feats.size(0)
        for i in range(batch_size):
            feat_np = feats[i].cpu().numpy()
            label = int(y_batch_np[i])

            # SNN head prediction (PyTorch)
            if preds_snn[i] == label:
                correct_snn += 1

            # Loihi prediction
            logits_loihi = run_loihi_for_feature(feat_np, W=W, sim_time=sim_time)
            pred_loihi = int(np.argmax(logits_loihi))
            if pred_loihi == label:
                correct_loihi += 1

            total += 1
            if total >= max_samples:
                snn_acc = correct_snn / total
                loihi_acc = correct_loihi / total
                return snn_acc, loihi_acc, total

    snn_acc = correct_snn / max(total, 1)
    loihi_acc = correct_loihi / max(total, 1)
    return snn_acc, loihi_acc, total


In [63]:
max_samples = 50    # or 100 if you want
sim_time = 0.1      # 100 ms per sample

snn_acc, loihi_acc, total = eval_loihi_classifier_snn(
    loader=test_loader,
    snn_model=snn_model,
    W=W_snn,
    device=device,
    max_samples=50,
    sim_time=0.1,
)
print(f"Evaluated on {total} test samples")
print(f"SNN head accuracy:    {snn_acc*100:.2f}%")
print(f"Loihi classifier acc: {loihi_acc*100:.2f}%")

Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.                                                      
Simulation finished in 0:00:01.                                                 
Build finished in 0:00:01.  