In [11]:
import sys, torch
print(sys.executable)
print("Torch: ", torch.__version__)

/home/jimson/miniconda3/envs/kws-clean/bin/python
Torch:  2.4.1+cpu


In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchaudio
import torchaudio.transforms as T
from pathlib import Path

#random seed for reproducibility
torch.manual_seed(42)



<torch._C.Generator at 0x7f1ddc1e8b50>

In [None]:
import os
from pathlib import Path

import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import DataLoader

SAMPLE_RATE = 16000
KEYWORDS = ["yes", "no", "stop", "go", "up", "down"]
KW_SET = set(KEYWORDS)
LABEL2IDX = {k: i for i, k in enumerate(KEYWORDS)}

DATA_ROOT = Path("data")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

#converting raw audio data to better representation
MFCC = T.MFCC(
    sample_rate=SAMPLE_RATE,
    #features of each audio being retained
    n_mfcc=40,
    melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 40, "center": False},
)

#wrapping speechCommands and filtering to key words 
class KeywordCommands(torchaudio.datasets.SPEECHCOMMANDS):
    #calling constructor of parent class "torchaudio.datasets.SPEECHCOMMANDS"
    def __init__(self, subset: str):
        super().__init__(root=str(DATA_ROOT), download=True, subset=subset)
        #fid is full file path : /home/jimson/data/SpeechCommands/speech_commands_v0.02/yes/0a7c2a8d_nohash_0.wav
        def label_from_fileid(fid: str) -> str:
            # Make path relative to dataset root, then take first component as label
            # rel = yes/0a7c2a8d_nohash_0.wav
            rel = os.path.relpath(fid, self._path)
            #splits rel by / or \ and returns yes
            return rel.split(os.sep)[0]

        #self._walker is python list of file paths the dataset will use to load samples
        all_files = list(self._walker)
        #self._walker is list of file paths in the keywords set
        self._walker = [f for f in all_files if label_from_fileid(f) in KW_SET]

        
        print(f"[{subset}] total files before keyword filter: {len(all_files)}")
        print(f"[{subset}] files after keyword filter: {len(self._walker)}   keywords={sorted(KW_SET)}")

    def __getitem__(self, index):
        #calling SPEECHCOMMANDS getitem func from loaded audio file at position index in self._walker
        #usually comes in 5 tuple (wavefrom, sr, label, speaker_id, utterace num), but *_ ignores rest of values in tuple
        waveform, sr, label, *_ = super().__getitem__(index)
        # takes [channels, features, Timeframes] and converts to [features, timeFrames];
        #what audio sounds like at that timeframe
        mfcc = MFCC(waveform).squeeze(0)
        #standardizing MFCC feautres so they have zero mean and unnit variance
        mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6)
        #returning mfcc tensor and index to keyword 
        return mfcc, LABEL2IDX[label]
    
#mfcc tensors don't all have the same T(time length)
#spoken words might have different durations
#processing multiple samples together , all tensors need to be of the same shape 
def pad_collate(batch):
    # * is the unpackign operator
    tensors, targets = zip(*batch)
    #getting mac timeFrame from list
    max_len = max(t.shape[1] for t in tensors)
    #padding tensors to be the same time_frame
    padded = [torch.nn.functional.pad(t, (0, max_len - t.shape[1])) for t in tensors]
    #combining all padded [40, max_len] tensors into one big batch tensor and returning: [B, 40, max_len] and [B]
    return torch.stack(padded), torch.tensor(targets)


train_ds = KeywordCommands("training")
val_ds   = KeywordCommands("validation")
test_ds  = KeywordCommands("testing")

for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
    if len(ds) == 0:
        raise RuntimeError(
            f"{name} dataset is empty after filtering.\n"
        )

BATCH_SIZE = 64
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, collate_fn=pad_collate)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, collate_fn=pad_collate)


[training] total files before keyword filter: 84843
[training] labels present (first 15): ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn'] ...
[training] files after keyword filter: 18657   keywords=['down', 'go', 'no', 'stop', 'up', 'yes']
[validation] total files before keyword filter: 9981
[validation] labels present (first 15): ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn'] ...
[validation] files after keyword filter: 2252   keywords=['down', 'go', 'no', 'stop', 'up', 'yes']
[testing] total files before keyword filter: 11005
[testing] labels present (first 15): ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn'] ...
[testing] files after keyword filter: 2468   keywords=['down', 'go', 'no', 'stop', 'up', 'yes']


In [None]:
import snntorch as snn
from snntorch import surrogate

