# STDP Finds the Start of Repeating Patterns in Continuous Spike Trains

In this notebook we will reproduce the experiments described in [Masquelier & Thorpe (2008)](https://www.semanticscholar.org/paper/Spike-Timing-Dependent-Plasticity-Finds-the-Start-Masquelier-Guyonneau/432b5bfa6fc260289fef45544a43ebcd8892915e).

In [None]:
# These imports will be used in the notebook
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] =(12,9)

## LIF neuron model

The LIF neuron model used in this experiment is based on Gerstner's [Spike Response Model](http://lcn.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000).

At every time-step, the neuron membrane potential p is given by the formula:

$$p=\eta(t-t_{i})\sum_{j|t_{j}>t_{i}}{}w_{j}\varepsilon(t-t_{j})$$

where $\eta(t-t_{i})$ is the membrane response after a spike at time $t_{i}$:

$$\eta(t-t_{i})=K_{1}exp(-\frac{t-t_{i}}{\tau_{m}})-K_{2}(exp(-\frac{t-t_{i}}{\tau_{m}})-exp(-\frac{t-t_{i}}{\tau_{s}}))$$

and $\varepsilon(t)$ describes the Excitatory Post-Synaptic Potential of each synapse spike at time $t_{j}$:

$$\varepsilon(t-t_{j})=K(exp(-\frac{t-t_{j}}{\tau_{m}})-exp(-\frac{t-t_{j}}{\tau_{s}}))$$

Note that K has to be chosen so that the max of $\eta(t)$ is 1, knowing that $\eta(t)$ is maximum when:
$$t=\frac{\tau_{m}\tau_{s}}{\tau_{m}-\tau_{s}}ln(\frac{\tau_{m}}{\tau_{s}})$$

In this simplified version of the neuron, the synaptic weights $w_{j}$ remain constant.

