In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pyNN.nest as p
import relu_utils as alg
import spiking_relu as sr
import random
import mnist_utils as mu
import os.path
import sys
import cnn_utils as cnnu
import matplotlib.cm as cm
#USAGE: spiking_dbn.py scaled_weight b10_epoc5

In [2]:
w_listf = 'scaled_weight'
dbn_f = 'b10_epoc5'
dbnet = alg.load_dict(dbn_f)
cell_params_lif = {'cm': 0.25,
                   'i_offset': 0.0,
                   'tau_m': 20.0,
                   'tau_refrac': 1.,
                   'tau_syn_E': 1.0,
                   'tau_syn_I': 1.0,
                   'v_reset': -70.0,
                   'v_rest': -65.0,
                   'v_thresh': -50.0
                   }

In [3]:
if os.path.isfile('%s.pkl'%w_listf):
    scaled_w = alg.load_dict(w_listf)
    w = scaled_w['w']
    k = scaled_w['k']
    x0 = scaled_w['x0']
    y0 = scaled_w['y0']
    print 'found w_list file'
else:
    w, k, x0, y0 = sr.w_adjust(dbnet, cell_params_lif)
    scaled_w = {}
    scaled_w['w'] = w
    scaled_w['k'] = k
    scaled_w['x0'] = x0
    scaled_w['y0'] = y0
    alg.save_dict(scaled_w, w_listf)

found w_list file


In [3]:
num_test = 10
random.seed(0)
dur_test = 1000
silence = 200
test_x = dbnet['test_x']

In [5]:
offset = 0
test = test_x[offset:(offset+num_test), :]
spike_source_data = sr.gen_spike_source(test)                
spikes = sr.run_test(w, cell_params_lif, spike_source_data)
spike_count = list()

for i in range(w[-1].shape[1]):
    index_i = np.where(spikes[:,0] == i)
    spike_train = spikes[index_i, 1]
    temp = sr.counter(spike_train, range(0, (dur_test+silence)*num_test,dur_test+silence), dur_test)
    spike_count.append(temp)
spike_count = np.array(spike_count)/(dur_test / 1000.)
r = np.argmax(spike_count, axis=0)
correct = np.sum(r == dbnet['test_y'][offset:offset+num_test]).astype(int) - len(np.where(spike_count.max(axis=0)==0)[0])
print correct



9


In [18]:
rs = np.load('result_list.npy')
limit = 9990.
correct = np.where(rs[:limit,1] == 1)[0].shape[0]
wrong = np.where(rs[:limit,1] == 0)[0].shape[0]
noresp = np.where(rs[:limit,1] == -1)[0].shape[0]
print  correct/limit*100.,  wrong/limit*100., noresp/limit*100.

89.8498498498 8.71871871872 1.43143143143


#CNN

In [None]:
def run_cnn_test(w_list, cell_para, spike_source_data):
    pop_list = []
    p.setup(timestep=1.0, min_delay=1.0, max_delay=3.0)
    #input poisson layer
    input_size = 28
    pop_in = p.Population(input_size, p.SpikeSourceArray, {'spike_times' : []})
    for j in range(input_size):
        pop_in[j].spike_times = spike_source_data[j]
    pop_list.append(pop_in)
    
    for w in w_list:        
        pos_w = np.copy(w)
        pos_w[pos_w < 0] = 0
        neg_w = np.copy(w)
        neg_w[neg_w > 0] = 0
        
        output_size = w.shape[1]
        pop_out = p.Population(output_size, p.IF_curr_exp, cell_para)
        p.Projection(pop_in, pop_out, p.AllToAllConnector(weights = pos_w), target='excitatory')
        p.Projection(pop_in, pop_out, p.AllToAllConnector(weights = neg_w), target='inhibitory')
        pop_list.append(pop_out)
        pop_in = pop_out

    pop_out.record()
    run_time = np.ceil(np.max(spike_source_data)[0]/1000.)*1000
    p.run(run_time)
    spikes = pop_out.getSpikes(compatible_output=True)
    p.end()
    return spikes

