# Sound localisation with surrogate gradient descent

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import torch
import torch.nn as nn

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")     
else:
    device = torch.device("cpu")

## Sound localisation stimuli

The following function creates a set of stimuli that can be used for training or testing. It returns two arrays ``ipd`` and ``spikes``. ``ipd`` is an array of length ``num_samples`` that gives the true IPD, and ``spikes`` is an array of 0 (no spike) and 1 (spike) of shape ``(num_samples, duration_steps, 2*anf_per_ear)`` where ``anf_per_ear`` is how many neurons there are receiving the same signal (but generating independent Poisson noise) there are for each ear, and ``duration_steps`` is the number of time steps there are in the stimulus.

In [None]:
second = 1
ms = 1e-3
Hz = 1

# Stimulus and simulation parameters
dt = 1*ms          # large time step to make simulations run faster for tutorial
anf_per_ear=100    # repeats of each ear with independent noise
envelope_power=4   # higher values make sharper envelopes, easier
rate_max=1000*Hz   # maximum Poisson firing rate
f=20*Hz            # stimulus frequency
duration=.1*second # stimulus duration

# Generate an input signal (spike array) from array of true IPDs
def input_signal(ipd):
    num_samples = len(ipd)
    duration_steps = int(np.round(duration/dt))
    T = np.arange(duration_steps)*dt
    phi = 2*np.pi*f*T
    theta = np.zeros((num_samples, duration_steps, 2*anf_per_ear))
    phase_delays = np.linspace(0, np.pi/2, anf_per_ear)
    
    theta[:, :, :anf_per_ear] = phi[np.newaxis, :, np.newaxis]+phase_delays[np.newaxis, np.newaxis, :]
    theta[:, :, anf_per_ear:] = phi[np.newaxis, :, np.newaxis]+ipd[:, np.newaxis, np.newaxis]+phase_delays[np.newaxis, np.newaxis, ::-1]
    spikes = np.random.rand(num_samples, duration_steps, 2*anf_per_ear)<rate_max*dt*(0.5*(1+np.sin(theta)))**envelope_power
    return spikes

# Generate some true IPDs from U(-pi/2, pi/2) and corresponding spike arrays
def random_ipd_input_signal(num_samples):
    ipd = np.random.rand(num_samples)*np.pi-np.pi/2 # uniformly random in (-pi/2, pi/2)
    spikes = input_signal(ipd)
    return ipd, spikes

# Plot a few just to show how it looks
ipd, spikes = random_ipd_input_signal(8)
plt.figure(figsize=(10, 4), dpi=100)
for i in range(8):
    plt.subplot(2, 4, i+1)
    plt.imshow(spikes[i, :, :].T, aspect='auto', interpolation='nearest', cmap=plt.cm.gray_r)
    plt.title(f'True IPD = {int(ipd[i]*180/np.pi)} deg')
    if i>=4:
        plt.xlabel('Time (steps)')
    if i%4==0:
        plt.ylabel('Input neuron index')
plt.tight_layout()

Now with this, the aim is to take these input spikes and infer the IPD. We can do this either by discretising and using a classification approach, or with a regression approach. For the moment, let's try it with a classification approach.

## Classification approach

TODO:

* Conversion function
* Parameter of number of classes

## Surrogate gradient descent

First, this is the key part of surrogate gradient descent, a function where we override the computation of the gradient to replace it with a smoothed gradient. You can see that in the forward pass (method ``forward``) it returns the Heaviside function of the input (takes value 1 if the input is ``>0``) or value 0 otherwise. In the backwards pass, it returns the gradient of a sigmoid function.

In [None]:
class SurrGradSpike(torch.autograd.Function):
    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad = grad_output/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

spike_fn  = SurrGradSpike.apply

## Membrane only (no spiking neurons)

In [None]:
# Parameters for training. These aren't optimal, but instead designed
# to give a reasonable result in a small amount of time for the tutorial!
batch_size = 128
n_training_batches = 128
n_testing_batches = 32
num_classes = 180//15 ## classes at 15 degree increments
num_samples = batch_size*n_training_batches

# Generate the training data
# spikes has shape (num_samples, duration_steps, 2*anf_per_ear), ipds has shape (num_samples,)
ipds, spikes = random_ipd_input_signal(num_samples)
# Convert this
ipds = torch.tensor((ipds+np.pi/2)*num_classes/(np.pi), device=device, dtype=torch.long)
spikes = torch.tensor(spikes, device=device, dtype=dtype)

_, nb_steps, input_size = spikes.shape
output_size = num_classes

# Filter parameters
time_step = 1e-3
tau = 20e-3
alpha = np.exp(-time_step / tau)

# Weights and uniform weight initialisation
W = nn.Parameter(torch.empty((input_size, output_size), device=device, dtype=dtype, requires_grad=True))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(W, -bound, bound)


