In [1]:
from simulation import  *

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torchvision
from torchvision import transforms

from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils

import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm


In [2]:
length = 3
timesteps = 4 # Resolution for timesteps and weights
window_size = 2
num_neurons = 60 # Number of excitatory neurons in the column
threshold = 8 # Firing threshold for every excitatory neuron
rf_size = window_size*2
inchannels = 1

kwta = 3
inhibition_radius = 0


simulation = Simulation()
corpus = Corpus()
sentences = simulation.construct_sentences()
tokens = corpus.tokenize(sentences)



corpus.dictionary.get_encoding(length,timesteps)
spike_data = SpikeData(tokens, sentences, corpus)
spike_input, spike_output = spike_data.convert_tokens(window_size)
print(np.max(spike_output))

3


In [23]:
class Column(nn.Module):
    def __init__(self, num_neurons, threshold):
        super(Column, self).__init__()
        self.k = num_neurons
        self.thresh = threshold
        self.ec = snn.LocalConvolution(input_size=(rf_size,length),
                                       in_channels=inchannels,
                                       out_channels=self.k,
                                       kernel_size=(rf_size,length),
                                       stride=1)
        self.stdp = snn.ModSTDP(self.ec, 10/128, 10/128, 1/128, 96/128, 4/128, maxweight = timesteps)

    def forward(self, rec_field):
        out = self.ec(rec_field)
        spike, pot = sf.fire(out, self.thresh, True)
        winners = sf.get_k_winners(pot, kwta = kwta, inhibition_radius = inhibition_radius, spikes = spike)
        coef = torch.zeros_like(out[0]).unsqueeze_(0)
        coef[:,winners,:,:] = 1
        return torch.mul(pot, coef).sign()
    
 

In [24]:
temporal_transform = utils.Intensity2Latency(timesteps)

### Column Initialization ###

MyColumn = Column(num_neurons, threshold)

In [25]:
cat_idx = corpus.dictionary.word2idx['cat']
cat_enc = corpus.dictionary.idx2spike[cat_idx]
cat_context = spike_input[np.all(spike_output == cat_enc, axis=1)]
cat_context = torch.from_numpy(cat_context)
cat_context = temporal_transform(cat_context)
cat_context = cat_context.sign()

class DatasetContext(Dataset):
    def __init__(self, context, transform=None):
        self.data = context
        
    def __len__(self):
        return self.data.size(1)
    
    def __getitem__(self, index):
        image = self.data[:,index,:,:].reshape((self.data.size(0),1, self.data.size(2), -1))        
        return image
    
cat = DatasetContext(cat_context)
trainLoader = DataLoader(cat, batch_size=1000, shuffle=True)

In [26]:
dog_idx = corpus.dictionary.word2idx['dog']
dog_enc = corpus.dictionary.idx2spike[dog_idx]
dog_context = spike_input[np.all(spike_output == dog_enc, axis=1)]

In [27]:
for epochs in range(1):
    start = time.time()
    cnt = 0
    for data in tqdm(trainLoader):
        for i in range(1):
            out = MyColumn(data[i])
            MyColumn.stdp(data[i],out)
    end = time.time()
    print("Training done under ", end-start)

100%|██████████| 1/1 [00:00<00:00, 101.44it/s]

tensor([[[[0.]],

         [[0.]],

         [[0.]]]])
Training done under  0.011622905731201172





In [15]:
print spike_input[0]
print spike_input[1]
print spike_input[2]
print spike_input[3]


[[3 0 2]
 [0 1 1]
 [0 3 2]
 [1 0 3]]
[[0 1 1]
 [2 3 3]
 [1 0 3]
 [1 0 0]]
[[2 3 3]
 [0 3 2]
 [1 0 0]
 [1 3 1]]
[[0 3 2]
 [1 0 3]
 [1 3 1]
 [0 3 2]]


In [20]:
cat_idx = corpus.dictionary.word2idx['cat']
dog_idx = corpus.dictionary.word2idx['dog']
cat_enc = corpus.dictionary.idx2spike[cat_idx]
dog_enc = corpus.dictionary.idx2spike[dog_idx]
print dog_enc

[3 3 3]
