In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import random
import math

import pickle
import os

In [5]:
from simulator import AELayer, spk_t

def my_stdp(ref_time, run_time, curr_time, spike_to_value,
            t_minus, t_plus, weights, min_w, max_w, learn_rate,
            last_pre_spikes, pre_spikes, 
            post_spikes, target_spikes):
    w = weights
    # time weight <==> sooner spikes (t->ref) should be more important    
    tw = spike_to_value(ref_time, run_time, curr_time, 1.)

#     # weight updates
#     left_bound = max(t-tau_stdp+1-delay, 0)
    if (post_spikes>0).any() or (target_spikes>0).any():
#         print("post_spikes")
#         print(post_spikes)
#         print("target_spikes")
#         print(target_spikes)
        rows = np.where( np.logical_and(last_pre_spikes > (curr_time - t_minus), last_pre_spikes >= 0) )[0]
#         print("ROWS")
#         print(rows)
        if len(rows):
            w[rows, :] += learn_rate*np.outer(last_pre_spikes[rows,0], target_spikes[:, 0])*tw
            w[rows, :] -= learn_rate*np.outer(last_pre_spikes[rows,0], post_spikes[:, 0])*tw
        
        cols = np.where( target_spikes > 0 )[0]
#         print("np.outer")
#         print(np.outer(np.ones_like(pre_spikes, dtype=spk_t), target_spikes[cols, 0]))
#         print("w column")
#         print(w[:, 0])
#         print("pre mult")
#         print(np.ones_like(pre_spikes, dtype=spk_t))
#         print("target cols")
#         print(target_spikes[cols, 0])
        w[:, cols] += (learn_rate*tw)*np.outer(np.ones_like(pre_spikes, dtype=spk_t), target_spikes[cols, 0])
        cols = np.where( post_spikes > 0 )[0]
        w[:, cols] -= (learn_rate*tw)*np.outer(np.ones_like(pre_spikes, dtype=spk_t), post_spikes[cols, 0])

        w[w>max_w] = max_w
        w[w<min_w] = min_w

    return w

In [6]:
np.random.seed(10)
# www = np.arange(3*2).reshape((3, 2))
# www.T[1, :] = 333
# print(www)
# #network configuraiton
# v_size = 21
# h_size = 50

In [8]:
import mnist_utils as mu
train_x, train_y = mu.get_train_data()
train_x /= 255. # <- wouldn't this make it [0., 1.]?
v_size = 794
h_size = 500
run_time = 30
description = {'level': 0,
               'run_time': run_time,
               'sizes': {'in': 4, 'hid': 5, 'rcn': 4},
               'delays': {'in': np.ones(4),
                          'hid': np.ones(5),
                          'rcn': np.ones(4)},
               'in_times': [[1.], [], [2.], []],
               'neuron_params': {'v_thresh': 1.,   # membrane potential threshold
                                 'v_rest':   0.,     # resting potential
                                 'tau_m':    20.,},
               'stdp': {'func':       my_stdp,
                        'max_w':      0.5,
                        'min_w':     -0.5,
                        't_plus':     20.,
                        't_minus':    20.,
                        'learn_rate': 0.001,
                        'target_times': [[5.], [], [7.], []]},
              }

lvl_0 = AELayer(description)
prev_w  = lvl_0._w.copy()
out_spk = None
for i in range(run_time):
    if i == 0:
        out_spk = lvl_0.sim(i)
    else:
        out_spk[:] = lvl_0.sim(i)
#     print(np.sum(lvl_0._w - prev_w))
    prev_w[:] = lvl_0._w


In [8]:
#weights initialization
w_bound = 0.5 #0.3
w_init = 0.1 #0.1
# w = np.random.uniform(-w_init, w_init, (v_size, h_size)) 
w_offset = 0.05
w = np.random.normal(w_offset, w_init, (v_size, h_size)) 