In [None]:
class LIFNeuron(object):

    def __init__(self,
                 n_syn, W, max_spikes=None, 
                 p_rest=0.0, tau_rest=1.0, tau_m=10.0, tau_s=2.5, T=None,
                 K=2.1, K1=2.0, K2=4.0):

        # Model parameters

        # Membrane resting potential
        self.p_rest = p_rest
        
        # Duration of the recovery period
        self.tau_rest = tau_rest
        
        # Membrane time constant
        self.tau_m = tau_m
        
        # Synaptic time constant
        self.tau_s = tau_s
        
        # Spiking threshold
        if T is None:
            self.T = n_syn/4
        else:
            self.T = T
        
        # Model constants
        self.K = K
        self.K1 = K1
        self.K2 = K2

        # The number of synapses
        self.n_syn = n_syn
        
        # The synapse efficacy weights
        self.w = tf.Variable(W)
        
        # The incoming spike times memory window
        if max_spikes is None:
            self.max_spikes = 70
        else:
            self.max_spikes = max_spikes

        # Placeholders (ie things that are fed to the graph at runtime)

        # A boolean tensor indicating which synapses have spiked during dt
        self.new_spikes = tf.placeholder(shape=[m], dtype=tf.bool, name='new_spikes')

        # The time increment since the last update
        self.dt = tf.placeholder(dtype=tf.float32, name='dt')
        
        # Variables (ie things that are modified by the graph at runtime)

        # The neuron memory of incoming spike times
        self.t_spikes = tf.Variable(tf.constant(100000.0, shape=[self.max_spikes, self.n_syn]), dtype=tf.float32)
        
        # The last spike time insertion index
        self.t_spikes_idx = tf.Variable(self.n_syn - 1, dtype=tf.int32)

        # The relative time since the last spike (assume it was a very long time ago)
        self.last_spike = tf.Variable(1000.0, dtype=tf.float32, name='last_spike')
        
        # The membrane potential
        self.p = tf.Variable(self.p_rest,dtype=tf.float32, name='p')
        
        # The duration remaining in the resting period (between 0 and self.tau_s)
        self.t_rest = tf.Variable(0.0,dtype=tf.float32, name='t_rest')

    # Excitatory post-synaptic potential (EPSP)
    def epsilon_op(self):

        # We only use the negative value of the relative spike times
        spikes_t_op = tf.negative(self.t_spikes)

        return self.K *(tf.exp(spikes_t_op/self.tau_m) - tf.exp(spikes_t_op/self.tau_s))
    
    # Membrane spike response
    def eta_op(self):
        
        # We only use the negative value of the relative time
        t_op = tf.negative(self.last_spike)
        
        # Evaluate the spiking positive pulse
        pos_pulse_op = self.K1 * tf.exp(t_op/self.tau_m)
        
        # Evaluate the negative spike after-potential
        neg_after_op = self.K2 * (tf.exp(t_op/self.tau_m) - tf.exp(t_op/self.tau_s))

        # Evaluate the new post synaptic membrane potential
        return self.T * (pos_pulse_op - neg_after_op)
    
    # Neuron behaviour during integrating phase (t_rest = 0)
    def integrating_p_op(self):
        
        # Evaluate synaptic EPSPs. We ignore synaptic spikes older than the last neuron spike
        epsilons_op = tf.where(tf.logical_and(self.t_spikes >=0, self.t_spikes < self.last_spike),
                               self.epsilon_op(),
                               self.t_spikes*0.0)
                          
        # Update the membrane potential with spike membrane response and weighted incoming EPSPs 
        return self.eta_op() + tf.reduce_sum(self.w * epsilons_op)

    # Neuron behaviour during resting phase (t_rest > 0)
    def resting_p_op(self):
   
        # Membrane potential is only impacted by the last post-synaptic spike (ignore EPSPs)
        return self.eta_op()
    
    def update_spikes_times(self):
        
        # Increase the age of all the existing spikes by dt
        old_spikes_op = self.t_spikes.assign_add(tf.ones(tf.shape(self.t_spikes), dtype=tf.float32) * self.dt)

        # Increment last spike index (modulo max_spikes)
        new_idx_op = self.t_spikes_idx.assign(tf.mod(self.t_spikes_idx + 1, self.max_spikes))

        # Create a list of coordinates to insert the new spikes
        idx_op = tf.constant(1, shape=[self.n_syn], dtype=tf.int32) * new_idx_op
        coord_op = tf.stack([idx_op, tf.range(self.n_syn)], axis=1)

        # Create a vector of new spike times (non-spikes are assigned a very high time)
        new_spikes_op = tf.where(self.new_spikes,
                                 tf.constant(0.0, shape=[self.n_syn]),
                                 tf.constant(100000.0, shape=[self.n_syn]))
        
        # Replace older spikes by new ones
        return tf.scatter_nd_update(old_spikes_op, coord_op, new_spikes_op)
    
    def integrating_w_op(self):
        
        # For the base LIF Neuron, the weights remain constants when integrating
        return tf.identity(self.w)

    def firing_w_op(self):

        # For the base LIF Neuron, the weights remain constants when firing
        return tf.identity(self.w)
    
    def response(self):
        
        # Update our internal memory of the synapse spikes (age older spike, add new ones)
        update_spikes_op = self.update_spikes_times()
        
        # Increase the relative time of the last spike by the time elapsed
        last_spike_age_op = self.last_spike.assign_add(self.dt)
        
        # Evaluate the new membrane potential, making sure the synapse spikes and last spike time are updated first
        with tf.control_dependencies([update_spikes_op, last_spike_age_op]):
            p_op = tf.cond(self.t_rest > 0.0,
                           self.resting_p_op,     # The neuron is resting, and synaptic input is ignored 
                           self.integrating_p_op) # Integrate both synaptic input and spike dynamics
        
        # Update weights
        w_op = tf.cond(p_op > self.T,
                       self.firing_w_op,      # The neuron is firing
                       self.integrating_w_op) # Normal behavior

        # Update the time of the last spike, but only once the weights have been updated
        with tf.control_dependencies([w_op]):
            last_spike_op = tf.cond(p_op > self.T,
                                    lambda: self.last_spike.assign(0.0),  # The neuron is firing, the last spike is now
                                    lambda: tf.identity(self.last_spike)) # Nothing to do
        # Update the resting period
        t_rest_op = tf.cond(p_op > self.T,
                            lambda: self.t_rest.assign(self.tau_rest),    # The neuron is firing, start resting period
                            lambda: self.t_rest.assign(tf.maximum(self.t_rest - self.dt, 0.0))) # Decrease resting period
        
        # We finally update the internal membrane potential after the resting period
        # and last spike times have been updated
        with tf.control_dependencies([t_rest_op, last_spike_op]):
            return self.p.assign(p_op)

