In [2]:
%pylab inline
from numpy import *
from numpy.random import randn, rand, randint

from cmath import polar

import numpy as np

# for the following package you might have to install brewer2mpl first, than prettyplotlib
# you can install both with the command
# > sudo pip install brewer2mpl
# > sudo pip install prettyplotlib
import prettyplotlib as ppl
from scipy.ndimage.filters import gaussian_filter1d
%load_ext Cython



Populating the interactive namespace from numpy and matplotlib


# Optimal Learning

First we define the simulator that will do all the hard lifting: run the simulation, calculate spike times, integrate currents, change the weights, etc. If you want to change something about the learning, include the correction for the mean, etc., this is the place you have to edit.

The code is written in Cython for reasons of speed. Cython converts Python code into C-code, the only thing you have to add are the types of the objects you define. Essentially all the funny parts of the code are related to this. In addition, you will notice that all the loops are written in index notation. This would be horribly slow in Python but actually helps the conversion into C-code. So, never vectorize anything!

In [5]:
print xrange(18)

xrange(18)


In [None]:
%%cython

cimport cython
cimport numpy as np
from libc.math cimport log, exp, sqrt, sin
    
import numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def c_analytic_opt_simulate(np.ndarray[double, ndim=2] Wf, np.ndarray[double, ndim=1] T, np.ndarray[double, ndim=2] F, 
                        np.ndarray[double, ndim=1] I_int, np.ndarray[double, ndim=1] C_int, np.ndarray[double, ndim=1] V, 
                        double yd, float o, double mu1, double mu2, double inpDt, double recDt, double Rstep, 
                        double Fstep, double Ustep, double t0, unsigned int sim_T):
    '''

       PARAMETERS
       ----------
       Wf    - fast reconnect connections
       T     - Thresholds
       F     - feedforward connections
       I_int - the initial value of the integrated input current (vector of size N (# neurons)). Of minor importance if you restart a simulation.
       C_int - same as I_int but for the true control current (vector of size I (# input dimensions))
       V     - initial membrane voltage
       o     - std of noise
       mu1   - L1 constraint
       mu2   - L2 constraint
       inpDt - the input current is constant over a period of time and switches regularly with frequency f = 1/inpDt
       recDt - we collect a couple of variables over the time course of the simulation with a sampling frequency 1/recDt
       Rstep - learning rate of lateral connections
       Fstep - learning rate of feedforward connections
       t0    - intial time (important if you restart the simulation for the timestamps)
       sim_T - total simulation time

    '''
    # basic parameters of the network
    cdef unsigned int N = Wf.shape[0]                 # number of neurons
    cdef unsigned int J = F.shape[0]                  # number of inputs                
    cdef double t = t0                                # time
    
    # containers for data collections
    Wfdata = list()                                   # container for state of lateral weights
    Fdata = list()                                    # container for state of feedforward weights
    tdata = list()                                    # time-stamps
    ratedata = list()                                 # rates  

    cdef double last_rec_t = t0                       # last time data was collected
    
    # normalization of feedforward weights
    cdef double Fnorm = sqrt(sum(F[:,0]**2))
    
    # spike times
    cdef np.ndarray[double, ndim=1] s = np.empty(N)    # time of spike of neurons (if they would be in isolation)
    cdef double snext                                  # time to next spike in population (minimum of s)

    # input-related variables
    cdef np.ndarray[double, ndim=1] c = abs(np.random.randn(J))   # control current (size J)
    cdef np.ndarray[double, ndim=1] Fc = np.dot(F.T,c)            # input current (size N)
    
    # other dynamic variables to track
    cdef np.ndarray[double, ndim=1] rate = np.zeros(N)
    cdef np.ndarray[double, ndim=1] r = np.zeros(N)
    cdef double totr = 0
    cdef np.ndarray[double, ndim=1] Fu = np.zeros(N) 
    cdef np.ndarray[double, ndim=1] x = np.zeros(J)
    cdef double Fx = 0
    cdef double Wr = 0
    
    # auxiliary variables
    cdef unsigned int i, j
    cdef double norm = 0
    cdef double divisor = 0
    cdef double logarg = 0
    
    # simulation specific variables
    #   the simulation analytically computes the spike-times and steps to the next neuron that spikes
    #   if this time is larger than no_spike_dt (e.g. if no neuron would spike because the leak is too strong
    #   and the signal to weak), than no spike will be emited and the simulation steps no_spike_dt forward
    cdef double no_spike_dt = 10
    
    #   connected auxiliary variables
    cdef unsigned int spike_flag = 0    # auxiliary variables that signals if there was a spike in a simulation step
    cdef double last_input_t = 0        # last time the input changed
    
    # counts the number of spikes from beginning of simulation
    cdef unsigned int spikenum = 0
    
    # signal progress
    cdef double progress_update_dt = sim_T/1000
    cdef double last_progress_update_t = 0
    
    while t < sim_T + t0:   
        # compute the next spike-times
        snext, i = no_spike_dt, 0
        spike_flag = 0
                
        for j in xrange(N):
            # if (for whatever reason) the voltage is above the threshold right now, make the neuron spike
            if V[j] > T[j] - Fu[j]:
                snext, i = 0, j
                break
            else:
                # compute the time at which the membrane voltage would hit the threshold (spike-time)
                # you have to be careful not to run in numerical issues
                divisor = Fc[j] - yd*(T[j] - Fu[j])
                if abs(divisor) < 1e-10:
                    divisor = np.sign(divisor)*1e-10
                
                logarg = (Fc[j] - yd*V[j])/divisor
                
                if logarg >= 1:   # neuron would spike in the future
                    s[j] = 1/yd*log(logarg)    # spike time of neuron j
                    # check if spike time is lower than all previously computed spike times
                    if s[j] < snext:
                        snext, i = s[j], j
                        spike_flag = 1      # yes, there is a neuron that would like to spike!
        
        # calculate voltage at time of spike & propagate spike to the network
        for j in xrange(N):
            V[j] = Fc[j]/yd + exp(-yd*snext)*(V[j] - Fc[j]/yd)# + sqrt(snext)*o*np.random.randn()

        # propagate spikes
        if spike_flag == 1:
            for n in xrange(N):
                V[n] -= Wf[n,i]

        # update x and xhat
        for j in xrange(J):
            x[j] = c[j]/yd + exp(-yd*snext)*(x[j] - c[j]/yd)
                
        # update rate
        for n in xrange(N):
            r[n] = exp(-yd*snext)*r[n]
        
        totr = exp(-yd*snext)*totr
            
        if spike_flag:
            r[i] += 1
            totr += 1
            
        # update u
        for n in xrange(N):
            Fx = 0
            for j in xrange(J):
                Fx += F[j,n]*x[j]
                
            Wr = 0
            for m in xrange(N):
                Wr += Wf[n,m]*r[m]
                
            Fu[n] += Ustep*(Fx - Wr - Fu[n])

        # integrate membrane currents
        for n in xrange(N):
            #I_int[n] += Fc[n]*snext                                     # without leak
            I_int[n] = Fc[n]/yd + exp(-yd*snext)*(I_int[n] - Fc[n]/yd)   # with leak
            

        # input input currents
        for j in xrange(J):
            C_int[j] += c[j]*snext                                               # without leak
            #C_int[j] += c[j]/I_lam + (C_int[j] - c[j]/I_lam)*exp(-I_lam*snext)  # with leak

        # update recurrent weights (but wait some spikes to avoid init effects)        
        if spikenum > 100:
            if spike_flag:
                for n in xrange(N):
                    if n != i:
                        Wf[n,i] += Rstep*(I_int[n] - Wf[n,i] - mu2*r[n]/totr)

        # update feedforward weights
        if spikenum > 100:
            if spike_flag == 1:
                for j in xrange(J):
                    F[j,i] += Fstep*(C_int[j] - F[j,i])
                    #F[j,i] += Fstep*(c[j] - F[j,i])
                    
                # normalize
                norm = 0
                for j in xrange(J):
                    norm += F[j,i]**2
                    
                norm = sqrt(norm)
                
                for j in xrange(J):
                    F[j,i] *= Fnorm/norm
                    
                # recompute product Fc (for given column)
                Fc[i] = 0
                for j in xrange(J):
                    Fc[i] += F[j,i]*c[j]

        # periodically set input
        if t - last_input_t > inpDt:
            
            # randomly draw input
            for j in xrange(J):
                #c[j] = abs(np.random.randn())
                c[j] = np.random.rand()*3.464

            # recompute Fc
            for n in xrange(N):
                Fc[n] = 0
                for j in xrange(J):
                    Fc[n] += F[j,n]*c[j]
            
            last_input_t = t
        
        # set new time
        t = t + snext
        
        # increase the spike count
        if spike_flag:
            spikenum = spikenum + 1
        
        # reset current counter if one neuron spiked
        if spike_flag == 1:
            for j in xrange(J):
                C_int[j] = 0
                
            for n in xrange(N):
                I_int[n] = 0
        
        # set last spike-time
        if spike_flag == 1:
            last_spike_t = t
        
        # collect spike-numbers
        if spike_flag == 1:
            rate[i] += 1
        
        # collect data
        if t - last_rec_t > recDt:
            Wfdata.append(Wf.copy())
            Fdata.append(F.copy())            
            tdata.append(t)
            ratedata.append(rate.copy()/float(recDt))
            for n in xrange(N):
                rate[n] = 0
            last_rec_t = t
            
        # print progress
        if t - last_progress_update_t > progress_update_dt:
            print '\r', np.around((t-t0)/sim_T*100,1), '%',
            last_progress_update_t = t
           
    return V, Wf, F, I_int, C_int, t, Wfdata, Fdata, tdata, ratedata

