In [None]:
import snntorch as snn
import pandas as pd
import spectrograms
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch
import numpy as np
from sklearn.preprocessing import LabelEncoder

In [None]:
mfcc_train = ...
mfcc_test = ...

### Time based input, Leaky MLP-SNN, 1 Hidden Layer (FNN)

In [None]:
# For ease set training and testing set to be specific variable
# Ease of changing spectrogram type
train, test = mfcc_train, mfcc_test

# Set input, step, and hidden node size based on a single training sample (assuming theyre uniform and normalised)
sample, _ = train[0]                # Sample in form Time x Frequency Bins

num_inputs = sample.shape[1]        # Depends on the input spectrogram (number of frequency bins; y-axis)
num_steps = sample.shape[0]         # Number of samples per spectrogram (or spectrogram sample rate * time of audio; x-axis)
num_hidden = num_inputs // 2        # Ideally half the number of inputs (originally 1000)
num_hidden_layers = 1               # 2 hidden layers with num_hidden nodes each
num_ouputs = 2                      # Either Music or Non-Music

In [None]:
# Set decay rate and threshold:
# Arbitrary threshold, decay rate set close to 1 for reasonable accuracy -- given delta_t << tau in: beta = (1 - delta_t/tau)
beta = 0.95
threshold = 0.75

class Net(nn.Module):
    # Initialise network with 2 forward connections (linear connections) and 2 leaky integrated fire layers (hidden and output)
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_ouputs)
        self.lif2 = snn.Leaky(beta=beta)

    # Define a forward pass assuming x is normalised data (i.e. all values in [0,1])
    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spk1_rec = []
        mem1_rec = []
        spk2_rec = []
        mem2_rec = []

        # Step through the time sets within the data -- get current from data at a given time, forward it to lif
        # Use the lif spikes to generate a current from spikes, feed this through a second (output) lif
        for step in range(num_steps):
            cur1 = self.fc1(x[:,step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk1_rec, dim=0), torch.stack(mem1_rec, dim=0), torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [None]:
# Arbitrarily set num_epochs depending on converging rate
num_epochs = 20

# Initialise counter, and loss histories
counter = 0
loss_hist = []
test_loss_hist = []

# Arbitrarily set batch_size -- ideally based on memory utilisation and speed
batch_size = 20

In [None]:
# Initialise network. Set to cuda where available
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Net().to(device)

In [None]:
# Helper output functions
def print_batch_accuracy(data, targets, train=False):
    _, _, output, _ = net(data)
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(epoch, iter_counter, data, targets, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [None]:
# Make train and test dataloaders based on batch_size
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)

In [None]:
# Initialise loss (CE Loss) and Adam optimiser -- learning rate is a hyperparameter
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

# Train for num_epochs
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Loop through batches -- separate out data and targets in each batch [batch_size x times x frequencies]
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # Forward Pass in train mode
        net.train()
        spk1_rec, mem1_rec, spk2_rec, mem2_rec = net(data)

        # Sum loss over time (batch membrane2 records for each time step against batch targets)
        loss_val = torch.zeros((1), dtype=torch.float32, device=device)
        for step in range(num_steps):
            loss_val += loss(mem2_rec[step], targets)

        # Weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Loss history storage
        # loss_hist.append(loss_val.item())

        # Test set (for loss history in current form)
        # Preformance boost if only evaluated when counter / 50 
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Forward Pass in test ode
            test_spk1, test_mem1, test_spk2, test_mem2 = net(test_data)

            # Sum loss over time for test set
            test_loss = torch.zeros((1), dtype=torch.float32, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem2[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer(epoch, iter_counter, data, targets, test_data, test_targets)
            counter += 1
            iter_counter +=1

In [None]:
# Get all train and test data
data, targets = train.tensors[0], train.tensors[1]
test_data, test_targets = test.tensors[0], test.tensors[1]

# Find overall train accuracy
# Check max spikes in output neurons, compare against targets
_, _, output, _ = net(data.to(device))
_, idx = output.sum(dim=0).max(1)
acc = np.mean((targets.to(device) == idx).detach().cpu().numpy())
print(acc)

# Find overall test accuracy
_, _, output, _ = net(test_data.to(device))
_, idx = output.sum(dim=0).max(1)
acc = np.mean((test_targets.to(device) == idx).detach().cpu().numpy())
print(acc)