So this is a working convolutional SNN with the Architecture proposed by that one paper I suggested.
This is neither optimized code nor is it compatible with our data yet but I thought it is a reasonable starting place.

So this Network receives the raw data as inputs not spikes. It seems easiest to just put in the spectogram data (what we called voltage) into the LIF neurons instead of using the poisson encoding first. (poisson encoding was at least useful for us to see if we can generate sensible spikes from the data).  

Right now there is a convolution on the input data itself, I would start out by trying to avoid and only convolve in the spike domain, but we'll see. I'm also unsere whether we should give the input timestep by timestep or not. With the setup right now, it expects it all at once.

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
from datasets import load_dataset
import torchaudio

train_set = torchaudio.datasets.SPEECHCOMMANDS(root='./data', download=True, subset='training')
test_set = torchaudio.datasets.SPEECHCOMMANDS(root='./data', download=True, subset='testing')

from torch.nn.utils.rnn import pad_sequence
"""
def collate_fn(batch):
    # batch is a list of tuples (waveform, label)
    waveforms,_, labels,_,_ = zip(*batch)
    # Pad sequences dynamically to match the length of the longest in the batch
    waveforms_padded = pad_sequence(waveforms, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    return waveforms_padded, labels
"""
max_length = 16000  # or your desired length

from sklearn.preprocessing import LabelEncoder

# Assume `labels` contains the string labels from your dataset
label_encoder = LabelEncoder()
#encoded_labels = label_encoder.fit_transform(labels)



def collate_fn(batch):
    max_length = 16000
    waveforms, labels = zip(*[(item[0].squeeze()[:max_length], item[2]) for item in batch])
    
    # Ensure each waveform is 1D before padding
    waveforms = [waveform if waveform.ndim == 1 else waveform.mean(dim=0) for waveform in waveforms]
    
    #print("Shapes before padding:", [w.shape for w in waveforms])  # Add this line
    
    waveforms_padded = pad_sequence(waveforms, batch_first=True, padding_value=0)

    labels = label_encoder.fit_transform(labels)  # Encode the labels as integers
    
    labels = torch.tensor(labels)
    #print('done')
    return waveforms_padded, labels



train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn, drop_last=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn, drop_last=True)

In [3]:
from scipy.signal.windows import exponential, gaussian
from scipy.signal import square, ShortTimeFFT


sample_rate=16000
g_std = 10      # standard deviation for Gaussian window in samples
win_size = 40   # window size in samples
win_gauss = gaussian(win_size, std=g_std, sym=True)  # symmetric Gaussian wind.
SFT = ShortTimeFFT(win_gauss, hop=2, fs=sample_rate, mfft=2000, scale_to='psd')
batch_size = 64
num_samples = 16000

duration = num_samples / sample_rate

In [4]:
import numpy as np
from scipy.signal import stft
import librosa

def mel_spectrogram(audio, sample_rate, n_mels=128, f_min=0, f_max=None):
  if f_max is None:
    f_max = sample_rate / 2
  _, _, spectrogram = stft(audio, nperseg=512, noverlap=256, fs=sample_rate)
  #print("spectrogram: ", spectrogram.shape)
  # mel_spectrogram = mel(spectrogram, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max)
  mel_spectrogram = mel(spectrogram, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max)
  return mel_spectrogram

def mel(spectrogram, sr=44100, n_mels=128, fmin=0, fmax=None):
  return librosa.feature.melspectrogram(S=spectrogram, sr=sr, n_mels=n_mels, fmin=fmin, fmax=fmax)

class RBFNetwork:
  def __init__(self, input_dim, num_centers, sigma):
    self.centers = np.random.rand(num_centers, input_dim)  # Initialize centers randomly
    #print("centers: ", self.centers.shape)
    self.sigma = sigma

  def rbf(self, x):
    #print("x: ", x.shape)
    #print("centers: ", self.centers.shape)
    # return np.exp(-np.linalg.norm(x - self.centers, axis=1) ** 2 / (2 * self.sigma ** 2))
    def compute_distances(xi):
            # xi - self.centers creates a new array where each center is subtracted from xi
            # np.linalg.norm(..., axis=1) computes the norm along the axis of the centers
            return np.linalg.norm(xi - self.centers, axis=1)
    norms = np.apply_along_axis(compute_distances, 1, x)
    return np.exp(- norms ** 2 / (2 * self.sigma ** 2))
    
  def predict(self, X):
    #print("## predict ##")
    #print("X: ", X.shape)
    y = self.rbf(X)
    #print("y: ", y.shape)
    # normalize to 0 - 1 along the batch dimension
    y = (y - np.min(y, axis=0)) / (np.max(y, axis=0) - np.min(y, axis=0))
    return y

