# STDP Basics, 'offline'

Demonstrates calculate dw from the stdp learning rule, applied to spike processes and firing records (as can be obtained from the model.step_time end timestep callback).

In [1]:
import numpy as np
from spikeflow import firing_to_spike_process, firings_to_spike_processes
from spikeflow import spike_process_delta_times
from spikeflow import STDPParams, stdp_offline_dw, stdp_offline_dw_process, stdp_offline_dw_processes, stdp_offline_dw_firings

  from ._conv import register_converters as _register_converters


# Make some data: create a couple firing records and spike processes

In [2]:
f1 = np.array([True, False, False, True, True, False, False, True])
f2 = np.array([False, False, True, False, False, True, True, False])
print('firing 1:', f1.shape, f1)
print('firing 2:', f2.shape, f2)

s1 = firing_to_spike_process(f1)
s2 = firing_to_spike_process(f2)
print('spike process 1:', s1.shape, s1)
print('spike process 2:', s2.shape, s2)

firing 1: (8,) [ True False False  True  True False False  True]
firing 2: (8,) [False False  True False False  True  True False]
spike process 1: (4,) [0 3 4 7]
spike process 2: (3,) [2 5 6]


# Apply STDP learning rule to calculate change in weight in various ways

In [3]:
stdp_params = STDPParams(APlus=1.0, AMinus=1.0, TauPlus=10.0, TauMinus=10.0)

## ... calculate from time deltas, to check the math

In [4]:
delta_times = spike_process_delta_times(s1, s2)
print('Delta times shape', delta_times.shape)
print(delta_times)

dw = stdp_offline_dw(delta_times, stdp_params)
print('dw  shape', dw.shape)
print(dw)

Delta times shape (12,)
[ 2  5  6 -1  2  3 -2  1  2 -5 -2 -1]
dw  shape (12,)
[ 0.81873075  0.60653066  0.54881164 -0.90483742  0.81873075  0.74081822
 -0.81873075  0.90483742  0.81873075 -0.60653066 -0.81873075 -0.90483742]


## ... directly from spike processes

In [5]:
dw_p = stdp_offline_dw_process(s1, s2, stdp_params)
print('SUM dw from time deltas:', sum(dw))
print('SUM dw from spike processes:', dw_p)

SUM dw from time deltas: 1.2035231918177673
SUM dw from spike processes: 1.2035231918177673


## ... but like a 1x1 weight matrix

In [6]:
dW = stdp_offline_dw_processes(np.ones((1,1), dtype=np.float32), [s1], [s2], stdp_params)
print('dW  shape', dW.shape)
print(dW)

dW  shape (1, 1)
[[1.2035232]]


## ... or like a bigger, sparse weight matrix

In [7]:
W = np.array([[1, 0, 0], [0, 1, 1]], dtype=np.float32)
dW = stdp_offline_dw_processes(W, [s1, s1], [s2, s2, s1], stdp_params)
print('W  shape', W.shape)
print(W)
print('dW  shape', dW.shape)
print(dW)

W  shape (2, 3)
[[1. 0. 0.]
 [0. 1. 1.]]
dW  shape (2, 3)
[[1.2035232 0.        0.       ]
 [0.        1.2035232 0.       ]]


## ... same but from firing matrices

(as can be extracted from model.step_time's end timestep callback)

In [8]:
f_in = np.vstack([f1, f1]).T
f_out = np.vstack([f2, f2, f1]).T
print('f_in  shape', f_in.shape)
print('f_out  shape', f_out.shape)

dWf = stdp_offline_dw_firings(W, f_in, f_out, stdp_params)
print('dWf  shape', dWf.shape)
print(dWf)

f_in  shape (8, 2)
f_out  shape (8, 3)
dWf  shape (2, 3)
[[1.2035232 0.        0.       ]
 [0.        1.2035232 0.       ]]
