In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
#from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter




In [8]:
from datasets import load_dataset
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from sklearn.preprocessing import LabelEncoder

#data = load_dataset('google/speech_commands', 'v0.01', split=['train','test'], trust_remote_code=True)
train_set = torchaudio.datasets.SPEECHCOMMANDS(root='C:/Users/wasse/hello', download=True, subset='training')
test_set = torchaudio.datasets.SPEECHCOMMANDS(root='C:/Users/wasse/hello', download=True, subset='testing')

max_length = 16000  # or your desired length
# Assume `labels` contains the string labels from your dataset
all_labels = [item[2] for item in train_set] 
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(all_labels)


def collate_fn(batch):
    max_length = 16000
    waveforms, labels = zip(*[(item[0].squeeze()[:max_length], item[2]) for item in batch])

    # Ensure each waveform is 1D before padding
    waveforms = [waveform if waveform.ndim == 1 else waveform.mean(dim=0) for waveform in waveforms]

    #print("Shapes before padding:", [w.shape for w in waveforms])  # Add this line

    waveforms_padded = pad_sequence(waveforms, batch_first=True, padding_value=0)

    labels = label_encoder.transform(labels)  # Encode the labels as integers

    labels = torch.tensor(labels)
    #print('done')
    return waveforms_padded, labels


train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn, drop_last=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn, drop_last=True)

In [3]:
from scipy.signal.windows import exponential, gaussian
from scipy.signal import square, ShortTimeFFT

sample_rate=16000
g_std = 10      # standard deviation for Gaussian window in samples
win_size = 40   # window size in samples
win_gauss = gaussian(win_size, std=g_std, sym=True)  # symmetric Gaussian wind.
SFT = ShortTimeFFT(win_gauss, hop=2, fs=sample_rate, mfft=2000, scale_to='psd')
batch_size = 64
num_samples = 16000

duration = num_samples / sample_rate

In [4]:
import numpy as np
from scipy.signal import stft
import librosa

"""
def mel_spectrogram(audio, sample_rate, n_mels=128, f_min=0, f_max=None):
  if f_max is None:
    f_max = sample_rate / 2
  _, _, spectrogram = stft(audio, nperseg=512, noverlap=256, fs=sample_rate)
  #print("spectrogram: ", spectrogram.shape)
  # mel_spectrogram = mel(spectrogram, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max)
  mel_spectrogram = mel(spectrogram, sr=sample_rate, n_mels=n_mels, fmin=f_min, fmax=f_max)
  return mel_spectrogram
"""

def spectrogram(audio, sample_rate, n_mels=128, f_min=0, f_max=None):
  if f_max is None:
    f_max = sample_rate / 2
  _, _, spectrogram = stft(audio, nperseg=512, noverlap=256, fs=sample_rate)
  return spectrogram


def mel(spectrogram, sr=44100, n_mels=128, fmin=0, fmax=None):
  return librosa.feature.melspectrogram(S=spectrogram, sr=sr, n_mels=n_mels, fmin=fmin, fmax=fmax)

class RBFNetwork:
  def __init__(self, input_dim, num_centers, sigma):
    self.centers = np.random.rand(num_centers, input_dim)  # Initialize centers randomly
    #print("centers: ", self.centers.shape)
    self.sigma = sigma

  def rbf(self, x):
    #print("x: ", x.shape)
    #print("centers: ", self.centers.shape)
    # return np.exp(-np.linalg.norm(x - self.centers, axis=1) ** 2 / (2 * self.sigma ** 2))
    def compute_distances(xi):
            # xi - self.centers creates a new array where each center is subtracted from xi
            # np.linalg.norm(..., axis=1) computes the norm along the axis of the centers
            return np.linalg.norm(xi - self.centers, axis=1)
    norms = np.apply_along_axis(compute_distances, 1, x)
    return np.exp(- norms ** 2 / (2 * self.sigma ** 2))

  def predict(self, X):
    #print("## predict ##")
    #print("X: ", X.shape)
    y = self.rbf(X)
    #print("y: ", y.shape)
    # normalize to 0 - 1 along the batch dimension
    y = (y - np.min(y, axis=0)) / (np.max(y, axis=0) - np.min(y, axis=0))
    return y

