<a href="https://colab.research.google.com/github/inachenyx/SpeechSNN/blob/main/Speech_SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [12]:
import os
import io
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF

### Upload data from local files (NEEDS MANUAL UPLOAD)

In [3]:
from google.colab import files
uploaded = files.upload()

Saving ae.train to ae.train
Saving ae.test to ae.test
Saving size_ae.train to size_ae.train
Saving size_ae.test to size_ae.test


In [4]:
# List uploaded files
os.listdir()

['.config',
 'size_ae.train',
 'size_ae.test',
 'ae.train',
 'ae.test',
 'sample_data']

### Read original data (has type bytes)

In [6]:
with open("ae.train", "rb") as f:
    train_bytes = f.read()

with open("ae.test", "rb") as f:
    test_bytes = f.read()

with open("size_ae.train", "rb") as f:
    train_counts_bytes = f.read()

with open("size_ae.test", "rb") as f:
    test_counts_bytes = f.read()

print(train_bytes[:100]) # bytes containing space-separated float values
print(type(train_bytes)) # <class 'bytes'> raw bytes
print(test_bytes[:100])
print(type(test_bytes)) # <class 'bytes'> raw bytes
print(train_counts_bytes)
print(type(train_counts_bytes)) # <class 'bytes'> raw bytes
print(test_counts_bytes)
print(type(test_counts_bytes)) # <class 'bytes'> raw bytes

b'1.860936 -0.207383 0.261557 -0.214562 -0.171253 -0.118167 -0.277557 0.025668 0.126701 -0.306756 -0.2'
<class 'bytes'>
b'1.635533 0.024848 0.432087 -0.361914 -0.074776 -0.693481 -0.229621 0.261503 -0.089421 -0.020431 -0.0'
<class 'bytes'>
b'30 30 30 30 30 30 30 30 30\n'
<class 'bytes'>
b'31 35 88 44 29 24 40 50 29 \n'
<class 'bytes'>


### Get train_counts, test_counts (how many utterances belong to each speaker)

In [7]:
def byte_to_list(byte_data):
  """Decode label counts from raw bytes data to a list of integers

    Args:
        byte_data (bytes): data read from size_ae.train, size_ae.test

    Returns:
        List[int]: the decoded label counts
    """
  text = byte_data.decode("utf-8")
  list_str = text.strip().split()
  list_fl = [int(x) for x in list_str]
  # list_fl = list(map(int, list_str))
  return list_fl

train_counts = byte_to_list(train_counts_bytes)
test_counts = byte_to_list(test_counts_bytes)
print(train_counts)
print(test_counts)

[30, 30, 30, 30, 30, 30, 30, 30, 30]
[31, 35, 88, 44, 29, 24, 40, 50, 29]


### Get Latency Encoded Tensors

In [9]:
def prepare_spike_data(byte_data, label_counts):
    """Parses raw byte input, assigns speaker labels, pads sequences, and normalizes.
    Args:
        - byte_data (bytes): raw byte input read from ae.train, ae.test
        - label_counts (List[int]): each element represents how many blocks belong to a speaker

    Returns:
        - torch.Tensor of shape (N, 12, T): padded and normalized data
        - torch.Tensor of shape (N,): speaker labels
        - int: maximum sequence length (T)
    """
    text = byte_data.decode("utf-8")
    blocks = [b.strip() for b in text.strip().split("\n\n") if b.strip()]

    data_blocks = []
    labels = []

    for speaker_id, count in enumerate(label_counts, start=1):
        for idx in range(count):
            lines = blocks[idx].splitlines()
            frame_data = [list(map(float, line.strip().split())) for line in lines]
            data_blocks.append(np.array(frame_data))  # shape (Ti, 12)
            labels.append(speaker_id)

    latency_encoded_blocks = []
    Tmax = max([block.shape[0] for block in data_blocks])  # maximum time steps

    for block in data_blocks:
        block = block.T  # shape (12, T)
        avg = np.mean(block)
        binary_block = (block >= avg).astype(np.float32)  # binarized (12, T)

        # Split into T (12, 1) columns
        T = binary_block.shape[1]
        columns = [binary_block[:, t].reshape(12, 1) for t in range(T)]

        # Norm-based latency mapping
        norms = np.array([np.linalg.norm(col) for col in columns])
        min_norm, max_norm = norms.min(), norms.max()

        # Prevent division by zero if all norms are the same
        if max_norm - min_norm == 0:
            norm_scaled = np.zeros_like(norms)
        else:
            norm_scaled = ((norms - min_norm) / (max_norm - min_norm)) * 50.0
        norm_scaled = norm_scaled.astype(int)  # map to [0, 50]ms

        # Create output matrix of shape (12, Tmax * 100)
        stim_length = Tmax * 100  # 100ms per time slot
        output_matrix = np.zeros((12, stim_length), dtype=np.float32)

        for i, col in enumerate(columns):
            t_offset = i * 100 + norm_scaled[i]  # place within 100ms window
            if t_offset < stim_length:
                output_matrix[:, t_offset] = col.flatten()

        latency_encoded_blocks.append(output_matrix)

    # Stack all (12, Tmax*100) into a tensor of shape (N, 12, Tmax*100)
    data_tensor = torch.tensor(np.stack(latency_encoded_blocks), dtype=torch.float32)
    label_tensor = torch.tensor(labels, dtype=torch.long)

    return data_tensor, label_tensor, Tmax


In [10]:
X_train, y_train, train_T = prepare_spike_data(train_bytes, train_counts)
X_test, y_test, test_T = prepare_spike_data(test_bytes, test_counts)

print(X_train.shape)  # (270, 12, train_T)
print(y_train.shape)  # (270,)
print(X_test.shape)  # (370, 12, test_T)
print(y_test.shape)  # (270,)
print(train_T) # 26
print(test_T) # 29

torch.Size([270, 12, 2600])
torch.Size([270])
torch.Size([370, 12, 2900])
torch.Size([370])
26
29


### Visualize latency encoded spike trains

### Training with snnTorch

In [11]:
# All tensors are shaped: (N, 12, Tmax*100)
# Labels are integers from 1 to 9 (convert to 0–8 for classification)
y_train -= 1
y_test -= 1

In [None]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF

# Parameters
beta = 0.9  # membrane decay
num_inputs = X_train.shape[1]  # 12
num_steps = X_train.shape[2]   # Tmax * 100
num_hidden = 100
num_outputs = 9

# Define the network
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x):
        spk1_rec, mem1 = [], None
        spk2_rec, mem2 = [], None

        for step in range(x.size(2)):  # Loop through time steps
            input_t = x[:, :, step]
            cur1 = self.fc1(input_t)
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)

        return torch.stack(spk2_rec, dim=0)

model = SNN()


In [None]:
from torch.utils.data import TensorDataset, DataLoader

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=64)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = SF.mse_count_loss(correct_rate=True)

for epoch in range(10):
    for batch in train_loader:
        x, y = batch
        spk_out = model(x)
        loss_val = loss_fn(spk_out, y)
        optimizer.zero_grad()
        loss_val.backward(retain_graph=True)
        optimizer.step()
    print(f"Epoch {epoch}: loss = {loss_val.item():.4f}")

In [None]:
def evaluate(model, loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            spk_out = model(x)
            spk_count = spk_out.sum(dim=0)  # Sum over time
            preds = spk_count.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

acc = evaluate(model, test_loader)
print(f"Test Accuracy: {acc * 100:.2f}%")