## Stimulate neuron with predefined synapse input

We replicate the figure 3 of the original paper by stimulating a LIF neuron with six consecutive spikes.

The neuron has a refractory period of 1 ms and a threshold of 1.


In [None]:
# Test neuron response with constant synaptic weights

# Duration of the simulation in ms
T = 80
# Duration of each time step in ms
dt = 1.0
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 1
# Spiking times
spikes = [2.0, 23.0, 44.0, 45.0, 48.0, 61.0]
# We define the base synaptic efficacy as a uniform vector
W = np.full((m), 0.475, dtype=np.float32)
# Output variables
P = []

with tf.Session() as sess:

    neuron = LIFNeuron(m,W, T=1)

    sess.run(tf.global_variables_initializer())

    response = neuron.response()
    for step in range(steps):
        
        t = step * dt
        syn_has_spiked = [t in spikes]
        feed = { neuron.new_spikes: syn_has_spiked, neuron.dt: dt}
        p = sess.run(response, feed_dict=feed)
        P.append((t,p))

In [None]:
# Draw membrane potential
plt.figure()
plt.plot(*zip(*P))
plt.axhline(y=neuron.T, color='r', linestyle='-')
plt.axhline(y=neuron.p_rest, color='y', linestyle='--')
for spike  in spikes:
    plt.axvline(x=spike, color='gray', linestyle='--')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

As in the original paper. we see that because of the leaky nature of the neuron, the stimulating spikes have to be nearly synchronous for the threshold to be reached. 

## Simulate neuron response with random input

We feed the neuron with 2000 synapses that generate spikes at random interval with a frequency of 45 Hz.

The synaptic efficacy weights are arbitrarily set to 0.475 and remain constant throughout the simulation.

In [None]:
# Simulation with constant synaptic weights

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1.0
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 2000
# Spiking frequency in Hz
f = 4.5e-2
# We need to keep track of input spikes over time
spikes = np.zeros((steps,m), dtype=np.bool)
# We define the base synaptic efficacy as a uniform vector
W = np.full((m), 0.475, dtype=np.float32)
# Output variables
P = []

with tf.Session() as sess:

    neuron = LIFNeuron(m,W)

    sess.run(tf.global_variables_initializer())

    response = neuron.response()
    for step in range(steps):
        
        t = step * dt
        r = np.random.uniform(0,1, size=(m))
        syn_has_spiked = r < f * dt
        spikes[step,:] = syn_has_spiked
        feed = { neuron.new_spikes: syn_has_spiked, neuron.dt: dt}
        p = sess.run(response, feed_dict=feed)
        P.append((t,p))

We draw the neuron membrane response to the 2000 random synaptic spike trains. We can see that the neuron mostly saturates and continuously generate spikes.