def rbf_encode_audio(audio, sample_rate, SFT, n_mels=128, num_rbf=160, sigma=1.0):
  spec = spectrogram(audio, sample_rate, n_mels)
  spec = np.abs(spec)
  # mel_spec = SFT.spectrogram(audio)
  #print("shape mel_spec: ", mel_spec.shape)
  rbf_network = RBFNetwork(spec.shape[0], num_rbf//10, sigma)
  rbf_activations = rbf_network.predict(spec.T)  # transpose to get the batch dimension first
  return rbf_activations, spec

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
from snntorch import utils as sutils
from snntorch.functional import quant

class customSNet(nn.Module):
    def __init__(self, num_steps, beta, threshold=1.0, spike_grad=snn.surrogate.fast_sigmoid(slope=25), num_class=10):
        super().__init__()
        self.num_steps = num_steps
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc1 = nn.Linear(448, 128)
        #self.fc1 = nn.Linear(896, 128)  # use 6720 for real spektogram, 896 for spikes or rbf activity
        #self.fc1 = nn.Linear(6720, 128) # use 6720 for real spektogram, 896 for spikes or rbf activity
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc2 = nn.Linear(128, 64)
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        self.fc3 = nn.Linear(64, num_class)
        self.lif5 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        batch_size_curr = x.shape[0]
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        # Record the final layer
        spk5_rec = []
        mem5_rec = []

        for step in range(self.num_steps):
            #print('x0', x.shape)
            cur1 = self.pool(self.conv1(x))
            #cur1 = self.conv1(x)
            #print('x1', x.shape)
            spk1, mem1 = self.lif1(cur1, mem1)
            #print('x2', spk1.shape, mem1.shape)
            cur2 = self.pool(self.conv2(spk1))
            #cur2 = self.conv2(spk1)
            #print('x3', cur2.shape)
            spk2, mem2 = self.lif2(cur2, mem2)
            #print('x4', spk2.shape, mem2.shape)
            cur3 = self.fc1(spk2.view(batch_size_curr, -1))
            #print('x5', cur3.shape)
            spk3, mem3 = self.lif3(cur3, mem3)
            #print('x6', spk3.shape, mem3.shape)
            cur4 = self.fc2(spk3)
            #print('x7', cur4.shape)
            spk4, mem4 = self.lif4(cur4, mem4)
            #print('x8', spk4.shape, mem4.shape)
            cur5 = self.fc3(spk4)
            #print('x9', cur5.shape)
            spk5, mem5 = self.lif5(cur5, mem5)
            #print('x10', spk5.shape, mem5.shape)

            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        return torch.stack(spk5_rec), torch.stack(mem5_rec)

# poisson spikes from rbf activity
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 16
    num_time_steps = 64
    spikes = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    rbfs = []
    for i in range(batch_size):
        rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT, num_rbf=160)
        rbfs.append(rbf_activations)
        spike_prob_scale = 1.7
        rbf_activations_traversed = rbf_activations.T
        spik_probs = rbf_activations_traversed / np.max(rbf_activations_traversed, axis=1, keepdims=True) * spike_prob_scale
        spike_trains = np.random.poisson(spik_probs[...] * duration, size=rbf_activations_traversed.shape)
        spike_trains = np.clip(spike_trains, 0, 1)
        spikes[i] = torch.from_numpy(spike_trains)

    return spikes
"""
# pseudo spectogram of rbf activity
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 25
    num_time_steps = 64
    rbfs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    for i in range(batch_size):
        rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        rbf_activations_traversed = rbf_activations.T
        rbfs[i] = torch.from_numpy(rbf_activations_traversed)

    return rbfs
"""

"""
# real spectograms
def encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):
    batch_size = data.size(0)
    num_features = 128
    num_time_steps = 64
    specs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)
    for i in range(batch_size):
        #rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)
        spec = mel_spectrogram(data[i],sr)
        specs[i] = torch.from_numpy(spec)

    return specs
"""

'\n# real spectograms\ndef encode_to_spikes(data, sr, tau_m=20.0, R=1.0, V_th=1.0, V_reset=0.0):\n    batch_size = data.size(0)\n    num_features = 128\n    num_time_steps = 64\n    specs = torch.zeros(batch_size, num_features, num_time_steps, device=data.device)\n    for i in range(batch_size):\n        #rbf_activations, mel_spec = rbf_encode_audio(data[i], sr, SFT=SFT)\n        spec = mel_spectrogram(data[i],sr)\n        specs[i] = torch.from_numpy(spec)\n\n    return specs\n'

In [10]:
# old versio with bug fixed and dimensions set to 16 to fit with tuning curves


from tqdm import tqdm
from snntorch import functional as SF


