# Genre Classification with spiking neural networks

## Install SpikingJelly

In [1]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.6-py3-none-any.whl (177 kB)
[?25l[K     |█▉                              | 10 kB 33.0 MB/s eta 0:00:01[K     |███▊                            | 20 kB 24.7 MB/s eta 0:00:01[K     |█████▌                          | 30 kB 17.8 MB/s eta 0:00:01[K     |███████▍                        | 40 kB 15.8 MB/s eta 0:00:01[K     |█████████▎                      | 51 kB 7.4 MB/s eta 0:00:01[K     |███████████                     | 61 kB 7.5 MB/s eta 0:00:01[K     |█████████████                   | 71 kB 7.3 MB/s eta 0:00:01[K     |██████████████▉                 | 81 kB 8.2 MB/s eta 0:00:01[K     |████████████████▋               | 92 kB 9.1 MB/s eta 0:00:01[K     |██████████████████▌             | 102 kB 7.2 MB/s eta 0:00:01[K     |████████████████████▎           | 112 kB 7.2 MB/s eta 0:00:01[K     |██████████████████████▏         | 122 kB 7.2 MB/s eta 0:00:01[K     |████████████████████████        | 133 kB 7.2 MB/s 

## Get data from Google Drive
Click on the link, and follow the steps to get authenticated

In [2]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

downloaded = drive.CreateFile({'id':"1C6rKHGr-0A8mbzdTbS8Yn_yhtqYMuFl8"})
downloaded.GetContentFile('paper_spectrograms.zip')
!unzip paper_spectrograms.zip
!rm -rf __MACOSX
!rm -f paper_spectrograms.zip

downloaded = drive.CreateFile({'id':"1rsSEkMksELO073o0x5hl_e7pYjlPdyEv"})
downloaded.GetContentFile('paper_labels.npy')

Archive:  paper_spectrograms.zip
  inflating: genres_numpy.npy        


In [3]:
# import all necessary modules
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import numpy as np
import spikingjelly.clock_driven.neuron as neurons
from spikingjelly.clock_driven import functional
from spikingjelly.clock_driven.layer import SeqToANNContainer
from spikingjelly.clock_driven.surrogate import Sigmoid
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import gc

In [4]:
# set variables such as device (gpu or cpu), batch_size and timesteps
batch_size = 8
timesteps = 4

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
X_melspec = np.load("./genres_numpy.npy")
X_melspec = np.expand_dims(np.moveaxis(X_melspec, 2, 1), axis=-1)
X_melspec = np.moveaxis(X_melspec, -1, 1)
X_melspec = torch.log(torch.Tensor(X_melspec) + 1)

y = np.load("./paper_labels.npy").squeeze()
y = torch.Tensor(y).type(torch.long)

X_melspec, y = shuffle(X_melspec, y, random_state=0)

print(X_melspec.shape)    #expected: (999, 1, 647, 128)
print(y.shape)    #expected (999,)

torch.Size([999, 1, 647, 128])
torch.Size([999])


## The Model
![picture](https://drive.google.com/uc?export=view&id=1D6074PAg44CdAyxv1qkXdorhRcQNRxwr)
![picture](https://drive.google.com/uc?export=view&id=1OhdyfepB1TjU06iSpc-8ikkgHGXjoEqo)

In [6]:
alpha = 2
v_thresh = 0.75
tau = 1.5

class ExternalNet(nn.Module):
    def __init__(self, alpha, V_th, tau, neuron_type):
        super().__init__()
        self.num_dense_blocks = 3
        self.num_conv_filters = 32
        self.num_classes = 10

        if neuron_type == 'lif':
            self.lifs = nn.ModuleList([neurons.MultiStepLIFNode(
                surrogate_function = Sigmoid(alpha=alpha, spiking=True),
                tau = tau,
                v_threshold=V_th,
                detach_reset=True
            ) for i in range(self.num_dense_blocks + 2)])
        else:
            self.lifs = nn.ModuleList([neurons.MultiStepIFNode(
                surrogate_function = Sigmoid(alpha=alpha, spiking=True),
                v_threshold=V_th,
                detach_reset=True
            ) for i in range(self.num_dense_blocks + 2)])

        self.initial_layers = SeqToANNContainer(
            nn.Conv2d(1, self.num_conv_filters, 3, padding="same"),
            nn.BatchNorm2d(self.num_conv_filters),
            nn.MaxPool2d((4, 1)),
        )

        self.inception_blocks = nn.ModuleList()
        for i in range(self.num_dense_blocks):
            self.inception_blocks.append(self.get_inception_block(i))
        
        self.final_layers = SeqToANNContainer(
            nn.BatchNorm2d((4*self.num_dense_blocks+1)*self.num_conv_filters),
            nn.Conv2d((4*self.num_dense_blocks+1)*self.num_conv_filters, self.num_conv_filters, 1),
            nn.AvgPool2d(self.num_conv_filters),
            nn.BatchNorm2d(self.num_conv_filters),
        )
        self.avgpool = SeqToANNContainer(nn.AdaptiveAvgPool2d((1, 1)))
        self.final_linear = nn.Linear(self.num_conv_filters, self.num_classes)

    def base_conv_block(self, kernel_size, block_num):
        num_channels = self.num_conv_filters * (4*block_num + 1)
        return SeqToANNContainer(
            nn.BatchNorm2d(num_channels),
            nn.ReLU(),
            nn.Conv2d(num_channels, self.num_conv_filters, kernel_size, padding="same")
        )
    
    def base_conv_block_32(self, kernel_size):
        return SeqToANNContainer(
            nn.BatchNorm2d(self.num_conv_filters),
            nn.ReLU(),
            nn.Conv2d(self.num_conv_filters, self.num_conv_filters, kernel_size, padding="same")
        )
    
    def get_inception_block(self, block_num):
        return nn.ModuleList(
            modules=[
                self.base_conv_block(1, block_num),
                nn.Sequential(
                    self.base_conv_block(1, block_num),
                    self.base_conv_block_32(3),
                ),
                nn.Sequential(
                    self.base_conv_block(1, block_num),
                    self.base_conv_block_32(5),
                ),
                nn.Sequential(
                    SeqToANNContainer(nn.MaxPool2d(3, stride=1, padding=1)),
                    self.base_conv_block(1, block_num)
                )
            ]
        )

    def _forward_impl(self, x):
        x = self.initial_layers(x)
        x = self.lifs[0](x)
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        for (i, block) in enumerate(self.inception_blocks):
            out = torch.cat((block[0](x), block[1](x), block[2](x), block[3](x)), dim=2)
            out = self.lifs[i+1](out)
            x = torch.cat([x, out], dim=2)
            del out
        x = self.final_layers(x)
        x = self.avgpool(x)
        x = self.lifs[-1](x)
        x = torch.flatten(x, start_dim=2).mean(dim=0)
        x = self.final_linear(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)

## Training

In [7]:
def num_correct(model_out, labels):
    y_pred = torch.argmax(model_out, 1).to(device)
    return sum(y_pred == labels).detach().item()

def train_one_epoch(epoch, net, train_loader,
                    loss_fn, optimizer, acc_func, scaler,
                    train_loss_hist, train_acc_hist):
    print(f"======== Epoch: {epoch+1} ========")
    net.train()
    for batch_, (x, y_true) in enumerate(train_loader):
        batch = batch_ + 1
        # stack the input over t timesteps
        x = torch.stack([x for _ in range(timesteps)])
        x, y_true = x.to(device), y_true.to(device)
        with amp.autocast():
            y_pred = net(x)
            train_loss = loss_fn(y_pred, y_true)
        train_acc = acc_func(y_pred, y_true)/batch_size
        del x, y_true, y_pred

        # update weights
        optimizer.zero_grad()
        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        functional.reset_net(net)
        train_loss = train_loss.detach().item()

        if batch % 10 == 0:
            print(f"Batch {batch}: Loss = {train_loss}, accuracy = {train_acc}")
            train_loss_hist.append(train_loss)
            train_acc_hist.append(train_acc)
    gc.collect()
    torch.cuda.empty_cache()

def validation(epoch, net, val_loader, loss_fn, acc_func, val_loss_hist, val_acc_hist):
    print("Validation: ", end="")
    with torch.no_grad():
        val_loss, val_acc, model_preds = 0, 0, torch.Tensor([])
        for x, y_true in val_loader:
            x, y_true = x.to(device, non_blocking=True), y_true.to(device, non_blocking=True)
            x = torch.stack([x for _ in range(timesteps)])
            with amp.autocast():
                y_pred = net(x)
                val_loss_temp = loss_fn(y_pred, y_true)
            
            val_loss += val_loss_temp.detach().item()
            val_acc += acc_func(y_pred, y_true)

            functional.reset_net(net)
            y_pred = y_pred.detach().cpu()
            model_preds = torch.cat((model_preds, torch.argmax(y_pred, dim=1)))

        #last batch containis 4 samples, so has 4/16 = 0.25 weightage
        val_loss /= len(val_loader) - 1 + 4/16
        val_acc /= 119 # 100 samples in validation
        val_loss_hist.append(val_loss)
        val_acc_hist.append(val_acc)
        print(f"Loss = {val_loss}, accuracy = {val_acc}")
    gc.collect()
    torch.cuda.empty_cache()
    return val_loss, val_acc, model_preds

In [8]:
epochs = 100
alpha, V_th, tau, neuron_type = 2.0, 0.75, 1.25, 'lif'

X_train, y_train = X_melspec[:880], y[:880]
train = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=False)

X_test, y_test = X_melspec[880:], y[880:]
test = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, drop_last=False)

# define net, loss function, optimiizer, and scheduler
net = ExternalNet(alpha=alpha, V_th=V_th, tau=tau, neuron_type=neuron_type).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
#scaler for memory management
scaler = amp.GradScaler()

train_loss_hist, train_acc_hist, val_loss_hist, val_acc_hist = [], [], [], []
max_val_acc, min_val_loss, best_preds = 0, 0, None

for epoch in range(epochs):
    train_one_epoch(epoch, net, train_loader,
                    loss_fn, optimizer, num_correct, scaler,
                    train_loss_hist, train_acc_hist)
    val_loss, val_acc, preds = validation(epoch, net, test_loader, loss_fn, num_correct, val_loss_hist, val_acc_hist)
    if val_acc >= max_val_acc:
        max_val_acc = val_acc
        min_val_loss = val_loss
        best_preds = preds
    scheduler.step(val_loss)

print("Best accuracy achieved:", max_val_acc)
print("Corresponding loss: ", min_val_loss)



  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Batch 10: Loss = 2.16845703125, accuracy = 0.25
Batch 20: Loss = 2.1492919921875, accuracy = 0.25
Batch 30: Loss = 2.0369873046875, accuracy = 0.25
Batch 40: Loss = 1.6343994140625, accuracy = 0.375
Batch 50: Loss = 1.81158447265625, accuracy = 0.25
Batch 60: Loss = 1.71160888671875, accuracy = 0.25
Batch 70: Loss = 1.59613037109375, accuracy = 0.5
Batch 80: Loss = 1.87249755859375, accuracy = 0.25
Batch 90: Loss = 1.6602783203125, accuracy = 0.375
Batch 100: Loss = 1.63128662109375, accuracy = 0.375
Batch 110: Loss = 1.80682373046875, accuracy = 0.5
Validation: Loss = 1.9081780366730272, accuracy = 0.3865546218487395
Batch 10: Loss = 1.46026611328125, accuracy = 0.5
Batch 20: Loss = 1.6534423828125, accuracy = 0.5
Batch 30: Loss = 1.90679931640625, accuracy = 0.5
Batch 40: Loss = 1.78033447265625, accuracy = 0.625
Batch 50: Loss = 1.3926239013671875, accuracy = 0.375
Batch 60: Loss = 2.3389892578125, accuracy = 0.125
Batch 70: Loss = 1.484588623046875, accuracy = 0.625
Batch 80: Loss 