# Reconstruction Error

The next simulation is only to capture the reconstruction error of a given set of parameters. It is less commented since it is very similar to the simulation above and you should not have to change anything here.

In [None]:
%%cython

cimport cython
cimport numpy as np
from libc.math cimport log, exp, sqrt, sin
    
import numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def performance(np.ndarray[double, ndim=2] Wf, np.ndarray[double, ndim=1] T, np.ndarray[double, ndim=2] F, 
                double yd, double mu1, double mu2, double inpDt, unsigned int sim_T):

    cdef unsigned int N = Wf.shape[0]                 # number of neurons
    cdef unsigned int J = F.shape[0]                  # number of inputs
    cdef unsigned int spike
    
    cdef double t = 0
    cdef double no_spike_dt = 10
    cdef unsigned int spike_flag = 0
    cdef double last_spike_t = 0
    cdef double divisor = 0
    cdef double logarg = 0
    cdef double last_input_t = 0                  # remember the last time the input changed
    cdef double last_rec_t = 0                # remember the last time data was collected
    cdef double snext
    cdef unsigned int i, j
    cdef unsigned int spikenum = 0                # counts the number of spikes
    
    cdef double progress_update_dt = sim_T/1000
    cdef double last_progress_update_t = 0
     
    cdef np.ndarray[double, ndim=1] rate = np.zeros(N)
    
    cdef np.ndarray[double, ndim=1] c = abs(np.random.randn(J))
    cdef np.ndarray[double, ndim=1] Fc = np.dot(F.T,c)
    cdef np.ndarray[double, ndim=1] s = np.empty(N)
    cdef np.ndarray[double, ndim=1] r = np.empty(N)
    
    cdef np.ndarray[double, ndim=1] x = np.zeros(J)    
    cdef np.ndarray[double, ndim=1] xhat = np.zeros(J)
    cdef np.ndarray[double, ndim=1] V = np.zeros(N)    
    
    cdef int cumrate = 0
    cdef double cumerror = 0
    cdef double cumr2 = 0
    cdef double cumr1 = 0
    
    # compute optimal decoder
    cdef np.ndarray[double, ndim=2] D = np.dot(np.linalg.pinv(F).T,Wf)
    
    while t < sim_T:   
        # compute the next spike-times
        snext, i = no_spike_dt, 0
        
        for j in xrange(N):
            # test argument of logarithm
            divisor = Fc[j] - yd*T[j]
            if abs(divisor) < 1e-10:
                divisor = 1e-10
            
            logarg = (Fc[j] - yd*V[j])/divisor
            
            if logarg > 1:
                s[j] = 1/yd*log(logarg)
                if s[j] < snext:
                    snext, i = s[j], j
        
        # if no neuron spikes raise flag
        if snext == no_spike_dt:
            spike_flag = 0
        else:
            spike_flag = 1
        
        # calculate rate
        if spike_flag:
            cumrate += 1

        # calculate error
        for j in xrange(J):           
            error = 2*c[j]*(exp(yd*snext)-1)**2*(x[j] - xhat[j])*yd
            error += (exp(2*yd*snext) - 1)*(x[j] - xhat[j])**2*yd**2
            error += c[j]**2*(4*exp(yd*snext) - 1 + exp(2*yd*snext)*(2*yd*snext - 3))
            cumerror += exp(-2*snext*yd)*error/(2*yd**3)
            
        for n in xrange(N):
            cumr2 += (1 - exp(-2*yd*snext))*r[n]**2/(2*yd)
            cumr1 += (1 - exp(-yd*snext))*r[n]/yd
            
        # update x and xhat
        for j in xrange(J):
            x[j] = c[j]/yd + exp(-yd*snext)*(x[j] - c[j]/yd)
            xhat[j] = exp(-yd*snext)*xhat[j]
            
            if spike_flag:
                xhat[j] += D[j,i]
                
        # update rate
        for n in xrange(N):
            r[n] = exp(-yd*snext)
            
        if spike_flag:
            r[i] += 1
                
        # calculate voltage at time of spike & propagate spike to the network
        for j in xrange(N):
            V[j] = Fc[j]/yd + exp(-yd*snext)*(V[j] - Fc[j]/yd)# + sqrt(snext)*o*np.random.randn()
            
            # avoid that noises pushes a neuron above the threshold
            if V[j] > T[j]:
                V[j] = T[j]
            
        # fix voltage of firing neuron to threshold (if it was changed by noise)
        if spike_flag == 1:
            V[i] = T[i]
        
        # propagate spikes
        if spike_flag == 1:
            for n in xrange(N):
                V[n] -= Wf[n,i]

        # periodically set input
        if t - last_input_t > inpDt:
            
            # randomly draw input
            for j in xrange(J):
                #c[j] = abs(np.random.randn())
                c[j] = np.random.rand()*3.464

            # recompute Fc
            for n in xrange(N):
                Fc[n] = 0
                for j in xrange(J):
                    Fc[n] += F[j,n]*c[j]
            
            last_input_t = t
        
        # set new time
        t = t + snext
        
        # print progress
        if t - last_progress_update_t > progress_update_dt:
            #print x - xhat
            print '\r', np.around(t/sim_T*100,1), '%',
            last_progress_update_t = t
            
    return cumrate/(float(sim_T)*N), cumerror/float(sim_T), cumr1/float(sim_T), cumr2/float(sim_T)