#LIF neuron parameters
v_thresh = 1.   # membrane potential threshold
v_rest = 0.     # resting potential
tau_m = 20.     # membrane constant
delay = 5       # synaptic delay

#STDP config
tau_stdp = 20   # STDP window length
eta = 0.001  # learning rate
delta_w = eta*np.logspace(0,1,tau_stdp)  #expenential decaying STDP curve



run_len = 50      # Length of each trial
teach_delay = 10  # Delay length of the teaching signal
K = 30.
record_flag = True #False

In [None]:
# # Setting up the patterns to learn
# patterns = list()
# # # patterns.append([3,3,3,3,5,5,5,5,8,8,3,3,3,3,5,5,5,5,8,8])
# # # patterns.append([2,2,2,2,2,8,8,8,8,8,2,2,2,2,2,8,8,8,8,8])
# patterns.append([2,2,2,2,2,2,2,2,2,2,2,2,2,2,16,16,16,16,16,16,16]) #1,1,0 #16,16,16,16,16,16,16   #8,8,8,8,8,8,8
# patterns.append([2,2,2,2,2,2,2,16,16,16,16,16,16,16,2,2,2,2,2,2,2]) #1,0,1
# patterns.append([16,16,16,16,16,16,16,2,2,2,2,2,2,2,2,2,2,2,2,2,2]) #0,1,1

# patterns = np.array(patterns)
# print patterns.shape

In [None]:
num_test = 10
patterns = np.int16(np.floor((1.-train_x[:num_test])*K ))
patterns_y = np.zeros((num_test, 10))
patterns_y[range(num_test), np.int16(train_y[:num_test])] = 1.
patterns_y = np.int16(np.floor((1.-patterns_y[:num_test])*K ))
patterns = np.append(patterns,patterns_y, axis=1)


In [None]:
epoch = 5 # training epochs

# Recording of the neural status
if record_flag:
#     h_mem_list = []     # membrane potential of hiden units 
#     h_spike_list = []   # spikes of hiden units 
#     o_mem_list = []     # membrane potential of output units 
#     o_spike_list = []   # spikes of output units 
    w_list = []         # weights
    loss_list = []
    predict_list = []

# Initialise neural status
h_spike = np.zeros((h_size, run_len))   # no output spikes from hiden units 
o_spike = np.zeros((v_size, run_len))   # no output spikes from ouput units 
h_mem = np.zeros((h_size, 1))           # membrane potential=0 mV for hiden units 
o_mem = np.zeros((v_size, 1))           # membrane potential=0 mV for output units 

for iteration in range(epoch):
    print 'epoch:%d'%iteration
    for p_id in range(patterns.shape[0]):
        # reset the neural status
        h_mem[:] = v_rest
        o_mem[:] = v_rest
        h_spike[:,:] = v_rest
        o_spike[:,:] = v_rest
        
        # the input spikes and the teaching signal
        v_spike = np.zeros((v_size, run_len))
        v_teach = np.zeros((v_size, run_len))
        
        v_pattern = np.copy(patterns[p_id,:])
        
        # add noise to the input signal
#         v_pattern += np.random.normal(0, 0.1, v_pattern.shape)
#         v_pattern[v_pattern<0] = 0
        
        v_spike[(range(v_size), v_pattern)]=1.
        v_spike[:, K]=0.
        v_teach[(range(v_size), np.array(v_pattern)+teach_delay)] = 1.
        
        # in the period of delay, nothing changed