num_classes = 35
num_steps = 20
model = customSNet(num_steps = num_steps, beta = 0.9, threshold=1.0, spike_grad=snn.surrogate.fast_sigmoid(slope=25), num_class=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
sr = 16000  # Number of time steps to simulate
batch_size = 64
train_loss_hist = []
train_accu_hist = []
train_accu_hist_temp = []

n_epochs = 10

iterCount = 0
for epoch in range(n_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        spikes = encode_to_spikes(inputs, sr)
        spike_input = spikes.unsqueeze(1)

        model.train()
        spk_rec, mem_rec = model(spike_input)
        labels = labels.long()
        loss_val = torch.zeros((1), dtype=torch.float)
        for step in range(num_steps):
            loss_val += loss_fn(mem_rec[step], labels)

        # Gradient calculation + weight update
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        avg_loss = loss_val.item()/len(train_loader)
        train_loss_hist.append(loss_val.item())
        acc = SF.accuracy_rate(spk_rec, labels)
        acc2 = SF.accuracy_temporal(spk_rec, labels)
        train_accu_hist.append(acc)
        train_accu_hist_temp.append(acc2)
        iterCount +=1
    print(f' Epoch: {epoch} | Train Loss: {train_loss_hist[-1]:.3f} | Avg Loss: {avg_loss:.3f} | Accuracy: {train_accu_hist[-1]:.3f} | Accuracy: {train_accu_hist_temp[-1]:.3f} | Iteration: {iterCount}')

KeyboardInterrupt: 

In [11]:
#updated version with batch norm and decaying learning rate

import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
from snntorch import surrogate
from torch.optim import lr_scheduler

class customSNet(nn.Module):
    def __init__(self, num_steps, beta, threshold=1.0, spike_grad=snn.surrogate.fast_sigmoid(slope=25), num_class=35):
        super().__init__()
        self.num_steps = num_steps
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)  # Batch Norm after conv1
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)  # Batch Norm after conv2
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        self.fc1 = nn.Linear(448, 128)  # 448 for rbf activations or spikes
        self.bn3 = nn.BatchNorm1d(128)  # Batch Norm after fc1
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        self.fc2 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)  # Batch Norm after fc2
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        self.fc3 = nn.Linear(64, num_class)
        self.lif5 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)

    def forward(self, x):
        batch_size_curr = x.shape[0]
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        spk5_rec = []
        mem5_rec = []

        for step in range(self.num_steps):
            cur1 = self.pool(self.bn1(self.conv1(x)))
            spk1, mem1 = self.lif1(cur1, mem1)
            
            cur2 = self.pool(self.bn2(self.conv2(spk1)))
            spk2, mem2 = self.lif2(cur2, mem2)
            
            cur3 = self.bn3(self.fc1(spk2.view(batch_size_curr, -1)))
            spk3, mem3 = self.lif3(cur3, mem3)
            
            cur4 = self.bn4(self.fc2(spk3))
            spk4, mem4 = self.lif4(cur4, mem4)
            
            cur5 = self.fc3(spk4)
            spk5, mem5 = self.lif5(cur5, mem5)

            spk5_rec.append(spk5)
            mem5_rec.append(mem5)

        return torch.stack(spk5_rec), torch.stack(mem5_rec)

# Hyperparameters
num_steps = 20
beta = 0.9
num_class = 35
learning_rate = 0.001
num_epochs = 10  # Adjust based on your needs

# Initialize model, loss, optimizer, and scheduler
model = customSNet(num_steps=num_steps, beta=beta, num_class=num_class)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)


# Example Training Loop with Validation
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data, labels in train_loader:
        spikes = encode_to_spikes(data, sr)
        spike_input = spikes.unsqueeze(1)

        optimizer.zero_grad()
        output, _ = model(spike_input)
        # Aggregate outputs over time steps, e.g., mean
        output_mean = output.mean(dim=0)
        labels = labels.long()
        loss = loss_fn(output_mean, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * spike_input.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Validation phase
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for data, labels in test_loader:
            spikes = encode_to_spikes(data, sr)
            spike_input = spikes.unsqueeze(1)
            output, _ = model(spike_input)
            output_mean = output.mean(dim=0)
            preds = output_mean.argmax(dim=1)
            labels = labels.long()
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    
    val_accuracy = val_correct / val_total
    
    # Step the scheduler
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Learning Rate: {scheduler.optimizer.param_groups[0]['lr']:.6f}")

KeyboardInterrupt: 