In [None]:
def get_performance(system,T=1000000):
    return performance(system.Wf, system.T, system.F, system.yd, system.mu1, system.mu2, system.inpDt, T)

In [None]:
def FD_distance(system):
    Wf = system.Wf - eye(system.Wf.shape[0])*system.mu2
    D = dot(linalg.pinv(system.F).T,Wf).T
    return sum((Wf - dot(system.F.T,D.T))**2)/sum(Wf**2)

# Simulator

The simulator object essentially stores all relevant variables of the system. Think of it is the "network object". It stores the thresholds, weights, collected data, L1/L2 constraints, etc. We will first initialize this object with all relevant parameters and than call its function run(), which in turn calls the simulation function above.

In [None]:
class simulator(object):
    
    def __init__(self,Wf,T,F,yd,o,inpDt,mode='opt'):
        # network structure
        self.Wf = Wf   # lateral
        self.F = F     # feedforward
        self.T = T     # threshold
        self.yd = yd   # leak
        self.mu2 = 0.
        self.mu1 = 0.
        
        # dynamical parameters
        self.o = o     # std of noise
        self.I_int = zeros_like(T)                     # state of integrated input current
        self.C_int = zeros(F.shape[0])                 # state of integrated control current
        self.V = zeros_like(T) + 1e-5*randn(*T.shape)  # state of voltage
        
        # input parameters
        self.inpDt = inpDt
        
        # recorded data        
        self.rec_t = array([0])
        self.rec_Wf = Wf.copy()
        self.rec_F = F.copy()
        self.rec_rate = zeros_like(T)
        self.rec_ISI = array([0])
        
        # simulation/plasticity parameters
        self.t = 0
        self.Rstep = 1e-5    # default learning step for laterals
        self.Fstep = 1e-6    # default learning step for feedforward
        self.Ustep = 1e-6    # default learning step for feedforward        
        self.recDt = 100     # default distance between data recordings 
        
    def run(self,Tsim):
        if not hasattr(self, 'rec_Rstep'):
            self.rec_Rstep = array([self.Rstep])
        
        # run simulation
        V, Wf, F, I_int, C_int, t, Wf_data, F_data, t_data, rate_data = c_analytic_opt_simulate(self.Wf,\
                self.T,self.F,self.I_int,self.C_int,self.V,self.yd,self.o,self.mu1,self.mu2,self.inpDt,\
                self.recDt,self.Rstep,self.Fstep,self.Ustep,self.t,Tsim)

        # update internal values for next run
        self.V[:] = V
        self.Wf[:,:] = Wf
        self.F[:,:] = F
        self.I_int[:] = I_int
        self.C_int[:] = C_int
        self.t = t
        
        # update stored values
        t_data = array(t_data)
        self.rec_t = hstack([self.rec_t,t_data])
        self.rec_rate = vstack([self.rec_rate.T,rate_data]).T        
        
        Wf_data = dstack(Wf_data)
        F_data = dstack(F_data)        
        self.rec_Wf = dstack([self.rec_Wf,Wf_data])
        self.rec_F = dstack([self.rec_F,F_data])