def data_generator(ipds, spikes):
    perm = torch.randperm(spikes.shape[0])
    spikes = spikes[perm, :, :]
    ipds = ipds[perm]
    n, _, _ = spikes.shape
    n_batch = n//batch_size
    for i in range(n_batch):
        x_local = spikes[i*batch_size:(i+1)*batch_size, :, :]
        y_local = ipds[i*batch_size:(i+1)*batch_size]
        yield x_local, y_local

        
def snn(input_spikes):
    # Input has shape (batch_size, nb_steps, input_size)
    mem = torch.zeros((batch_size, output_size), device=device, dtype=dtype)

    # mem_rec will store the membrane in each time step
    mem_rec = [mem]

    # Batch matrix multiplication all time steps
    # Equivalent to matrix multiply input[b, :, :] x W for all b, but faster
    h = torch.einsum("abc,cd->abd", (input_spikes, W))
    # Update membrane and spikes one time step at a time
    for t in range(nb_steps - 1):
        new_mem = (alpha * mem + (1. - alpha) * (h[:, t, :]))
        mem = new_mem

        mem_rec.append(mem)  # Save the new value

    mem_rec = torch.stack(mem_rec, dim=1)  # (batch_size, nb_steps, output_size)
    
    return mem_rec

# Training parameters
nb_epochs = 10
lr = 0.01

# Optimiser and loss function
optimizer = torch.optim.Adam([W], lr=lr)
log_softmax_fn = nn.LogSoftmax(dim=1)
loss_fn = nn.NLLLoss()
# regression: use nn.MSE, 

print(f"Want loss for epoch 1 to be about {-np.log(1/num_classes):.2f}, multiply m by constant to get this")

loss_hist = []
for e in range(nb_epochs):
    local_loss = []
    for x_local, y_local in data_generator(ipds, spikes):
        # Run network
        output = snn(x_local)

        # Compute cross entropy loss
        m = torch.sum(output, 1)*0.01  # Sum time dimension
        
        loss_ce = loss_fn(log_softmax_fn(m), y_local)

        # Compute regularisation loss
        loss_reg = 0
        # loss_reg = loss_reg_fn([spk_rec1, spk_rec2])
        loss = loss_ce + loss_reg
        local_loss.append(loss.item())

        # Update gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_hist.append(np.mean(local_loss))
    print("Epoch %i: loss=%.5f"%(e+1, np.mean(local_loss)))

In [None]:
plt.plot(loss_hist)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.tight_layout()

In [None]:
# Train Accuracy
accs_train = []
for x_local, y_local in data_generator(ipds, spikes):
    output = snn(x_local)
    m = torch.sum(output, 1)  # Sum time dimension
    _, am = torch.max(m, 1)  # argmax over output units
    tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
    accs_train.append(tmp)
print(f"Train Accuracy: {100*np.mean(accs_train):.1f}%")

# Test Accuracy
ipds_test, spikes_test = random_ipd_input_signal(batch_size*n_testing_batches, **params)
ipds_test = torch.tensor((ipds_test+np.pi/2)*num_classes/(np.pi), device=device, dtype=torch.long)
spikes_test = torch.tensor(spikes_test, device=device, dtype=dtype)
accs_test = []
ipd_true = []
ipd_est = []
confusion = np.zeros((num_classes, num_classes))
for x_local, y_local in data_generator(ipds_test, spikes_test):
    output = snn(x_local)
    m = torch.sum(output, 1)  # Sum time dimension
    _, am = torch.max(m, 1)  # argmax over output units
    tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
    for i, j in zip(y_local.detach().cpu().numpy(), am.detach().cpu().numpy()):
        confusion[j, i] += 1
    ipd_true.append(y_local.detach().cpu().numpy()/num_classes*np.pi-np.pi/2)
    ipd_est.append(am.detach().cpu().numpy()/num_classes*np.pi-np.pi/2)
    accs_test.append(tmp)
ipd_true = np.hstack(ipd_true)
ipd_est = np.hstack(ipd_est)
abs_errors_deg = abs(ipd_true-ipd_est)*180/np.pi
print(f"Test Accuracy: {100*np.mean(accs_test):.1f}%")
print(f"Chance level: {100*1/num_classes:.1f}%")
print(f"Absolute error: {np.mean(abs_errors_deg):.1f} deg")

plt.figure(figsize=(10, 4), dpi=100)
plt.subplot(121)
plt.hist(ipd_true*180/np.pi, bins=num_classes, label='True')
plt.hist(ipd_est*180/np.pi, bins=num_classes, label='Estimated')
plt.xlabel("IPD")
plt.yticks([])
plt.legend(loc='best')
plt.subplot(122)
confusion /= np.sum(confusion, axis=0)[np.newaxis, :]
plt.imshow(confusion, interpolation='nearest', aspect='auto', origin='lower', extent=(-90, 90, -90, 90))
plt.xlabel('True IPD')
plt.ylabel('Estimated IPD')
plt.title('Confusion matrix')
plt.tight_layout()

In [None]:
plt.imshow(W.detach().cpu().numpy(), interpolation='nearest', aspect='auto', origin='lower')