In [242]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [243]:
import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import haiku as hk

from lif import RecurrentLIF, compute_loss_gradient, eval3
from utils import exp_convolve, shift_by_one_time_step


In [322]:
n_in = 3
n_LIF = 4#10
n_ALIF = 4#0 #4
n_rec = n_ALIF + n_LIF

dt = 1  # ms
tau_v = 20  # ms
tau_a = 500  # ms
T = 10  # ms
f0 = 100  # Hz

thr = 0.62#0.62
beta = 0.07 * jnp.concatenate([jnp.zeros(n_LIF), jnp.ones(n_ALIF)])
dampening_factor = 0.3
n_ref = 3
decay_out = jnp.exp(-dt / tau_v)

In [323]:
key = jax.random.PRNGKey(3)
# key = jax.random.PRNGKey(4)

# 10 neurons and 4 key

# inputs = jnp.array([[[1.0], [1.0], [1.0], [1.0], [1.0]]])
# inputs = jnp.ones((1, T, n_in))
inputs = (jax.random.uniform(key, shape=(1, T, n_in)) < f0 * dt / 1000).astype(float)
print(inputs.shape, inputs)

(1, 10, 3) [[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]
  [1. 0. 0.]
  [0. 0. 0.]
  [1. 0. 0.]
  [0. 1. 0.]
  [0. 1. 0.]]]