def rbf_encode_audio(audio, sample_rate, SFT, n_mels=128, num_rbf=256, sigma=1.0):
  mel_spec = mel_spectrogram(audio, sample_rate, n_mels)
  mel_spec = np.abs(mel_spec)
  # mel_spec = SFT.spectrogram(audio)
  #print("shape mel_spec: ", mel_spec.shape)
  rbf_network = RBFNetwork(mel_spec.shape[0], num_rbf//10, sigma)
  rbf_activations = rbf_network.predict(mel_spec.T)  # transpose to get the batch dimension first
  return rbf_activations, mel_spec

In [5]:
# old model,ignore and use the next one

import torch
import torch.nn as nn
import spikingjelly as sj
#from spikingjelly import surrogate, neuron

class SpikingCNN(nn.Module):
    def __init__(self, num_classes, spike_grad=surrogate.ATan(), threshold=1.0, num_time_steps=16000):
        super(SpikingCNN, self).__init__()
        self.num_time_steps = num_time_steps
        self.conv1 = nn.Conv1d(25, 16, kernel_size=3, stride=1, padding=1)  # [batch_size, 16, num_time_steps]
        self.lif1 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)  # [batch_size, 16, num_time_steps // 2]
        self.lif2 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        
        self.lif3 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # x should be [batch_size, num_features, num_time_steps]
        print('x0', x.shape)
        x = self.conv1(x)
        print('x1', x.shape)
        x = self.lif1(x)  
        print('x2', x.shape)
        x = self.pool1(x)            
        print('x3', x.shape)
        x = self.conv2(x) 
        print('x4', x.shape) #
        x = self.lif2(x)
        print('x5', x.shape)
        x = self.pool2(x) #
        print('x6', x.shape)
        x = x.view(x.size(0), -1)   
        #print(x.shape)
        x = self.lif3(self.fc1(x))
        x = self.fc2(x)
        print(x.shape)
        return x

def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 25
    num_time_steps = 64
    spikes = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    for i in range(batch_size):
        rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        spike_prob_scale = 1.7
        rbf_activations_traversed = rbf_activations.T
        spik_probs = rbf_activations_traversed / np.max(rbf_activations_traversed, axis=1, keepdims=True) * spike_prob_scale 
        spike_trains = np.random.poisson(spik_probs[...] * duration, size=rbf_activations_traversed.shape)
        spike_trains = np.clip(spike_trains, 0, 1)
        spikes[i] = torch.from_numpy(spike_trains)
        
    return spikes

def train(model, dataloader, optimizer, loss_fn, num_time_steps, num_classes, batch_size,sr):
    model.train()
    count = 0
    for batch in dataloader:
        inputs, labels = batch
        optimizer.zero_grad()

        spikes = encode_to_spikes(inputs, sr)
        outputs = torch.zeros(inputs.size(0), num_classes, device=inputs.device)  # Accumulate outputs
        
        for t in range(num_time_steps):
            spike_input = spikes[:, :, t] # Shape: [batch_size, 1, num_time_steps]
            #print(spike_input.shape, spike_input.unsqueeze(1).shape)
            output = model(spike_input.unsqueeze(2))
            #output = model(spike_input)
            outputs += output  # Aggregate outputs over time

        # Compute loss and backpropagate
        loss = loss_fn(outputs, labels)
        loss.backward()  # Use surrogate gradients
        optimizer.step()

        # Optionally, print or log training progress
        
        if count % 100 == 0:
            print(count, f"Loss: {loss.item()}")
        functional.reset_net(model)
        count += 1

        


model = SpikingCNN(num_classes=35)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
sr = 16000  # Number of time steps to simulate
batch_size = 64
num_time_steps=64

# Assuming train_loader is correctly defined and provides batches of (inputs, labels)
train(model, train_loader, optimizer, loss_fn, num_time_steps, num_classes=35, batch_size=batch_size,sr=sr)

x0 torch.Size([64, 25, 1])
x1 torch.Size([64, 16, 1])
x2 torch.Size([64, 16, 1])


RuntimeError: max_pool1d() Invalid computed output size: 0

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
from snntorch import utils as sutils
from snntorch.functional import quant