In [15]:
w_cnn, l_cnn = cnnu.readmat('cnn_relu.mat')#cnn609.mat softplus 3-5 train.
#r = cnnu.test(w_cnn, l_cnn, test_x[:100,:], False)
SUM_rate = 2000.
tx = np.zeros((num_test, 784))
for i in range(num_test):
    tx[i] = test_x[i]/sum(test_x[i])*SUM_rate
w_cnn, a = cnnu.scale_weight(w_cnn, l_cnn, tx[:100,:])

scale:  18.6158395604
scale:  10.211737225
scale:  7.27692308695
scale:  14.3287809602
46.3339326908
scale:  2.70857821288
50.0


In [5]:
cell_params_lif = {'cm': 0.25,
                   'i_offset': 0.0,
                   'tau_m': 20.0,
                   'tau_refrac': 1.,
                   'tau_syn_E': 1.0,
                   'tau_syn_I': 1.0,
                   'v_reset': -70.0,
                   'v_rest': -65.0,
                   'v_thresh': -50.0
                   }
def conv_conn(in_size, out_size, w):
    conn_list_exci = []
    conn_list_inhi = []
    #conn_list = [] #nest works with mixed exci and inhi connections
    k_size = in_size - out_size + 1
    for x_ind in range(out_size):
        for y_ind in range(out_size):
            out_ind = x_ind * out_size + y_ind
            for kx in range(k_size):
                for ky in range(k_size):
                    in_ind = (x_ind+kx) * in_size + (y_ind+ky)
                    weight = w[k_size-1-ky][k_size-1-kx] #transpose(w)
                    if weight>0:
                        conn_list_exci.append((in_ind, out_ind, weight, 1.)) 
                    elif weight<0:
                        conn_list_inhi.append((in_ind, out_ind, weight, 1.)) 
                    #conn_list.append((in_ind, out_ind, weight, 1.))
    return conn_list_exci, conn_list_inhi#, conn_list

def pool_conn(in_size, out_size, w):
    conn_list = []
    step = in_size/out_size
    for x_ind in range(out_size):
        for y_ind in range(out_size):
            out_ind = x_ind * out_size + y_ind
            for kx in range(step):
                for ky in range(step):
                    in_ind = (x_ind*step+kx) * in_size + (y_ind*step+ky)
                    conn_list.append((in_ind, out_ind, w, 1.))
    return conn_list

def out_conn(w):
    conn_list_exci = []
    conn_list_inhi = []
    #conn_list = [] #nest works with mixed exci and inhi connections
    for j in range(w.shape[0]):
        for i in range(w.shape[1]):
            weight = w[j][i]
            if weight>0:
                conn_list_exci.append((i, j, weight, 1.)) 
            elif weight<0:
                conn_list_inhi.append((i, j, weight, 1.)) 
            #conn_list.append((i, j, weight, 1.))
    return conn_list_exci, conn_list_inhi#, conn_list
    
    for x_ind in range(out_size):
        for y_ind in range(out_size):
            out_ind = x_ind * out_size + y_ind
            for kx in range(k_size):
                for ky in range(k_size):
                    in_ind = (x_ind+kx) * in_size + (y_ind+ky)
                    weight = w[k_size-1-ky][k_size-1-kx] #transpose(w)
                    if weight>0:
                        conn_list_exci.append((in_ind, out_ind, weight, 1.)) 
                    elif weight<0:
                        conn_list_inhi.append((in_ind, out_ind, weight, 1.)) 
                    #conn_list.append((in_ind, out_ind, weight, 1.))
    return conn_list_exci, conn_list_inhi#, conn_list

def conv_pops(pop1, pop2, w):
    in_size = int(np.sqrt(pop1.size))
    out_size = int(np.sqrt(pop2.size))
    conn_exci, conn_inhi = conv_conn(in_size, out_size, w)
    if len(conn_exci)>0:
        p.Projection(pop1, pop2, p.FromListConnector(conn_exci), target='excitatory')
    if len(conn_inhi)>0:
        p.Projection(pop1, pop2, p.FromListConnector(conn_inhi), target='inhibitory')
    return