# Simulation

In [None]:
# NETWORK VARIABLES
N    = 5                    # number of Neurons in Network
I    = 2                     # dimension of input
yd   = 0.005                  # decoder timescale (in milliseconds)
beta = 0.05                    # quadratic constraint
o    = 1e-8                  # std of voltage noise
    
# INIT RECURRENT & FEEDFORWARD CONNECTIVITY
F = ones((I,N))/float(N)     # Initilise Readout connectivity
F = rand(*(I,N))
F /= sqrt(sum(F**2,axis=0))[None,:]
#F /= sqrt(float(N))

F_origin = F.copy()
Fnorm = sqrt(sum(F[:,0]**2))

# save optimal recurrent weights for later use
opt_W = dot(F.T,F)

# random initial lateral connectivity
Wf = 0.5*rand(*opt_W.shape)*opt_W

# set correct resets (diagonal of Wf)
Wf[np.diag_indices_from(Wf)] = diag(opt_W) + beta

# INIT THRESHOLD
T = diag(opt_W)/2. + beta/2.

sim_T = 1000
Dt = 10

In [None]:
# initialize network
system = simulator(Wf.copy(),T.copy(),F.copy(),yd,o,inpDt=10,mode='opt')

# some data collections
errors = []
times = []
objective = []
FD = []