#         if record_flag:
#             for t in range(delay):
#                 h_mem_list.append(h_mem.copy())
#                 h_spike_list.append(h_spike[:, t].copy())
#                 o_mem_list.append(o_mem.copy())
#                 o_spike_list.append(o_spike[:, t].copy())
#                 w_list.append(w.flatten())
            
        # Main part for neural status updating
        for t in range(delay,run_len):
            # hid units
            h_mem *= np.exp(-1/tau_m)   #decay
            h_mem += np.reshape(np.dot(v_spike[:,t-delay],w),(h_mem.shape))  # add up spiking input
            h_spike[(h_mem>v_thresh)[:,0], t] = 1.  #generate spikes
            h_mem[(h_mem>v_thresh)] = v_rest        #reset membrane potential
            #h_mem[(h_mem<v_rest)] = v_rest

            
            # output units
            o_mem *= np.exp(-1/tau_m)
            o_mem += np.reshape(np.dot(h_spike[:,t-delay],np.transpose(w)),(o_mem.shape))
            o_spike[(o_mem>v_thresh)[:,0], t] = 1.
            o_mem[(o_mem>v_thresh)] = v_rest
            #o_mem[(o_mem<v_rest)] = v_rest


            # t indicates the importance
            impt = np.float(run_len-t+delay)/np.float(run_len)
            
            # weight updates
            left_bound = max(t-tau_stdp+1-delay, 0)
            if (o_spike[:, t]>0).any() or (v_teach[:, t]>0).any():
                
                # Look the spikes of hiden units for a time period of STDP window
                temp_deltaw = np.einsum('jk,k->jk', h_spike[:, left_bound:t-delay+1], delta_w[left_bound-t+delay-1:])

            # W-
                w -= np.sum(np.einsum('i,jk->ijk', o_spike[:, t], temp_deltaw), axis=2)*impt #STDP
                w[o_spike[:, t]>0, :] -= (eta*impt*0.1) # weights decrease even without STDP
            # W+
                w += np.sum(np.einsum('i,jk->ijk', v_teach[:, t], temp_deltaw), axis=2)*impt #STDP
                w[v_teach[:, t]>0, :] += (eta*impt*0.1) # weights increase even without STDP

                w[w>w_bound] = w_bound
                w[w<-w_bound] = -w_bound
            
#             if record_flag:
#                 h_mem_list.append(h_mem.copy())
#                 h_spike_list.append(h_spike[:, t].copy())
#                 o_mem_list.append(o_mem.copy())
#                 o_spike_list.append(o_spike[:, t].copy())
#                 w_list.append(w.flatten())
        if np.mod(p_id,1)==0: #p_id == patterns.shape[0]-1:# and  np.mod(iteration,10) == 9: 
#             print iteration
#             neuron_id, time_stamp = np.where(o_spike==1)
#             plt.plot(time_stamp, neuron_id, '.')
#             plt.xlim((0,run_len))
#             plt.show()
            
            recon = (o_spike.argmax(axis=1)-teach_delay)*1.
            
            recon[recon>=0] = (K-recon[recon>=0])/K
            recon[recon<0] = 0
            
            loss =  ((((K-patterns)/K)[p_id]-recon.flatten()) ** 2).mean()
#             print 'Loss:', ((((K-patterns)/K)[p_id]-recon.flatten()) ** 2).mean()
            
            predict = np.argmax(recon[-10:]) 

#             print 'Predict: ', np.argmax(predict) 
#             print predict
#             recon_img = np.reshape(recon[:-10], (28,28))
#             plt.imshow(recon_img, cmap=cm.gray_r, clim=(0,1))
#             plt.show()
            if record_flag:
                loss_list.append(loss)
#                 predict_list.append(predict)
#                 w_list.append(w.flatten())
                w_list.append(w[-10:,:].flatten())
                print p_id, '%0.3f'%loss, np.int16(train_y[p_id]), predict, '%.2f'%recon[-10+predict]
    if np.mod(iteration,10) == 9:
        delta_w *= 0.8

In [None]:
# plt.plot(loss_list)
avg_num = 10
img_num = len(loss_list)
loss_plot = np.reshape(np.array(loss_list), (img_num/avg_num, avg_num))
plt.semilogy(np.average(loss_plot,axis=1))
plt.title('Loss (MSE)')


In [None]:
w_list = np.array(w_list)
plt.plot(w_list[:,::10])
print w_list.shape