<a href="https://colab.research.google.com/github/inachenyx/SpeechSNN/blob/main/SpeakerExample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio snntorch tqdm



### Speaker Recognition Pipeline

In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
import snntorch as snn
from snntorch import utils
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchaudio.datasets import SPEECHCOMMANDS
import os
import random
import numpy as np
from tqdm import tqdm

# Set seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## When not enough RAM in colab

### Setup Dataset with On-the-Fly Rate Encoding

In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchaudio.datasets import SPEECHCOMMANDS
import os
from tqdm import tqdm
import random
import numpy as np

# For reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define MFCC transform
mfcc_transform = T.MFCC(
    sample_rate=16000,
    n_mfcc=13,
    melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 40},
)

# Load SpeechCommands dataset subset
class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__(".", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

train_raw = SubsetSC("training")
test_raw = SubsetSC("testing")

# Build speaker index
def build_speaker_index(dataset):
    speaker_ids = sorted({data[3] for data in dataset})
    return {sid: i for i, sid in enumerate(speaker_ids)}

speaker_dict = build_speaker_index(train_raw)
num_classes = len(speaker_dict)


### Custom Dataset with On-the-Fly Rate Encoding

In [None]:
class RateEncodedMFCCDataset(Dataset):
    def __init__(self, dataset, speaker_to_idx, transform, num_steps=20, max_len=80):
        self.dataset = dataset
        self.speaker_to_idx = speaker_to_idx
        self.transform = transform
        self.num_steps = num_steps
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sample_rate, _, speaker_id, _ = self.dataset[idx]
        mfcc = self.transform(waveform).squeeze(0)  # [13, time]
        mfcc = mfcc[:, :self.max_len]
        if mfcc.shape[1] < self.max_len:
            pad = torch.zeros(mfcc.shape[0], self.max_len - mfcc.shape[1])
            mfcc = torch.cat([mfcc, pad], dim=1)

        mfcc = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-5)
        spikes = (torch.rand(self.num_steps, *mfcc.shape) < mfcc.unsqueeze(0)).float()

        label = self.speaker_to_idx[speaker_id]
        return spikes, label


### LSTM-based Classifier

In [None]:
import torch.nn as nn

class SpikingLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):  # x: [B, T, 13, 80]
        x = x.mean(dim=3)  # average over time axis → shape: [B, T, 13]
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :])  # take last output step


### Training and Evaluation

In [None]:
def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        output = model(xb)
        loss = loss_fn(output, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb).argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return correct / total


In [None]:
# DataLoaders
train_loader = DataLoader(
    RateEncodedMFCCDataset(train_raw, speaker_dict, mfcc_transform),
    batch_size=64, shuffle=True, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    RateEncodedMFCCDataset(test_raw, speaker_dict, mfcc_transform),
    batch_size=64, shuffle=False, num_workers=2, pin_memory=True
)

# Model setup
model = SpikingLSTM(input_size=13, hidden_size=128, output_size=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Train
for epoch in range(10):
    loss = train_epoch(model, train_loader, optimizer, loss_fn)
    acc = evaluate(model, test_loader)
    print(f"Epoch {epoch+1}: Loss = {loss:.4f}, Test Acc = {acc:.4f}")


Epoch 1: Loss = 7.2427, Test Acc = 0.0224
Epoch 2: Loss = 6.5945, Test Acc = 0.0151
Epoch 3: Loss = 6.3827, Test Acc = 0.0303
Epoch 4: Loss = 6.0944, Test Acc = 0.0192
Epoch 5: Loss = 5.8730, Test Acc = 0.0370
Epoch 6: Loss = 5.7337, Test Acc = 0.0434
Epoch 7: Loss = 5.6463, Test Acc = 0.0362


## When have enough RAM storage

### Dataset Setup

In [None]:
# Custom subclass to extract speaker labels
class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__(".", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

train_dataset = SubsetSC("training")
test_dataset = SubsetSC("testing")


100%|██████████| 2.26G/2.26G [00:32<00:00, 75.1MB/s]


### Preprocessing: MFCC Extraction

In [None]:
mfcc_transform = T.MFCC(
    sample_rate=16000,
    n_mfcc=13,
    melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 40},
)

def extract_mfcc(dataset, max_len=80):
    features, labels, speakers = [], [], []
    speaker_to_idx = {}
    speaker_count = 0

    for waveform, sample_rate, _, speaker_id, _ in tqdm(dataset):
        mfcc = mfcc_transform(waveform).squeeze(0)  # [n_mfcc, time]
        mfcc = mfcc[:, :max_len]  # crop or pad
        if mfcc.shape[1] < max_len:
            pad = torch.zeros(mfcc.shape[0], max_len - mfcc.shape[1])
            mfcc = torch.cat([mfcc, pad], dim=1)

        if speaker_id not in speaker_to_idx:
            speaker_to_idx[speaker_id] = speaker_count
            speaker_count += 1

        features.append(mfcc)
        labels.append(speaker_to_idx[speaker_id])

    return torch.stack(features), torch.tensor(labels), speaker_to_idx

X_train, y_train, speaker_dict = extract_mfcc(train_dataset)
X_test, y_test, _ = extract_mfcc(test_dataset)


100%|██████████| 105829/105829 [11:39<00:00, 151.34it/s]
100%|██████████| 11005/11005 [01:18<00:00, 139.89it/s]


### Spiking Input Encoding (Rate Coding)

In [None]:
# Encode MFCC features as spikes: simple rate coding
def rate_encode(x, num_steps=20):
    x = (x - x.min()) / (x.max() - x.min())  # normalize
    spikes = torch.rand((num_steps, *x.shape)) < x.unsqueeze(0)
    return spikes.float()

# Encode entire dataset
def encode_dataset(X, num_steps=20):
    return torch.stack([rate_encode(x, num_steps) for x in tqdm(X)])

num_steps = 20
X_train_encoded = encode_dataset(X_train, num_steps)
X_test_encoded = encode_dataset(X_test, num_steps)


NameError: name 'X_train' is not defined

### LSTM-Based Spiking Model

In [None]:
class SpikingLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.readout = nn.Linear(hidden_size, output_size)

    def forward(self, spike_input):  # [T, B, C, F]
        T, B, C, F = spike_input.shape
        x = spike_input.permute(1, 0, 3, 2)  # [B, T, F, C]
        x = x.mean(-1)  # [B, T, F]
        out, _ = self.lstm(x)
        out = self.readout(out[:, -1, :])  # last time step
        return out


### Training and Evaluation

In [None]:
def train(model, optimizer, criterion, X, y, batch_size=64):
    model.train()
    perm = torch.randperm(X.size(0))
    X, y = X[perm], y[perm]

    for i in range(0, len(X), batch_size):
        xb = X[i:i+batch_size].to(device)
        yb = y[i:i+batch_size].to(device)

        out = model(xb)
        loss = criterion(out, yb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def evaluate(model, X, y, batch_size=64):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for i in range(0, len(X), batch_size):
            xb = X[i:i+batch_size].to(device)
            yb = y[i:i+batch_size].to(device)
            out = model(xb)
            pred = torch.argmax(out, dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)

    return correct / total

In [None]:
input_size = X_train.shape[1]  # MFCCs = 13
hidden_size = 128
output_size = len(speaker_dict)

model = SpikingLSTM(input_size, hidden_size, output_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    train(model, optimizer, criterion, X_train_encoded, y_train)
    acc = evaluate(model, X_test_encoded, y_test)
    print(f"Epoch {epoch+1}: Test Accuracy = {acc:.4f}")
