In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
# bi-directional srnn within pkg

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR,MultiStepLR
import math
import torch.nn.functional as F
from torch.utils import data

from SRNN_layers.spike_dense import *#spike_dense,readout_integrator
from SRNN_layers.spike_neuron import *#output_Neuron
from SRNN_layers.spike_rnn import *# spike_rnn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ',device)

def normalize(data_set,Vmax,Vmin):
    return (data_set-Vmin)/(Vmax-Vmin)#+1e-6)

train_data = np.load('./f40/train_f40_t100.npy')
test_data = np.load('./f40/test_f40_t100.npy')
valid_data = np.load('./f40/valid_f40_t100.npy')


num_channels = 39
use_channels = 39
Vmax = np.max(train_data[:,:,:use_channels],axis=(0,1))
Vmin = np.min(train_data[:,:,:use_channels],axis=(0,1))
print(train_data.shape,Vmax.shape,b_j0_value)

train_x = normalize(train_data[:,:,:use_channels],Vmax,Vmin)
train_y = train_data[:,:,num_channels:]

test_x = normalize(test_data[:,:,:num_channels],Vmax,Vmin)
test_y = test_data[:,:,num_channels:]

valid_x = normalize(valid_data[:,:,:num_channels],Vmax,Vmin)
valid_y = valid_data[:,:,num_channels:]

print('input dataset shap: ',train_x.shape)
print('output dataset shap: ',train_y.shape)
_,seq_length,input_dim = train_x.shape
_,_,output_dim = train_y.shape

batch_size =16
# spike_neuron.b_j0_value = 1.59

torch.manual_seed(0)
def get_DataLoader(train_x,train_y,batch_size=200):
    train_dataset = data.TensorDataset(torch.Tensor(train_x), torch.Tensor(np.argmax(train_y,axis=-1)))
    train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    return train_loader

train_loader = get_DataLoader(train_x,train_y,batch_size=batch_size)
test_loader = get_DataLoader(test_x,test_y,batch_size=batch_size)
valid_loader = get_DataLoader(valid_x,valid_y,batch_size=batch_size)