In [None]:
# Draw input spikes
real_spikes = np.argwhere(spikes > 0)
spike_index = real_spikes[:,1] + 1
spike_timings = real_spikes[:,0]
plt.figure()
plt.axis([0, T, 0, m])
plt.title('Synaptic spikes')
plt.ylabel('spikes')
plt.xlabel('Time (msec)')
plt.scatter(spike_timings, spike_index, s=2)
# Draw membrane potential
plt.figure()
plt.plot(*zip(*P))
plt.axhline(y=neuron.T, color='r', linestyle='-')
plt.axhline(y=neuron.p_rest, color='y', linestyle='--')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

## Introduce Spike Timing Dependent Plasticity

We extend the LIFNeuron by allowing it to modify its synapse weights using a Spike Timing Dependent Plasticity algorithm.

The STDP algorithm rewards synapses where spikes occurred immediately before a neuron spike, and inflicts penalties to the synapses where spikes occur after the neuron spike.

The 'rewards' are called Long Term synaptic Potentiation (LTP), and the penalties Long Term synaptic Depression (LTD).

For each synapse that spiked $\Delta{t}$ before a neuron spike:

$$\Delta{w} = a^{+}exp(-\frac{\Delta{t}}{\tau^{+}})$$

For each synapse that spikes $\Delta{t}$ after a neuron spike:

$$\Delta{w} = -a^{-}exp(-\frac{\Delta{t}}{\tau^{-}})$$ 

In [None]:
class STDPLIFNeuron(LIFNeuron):

    def __init__(self,
                 n_syn, W, max_spikes=None, 
                 p_rest=0.0, tau_rest=1.0, tau_m=10.0, tau_s=2.5, T=None,
                 K=2.1, K1=2.0, K2=4.0,
                 a_plus=0.03125, a_minus=0.0265625, tau_plus=16.8, tau_minus=33.7):
        
        # Call the parent contructor
        super(STDPLIFNeuron, self).__init__(n_syn, W, max_spikes,
                                            p_rest, tau_rest, tau_m, tau_s, T,
                                            K, K1, K2)
        
        self.a_plus = a_plus
        self.tau_plus = tau_plus
        self.a_minus = a_minus
        self.tau_minus = tau_minus
    
    # Long Term synaptic Potentiation
    def LTP_op(self):
        
        # Reward all spikes in our memory that happened before the new spike, but after the previous one
        rewards_op = tf.where(self.t_spikes < self.last_spike,
                              tf.constant(self.a_plus, shape=[self.max_spikes, self.n_syn]) * tf.exp(tf.negative(self.t_spikes/self.tau_plus)),
                              tf.constant(0.0, shape=[self.max_spikes, self.n_syn]))                              
        
        # Accumulate rewards for each synapse along the history axis
        acc_rewards_op = tf.reduce_sum(rewards_op,0)
        
        # Evaluate new weights
        new_w_op = tf.add(self.w, acc_rewards_op)
        
        # Update with new weights clamped to [0,1]
        return self.w.assign(tf.clip_by_value(new_w_op, 0.0, 1.0))
    
    # Long Term synaptic Depression
    def LTD_op(self):

        # Gather all spikes corresponding to the last insertion index
        new_spikes_op = tf.gather(self.t_spikes, self.t_spikes_idx)

        # Inflict penalties, inversely exponential to the time since the last spike
        penalties_op = tf.where(new_spikes_op <= 0.0, # Older spikes at this index have positive times
                                tf.constant(self.a_minus, shape=[self.n_syn]) * tf.exp(tf.negative(self.last_spike/self.tau_minus)),
                                tf.constant(0.0, shape=[self.n_syn]))
        
        # Evaluate new weights
        new_w_op = tf.subtract(self.w, penalties_op)
        
        # Update with new weights clamped to [0,1]
        return self.w.assign(tf.clip_by_value(new_w_op, 0.0, 1.0))
    
    def stdp_firing_op(self):
        
        # Apply long-term synaptic potentiation
        ltp_op = self.LTP_op()
        
        # Refractory period starts now
        t_rest_op = self.t_rest.assign(self.tau_rest)

        # Reset last spike time
        last_spike_op = tf.assign(self.last_spike, self.dt)
        
        # Explicitly mark operations not used by p_op as dependencies
        with tf.control_dependencies([ltp_op, t_rest_op, last_spike_op]):
            # Reset membrane potential
            p_op = self.p.assign(self.eta_op())

            return p_op

    def firing_w_op(self):
        
        return self.LTP_op()

    def integrating_w_op(self):
        
        # Apply long-term synaptic depression if we are still close to the last spike
        # Note that if we unconditionally applied the LTD, the weights will slowly
        # decrease to zero if no spike occurs.
        return tf.cond(self.last_spike < self.tau_minus*7,
                       self.LTD_op,
                       lambda: tf.identity(self.w))