class customSNet(nn.Module):
    def __init__(self, num_steps, beta, threshold=1.0, spike_grad=snn.surrogate.fast_sigmoid(slope=25), num_class=10):
        super().__init__()
        self.num_steps = num_steps
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc1 = nn.Linear(896, 128)  # use 6720 for real spektogram, 896 for spikes or rbf activity
        #self.fc1 = nn.Linear(6720, 128) # use 6720 for real spektogram, 896 for spikes or rbf activity
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc2 = nn.Linear(128, 64)
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc3 = nn.Linear(64, num_class)
        self.lif5 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        batch_size_curr = x.shape[0]
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky() 
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        # Record the final layer
        spk5_rec = []
        mem5_rec = []

        for step in range(self.num_steps):
            #print('x0', x.shape)
            cur1 = self.pool(self.conv1(x))
            #cur1 = self.conv1(x)
            #print('x1', x.shape)
            spk1, mem1 = self.lif1(cur1, mem1)
            #print('x2', spk1.shape, mem1.shape)
            cur2 = self.pool(self.conv2(spk1))
            #cur2 = self.conv2(spk1)
            #print('x3', cur2.shape)
            spk2, mem2 = self.lif2(cur2, mem2)
            #print('x4', spk2.shape, mem2.shape)
            cur3 = self.fc1(spk2.view(batch_size_curr, -1))
            #print('x5', cur3.shape)
            spk3, mem3 = self.lif3(cur3, mem3)
            #print('x6', spk3.shape, mem3.shape)
            cur4 = self.fc2(spk3)
            #print('x7', cur4.shape)
            spk4, mem4 = self.lif4(cur4, mem4)
            #print('x8', spk4.shape, mem4.shape)
            cur5 = self.fc3(spk4)
            #print('x9', cur5.shape)
            spk5, mem5 = self.lif5(cur5, mem5)
            #print('x10', spk5.shape, mem5.shape)
            
            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        return torch.stack(spk5_rec), torch.stack(mem5_rec)
    
# poisson spikes from rbf activity
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 25
    num_time_steps = 64
    spikes = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    rbfs = []
    for i in range(batch_size):
        rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        rbfs.append(rbf_activations)
        spike_prob_scale = 1.7
        rbf_activations_traversed = rbf_activations.T
        spik_probs = rbf_activations_traversed / np.max(rbf_activations_traversed, axis=1, keepdims=True) * spike_prob_scale 
        spike_trains = np.random.poisson(spik_probs[...] * duration, size=rbf_activations_traversed.shape)
        spike_trains = np.clip(spike_trains, 0, 1)
        spikes[i] = torch.from_numpy(spike_trains)
        
    return spikes
"""
# pseudo spectogram of rbf activity
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 25
    num_time_steps = 64
    rbfs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    for i in range(batch_size):
        rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        rbf_activations_traversed = rbf_activations.T
        rbfs[i] = torch.from_numpy(rbf_activations_traversed)
        
    return rbfs
"""

"""    
# real spectograms
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 128
    num_time_steps = 64
    specs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    for i in range(batch_size):
        #rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        spec = mel_spectrogram(data[i],sr)
        specs[i] = torch.from_numpy(spec)
        
    return rbfs
"""

'    \n# real spectograms\ndef encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):\n    batch_size = data.size(0)\n    num_features = 128\n    num_time_steps = 64\n    specs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)\n    for i in range(batch_size):\n        #rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)\n        spec = mel_spectrogram(data[i],sr)\n        specs[i] = torch.from_numpy(spec)\n        \n    return rbfs\n'

In [10]:
from tqdm import tqdm
import torch
import os
from snntorch import functional as SF


num_classes = 35
num_steps = 10
model = customSNet(num_steps = num_steps, beta = 0.9, threshold=1.0, spike_grad=snn.surrogate.fast_sigmoid(slope=25), num_class=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
sr = 16000  # Number of time steps to simulate
batch_size = 64
train_loss_hist = []
train_accu_hist = []
train_accu_hist_temp = []
#model.train()

n_epochs = 5


for epoch in range(n_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    iterCount = 0
    for batch in train_loader:
        inputs, labels = batch
        optimizer.zero_grad()

        spikes = encode_to_spikes(inputs, sr)
        #outputs = torch.zeros(inputs.size(0), num_classes, device=inputs.device)  
        spike_input = spikes.unsqueeze(1)

        model.train()
        spk_rec, mem_rec = model(spike_input)

        loss_val = torch.zeros((1), dtype=torch.float)
        for step in range(num_steps):
            loss_val += loss_fn(mem_rec[step], labels)

        # Gradient calculation + weight update
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        avg_loss = loss_val.item()/len(train_loader)
        train_loss_hist.append(loss_val.item())
        acc = SF.accuracy_rate(spk_rec, labels) 
        acc2 = SF.accuracy_temporal(spk_rec, labels) 
        train_accu_hist.append(acc)
        train_accu_hist_temp.append(acc2)
        iterCount +=1
    print(f' Epoch: {epoch} | Train Loss: {train_loss_hist[-1]:.3f} | Avg Loss: {avg_loss:.3f} | Accuracy: {train_accu_hist[-1]:.3f} | Accuracy: {train_accu_hist_temp[-1]:.3f} | Iteration: {iterCount}')

KeyboardInterrupt: 