class SNNKeywordNet(nn.Module):
    def __init__(self, num_classes, beta=0.9, num_steps=8):
        super().__init__()
        self.num_steps = num_steps

        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2)

        spike_grad = surrogate.fast_sigmoid()
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta= 0.95, spike_grad=spike_grad)
        self.lif3 = snn.Leaky(beta= 0.98, spike_grad=spike_grad)

        self.fc1 = None
        self.fc2 = None
        self.num_classes = num_classes

    def _build_heads_if_needed(self, x):
        with torch.no_grad():
            z1 = self.pool1(self.conv1(x))
            z2 = self.pool2(self.conv2(z1))
            flat_dim = z2.flatten(start_dim=1).shape[1]
        if self.fc1 is None:
            self.fc1 = nn.Linear(flat_dim, 64).to(x.device)
            self.fc2 = nn.Linear(64, self.num_classes).to(x.device)

    def forward(self, x, return_spikes: bool = False):
        # x: [B, 40, T] -> [B, 1, 40, T]
        x = x.unsqueeze(1)
        self._build_heads_if_needed(x)

        mem1 = mem2 = mem3 = None
        logits_sum = 0.0

        # spike counters
        if return_spikes:
            spikes_l1 = 0.0
            spikes_l2 = 0.0
            spikes_l3 = 0.0

        for _ in range(self.num_steps):
            z1 = self.conv1(x)
            spk1, mem1 = self.lif1(z1, mem1) if mem1 is not None else self.lif1(z1)
            p1 = self.pool1(spk1)

            z2 = self.conv2(p1)
            spk2, mem2 = self.lif2(z2, mem2) if mem2 is not None else self.lif2(z2)
            p2 = self.pool2(spk2)

            flat = p2.flatten(start_dim=1)
            z3 = self.fc1(flat)
            spk3, mem3 = self.lif3(z3, mem3) if mem3 is not None else self.lif3(z3)

            logits = self.fc2(spk3)
            logits_sum = logits_sum + logits

            if return_spikes:
                # sum over batch and all units
                spikes_l1 += spk1.sum().item()
                spikes_l2 += spk2.sum().item()
                spikes_l3 += spk3.sum().item()

        logits_mean = logits_sum / self.num_steps

        if return_spikes:
            spike_dict = {
                "layer1": spikes_l1,
                "layer2": spikes_l2,
                "layer3": spikes_l3,
                "total": spikes_l1 + spikes_l2 + spikes_l3,
            }
            return logits_mean, spike_dict

        return logits_mean


In [15]:
# Prep: make sure these are imported and the model exists
import torch
import torch.nn as nn

# SNNKeywordNet must be defined above (the class you wrote)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SNNKeywordNet(num_classes=len(KEYWORDS), beta=0.9, num_steps=8).to(device)

# (Optional) sanity-check a forward pass using one batch to confirm shapes
xb, yb = next(iter(train_loader))
with torch.no_grad():
    out = model(xb.to(device))
print("Sanity check → logits shape:", out.shape)  # should be [batch_size, len(KEYWORDS)]


Sanity check → logits shape: torch.Size([64, 6])


In [16]:
import torch.optim as optim
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
EPOCHS = 20  # adjust as needed

# Functions for one epoch of training and validation
def train_epoch(loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * y.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, desc="Validating", leave=False):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item() * y.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return total_loss / total, correct / total

In [17]:

@torch.no_grad()
def count_spikes(model, data_loader, batches=1):
    model.eval()
    total_spikes = 0.0
    by_layer = {"layer1": 0.0, "layer2": 0.0, "layer3": 0.0}
    seen = 0
    for x, _ in data_loader:
        x = x.to(device)
        _, spikes = model(x, return_spikes=True)
        total_spikes += spikes["total"]
        for k in ("layer1", "layer2", "layer3"):
            by_layer[k] += spikes[k]
        seen += 1
        if seen >= batches:
            break
    return total_spikes, by_layer

# Final evaluation on the test set
test_loss, test_acc = eval_epoch(test_loader)
print(f"Test Loss: {test_loss:.3f}, Test Accuracy: {test_acc:.2%}")

total_spk, spk_by_layer = count_spikes(model, test_loader, batches=1)
print(f"Approximate spikes (1 batch): {int(total_spk)}")
print("By layer:", {k: int(v) for k, v in spk_by_layer.items()})


                                                           

Test Loss: 1.793, Test Accuracy: 16.33%
Approximate spikes (1 batch): 622354
By layer: {'layer1': 523184, 'layer2': 99121, 'layer3': 49}