## Test STDP with random input

We apply a random input to an STDP capable LIFNeuron with a limited number of synapses, and draw the resulting rewards (green) and penalties (red).

In [None]:
# Simulation with evolving synaptic weights

# Duration of the simulation in ms
T = 100
# Duration of each time step in ms
dt = 1.0
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 20
# Spiking frequency in Hz
f = 4.5e-2
# We need to keep track of input spikes over time
spikes = np.zeros((steps,m), dtype=np.float32)
# We define the base synaptic efficacy as a uniform vector
W = np.full((m), 0.475, dtype=np.float32)
# Output variables
P = []

with tf.Session() as sess:

    neuron = STDPLIFNeuron(m,W)

    sess.run(tf.global_variables_initializer())

    response = neuron.response()
    w_prev = W
    delta_weights = np.zeros((steps, m))
    for step in range(steps):
        
        t = step * dt
        r = np.random.uniform(0,1, size=(m))
        syn_has_spiked = r < f * dt
        spikes[step,:] = syn_has_spiked
        feed = { neuron.new_spikes: syn_has_spiked, neuron.dt: dt}
        p = sess.run(response, feed_dict=feed)
        P.append((t,p))
        w_next = neuron.w.eval()
        delta_weights[step,:] = w_next - w_prev
        w_prev = w_next

In [None]:
plt.rcParams["figure.figsize"] =(12,9)
# Draw input spikes
real_spikes = np.argwhere(spikes > 0)
spike_index = real_spikes[:,1] + 1
spike_timings = real_spikes[:,0]
rewards = np.argwhere(delta_weights > 0)
rewards_timings = rewards[:,0]
rewards_index = rewards[:,1] + 1
penalties = np.argwhere(delta_weights < 0)
penalties_timings = penalties[:,0]
penalties_index = penalties[:,1] + 1
plt.figure()
plt.axis([0, T, 0, m])
plt.title('Synaptic spikes')
plt.ylabel('spikes')
plt.xlabel('Time (msec)')
plt.scatter(spike_timings, spike_index, s=100)
plt.scatter(rewards_timings, rewards_index, color='lightgreen')
plt.scatter(penalties_timings, penalties_index, color='red')
# Draw membrane potential
plt.figure()
plt.plot(*zip(*P))
plt.axhline(y=neuron.T, color='r', linestyle='-')
plt.axhline(y=neuron.p_rest, color='y', linestyle='--')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

On the graph above, we verify that the rewards (green dots) are assigned only when the neuron spike, and that they are assigned to synapses where a spike occured before the neuron spike (big blue dots).

Note: a reward is assigned event if the synapse spike is not synchronous with the neuron spike, but it will be lower.

We also verify that a penaly (red dot) is inflicted on every synapse where a spike occurs after a neuron spike.

Note: these penalties may later be counter-balanced by a reward if a neuron spike closely follows.

## Generate recurrent spike trains

We don't follow exactly the same procedure as in the original paper, as the evolution of the hardware and software allows us to generate spike trains more easily. The result, however, is equivalent.

We generate 2000 spike trains, from which we force the 1000 first to repeat a 50 ms pattern at random intervals.

We first define a random 50ms sequence, that will be used as input when the pattern is played.

