In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import haiku as hk

from lif import RecurrentLIF, compute_loss_gradient, eval3
from utils import exp_convolve, shift_by_one_time_step


In [3]:
n_in = 3
n_LIF = 4#10
n_ALIF = 0#0 #4
n_rec = n_ALIF + n_LIF

dt = 1  # ms
tau_v = 20  # ms
tau_a = 500  # ms
T = 10  # ms
f0 = 100  # Hz

thr = 0.62#0.62
beta = 0.07 * jnp.concatenate([jnp.zeros(n_LIF), jnp.ones(n_ALIF)])
dampening_factor = 0.3
n_ref = 3
decay_out = jnp.exp(-dt / tau_v)

In [4]:
key = jax.random.PRNGKey(2)
# key = jax.random.PRNGKey(4)

# 10 neurons and 4 key

# inputs = jnp.array([[[1.0], [1.0], [1.0], [1.0], [1.0]]])
# inputs = jnp.ones((1, T, n_in))
inputs = (jax.random.uniform(key, shape=(1, T, n_in)) < f0 * dt / 1000).astype(float)
print(inputs.shape, inputs)

(1, 10, 3) [[[0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]
  [0. 0. 0.]]]


In [5]:
def lsnn2(x, state=None, batch_size=1):
    core = hk.DeepRNN([
        # hk.Linear(n_rec), # otherwise not need Win
        RecurrentLIF(
            n_in,
            n_rec,
            tau=tau_v,
            thr=thr,
            dt=dt,
            dtype=jnp.float32,
            dampening_factor=dampening_factor,
            tau_adaptation=tau_a,
            beta=beta,
            tag='',
            stop_gradients=True,
            w_in_init=None,
            w_rec_init=None,
            n_refractory=n_ref,
            rec=True,
        )
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(batch_size)
    spikes, hiddens = core(x, state)#core.initial_state(1))
    return spikes, hiddens

lsnn2_hk = hk.without_apply_rng(hk.transform(lsnn2))

In [6]:
i0 = jnp.stack([inputs[:,0], inputs[:,0], inputs[:,0],inputs[:,0], inputs[:,0]], axis=0)
print(i0.shape)
params = lsnn2_hk.init(rng=key, x=i0, batch_size=5)
print(params)

(5, 1, 3)
{'RecurrentLIF': {'w_in': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'w_rec': Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
       [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
       [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
       [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)}}


In [7]:
loss, y_out, y_target, w_out, spikes, V, variations = eval3(
    lsnn2, inputs, params, params['RecurrentLIF']['w_rec'], None, None, key,
    n_rec, dt, tau_v, T, 1
)


[[0. 0. 1.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 1. 0.]]
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
[[0. 1. 1.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 1. 0.]]
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
[[1. 0. 0.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
[[0. 0. 1.]] -> [[1. 0. 1. 0.]]
[[0. 0. 0.]] -> [[0. 1. 0. 0.]]
(1, 10, 4)
0.95122945
(1, 10, 4)
(10, 1, 4) (1, 4) 0.95122945
zf [[[0.         0.04877055 0.         0.        ]
  [0.         0.04639198 0.04877055 0.        ]
  [0.04877055 0.04412942 0.04639198 0.        ]
  [0.04639198 0.09074774 0.04412942 0.        ]
  [0.04412942 0.08632193 0.09074774 0.        ]
  [0.09074774 0.08211196 0.08632193 0.        ]
  [0.08632193 0.12687787 0.08211196 0.        ]
  [0.08211196 0.12068997 0.07810732 0.        ]
  [0.12687787 0.11480386 0.12306853 0.        ]
  [0.12068997 0.15797536 0.11706641 0.        ]]]
(10, 1) (10, 1)


In [8]:
params

{'RecurrentLIF': {'w_in': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
         [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
         [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32),
  'w_rec': Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
         [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
         [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
         [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)}}

In [9]:
jnp.savez('lsnn2.npz', loss=loss, y_out=y_out, y_target=y_target, w_out=w_out, spikes=spikes, V=V, variations=variations)
jnp.savez('params.npz', w_out=w_out, w_in=params['RecurrentLIF']['w_in'], w_rec=params['RecurrentLIF']['w_rec'])
jnp.savez('inputs.npz', inputs=inputs)

In [10]:
exp_convolve(spikes, jnp.exp(-1/20))

(10, 1, 4) (1, 4) 0.95122945


Array([[[0.        , 0.04877055, 0.        , 0.        ],
        [0.        , 0.04639198, 0.04877055, 0.        ],
        [0.04877055, 0.04412942, 0.04639198, 0.        ],
        [0.04639198, 0.09074774, 0.04412942, 0.        ],
        [0.04412942, 0.08632193, 0.09074774, 0.        ],
        [0.09074774, 0.08211196, 0.08632193, 0.        ],
        [0.08632193, 0.12687787, 0.08211196, 0.        ],
        [0.08211196, 0.12068997, 0.07810732, 0.        ],
        [0.12687787, 0.11480386, 0.12306853, 0.        ],
        [0.12068997, 0.15797536, 0.11706641, 0.        ]]], dtype=float32)

In [11]:
print(w_out)

[[ 0.42851502]
 [-0.8945591 ]
 [ 0.16313784]
 [-1.6101485 ]]


In [12]:
print(y_out)

[[[-0.04362814]
  [-0.03354404]
  [-0.01100917]
  [-0.05410038]
  [-0.04350556]
  [-0.02048486]
  [-0.06311394]
  [-0.06003585]
  [-0.02825263]
  [-0.07050286]]]


In [13]:
print(spikes)

[[[0. 1. 0. 0.]
  [0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 1. 0.]
  [0. 1. 0. 0.]]]


In [14]:
learning_signals = jnp.einsum("btk,jk->tj", y_out - y_target, w_out)
print("ls", learning_signals)

pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)

# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html


gradients_eprop, eligibility_traces, _, _ = \
    compute_loss_gradient(learning_signals, pre_synpatic_spike_one_step_before, spikes, V,
                               variations, 
                               dt, thr, tau_a, tau_v, beta, dampening_factor,
                               decay_out, True)
gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

ls [[ 0.70276517 -1.4670781   0.26754627 -2.6406457 ]
 [-0.279064    0.5825682  -0.10624108  1.0485852 ]
 [-0.04610965  0.09625755 -0.01755418  0.17325738]
 [ 0.6447455  -1.3459575   0.2454579  -2.4226365 ]
 [ 0.3086181  -0.64426476  0.11749247 -1.159635  ]
 [ 0.32380497 -0.6759685   0.12327419 -1.2166997 ]
 [ 0.01503155 -0.03137955  0.00572259 -0.05648116]
 [-0.27900738  0.58245003 -0.10621952  1.0483724 ]
 [-0.4487736   0.9368505  -0.17085038  1.6862704 ]
 [ 0.04163218 -0.08691046  0.01584958 -0.15643322]]
[[[-1.7854602   0.7065196  -0.12989543 -1.0030212 ]
  [-0.29249185 -0.37670827  0.1329895  -0.7510459 ]
  [ 0.09827143  0.11767072 -0.92226696 -0.27788118]
  [-1.7964947   4.231631    1.003495   -1.1936592 ]
  [-0.30298838  2.9764812   1.2111037  -0.9323862 ]
  [ 0.08828679  3.3073232   0.10326697 -0.4503775 ]
  [ 0.32036343  3.880617   -1.0585889  -0.44702762]
  [ 1.7106297   2.6425867  -0.7504111  -0.22216843]
  [ 0.7929706   4.1714554   0.10752092 -0.26312488]
  [ 0.13079529  5.

Array([[ 0.        ,  0.        , -0.03873479,  0.90009314],
       [ 0.01200352,  0.        , -0.04557128,  1.        ],
       [-0.00351272,  0.        ,  0.        ,  0.73013407],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [15]:
ngep = gradients_eprop / jnp.max(jnp.abs(gradients_eprop))

In [16]:
fun = lambda *x: eval3(*x)[0]
def wrapper(w_rec):
    # w_rec = params['RecurrentLIF']['w_rec']
    # w_rec[0, 0] = w0
    return fun(lsnn2, inputs, params, w_rec, None, None, key, n_rec, dt, tau_v, T, 1)
   

In [17]:
jax.test_util.check_grads(
    wrapper,
    (params['RecurrentLIF']['w_rec'],),
    order=1,
    modes=['rev']
)

[[0. 0. 1.]] -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x000002749975DEB0>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x000002749AA5A110; to 'JaxprTracer' at 0x000002749AA5A0D0>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shardings': (UnspecifiedValue

AssertionError: 
Not equal to tolerance rtol=0.002, atol=0.002
VJP cotangent projection
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.90476847
Max relative difference: inf
 x: array(0.904768, dtype=float32)
 y: array(0., dtype=float32)

In [18]:

surrogate_grad3 = jax.value_and_grad(fun, argnums=3)(lsnn2, inputs, params, params['RecurrentLIF']['w_rec'].copy(), None, None, key,
    n_rec, dt, tau_v, T, 1)

[[0. 0. 1.]] -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x000002749FC718A0>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x000002749FF06160; to 'JaxprTracer' at 0x000002749FF06120>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shardings': (UnspecifiedValue

In [19]:
surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

Array([[ 0.        ,  0.        , -0.03873478,  0.9000932 ],
       [ 0.0120035 ,  0.        , -0.04557129,  1.        ],
       [-0.00351272,  0.        ,  0.        ,  0.7301342 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [20]:
nga = surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

In [21]:
jnp.allclose(ngep, nga, atol=1e-3)

Array(True, dtype=bool)

In [317]:
params

{'RecurrentLIF': {'w_in': Array([[-0.6301748 ,  0.48754728,  0.60680926,  0.8751213 ],
         [-0.40730554,  0.58812934, -0.8242644 ,  0.71092707],
         [-0.5873246 ,  0.8977012 ,  0.9403405 , -0.55874705]],      dtype=float32),
  'w_rec': Array([[ 0.7171732 , -0.42568645, -0.27384278,  0.33492026],
         [-0.06512911, -0.18751548, -0.68637997,  0.21879737],
         [ 0.3078918 , -0.17575324,  0.08511223, -0.16051175],
         [ 0.64690757, -0.17119652, -0.06764613, -0.5058548 ]],      dtype=float32)}}

In [274]:
from lif import SpikeFunction


vs = jnp.ones((2, 1))
vs = vs.at[0].set(-0.5)
print(vs)

fun2 = lambda *x: SpikeFunction(*x).sum()

jax.value_and_grad(fun2, argnums=0)(vs, dampening_factor) # 0.3 dampening factor

[[-0.5]
 [ 1. ]]


(Array(1., dtype=float32),
 Array([[0.15],
        [0.  ]], dtype=float32))