In [None]:
# learning rate of lateral
system.Rstep = 1e-5

In [None]:
Tsim = 10000000

# number of times during simulation in which to test reconstruction performance
Ntest = 10

# interval between two tests
sim_T = Tsim/float(Ntest)

# start simulation
for trial in xrange(Ntest):
    print trial
    
    # compute reconstruction error
    rate, error, r1, r2 = get_performance(system,T=100000)
    errors.append(error)
    times.append(system.t)
    objective.append(error + system.mu1*r1 + system.mu2*r2)
    
    # compute distance to FD
    FD.append(FD_distance(system))
    
    
    # parameters
    system.yd = 0.0005
    system.o = 1e-7
    system.recDt = sim_T/5
    #system.Rstep = 0.5*1e-5
    system.Fstep = 1e-6
    system.mu1 = 0*beta
    system.mu2 = beta
    
    # train system
    system.run(sim_T)
    
    print 'finished'

In [None]:
figure(figsize=(18,10))
subplots_adjust(hspace=0.3)

suptitle('5 neurons / 2 inputs / positive / feedforward / low leak / no L1 / medium L2 / ', fontsize=16)

# color
syscol = cm.Blues(0.9)

labels = 'Optimal'

def find_time_slice(t_start,t_end,time_axis):
    if t_end == -1:
        i_end = -1
    else:
        i_end = searchsorted(time_axis,t_end)
        
    i_start = searchsorted(time_axis,t_start)
    return i_start, i_end