We then generate random spike trains at every time-step: for the whole population if we are outside the pattern, for half of it otherwise.

The time to the next pattern is chosen with a probability of 0.25 among the next slices of 50 ms (omitting the first one to avoid consecutive patterns).

In [None]:
# Simulation with recurrent pattern

# Duration of the simulation in ms
T = 15000
# Duration of each time step in ms
dt = 1.0
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
m = 2000
# Spiking frequency in Hz
f = 4.5e-2
# We define the base synaptic efficacy as a uniform vector
W = np.full((m), 0.475, dtype=np.float32)
# Output variables
P = []
pattern_t = []
spike_trains = np.zeros((T,m), dtype=np.bool)

# First, create the 50 ms pattern
n_syn_pattern = int(m/2)
pattern = np.zeros((50, n_syn_pattern))
for step in range(50):

    r = np.random.uniform(0,1, size=(n_syn_pattern))
    pattern[step,:] = r < f

with tf.Session() as sess:

    neuron = STDPLIFNeuron(m,W)

    sess.run(tf.global_variables_initializer())

    response = neuron.response()
    
    syn_has_spiked = np.zeros((m), dtype=np.bool)

    pat_start_time = np.random.randint(25,75)
    pattern_t.append(pat_start_time)
    for step in range(steps):
        
        t = int(step * dt)

        # Evaluate the first population of neuron behavior
        if t >= pat_start_time and t < (pat_start_time + 50):
            # We just copy the pattern
            syn_has_spiked[:n_syn_pattern] = pattern[t - pat_start_time]
        else:
            # Generate new random spikes
            r = np.random.uniform(0,1, size=(n_syn_pattern))
            syn_has_spiked[:n_syn_pattern] = r < f * dt
            # Evaluate the time of the next pattern
            if t >= pat_start_time + 100:
                pat_start_time = t
                # We have 1/4 chances of replaying the pattern for each chunk of 50 ms
                r = np.random.uniform(0,1)
                while (r >= 0.25):
                    pat_start_time += 50
                    r = np.random.uniform(0,1)
                pattern_t.append(pat_start_time)
        # Evaluate the second population of neuron behavior                
        r = np.random.uniform(0,1, size=(n_syn_pattern))
        syn_has_spiked[n_syn_pattern:] = r < f * dt
        spike_trains[step,:] = syn_has_spiked
        feed = { neuron.new_spikes: syn_has_spiked, neuron.dt: dt}
        p = sess.run(response, feed_dict=feed)
        P.append((t,p))

In [None]:
plt.rcParams["figure.figsize"] =(12,9)
intervals = ([0,500],[7500, 7999],[14500,14999])
for interval in intervals:
    it_spike_trains = spike_trains[interval[0]:interval[1],:]
    it_P = P[interval[0]:interval[1]]
    it_pattern_t = np.array(pattern_t)
    it_pattern_t = it_pattern_t[np.logical_and(it_pattern_t >=interval[0], it_pattern_t <=interval[1])]
    # Draw input spikes, identifying the patterns
    spikes = np.argwhere(it_spike_trains == True)
    plt.figure()
    plt.axis([interval[0], interval[1], 0, m])
    plt.title('Synaptic spikes')
    plt.ylabel('spikes')
    plt.xlabel('Time (msec)')
    for pat_t in it_pattern_t:
        plt.fill_between((pat_t,pat_t+50,pat_t+50,pat_t),(0,0,m/2,m/2),facecolor='gray')
    t, s = spikes.T
    plt.scatter(t+interval[0], s, s=1)
    # Draw membrane potential
    plt.figure()
    plt.plot(*zip(*it_P))
    plt.axhline(y=neuron.T, color='r', linestyle='-')
    plt.axhline(y=neuron.p_rest, color='y', linestyle='--')
    plt.title('LIF response')
    plt.ylabel('Membrane Potential (mV)')
    plt.xlabel('Time (msec)')