### Settings and functions definition

In [1]:
import os
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

import random
import json

from sklearn.metrics import confusion_matrix

In [2]:
use_gpu = True
multiple_gpus = False
use_seed = True

In [3]:
if use_gpu:
    if multiple_gpus:
        gpu_sel = 1
        gpu_av = [torch.cuda.is_available() for ii in range(torch.cuda.device_count())]
        if True in gpu_av:
            if gpu_av[gpu_sel]:
                device = torch.device("cuda:"+str(gpu_sel))
            else:
                device = torch.device("cuda:"+str(gpu_av.index(True)))
            torch.cuda.set_per_process_memory_fraction(0.3, device=device)
    else:
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
            torch.cuda.set_per_process_memory_fraction(0.3, device=device) # decrese or drop memory fraction if more is available (the smaller the better)
else:
    device = torch.device("cpu")

In [4]:
if use_seed:
    seed = 42
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [5]:
dtype = torch.float

In [6]:
letters = ['Space', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
           'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

In [7]:
### NOTE: this function simulates classification immediately after reading.
###       To do so, a random sample from data_dict is taken.
###       Once the sample is loaded (or acquired, in the real case), a tensor must be crated with it: torch.tensor(reading[None, :, :],dtype=dtype)
###       Such (3D) tensor can then be fed into the network, which will provide its prediction.

def demo(params, file_name, taxels=None, letter_written=letters):
    
    max_time = int(54*25) #msc
    time_bin_size = int(params['time_bin_size']) # ms
    global time
    time = range(0,max_time,time_bin_size)
    ## Increase max_time to make sure no timestep is cut due to fractional amount of steps
    global time_step
    time_step = time_bin_size*0.001
    global data_steps
    data_steps = len(time)
    
    infile = open(file_name, 'rb')
    data_dict = pickle.load(infile)
    infile.close()
    # Extract data
    data = []
    labels = []
    bins = 1000  # [ms] 
    nchan = len(data_dict[1]['events']) # number of channels/sensors
    global nb_channels
    nb_channels = nchan
    nb_repetitions = 200
    idx = random.randrange(0,len(data_dict)-1)
    dat = data_dict[idx]['events'][:]
    events_array = np.zeros([nchan,round((max_time/time_bin_size)+0.5),2])
    for taxel in range(len(dat)):
        for event_type in range(len(dat[taxel])):
            if dat[taxel][event_type]:
                indx = bins*(np.array(dat[taxel][event_type]))
                indx = np.array((indx/time_bin_size).round(), dtype=int)
                events_array[taxel,indx,event_type] = 1
    if taxels != None:
        events_array = np.reshape(np.transpose(events_array, (1,0,2))[:,taxels,:],(events_array.shape[1],-1))
        selected_chans = 2*len(taxels)
    else:
        events_array = np.reshape(np.transpose(events_array, (1,0,2)),(events_array.shape[1],-1))
        selected_chans = 2*nchan
    reading = events_array
    label = letter_written[letter_written.index(data_dict[idx]['letter'])]
        
    return torch.tensor(reading[None, :, :],dtype=dtype), label

In [8]:
def run_snn(inputs, layers):

    bs = inputs.shape[0]
    h1_from_input = torch.einsum(
        "abc,cd->abd", (inputs.tile((nb_input_copies,)), layers[0]))
    syn = torch.zeros((bs, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((bs, nb_hidden), device=device, dtype=dtype)

    out = torch.zeros((bs, nb_hidden), device=device, dtype=dtype)

    # Here we define two lists which we use to record the membrane potentials and output spikes
    mem_rec = []
    spk_rec = []

    # Compute hidden (recurrent) layer activity
    for t in range(nb_steps):
        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, layers[2]))
        mthr = mem-1.0
        out = spike_fn(mthr)
        rst = out.detach()  # We do not want to backprop through the reset

        new_syn = alpha*syn + h1
        new_mem = (beta*mem + syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out)

        mem = new_mem
        syn = new_syn

    # Now we merge the recorded membrane potentials into a single tensor
    mem_rec = torch.stack(mem_rec, dim=1)
    spk_rec = torch.stack(spk_rec, dim=1)

    # Readout layer
    h2 = torch.einsum("abc,cd->abd", (spk_rec, layers[1]))
    flt = torch.zeros((bs, nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((bs, nb_outputs), device=device, dtype=dtype)
    # out is initialized as zeros, so it is fine to start with this
    s_out_rec = [out]
    out_rec = [out]
    for t in range(nb_steps):
        mthr_out = out-1.0
        s_out = spike_fn(mthr_out)
        rst_out = s_out.detach()

        new_flt = alpha*flt + h2[:, t]
        new_out = (beta*out + flt)*(1.0-rst_out)

        flt = new_flt
        out = new_out

        out_rec.append(out)
        s_out_rec.append(s_out)

    out_rec = torch.stack(out_rec, dim=1)
    s_out_rec = torch.stack(s_out_rec, dim=1)
    other_recs = [mem_rec, spk_rec, s_out_rec]
    layers_update = layers

    return out_rec, other_recs, layers_update


In [9]:
def load_layers(file, map_location, requires_grad=True, variable=False):
    
    if variable:
        
        lays = file
        
        for ii in lays:
            ii.requires_grad = requires_grad
    
    else:
        
        lays = torch.load(file, map_location=map_location)
    
        for ii in lays:
            ii.requires_grad = requires_grad
        
    return lays

In [10]:
def build_and_predict(params, x):
    
    x = x.to(device)
    
    global nb_input_copies
    nb_input_copies = params['nb_input_copies']  # Num of spiking neurons used to encode each channel
    global nb_inputs
    nb_inputs  = nb_channels*nb_input_copies
    global nb_hidden
    nb_hidden  = 450
    global nb_outputs
    nb_outputs = len(np.unique(letters))+1
    global nb_steps
    nb_steps = data_steps

    tau_mem = params['tau_mem'] # ms
    tau_syn = tau_mem/params['tau_ratio']
    
    global alpha
    alpha   = float(np.exp(-time_step/tau_syn))
    global beta
    beta    = float(np.exp(-time_step/tau_mem))

    fwd_weight_scale = params['fwd_weight_scale']
    rec_weight_scale = params['weight_scale_factor']*fwd_weight_scale

    # Spiking network
    layers = load_layers('./trained/layers_th1.pt', map_location=device)
    
    # Make predictions
    output, others, _ = run_snn(x,layers)
    
    ### Classification through spikes
    m = torch.sum(others[-1],1) # sum over time
    _, am = torch.max(m,1) # argmax over output units
    #################################
    
    return letters[am.detach().cpu().numpy()[0]], output, others

### Specify data and network parameters

In [11]:
threshold = 1 

file_dir_data = '../../data/reading/'
file_type = 'data'
file_thr = str(threshold)
file_ref = 'Null'
file_name = file_dir_data + file_type + '_th' + file_thr + '_rp' + file_ref

file_dir_params = './net_params/'
param_filename = 'parameters_th1'
file_name_parameters = file_dir_params + param_filename + '.txt'
params = {}
with open(file_name_parameters) as file:
    for line in file:
        (key, value) = line.split()
        if key == 'time_bin_size' or key == 'nb_input_copies':
            params[key] = int(value)
        else:
            params[key] = np.double(value)

In [12]:
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements 
    the surrogate gradient. By subclassing torch.autograd.Function, 
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid 
    as this was done in Zenke & Ganguli (2018).
    """

    scale = params['scale']

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which 
        we need to later backpropagate our error signals. To achieve this we use the 
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the 
        surrogate gradient of the loss with respect to the input. 
        Here we use the normalized negative part of a fast sigmoid 
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

spike_fn  = SurrGradSpike.apply

### Use the pre-trained network

In [13]:
layers = load_layers('../../data/trained/layers_th1.pt', map_location=device)

print("Input weights matrix: {}x{}".format(len(layers[0]),len(layers[0][0])))
print("Hidden weights matrix: {}x{}".format(len(layers[2]),len(layers[2][0])))
print("Output weights matrix: {}x{}".format(len(layers[1]),len(layers[1][0])))


Input weights matrix: 96x450
Hidden weights matrix: 450x450
Output weights matrix: 450x28


#### Single read and prediction

In [14]:
### Take an input signal with its label
x,y = demo(params,file_name)
print("Reading:",y)

### And make prediction
pred, output, others = build_and_predict(params, x)
print("Prediction:",pred)

Reading: C
Prediction: C


#### Multiple read and predictions

In [15]:
trials = 1000
count = 0
for ii in range(trials):
    x,y = demo(params,file_name)
    pred, _, _ = build_and_predict(params, x) 
    if pred == y:
        count += 1
        print("Attempt {}: correct! (Reading: {}, Prediction: {})\n".format(str(ii+1),y,pred))
    else:
        print("Attempt {}: wrong (Reading: {}, Prediction: {})\n".format(str(ii+1),y,pred))
acc = np.round(count/trials*100,2)
print("Correct predictions: "+str(acc)+"%")

Attempt 1: correct! (Reading: W, Prediction: W)

Attempt 2: correct! (Reading: V, Prediction: V)

Attempt 3: correct! (Reading: G, Prediction: G)

Attempt 4: correct! (Reading: E, Prediction: E)

Attempt 5: correct! (Reading: R, Prediction: R)

Attempt 6: correct! (Reading: H, Prediction: H)

Attempt 7: correct! (Reading: R, Prediction: R)

Attempt 8: correct! (Reading: U, Prediction: U)

Attempt 9: correct! (Reading: S, Prediction: S)

Attempt 10: correct! (Reading: S, Prediction: S)

Attempt 11: correct! (Reading: Q, Prediction: Q)

Attempt 12: correct! (Reading: M, Prediction: M)

Attempt 13: correct! (Reading: H, Prediction: H)

Attempt 14: correct! (Reading: K, Prediction: K)

Attempt 15: correct! (Reading: M, Prediction: M)

Attempt 16: correct! (Reading: Q, Prediction: Q)

Attempt 17: correct! (Reading: Space, Prediction: Space)

Attempt 18: correct! (Reading: Y, Prediction: Y)

Attempt 19: correct! (Reading: E, Prediction: E)

Attempt 20: correct! (Reading: E, Prediction: E)

A