#start, end = 500000, -1
time = {}
t_start, t_end = 0, -1

start, end = find_time_slice(t_start,t_end,system.rec_t)
time = system.rec_t/float(1e7)

# add plot of population rate

subplot(231)
title('Distance to optimal recurrent weights')

def get_distance(system):
    Wf_opt = einsum('kit,kjt->ijt',system.rec_F[:,:,start:end],system.rec_F[:,:,start:end]) + eye(N)[:,:,None]*system.mu2
    distance =  sum(sum((system.rec_Wf[:,:,start:end] - Wf_opt)**2,axis=0),axis=0)/sum(sum(Wf_opt**2,axis=0),axis=0)
    
    return distance

distance = get_distance(system)
ppl.plot(time[start:end],distance,linewidth=3, color=syscol)

xlabel('time [1e7]')
#gca().set_yscale('log')

subplot(232)
title('Population rate')

def get_pop_rate(system):
    poprate = mean(system.rec_rate[:,start+1:end],axis=0)
    poprate = gaussian_filter1d(poprate,poprate.shape[0]/50)

    return poprate

poprate = get_pop_rate(system)
ppl.plot(time[start+1:end],poprate,linewidth=3,c=syscol)

for n in xrange(N):
    nrate = system.rec_rate[n,start+1:end]
    nrate = gaussian_filter1d(nrate,nrate.shape[0]/50)
    ppl.plot(time[start+1:end],nrate,linewidth=1,c=syscol,alpha=0.3)
        
xlabel('time [1e7]')

subplot(233)
title('Optimal vs Learned weights')

opt_W = dot(system.F.T,system.F) + system.mu2*eye(system.F.shape[1])
ppl.scatter(opt_W.flatten(),system.Wf.flatten(),facecolor=syscol)

subplot(234)
title('Reconstruction error')

ppl.plot(array(times[1:])/float(1e7),errors[1:],linewidth=3, color=syscol, label=labels)

xlabel('time [1e7]')
ax = gca()

subplot(235)
title('Distance to FD')

ppl.plot(array(times[1:])/float(1e7),FD[1:],linewidth=3, color=syscol)

xlabel('time [1e7]')

# draw legend
lg = figlegend(*ax.get_legend_handles_labels(),loc='lower right')
lg.draw_frame(False)

subplot(236,polar=True)
title('Receptive Fields')

for n in xrange(N):
    f = system.F[:,n]
    r, theta = polar(f[0] + f[1]*1j)
    ppl.plot([0,theta],[0,r],c=syscol,linewidth=1.5)

tick_params(\
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom='off',      # ticks along the bottom edge are off
    top='off',         # ticks along the top edge are off
    labelbottom='off', # labels along the bottom edge are off
    labelleft='off',
    labelright='off',
    labeltop='off',
    left='off',
    right='off')