def pool_pops(pop1, pop2, w):
    in_size = int(np.sqrt(pop1.size))
    out_size = int(np.sqrt(pop2.size))
    conn_exci = pool_conn(in_size, out_size, w)
    if len(conn_exci)>0:
        p.Projection(pop1, pop2, p.FromListConnector(conn_exci), target='excitatory')
    return

def out_pops(pop_list, pop2, w_layer):
    in_size = pop_list[0].size
    out_size = pop2.size
    for i in range(len(pop_list)):
        w = w_layer[:,i*in_size:(i+1)*in_size]
        conn_exci, conn_inhi = out_conn(w)
        if len(conn_exci)>0:
            p.Projection(pop_list[i], pop2, p.FromListConnector(conn_exci), target='excitatory')
        if len(conn_inhi)>0:
            p.Projection(pop_list[i], pop2, p.FromListConnector(conn_inhi), target='inhibitory')
    return

def init_inputlayer(input_size, data):
    pop_list = []
    pop = p.Population(input_size*input_size, p.SpikeSourceArray, {'spike_times' : []})
    spike_source_data = sr.gen_spike_source(data)
    for j in range(input_size*input_size):
        pop[j].spike_times = spike_source_data[j]
    pop_list.append(pop)
    return pop_list

In [6]:
def construct_layer(pop_list_in, mode, k_size, w_layer):
    in_num = len(pop_list_in) #populations number in previous layer
    in_size = int(np.sqrt(pop_list_in[0].size)) #in_size*in_size = neuron_num per pop in the previous layer
    pop_layer = []
    if mode > 0: #convoluational layer
        out_num = mode #populations number in current layer
        print in_num, out_num
        out_size = in_size - k_size + 1
        for j in range(out_num):
            pop_layer.append(p.Population(out_size*out_size, p.IF_curr_exp, cell_params_lif))
            for i in range(in_num):
                conv_pops(pop_list_in[i], pop_layer[j], w_layer[i][j])
    elif mode == 0: #pooling layer
        out_num = in_num #populations number in current layer
        print in_num, out_num
        out_size = in_size/k_size
        for j in range(out_num):
            pop_layer.append(p.Population(out_size*out_size, p.IF_curr_exp, cell_params_lif))
            pool_pops(pop_list_in[j], pop_layer[j], w_layer[0][0])
    elif mode == -1: #top layer
        out_size = k_size
        print out_size
        pop_layer.append(p.Population(out_size, p.IF_curr_exp, cell_params_lif))
        out_pops(pop_list_in, pop_layer[0], w_layer)
    return pop_layer

In [7]:
import scipy.io as sio
tmp_x = sio.loadmat('mnist.mat')['test_x']
tmp_x = np.transpose(tmp_x, (2, 0, 1))
tmp_x = np.reshape(tmp_x, (tmp_x.shape[0], 28*28), order='F' )

tmp_y = sio.loadmat('mnist.mat')['test_y']
tmp_y = np.argmax(tmp_y, axis=0)
num_test = 100
offset = 0
test = tmp_x[offset:(offset+num_test), :]

In [16]:
p.setup(timestep=1.0, min_delay=1.0, max_delay=3.0)
L = l_cnn
input_size = L[0][1]
pops_list = []
pops_list.append(init_inputlayer(input_size, test))

for l in range(5):
    pops_list.append(construct_layer(pops_list[l], L[l+1][0], L[l+1][1], w_cnn[l]))
pops_list[5][0].record()
p.run((dur_test+silence)*num_test)
spikes = pops_list[5][0].getSpikes(compatible_output=True)
p.end()

1 6
6 6
6 12
12 12
10


In [17]:
spike_count = []
for i in range(10):
    index_i = np.where(spikes[:,0] == i)
    spike_train = spikes[index_i, 1]
    temp = sr.counter(spike_train, range(0, (dur_test+silence)*num_test,dur_test+silence), dur_test)
    spike_count.append(temp)
spike_count = np.array(spike_count)/(dur_test / 1000.)
r = np.argmax(spike_count, axis=0)
correct = np.sum(r == tmp_y[offset:(offset+num_test)]).astype(int) #- len(np.where(spike_count.max(axis=0)==0)[0])
print correct


80


In [18]:
sr.plot_spikes(spikes,'out')