In [324]:
def lsnn2(x, state=None, batch_size=1):
    core = hk.DeepRNN([
        # hk.Linear(n_rec), # otherwise not need Win
        RecurrentLIF(
            n_in,
            n_rec,
            tau=tau_v,
            thr=thr,
            dt=dt,
            dtype=jnp.float32,
            dampening_factor=dampening_factor,
            tau_adaptation=tau_a,
            beta=beta,
            tag='',
            stop_gradients=True,
            w_in_init=None,
            w_rec_init=None,
            n_refractory=n_ref,
            rec=True,
        )
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(batch_size)
    spikes, hiddens = core(x, state)#core.initial_state(1))
    return spikes, hiddens

lsnn2_hk = hk.without_apply_rng(hk.transform(lsnn2))

In [325]:
i0 = jnp.stack([inputs[:,0], inputs[:,0], inputs[:,0],inputs[:,0], inputs[:,0]], axis=0)
print(i0.shape)
params = lsnn2_hk.init(rng=key, x=i0, batch_size=5)
print(params)

(5, 1, 3)
{'RecurrentLIF': {'w_in': Array([[ 0.08481716,  0.32309398,  0.27088046,  0.25027543, -0.27389896,
         0.08703034,  0.41582993,  0.1686924 ],
       [-0.69736767, -0.27332562, -0.13606313, -0.5863417 ,  0.39635217,
        -0.07739068,  0.2911337 , -0.39072135],
       [-0.6545443 , -0.01372574,  0.14147703,  0.08631024, -1.1109992 ,
         0.8299972 ,  0.16202024,  0.08654054]], dtype=float32), 'w_rec': Array([[ 0.09352326,  0.31284395,  0.38958675, -0.39840105, -0.1525931 ,
        -0.07649332, -0.4262588 ,  0.09456014],
       [-0.43706197, -0.5894613 , -0.43154418,  0.17181401,  0.6040146 ,
         0.5193728 , -0.0180139 , -0.28620195],
       [-0.30662236, -0.1718643 ,  0.23925252, -0.37198904, -0.37072596,
         0.50832236, -0.17347042,  0.24324189],
       [ 0.53575355,  0.10347075, -0.19402047, -0.54597175,  0.21587999,
         0.3189115 , -0.2996914 , -0.0842847 ],
       [ 0.01672795, -0.23839958,  0.04137432, -0.3883871 ,  0.37954986,
         0.1609568

In [326]:
loss, y_out, y_target, w_out, spikes, V, variations = eval3(
    lsnn2, inputs, params, params['RecurrentLIF']['w_rec'], None, None, key,
    n_rec, dt, tau_v, T, 1
)


[[0. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 1.]] -> [[0. 0. 0. 0. 0. 1. 0. 0.]]
[[1. 0. 0.]] -> [[0. 0. 1. 0. 0. 0. 1. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[1. 0. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 1.]]
[[0. 1. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 1. 0.]] -> [[0. 0. 0. 0. 0. 0. 0. 0.]]
(1, 10, 8)
0.95122945
(1, 10, 8)
(10, 1, 8) (1, 8) 0.95122945
zf [[[0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.         0.         0.         0.         0.         0.
   0.         0.        ]
  [0.         0.         0.         0.         0.         0.04877055
   0.         0.        ]
  [0.         0.         0.04877055 0.         0.   

In [327]:
params

{'RecurrentLIF': {'w_in': Array([[ 0.08481716,  0.32309398,  0.27088046,  0.25027543, -0.27389896,
           0.08703034,  0.41582993,  0.1686924 ],
         [-0.69736767, -0.27332562, -0.13606313, -0.5863417 ,  0.39635217,
          -0.07739068,  0.2911337 , -0.39072135],
         [-0.6545443 , -0.01372574,  0.14147703,  0.08631024, -1.1109992 ,
           0.8299972 ,  0.16202024,  0.08654054]], dtype=float32),
  'w_rec': Array([[ 0.09352326,  0.31284395,  0.38958675, -0.39840105, -0.1525931 ,
          -0.07649332, -0.4262588 ,  0.09456014],
         [-0.43706197, -0.5894613 , -0.43154418,  0.17181401,  0.6040146 ,
           0.5193728 , -0.0180139 , -0.28620195],
         [-0.30662236, -0.1718643 ,  0.23925252, -0.37198904, -0.37072596,
           0.50832236, -0.17347042,  0.24324189],
         [ 0.53575355,  0.10347075, -0.19402047, -0.54597175,  0.21587999,
           0.3189115 , -0.2996914 , -0.0842847 ],
         [ 0.01672795, -0.23839958,  0.04137432, -0.3883871 ,  0.37954986,


In [328]:
jnp.savez('lsnn2.npz', loss=loss, y_out=y_out, y_target=y_target, w_out=w_out, spikes=spikes, V=V, variations=variations)
jnp.savez('params.npz', w_out=w_out, w_in=params['RecurrentLIF']['w_in'], w_rec=params['RecurrentLIF']['w_rec'])
jnp.savez('inputs.npz', inputs=inputs)

In [329]:
exp_convolve(spikes, jnp.exp(-1/20))

(10, 1, 8) (1, 8) 0.95122945


Array([[[0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.04877055, 0.        , 0.        ],
        [0.        , 0.        , 0.04877055, 0.        , 0.        ,
         0.04639198, 0.04877055, 0.        ],
        [0.        , 0.        , 0.04639198, 0.        , 0.        ,
         0.04412942, 0.04639198, 0.        ],
        [0.        , 0.        , 0.04412942, 0.        , 0.        ,
         0.0419772 , 0.04412942, 0.04877055],
        [0.        , 0.        , 0.0419772 , 0.        , 0.        ,
         0.

In [330]:
print(w_out)

[[-0.475056  ]
 [ 1.0589067 ]
 [ 0.9018587 ]
 [-0.03301524]
 [-1.3541821 ]
 [-1.6734178 ]
 [ 1.0711491 ]
 [-0.3165598 ]]


In [331]:
print(y_out)

[[[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  [ 0.        ]
  [-0.0816135 ]
  [ 0.0185915 ]
  [ 0.01768479]
  [ 0.0013835 ]
  [ 0.00131602]
  [ 0.00125184]]]


In [332]:
print(spikes)

[[[0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0. 0.]
  [0. 0. 1. 0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]]]


In [333]:
learning_signals = jnp.einsum("btk,jk->tj", y_out - y_target, w_out)
print("ls", learning_signals)

pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)

# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html


gradients_eprop, eligibility_traces, _, _ = \
    compute_loss_gradient(learning_signals, pre_synpatic_spike_one_step_before, spikes, V,
                               variations, 
                               dt, thr, tau_a, tau_v, beta, dampening_factor,
                               decay_out, True)
gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

ls [[ 0.08836254 -0.19696137 -0.16774975  0.00614098  0.25188392  0.3112632
  -0.19923852  0.05888154]
 [-0.08375043  0.18668091  0.15899399 -0.00582045 -0.23873676 -0.2950167
   0.1888392  -0.05580819]
 [ 0.20887722 -0.46559036 -0.39653796  0.01451646  0.59541994  0.7357846
  -0.47097322  0.13918808]
 [-0.6203654   1.3828033   1.1777178  -0.04311388 -1.7683972  -2.18528
   1.3987905  -0.41338858]
 [ 0.8468739  -1.887694   -1.6077274   0.05885568  2.4140763   2.9831722
  -1.9095185   0.5643255 ]
 [-0.8978503   2.0013213   1.7045023  -0.06239841 -2.5593884  -3.16274
   2.0244594  -0.5982943 ]
 [-0.10287772  0.22931592  0.19530575 -0.00714975 -0.2932605  -0.36239392
   0.23196714 -0.06855392]
 [ 0.60050607 -1.3385367  -1.1400164   0.04173371  1.7117869   2.1153245
  -1.3540121   0.4001551 ]
 [-0.49737713  1.1086608   0.9442337  -0.0345665  -1.4178102  -1.7520456
   1.1214784  -0.33143377]
 [-0.66783845  1.4886216   1.2678419  -0.04641315 -1.9037228  -2.3525076
   1.5058321  -0.4450229 ]]

Array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.26572645,  0.        ,  0.        ,  0.        ,
        -0.6136005 ,  0.47267583, -0.11118145],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.42904863,  0.75198   ,  0.        ,  0.        ,
         0.        ,  1.        , -0.24369046],
       [ 0.        ,  0.26572645,  0.32906023,  0.        ,  0.        ,
        -0.6136005 ,  0.        , -0.11118145],
       [ 0.        ,  0.15769525,  0.3636678 ,  0.        ,  0.        ,
        -0.4786077 ,  0.5223875 ,  0.        ]], dtype=float32)

In [334]:
ngep = gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

In [335]:
fun = lambda *x: eval3(*x)[0]
def wrapper(w_rec):
    # w_rec = params['RecurrentLIF']['w_rec']
    # w_rec[0, 0] = w0
    return fun(lsnn2, inputs, params, w_rec, None, None, key, n_rec, dt, tau_v, T, 1)
   

In [336]:
jax.test_util.check_grads(
    wrapper,
    (params['RecurrentLIF']['w_rec'],),
    order=1,
    modes=['rev']
)

[[0. 0. 0.]] -> Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,8])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,8]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000027F6AD3E490>, in_tracers=(Traced<ShapedArray(float32[1,8]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x0000027F6D0866B0; to 'JaxprTracer' at 0x0000027F6D086670>], out_avals=[ShapedArray(float32[1,8])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,8][39m b[35m:bool[1,8][39m c[35m:f32[1,8][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,8][39m = select_n b a c
  [34m[22

AssertionError: 
Not equal to tolerance rtol=0.002, atol=0.002
VJP cotangent projection
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.075403
Max relative difference: inf
 x: array(-0.075403, dtype=float32)
 y: array(0., dtype=float32)

In [337]:

surrogate_grad3 = jax.value_and_grad(fun, argnums=3)(lsnn2, inputs, params, params['RecurrentLIF']['w_rec'].copy(), None, None, key,
    n_rec, dt, tau_v, T, 1)

[[0. 0. 0.]] -> Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,8])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,8]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000027F6621C170>, in_tracers=(Traced<ShapedArray(float32[1,8]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x0000027F66472700; to 'JaxprTracer' at 0x0000027F66472A30>], out_avals=[ShapedArray(float32[1,8])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,8][39m b[35m:bool[1,8][39m c[35m:f32[1,8][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,8][39m = select_n b a c
  [34m[22

In [338]:
surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

Array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.26572642,  0.        ,  0.        ,  0.        ,
        -0.61360043,  0.47267583, -0.11118145],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.42904863,  0.7519799 ,  0.        ,  0.        ,
         0.        ,  1.        , -0.2436904 ],
       [ 0.        ,  0.26572642,  0.32906023,  0.        ,  0.        ,
        -0.61360043,  0.        , -0.11118145],
       [ 0.        ,  0.15769523,  0.36366776,  0.        ,  0.        ,
        -0.4786077 ,  0.5223875 ,  0.        ]], dtype=float32)

In [339]:
nga = surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

In [340]:
jnp.allclose(ngep, nga, atol=1e-3)

Array(True, dtype=bool)

In [317]:
params

{'RecurrentLIF': {'w_in': Array([[-0.6301748 ,  0.48754728,  0.60680926,  0.8751213 ],
         [-0.40730554,  0.58812934, -0.8242644 ,  0.71092707],
         [-0.5873246 ,  0.8977012 ,  0.9403405 , -0.55874705]],      dtype=float32),
  'w_rec': Array([[ 0.7171732 , -0.42568645, -0.27384278,  0.33492026],
         [-0.06512911, -0.18751548, -0.68637997,  0.21879737],
         [ 0.3078918 , -0.17575324,  0.08511223, -0.16051175],
         [ 0.64690757, -0.17119652, -0.06764613, -0.5058548 ]],      dtype=float32)}}

In [274]:
from lif import SpikeFunction


vs = jnp.ones((2, 1))
vs = vs.at[0].set(-0.5)
print(vs)

fun2 = lambda *x: SpikeFunction(*x).sum()

jax.value_and_grad(fun2, argnums=0)(vs, dampening_factor) # 0.3 dampening factor

[[-0.5]
 [ 1. ]]


(Array(1., dtype=float32),
 Array([[0.15],
        [0.  ]], dtype=float32))