In [2]:
%load_ext autoreload
%autoreload 2

In [55]:
import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import haiku as hk

from lif import RecurrentLIF, compute_loss_gradient, eval3
from utils import shift_by_one_time_step


In [84]:
n_in = 3
n_LIF = 2
n_ALIF = 2
n_rec = n_ALIF + n_LIF

dt = 1  # ms
tau_v = 20  # ms
tau_a = 500  # ms
T = 15  # ms
f0 = 100  # Hz

thr = 0.62 
beta = 0.07 * jnp.concatenate([jnp.zeros(n_LIF), jnp.ones(n_ALIF)])
dampening_factor = 0.3
n_ref = 3
decay_out = jnp.exp(-dt / tau_v)

In [85]:
key = jax.random.PRNGKey(2)
inputs = (jax.random.uniform(key, shape=(1, T, n_in)) < f0 * dt / 1000).astype(float)
print(inputs.shape, inputs)

(1, 15, 3) [[[0. 0. 0.]
  [0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [1. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]]


In [86]:
def lsnn2(x, state=None, batch_size=1):
    core = hk.DeepRNN([
        RecurrentLIF(
            n_in,
            n_rec,
            tau=tau_v,
            thr=thr,
            dt=dt,
            dtype=jnp.float32,
            dampening_factor=dampening_factor,
            tau_adaptation=tau_a,
            beta=beta,
            tag='',
            stop_gradients=True,
            w_in_init=None,
            w_rec_init=None,
            n_refractory=n_ref,
            rec=True,
        )
    ])
    if state is None:
        state = core.initial_state(batch_size)
    spikes, hiddens = core(x, state)
    return spikes, hiddens

lsnn2_hk = hk.without_apply_rng(hk.transform(lsnn2))

In [87]:
i0 = jnp.stack([inputs[:,0], inputs[:,0], inputs[:,0],inputs[:,0], inputs[:,0]], axis=0)
print(i0.shape)
params = lsnn2_hk.init(rng=key, x=i0, batch_size=5)
print(params)

(5, 1, 3)
i_in (5, 4)
{'RecurrentLIF': {'w_in': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'w_rec': Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
       [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
       [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
       [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)}}


In [88]:
w_in_copy = [[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]]
w_in_copy = jnp.array(w_in_copy)

In [89]:
loss, y_out, y_target, w_out, spikes, V, variations = eval3(
    lsnn2, inputs, params, params['RecurrentLIF']['w_rec'], None, None, key,
    n_rec, dt, tau_v, T, 1
)


i_in (1, 4)
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 1.]] -> [[0. 1. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 0. 1. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 1. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 1. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
i_in (1, 4)
[[1. 0. 1.]] -> [[0. 1. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
i_in (1, 4)
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
i_in (1, 4)
[[0. 1. 0.]] -> [[0. 1. 1. 0.]]
i_in (1, 4)
[[0. 0. 1.]] -> [[0. 0. 0. 0.]]
(1, 15, 4)
0.95122945
(1, 15, 4)
(15, 1, 4) (1, 4) 0.95122945
zf [[[0.         0.         0.         0.        ]
  [0.         0.04877055 0.         0.        ]
  [0.         0.04639198 0.04877055 0.        ]
  [0.04877055 0.04412942 0.04639198 0.        ]
  [0.04639198 0.09074774 0.04412942 0.        ]
  [0.04412942 0.0863219

In [90]:
learning_signals = jnp.einsum("btk,jk->tj", y_out - y_target, w_out)
print("ls", learning_signals)

pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)

gradients_eprop, eligibility_traces, _, _ = \
    compute_loss_gradient(learning_signals, pre_synpatic_spike_one_step_before, spikes, V,
                               variations, 
                               dt, thr, tau_a, tau_v, beta, dampening_factor,
                               decay_out, True)
gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

ls [[-0.22999313  0.4801289  -0.08755955  0.864201  ]
 [-0.4017606   0.83870715 -0.1529523   1.5096186 ]
 [ 0.17614026 -0.36770678  0.06705749 -0.6618484 ]
 [ 0.01730856 -0.03613298  0.00658946 -0.06503703]
 [-0.3687168   0.76972556 -0.14037235  1.3854563 ]
 [ 0.6888039  -1.4379327   0.26223114 -2.588186  ]
 [-0.7305243   1.5250274  -0.2781143   2.7449508 ]
 [ 0.3685563  -0.76939046  0.14031124 -1.3848531 ]
 [ 0.01695852 -0.03540225  0.00645619 -0.06372177]
 [-0.95799553  1.9998916  -0.36471376  3.5996757 ]
 [ 0.06312727 -0.13178319  0.02403287 -0.23720123]
 [-0.4940403   1.0313483  -0.18808365  1.8563604 ]
 [-0.19081761  0.3983469  -0.07264523  0.7169987 ]
 [-0.06530839  0.13633645 -0.02486324  0.2453968 ]
 [-0.20962985  0.43761894 -0.07980715  0.7876858 ]]
[[[-1.         -1.         -1.         -1.        ]
  [-1.7854602   0.7065196  -0.12989543 -1.0030212 ]
  [-0.29249185 -0.37670827  0.1329895  -0.7510459 ]
  [ 0.09827143  0.11767072 -1.0351702  -0.27788118]
  [-0.9552919   1.46291

Array([[ 0.        ,  0.        , -0.33101112,  0.34045532],
       [-0.4847    ,  0.        , -0.4763276 ,  1.        ],
       [-0.23858026,  0.        ,  0.        ,  0.57953703],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [91]:
ngep = gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

In [92]:
fun = lambda *x: eval3(*x)[0]
def wrapper(w_rec):
    # w_rec = params['RecurrentLIF']['w_rec']
    # w_rec[0, 0] = w0
    return fun(lsnn2, inputs, params, w_rec, None, None, key, n_rec, dt, tau_v, T, 1)
   

In [93]:

surrogate_grad3 = jax.value_and_grad(fun, argnums=3)(lsnn2, inputs, params, params['RecurrentLIF']['w_rec'].copy(), None, None, key,
    n_rec, dt, tau_v, T, 1)

i_in (1, 4)
i_in (1, 4)
[[0. 0. 0.]] -> Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x000001E09C49FF20>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x000001E09FA82930; to 'JaxprTracer' at 0x000001E09FA82990>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shard

In [94]:
ngep

Array([[ 0.        ,  0.        , -0.33101112,  0.34045532],
       [-0.4847    ,  0.        , -0.4763276 ,  1.        ],
       [-0.23858026,  0.        ,  0.        ,  0.57953703],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [95]:
surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

Array([[ 0.        ,  0.        , -0.33101112,  0.34045535],
       [-0.48470008,  0.        , -0.47632757,  1.        ],
       [-0.23858027,  0.        ,  0.        ,  0.5795371 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [96]:
nga = surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

In [97]:
jnp.allclose(ngep, nga, atol=1e-3)

Array(True, dtype=bool)