class RNN_s(nn.Module):
    def __init__(self,criterion,device,delay=0):
        super(RNN_s, self).__init__()
        self.criterion = criterion
        self.delay = delay
        
        #self.network = [input_dim,128,128,256,output_dim]
        self.network = [39,256,256,output_dim]


        self.rnn_fw1 = spike_rnn(self.network[0],self.network[1],
                               tau_initializer='multi_normal',
                               tauM=[20,20,20,20],tauM_inital_std=[1,5,5,5],
                               tauAdp_inital=[200,200,250,200],tauAdp_inital_std=[5,50,100,50],
                               device=device)
        
        self.rnn_bw1 = spike_rnn(self.network[0],self.network[2],
                                tau_initializer='multi_normal',
                                tauM=[20,20,20,20],tauM_inital_std=[5,5,5,5],
                                tauAdp_inital=[200,200,150,200],tauAdp_inital_std=[5,50,30,10],
                                device=device)
        

        self.dense_mean = readout_integrator(self.network[2]+self.network[1],self.network[3],
                                    tauM=3,tauM_inital_std=.1,device=device)
        

    def forward(self, input,labels=None):
        b,s,c = input.shape
        self.rnn_fw1.set_neuron_state(b)
        self.rnn_bw1.set_neuron_state(b)
        self.dense_mean.set_neuron_state(b)
        
        loss = 0
        predictions = []
        fw_spikes = []
        bw_spikes = []
        mean_tensor = 0

        for l in range(s*5):
            input_fw=input[:,l//5,:].float()
            input_bw=input[:,-l//5,:].float()

            mem_layer1, spike_layer1 = self.rnn_fw1.forward(input_fw)
            mem_layer2, spike_layer2 = self.rnn_bw1.forward(input_bw)
            fw_spikes.append(spike_layer1)
            bw_spikes.insert(0,spike_layer2)
        
        for k in range(s*5):
            bw_idx = int(k//5)*5 + (4 - int(k%5))
            second_tensor = bw_spikes[k]#[bw_idx]
            merge_spikes = torch.cat((fw_spikes[k], second_tensor), -1)
            mean_tensor += merge_spikes
            # mem_layer3  = self.dense_mean(mean_tensor/5.)
            if k %5 ==4:
                mem_layer3  = self.dense_mean(mean_tensor/5.)# mean or accumulate
            
                output = F.log_softmax(mem_layer3,dim=-1)#
                predictions.append(output.data.cpu().numpy())
                if labels is not None:
                    loss += self.criterion(output, labels[:, k//5])
                mean_tensor = 0
    
        predictions = torch.tensor(predictions)
        fw_npy  = np.mean(np.array([t.detach().cpu().numpy() for t in fw_spikes]))
        bw_npy  = np.mean(np.array([t.detach().cpu().numpy() for t in bw_spikes]))
        return predictions, [fw_spikes,bw_spikes],(bw_npy+fw_npy)/2.


def test(data_loader,after_num_frames=0,is_fr=1):
    test_acc = 0.
    sum_samples = 0
    fr = []
    for i, (images, labels) in enumerate(data_loader):
        images = images.view(-1, seq_length, input_dim).to(device)
        labels = labels.view((-1,seq_length)).long().to(device)
        predictions, _,fr_ = model(images)
        _, predicted = torch.max(predictions.data, 2)
        labels = labels.cpu()
        predicted = predicted.cpu().t()
        fr.append(fr_)
        
        test_acc += (predicted == labels).sum()
        
        sum_samples = sum_samples + predicted.numel()
    # print(predicted[1],'\n',labels[1])
    if is_fr:
        print('Mean fr: ', np.mean(fr))
    return test_acc.data.cpu().numpy() / sum_samples

def train(model,loader,optimizer,scheduler=None,num_epochs=10):
    best_acc = 0
    path = 'model/'  # .pth'
    acc_list=[]
    print(model.rnn_fw1.b_j0)
    for epoch in range(num_epochs):
        train_acc = 0
        train_loss_sum = 0
        sum_samples = 0
        for i, (images, labels) in enumerate(loader):
            images = images.view(-1, seq_length, input_dim).requires_grad_().to(device)
            labels = labels.view((-1,seq_length)).long().to(device)
            optimizer.zero_grad()
    
            predictions, train_loss,fr_ = model(images, labels)
            _, predicted = torch.max(predictions.data, 2)
            
            train_loss.backward()
            train_loss_sum += train_loss
            optimizer.step()

            labels = labels.cpu()
            predicted = predicted.cpu().t()
            
            train_acc += (predicted == labels).sum()
            sum_samples = sum_samples + predicted.numel()
            torch.cuda.empty_cache()
        if scheduler is not None:
            scheduler.step()
            
        train_acc = train_acc.data.cpu().numpy() / sum_samples
        valid_acc = test(valid_loader)
        
        if valid_acc>best_acc and train_acc>0.30:
            best_acc = valid_acc
            torch.save(model, path+str(best_acc)[:7]+'-bi-srnn-v3_MN-v1.pth')

        acc_list.append(train_acc)
        print('epoch: {:3d}, Train Loss: {:.4f}, Train Acc: {:.4f},Valid Acc: {:.4f}'.format(epoch,
                                                                           train_loss_sum.item()/len(loader)/(seq_length),
                                                                           train_acc,valid_acc), flush=True)
    return acc_list



gradient type:  MG
device:  cuda:0
(13134, 100, 100) (39,) 1.6
input dataset shap:  (13134, 100, 39)
output dataset shap:  (13134, 100, 61)


In [7]:
num_epochs = 200
criterion = nn.NLLLoss()#nn.CrossEntropyLoss()
model = RNN_s(criterion=criterion,device=device)
# model = torch.load('./model/0.66108-bi-srnn-v3_MN-v1.pth')
model = torch.load('./0.66292-bi-srnn-v3_MN-v1.pth')
# model = torch.load('./0.65901-bi-srnn-v3_MN-v1.pth')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
model.to(device)

device: cuda:0


RNN_s(
  (criterion): NLLLoss()
  (rnn_fw1): spike_rnn(
    (dense): Linear(in_features=39, out_features=512, bias=True)
    (recurrent): Linear(in_features=512, out_features=512, bias=True)
  )
  (rnn_bw1): spike_rnn(
    (dense): Linear(in_features=39, out_features=512, bias=True)
    (recurrent): Linear(in_features=512, out_features=512, bias=True)
  )
  (dense_mean): readout_integrator(
    (dense): Linear(in_features=1024, out_features=61, bias=True)
  )
)

In [11]:
# with sechdual
test_acc = test(valid_loader)
print(test_acc)

Mean fr:  0.04215048685935991
0.6618980963045913


In [10]:
for i, (images, labels) in enumerate(test_loader):
    if i ==0:
        images = images.view(-1, seq_length, input_dim).to(device)
        labels = labels.view((-1,seq_length)).long().to(device)
        predictions, states,fr_ = model(images)
    else:
        break

In [12]:
fw_spike = np.array(states[0])
bw_spike = np.array(states[1])

In [24]:

b,_ = fw_spike[0].detach().cpu().numpy().shape
spike_np = np.zeros((500,b,512+512))
for i in range(500):
    spike_np[i,:,:512] = fw_spike[i].detach().cpu().numpy()
    spike_np[i,:,512:] = bw_spike[i].detach().cpu().numpy()
spike_np.shape

(500, 16, 1024)

In [25]:
spike_count = {'total':[],'fr':[],'per step':[]}
spike_count['total'].append([np.mean(np.sum(spike_np,axis=(0,2))),np.max(np.sum(spike_np,axis=(0,2))),np.min(np.sum(spike_np,axis=(0,2)))])
spike_count['per step'].append([np.mean(np.sum(spike_np,axis=(2))),np.max(np.sum(spike_np,axis=(2))),np.min(np.sum(spike_np,axis=(2)))])
spike_count['fr'].append(np.mean(spike_np))

In [26]:
spike_count

{'total': [[21589.4375, 23229.0, 20550.0]],
 'fr': [0.0421668701171875],
 'per step': [[43.178875, 102.0, 16.0]]}

In [31]:
spike_count = {'total':[],'fr':[],'per step':[]}
for i, (images, labels) in enumerate(test_loader):
    images = images.view(-1, seq_length, input_dim).to(device)
    labels = labels.view((-1,seq_length)).long().to(device)
    predictions, states,fr_ = model(images)

    fw_spike = np.array(states[0])
    bw_spike = np.array(states[1])
    b,_ = fw_spike[0].detach().cpu().numpy().shape
    spike_np = np.zeros((500,b,512+512))
    for i in range(500):
        spike_np[i,:,:512] = fw_spike[i].detach().cpu().numpy()
        spike_np[i,:,512:] = bw_spike[i].detach().cpu().numpy()
    
    spike_count['total'].append([np.mean(np.sum(spike_np,axis=(0,2))),np.max(np.sum(spike_np,axis=(0,2))),np.min(np.sum(spike_np,axis=(0,2)))])
    spike_count['per step'].append([np.mean(np.sum(spike_np,axis=(2))),np.max(np.sum(spike_np,axis=(2))),np.min(np.sum(spike_np,axis=(2)))])
    spike_count['fr'].append(np.mean(spike_np))

In [32]:
spike_total = np.array(spike_count['total'])
spike_total.shape

(43, 3)

In [33]:
np.mean(spike_total[0]),np.max(spike_total[1]),np.min(spike_total[2])

(21241.8125, 22959.0, 20185.0)

In [34]:
spike_per = np.array(spike_count['per step'])
np.mean(spike_per[0]),np.max(spike_per[1]),np.min(spike_per[2])

(49.892291666666665, 105.0, 15.0)

In [35]:
spike_fr = np.array(spike_count['fr'])
np.mean(spike_fr)

0.042145522953003875