# shift the axis a bit down
bbox=gca().get_position()
gca().set_position([bbox.x0, bbox.y0-0.02, bbox.x1-bbox.x0, bbox.y1-bbox.y0])

#savefig('positive_5_2_feedforward_mediumL2_lowLeak.pdf', bbox_inches='tight')

show()

In [None]:
figure(figsize=(18,10))
subplots_adjust(hspace=0.3)

suptitle('5 neurons / 2 inputs / positive / feedforward / low leak / no L1 / medium L2 / ', fontsize=16)

# color
syscol = cm.Blues(0.9)

labels = 'Optimal'

def find_time_slice(t_start,t_end,time_axis):
    if t_end == -1:
        i_end = -1
    else:
        i_end = searchsorted(time_axis,t_end)
        
    i_start = searchsorted(time_axis,t_start)
    return i_start, i_end

#start, end = 500000, -1
time = {}
t_start, t_end = 0, -1

start, end = find_time_slice(t_start,t_end,system.rec_t)
time = system.rec_t/float(1e7)

# add plot of population rate

subplot(231)
title('Distance to optimal recurrent weights')

def get_distance(system):
    Wf_opt = einsum('kit,kjt->ijt',system.rec_F[:,:,start:end],system.rec_F[:,:,start:end]) + eye(N)[:,:,None]*system.mu2
    distance =  sum(sum((system.rec_Wf[:,:,start:end] - Wf_opt)**2,axis=0),axis=0)/sum(sum(Wf_opt**2,axis=0),axis=0)
    
    return distance

distance = get_distance(system)
ppl.plot(time[start:end],distance,linewidth=3, color=syscol)

xlabel('time [1e7]')
#gca().set_yscale('log')

subplot(232)
title('Population rate')

def get_pop_rate(system):
    poprate = mean(system.rec_rate[:,start+1:end],axis=0)
    poprate = gaussian_filter1d(poprate,poprate.shape[0]/50)

    return poprate

poprate = get_pop_rate(system)
ppl.plot(time[start+1:end],poprate,linewidth=3,c=syscol)

for n in xrange(N):
    nrate = system.rec_rate[n,start+1:end]
    nrate = gaussian_filter1d(nrate,nrate.shape[0]/50)
    ppl.plot(time[start+1:end],nrate,linewidth=1,c=syscol,alpha=0.3)
        
xlabel('time [1e7]')

subplot(233)
title('Optimal vs Learned weights')

opt_W = dot(system.F.T,system.F) + system.mu2*eye(system.F.shape[1])
ppl.scatter(opt_W.flatten(),system.Wf.flatten(),facecolor=syscol)

subplot(234)
title('Reconstruction error')

ppl.plot(array(times[1:])/float(1e7),errors[1:],linewidth=3, color=syscol, label=labels)

xlabel('time [1e7]')
ax = gca()

subplot(235)
title('Distance to FD')

ppl.plot(array(times[1:])/float(1e7),FD[1:],linewidth=3, color=syscol)

xlabel('time [1e7]')

# draw legend
lg = figlegend(*ax.get_legend_handles_labels(),loc='lower right')
lg.draw_frame(False)

subplot(236,polar=True)
title('Receptive Fields')

for n in xrange(N):
    f = system.F[:,n]
    r, theta = polar(f[0] + f[1]*1j)
    ppl.plot([0,theta],[0,r],c=syscol,linewidth=1.5)

tick_params(\
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom='off',      # ticks along the bottom edge are off
    top='off',         # ticks along the top edge are off
    labelbottom='off', # labels along the bottom edge are off
    labelleft='off',
    labelright='off',
    labeltop='off',
    left='off',
    right='off')

# shift the axis a bit down
bbox=gca().get_position()
gca().set_position([bbox.x0, bbox.y0-0.02, bbox.x1-bbox.x0, bbox.y1-bbox.y0])

#savefig('positive_5_2_feedforward_mediumL2_lowLeak.pdf', bbox_inches='tight')

show()