In [3]:
%matplotlib widget

In [50]:
import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import haiku as hk

# jax.config.update("jax_disable_jit", True)
# import logging
# logging.getLogger("jax").setLevel(logging.WARNING)


In [51]:
# 1. Let's define some parameters
n_in = 3
n_LIF = 4
n_ALIF = 0 #4
n_rec = n_ALIF + n_LIF

dt = 1  # ms
tau_v = 20  # ms
tau_a = 500  # ms
T = 10  # ms
f0 = 100  # Hz

thr = 0.62
beta = 0.07 * jnp.concatenate([jnp.zeros(n_LIF), jnp.ones(n_ALIF)])
dampening_factor = 0.3
n_ref = 3

In [52]:
# jax set random seed
key = jax.random.PRNGKey(2)

inputs = (jax.random.uniform(key, shape=(1, T, n_in)) < f0 * dt / 1000).astype(float)
print(inputs.shape, inputs)

(1, 10, 3) [[[0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]
  [0. 0. 0.]]]


In [53]:
def lsnn(x, state=None):
    core = hk.DeepRNN([
        hk.Linear(4),
        snn.LIF((4,), activation=spyx.axn.superspike(), threshold=0.3),
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(1)
    spikes, V = core(x, state)#core.initial_state(1))
    return spikes, V

lsnn_hk = hk.without_apply_rng(hk.transform(lsnn))

In [54]:
print(inputs[:,0].shape)
params = lsnn_hk.init(rng=key, x=inputs[:,0])#, state=jnp.zeros((1,4,3)))
print(params)

(1, 3)
{'linear': {'w': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'b': Array([0., 0., 0., 0.], dtype=float32)}, 'LIF': {'beta': Array([0.5494536, 0.6733595, 0.2541612, 0.2463395], dtype=float32)}}


In [55]:
spikes, V = lsnn_hk.apply(params, inputs[:,0])
print(spikes.shape, V[0].shape, spikes, V)

(1, 4) (1, 4) [[0. 0. 0. 0.]] (Array([[-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]], dtype=float32),)


In [56]:
state = None
spikes = []
V = []
for t in range(T):
    outs, state = lsnn_hk.apply(params, inputs[:,t], state)
    print(inputs[:,t], "->", outs[0])
    spikes.append(outs[0])
    V.append(outs[1][0])
# print(spikes)
spikes = jnp.stack(spikes, axis=0)
print(spikes.shape)

[[0. 0. 1.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [0. 1. 1. 0.]
[[0. 0. 0.]] -> [0. 1. 0. 0.]
[[0. 1. 1.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [0. 1. 1. 0.]
[[0. 0. 0.]] -> [0. 1. 0. 0.]
[[1. 0. 0.]] -> [0. 0. 0. 0.]


[[0. 0. 0.]] -> [1. 0. 0. 1.]
[[0. 0. 1.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [0. 1. 1. 0.]
(10, 4)


In [57]:
def exp_convolve(tensor, decay):
    '''
    Filters a tensor with an exponential filter.
    :param tensor: a tensor of shape (trial, time, neuron)
    :param decay: a decay constant of the form exp(-dt/tau) with tau the time constant
    :return: the filtered tensor of shape (trial, time, neuron)
    '''
    r_shp = range(len(tensor.shape))
    transpose_perm = [1, 0] + list(r_shp)[2:]

    tensor_time_major = jax.lax.transpose(tensor, permutation=transpose_perm)
    initializer = jnp.zeros_like(tensor_time_major[0])
    _, filtered_tensor = jax.lax.scan(lambda a, x: (a * decay + (1 - decay) * x, a * decay + (1 - decay) * x), xs=tensor_time_major, init=initializer)
    filtered_tensor = jax.lax.transpose(filtered_tensor, permutation=transpose_perm)

    return filtered_tensor

In [58]:
w_out = jax.random.normal(key=key, shape=[n_rec, 1])
decay_out = jnp.exp(-1 / 20)
print(spikes.shape)
z_filtered = exp_convolve(spikes, decay_out)
print(z_filtered.shape)
y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
y_target = jax.random.normal(key=key, shape=[T, 1])
loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
print(loss)

(10, 4)
(10, 4)
4.169595


In [59]:
# REMARK : Z as a hidden state!

In [60]:
def eval_fct(inputs):
    state = None
    spikes = []
    V = []
    for t in range(T):
        outs, state = lsnn_hk.apply(params, inputs[:,t], state)
        print(inputs[:,t], "->", outs[0])
        spikes.append(outs[0])
        V.append(outs[1][0])
    # print(spikes)
    spikes = jnp.stack(spikes, axis=0)
    print(spikes.shape)
    w_out = jax.random.normal(key=key, shape=[n_rec, 1])
    decay_out = jnp.exp(-1 / 20)
    print(spikes.shape)
    z_filtered = exp_convolve(spikes, decay_out)
    print(z_filtered.shape)
    y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
    y_target = jax.random.normal(key=key, shape=[T, 1])
    loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
    print(loss)

    return loss

In [61]:
surrogate_grad = jax.value_and_grad(eval_fct)(inputs)

Traced<ConcreteArray([[0. 0. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 1.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000017305586970>, in_tracers=(Traced<ShapedArray(float32[1,1,3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x0000017309D1D580; to 'JaxprTracer' at 0x0000017309D1DEA0>], out_avals=[ShapedArray(float32[1,3])], primitive=squeeze, params={'dimensions': (1,)}, effects=frozenset(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x00000173089855B0>, name_stack=NameStack(stack=(Transform(name='jvp'),)))) -> [0. 0. 0. 0.]
Traced<ConcreteArray([[0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(flo

In [62]:
print(surrogate_grad[0].shape, surrogate_grad[1].shape)
print(surrogate_grad[0], surrogate_grad[1])

() (1, 10, 3)
4.169595 [[[-4.1758074e-04  4.3031195e-04  4.4645189e-04]
  [-4.0897168e-05 -4.2305433e-04 -4.6824905e-04]
  [ 6.1749428e-04 -2.5006174e-03 -3.5016711e-03]
  [ 5.9845735e-04 -1.1083717e-03 -1.7980392e-03]
  [ 1.2781026e-03 -1.7627638e-03 -2.8037443e-03]
  [ 2.0294576e-03 -1.9827855e-03 -3.4295523e-03]
  [ 2.0014024e-03  3.7792340e-04  3.1928593e-04]
  [ 1.2110691e-03  6.1571115e-04  5.3813226e-05]
  [-6.9205730e-06  8.0696773e-06  8.3334262e-06]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


In [63]:
learning_signals = jnp.einsum("tk,jk->tj", y_out - y_target, w_out)


In [64]:
def shift_by_one_time_step(tensor, initializer=None):
    '''
    Shift the input on the time dimension by one.
    :param tensor: a tensor of shape (trial, time, neuron)
    :param initializer: pre-prend this as the new first element on the time dimension
    :return: a shifted tensor of shape (trial, time, neuron)
    '''
    if len(tensor.shape) == 2:
        tensor = jnp.expand_dims(tensor, axis=0)
    r_shp = range(len(tensor.shape))
    transpose_perm = [1, 0] + list(r_shp)[2:]
    tensor_time_major = jax.lax.transpose(tensor, permutation=transpose_perm)

    if initializer is None:
        initializer = jnp.zeros_like(tensor_time_major[0])

    # print(initializer.shape, tensor_time_major[:,:-1].shape)

    shifted_tensor = jnp.concat([initializer[None, :, :], tensor_time_major[:-1]], axis=0)
    # shifted_tensor = tensor_time_major
    # shifted_tensor = jnp.concatenate([initializer[:], tensor_time_major[:,:-1]], axis=1)

    shifted_tensor = jax.lax.transpose(shifted_tensor, permutation=transpose_perm)
    return shifted_tensor

In [65]:
print(spikes.shape)
pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)[0]
print(pre_synpatic_spike_one_step_before.shape)

(10, 4)
(1, 4) (10, 0, 4)
(10, 4)


In [206]:
from collections import namedtuple

# https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb


CustomALIFStateTuple = namedtuple('CustomALIFStateTuple', ('s', 'z', 'r', 'z_local'))

    
@jax.custom_gradient
def SpikeFunction(v_scaled, dampening_factor):
    z_ = jnp.greater(v_scaled, 0.)
    z_ = z_.astype(jnp.float32)

    def grad(dy):
        dE_dz = dy
        dz_dv_scaled = jnp.maximum(1 - jnp.abs(v_scaled), 0)
        dz_dv_scaled *= dampening_factor

        dE_dv_scaled = dE_dz * dz_dv_scaled

        return (dE_dv_scaled, jnp.zeros_like(dampening_factor).astype(jnp.float32))

    return z_, grad

class RecurrentLIF(hk.RNNCore):
    """
    Leaky Integrate and Fire neuron model inspired by the implementation in
    snnTorch:

    https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html
    
    """

    def __init__(self, 
                 n_in, n_rec, tau=20., thr=.615, dt=1., dtype=jnp.float32, dampening_factor=0.3,
                 tau_adaptation=200., beta=.16, tag='',
                 stop_gradients=False, w_in_init=None, w_rec_init=None, n_refractory=1, rec=True,
                 name="RecurrentLIF"):
        super(RecurrentLIF, self).__init__(name=name)

        self.n_refractory = n_refractory
        self.tau_adaptation = tau_adaptation
        self.beta = beta
        self.decay_b = jnp.exp(-dt / tau_adaptation)

        if jnp.isscalar(tau): tau = jnp.ones(n_rec, dtype=dtype) * jnp.mean(tau)
        if jnp.isscalar(thr): thr = jnp.ones(n_rec, dtype=dtype) * jnp.mean(thr)

        tau = jnp.array(tau, dtype=dtype)
        dt = jnp.array(dt, dtype=dtype)
        self.rec = rec

        self.dampening_factor = dampening_factor
        self.stop_gradients = stop_gradients
        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        self.tau = tau
        self._decay = jnp.exp(-dt / tau)
        self.thr = thr

        # init_w_in_var = w_in_init if w_in_init is not None else \
        #         (jax.random.uniform(key, shape=(n_in, n_rec)) / jnp.sqrt(n_in)).astype(dtype)
        init_w_in_var = w_in_init if w_in_init is not None else hk.initializers.TruncatedNormal(1./jnp.sqrt(n_in))
        self.w_in_var = hk.get_parameter("w_in" + tag, (n_in, n_rec), dtype, init_w_in_var)
        self.w_in_val = self.w_in_var

        if rec:
            # init_w_rec_var = w_rec_init if w_rec_init is not None else \
            # (jax.random.uniform(key, shape=(n_rec, n_rec)) / jnp.sqrt(n_rec)).astype(dtype)
            init_w_rec_var = w_rec_init if w_rec_init is not None else hk.initializers.TruncatedNormal(1./jnp.sqrt(n_rec))
            self.w_rec_var = hk.get_parameter("w_rec" + tag, (n_rec, n_rec), dtype, init_w_rec_var)

            self.recurrent_disconnect_mask = jnp.diag(jnp.ones(n_rec, dtype=bool))

            # Disconnect autotapse
            self.w_rec_val = jnp.where(self.recurrent_disconnect_mask, jnp.zeros_like(self.w_rec_var), self.w_rec_var)

            # dw_val_dw_var_rec = jnp.ones((self._num_units,self._num_units)) - jnp.diag(jnp.ones(self._num_units))
        
        # dw_val_dw_var_in = jnp.ones((n_in,self._num_units))

        # self.dw_val_dw_var = [dw_val_dw_var_in, dw_val_dw_var_rec] if rec else [dw_val_dw_var_in,]

        self.variable_list = [self.w_in_var, self.w_rec_var] if rec else [self.w_in_var,]
        self.built = True

    def initial_state(self, batch_size, dtype=jnp.float32, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        s0 = jnp.zeros(shape=(batch_size, n_rec, 2), dtype=dtype)
        z0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
        z_local0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
        r0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
        return CustomALIFStateTuple(s=s0, z=z0, r=r0, z_local=z_local0)
    
    def compute_z(self, v, b):
        adaptive_thr = self.thr + b * self.beta
        v_scaled = (v - adaptive_thr) / self.thr
        z = SpikeFunction(v_scaled, self.dampening_factor)
        z = z * 1 / self.dt
        return z
        
    def __call__(self, inputs, state, scope=None, dtype=jnp.float32):
        decay = self._decay

        z = state.z
        z_local = state.z_local
        s = state.s
        # print("zs", z.shape, s.shape)
        # return

        if self.stop_gradients:
            z = jax.lax.stop_gradient(z)

        if len(self.w_in_val.shape) == 3:
            i_in = jnp.einsum('bi,bij->bj', inputs, self.w_in_val)
        else:
            # print(inputs.shape, self.w_in_val.shape)
            # i_in = jnp.matmul(inputs, self.w_in_val)
            i_in = jnp.einsum('abc,cd->ad', inputs, self.w_in_val)
            # print(inputs.shape, self.w_in_val.shape)
            # print("i_in", i_in.shape)

        if self.rec:
            if len(self.w_rec_val.shape) == 3:
                i_rec = jnp.einsum('bi,bij->bj', z, self.w_rec_val)
            else:
                # print("z wrec", z.shape, self.w_rec_val.shape)
                i_rec = jnp.matmul(z, self.w_rec_val)

            i_t = i_in + i_rec
        else:
            i_t = i_in

        # print("i_t", i_t.shape)


        def get_new_v_b(s, i_t):
            v, b = s[..., 0], s[..., 1]
            # print("vs", v.shape, b.shape)
            # old_z = self.compute_z(v, b)
            new_b = self.decay_b * b + z_local #old_z

            I_reset = z * self.thr * self.dt
            # print('vii', v.shape, i_t.shape, I_reset.shape)
            new_v = decay * v + i_t  - I_reset

            return new_v, new_b
        
        new_v, new_b = get_new_v_b(s, i_t)
        # print("nv nb", new_v.shape, new_b.shape)


        is_refractory = jnp.greater(state.r, .1)
        zeros_like_spikes = jnp.zeros_like(state.z)
        new_z = jnp.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, new_b))
        new_z_local = jnp.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, new_b))

        new_r = jnp.clip(state.r + self.n_refractory * new_z - 1,
                                 0., float(self.n_refractory))
        new_s = jnp.stack((new_v, new_b), axis=-1)       

        new_state = CustomALIFStateTuple(s=new_s, z=new_z, r=new_r, z_local=new_z_local)
        return new_z, new_state

In [207]:
def lsnn2(x, state=None, batch_size=1):
    core = hk.DeepRNN([
        # hk.Linear(n_rec), # otherwise not need Win
        RecurrentLIF(
            n_in,
            n_rec,
            tau=tau_v,
            thr=thr,
            dt=dt,
            dtype=jnp.float32,
            dampening_factor=dampening_factor,
            tau_adaptation=tau_a,
            beta=beta,
            tag='',
            stop_gradients=False,
            w_in_init=None,
            w_rec_init=None,
            n_refractory=n_ref,
            rec=True,
        )
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(batch_size)
    spikes, hiddens = core(x, state)#core.initial_state(1))
    return spikes, hiddens

lsnn2_hk = hk.without_apply_rng(hk.transform(lsnn2))

In [208]:
# print(inputs[:,0].shape)
# i0 = jnp.expand_dims(inputs[:,0], axis=0)
i0 = jnp.stack([inputs[:,0], inputs[:,0], inputs[:,0],inputs[:,0], inputs[:,0]], axis=0)
print(i0.shape)
params = lsnn2_hk.init(rng=key, x=i0, batch_size=5)
print(params)

(5, 1, 3)
{'RecurrentLIF': {'w_in': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'w_rec': Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
       [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
       [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
       [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)}}


In [209]:
def eval2(inputs, batch_size=5):
    state = None
    spikes = []
    V = []
    variations = []
    for t in range(T):
        input_t = [jnp.copy(inputs[:,t]) for _ in range(batch_size)]
        it = jnp.stack(input_t, axis=0)
        outs, state = lsnn2_hk.apply(params, it, state, batch_size)
        # print(state[0].s.shape)
        print(inputs[:,t], "->", outs)
        spikes.append(outs)
        V.append(state[0].s[...,0])
        variations.append(state[0].s[...,1])
        # print(V[-1].shape)

    if batch_size == 1:
        spikes = jnp.stack([s[0] for s in spikes], axis=0)
        spikes = jnp.expand_dims(spikes, axis=0)#, V = jnp.expand_dims(V, axis=0), variations = jnp.expand_dims(variations, axis=0)
        print(spikes.shape)
    # print(spikes)
    else:
        spikes = jnp.stack(spikes, axis=1)
    V = jnp.stack(V, axis=1)
    variations = jnp.stack(variations, axis=1)
    print(spikes.shape)
    w_out = jax.random.normal(key=key, shape=[n_rec, 1])
    decay_out = jnp.exp(-dt / tau_v)
    print(spikes.shape)
    z_filtered = exp_convolve(spikes, decay_out)
    print(z_filtered.shape)
    y_out = jnp.einsum("btj,jk->btk", z_filtered, w_out) # no batch dim
    y_target = jax.random.normal(key=key, shape=[1, T, 1])
    print(y_out.shape, y_target.shape)
    loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
    return loss, y_out, y_target, w_out, spikes, V, variations

In [210]:
eval2(inputs, batch_size=1);

[[0. 0. 1.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 1. 0.]]
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
[[0. 1. 1.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 1. 0.]]
[[0. 0. 0.]] -> [[1. 0. 0. 0.]]
[[1. 0. 0.]] -> [[0. 1. 0. 0.]]
[[0. 0. 0.]] -> [[0. 0. 0. 0.]]
[[0. 0. 1.]] -> [[1. 0. 1. 0.]]
[[0. 0. 0.]] -> [[0. 1. 0. 0.]]
(1, 10, 4)
(1, 10, 4)
(1, 10, 4)
(1, 10, 4)
(1, 10, 1) (1, 10, 1)


In [211]:
surrogate_grad = jax.value_and_grad(lambda x: eval2(x, 1)[0])(inputs)

Traced<ConcreteArray([[0. 0. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 1.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000017310D7E5D0>, in_tracers=(Traced<ShapedArray(float32[1,1,3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x0000017310E88CC0; to 'JaxprTracer' at 0x0000017310E88F00>], out_avals=[ShapedArray(float32[1,3])], primitive=squeeze, params={'dimensions': (1,)}, effects=frozenset(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x000001730F961FF0>, name_stack=NameStack(stack=(Transform(name='jvp'),)))) -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,

In [212]:
print(surrogate_grad)

(Array(4.005104, dtype=float32), Array([[[-4.3032640e-03, -5.6647146e-03, -9.6947327e-03],
        [-4.9892762e-03,  2.2617994e-04,  2.6807082e-03],
        [-5.2042147e-03, -1.9872107e-03,  4.8570749e-03],
        [ 1.6127031e-02, -8.3251214e-03,  8.0612861e-03],
        [ 1.6953882e-02, -8.7519595e-03,  8.4745968e-03],
        [ 1.8808581e-02, -9.4396481e-03,  8.9050224e-03],
        [ 3.7410419e-02, -1.0757166e-02, -1.2092085e-03],
        [ 3.0142812e-02, -9.0812314e-03, -1.2332182e-03],
        [ 1.2657992e-02, -4.7759078e-03, -1.0077114e-03],
        [-8.1238413e-04,  1.9699767e-04,  3.3596320e-06]]], dtype=float32))


In [213]:
print(surrogate_grad)

(Array(4.005104, dtype=float32), Array([[[-4.3032640e-03, -5.6647146e-03, -9.6947327e-03],
        [-4.9892762e-03,  2.2617994e-04,  2.6807082e-03],
        [-5.2042147e-03, -1.9872107e-03,  4.8570749e-03],
        [ 1.6127031e-02, -8.3251214e-03,  8.0612861e-03],
        [ 1.6953882e-02, -8.7519595e-03,  8.4745968e-03],
        [ 1.8808581e-02, -9.4396481e-03,  8.9050224e-03],
        [ 3.7410419e-02, -1.0757166e-02, -1.2092085e-03],
        [ 3.0142812e-02, -9.0812314e-03, -1.2332182e-03],
        [ 1.2657992e-02, -4.7759078e-03, -1.0077114e-03],
        [-8.1238413e-04,  1.9699767e-04,  3.3596320e-06]]], dtype=float32))


In [214]:
pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)
print(pre_synpatic_spike_one_step_before.shape)

(1, 4) (10, 0, 4)
(1, 10, 4)


In [215]:
print(beta, thr)

[0. 0. 0. 0.] 0.62


In [216]:
def iterate(fun, xs, init, remove_first=False):
    rets = [init]
    for x in xs:
        rets.append(fun(rets[-1], x))
    return jnp.array(rets if not remove_first else rets[1:])

In [217]:
print(jnp.ones(4)*jnp.exp(-dt/tau_v))

[0.95122945 0.95122945 0.95122945 0.95122945]


In [218]:
def compute_eligibility_traces(v_scaled, z_pre, z_post, is_rec):
    n_neurons = jnp.shape(z_post)[2]
    rho = jnp.exp(-dt / tau_a)
    # beta = beta # defined outside
    alpha = jnp.ones(4)*jnp.exp(-dt/tau_v)
    n_ref = 3 #n_refractory

    # everything should be time major
    # z_pre = tf.transpose(z_pre, perm=[1, 0, 2])
    # v_scaled = tf.transpose(v_scaled, perm=[1, 0, 2])
    # z_post = tf.transpose(z_post, perm=[1, 0, 2])

    z_pre = jax.lax.transpose(z_pre, permutation=[1, 0, 2])
    v_scaled = jax.lax.transpose(v_scaled, permutation=[1, 0, 2])
    z_post = jax.lax.transpose(z_post, permutation=[1, 0, 2])

    psi_no_ref = dampening_factor / thr * jnp.maximum(0., 1. - jnp.abs(v_scaled))

    update_refractory = lambda refractory_count, z_post:\
        jnp.where(z_post > 0,jnp.ones_like(refractory_count) * (n_ref - 1),jnp.maximum(0, refractory_count - 1))

    refractory_count_init = jnp.zeros_like(z_post[0], dtype=jnp.int32)
    
    # refractory_count = [refractory_count_init]
    # for z in z_post[:-1]:
    #     refractory_count.append(update_refractory(refractory_count[-1], z))
    # print(jnp.array(refractory_count).shape)
    refractory_count = iterate(update_refractory, z_post[:-1], refractory_count_init)

    print("refractory_count", refractory_count)
    # refractory_count = tf.scan(update_refractory, z_post[:-1], initializer=refractory_count_init)
    # print(refractory_count_init.shape, z_post[:-1].shape)
    # print(update_refractory(update_refractory(refractory_count_init, z_post[0]), z_post[1]))
    # print(jax.lax.scan(update_refractory, xs=z_post[:-1], init=refractory_count_init))
    # _, refractory_count = jax.lax.scan(update_refractory, xs=z_post[:-1], init=refractory_count_init)
    # refractory_count = jnp.concat([jnp.expand_dims(refractory_count_init, axis=0), refractory_count], axis=0)

    is_refractory = refractory_count > 0
    psi = jnp.where(is_refractory, jnp.zeros_like(psi_no_ref), psi_no_ref)

    print("psi", psi)

    update_epsilon_v = lambda epsilon_v, z_pre: alpha[None, None, :] * epsilon_v + z_pre[:, :, None] #alpha[None, None, :] * epsilon_v + z_pre[:, :, None]
    epsilon_v_zero = jnp.ones((1, 1, n_neurons)) * z_pre[0][:, :, None]
    print("evz", epsilon_v_zero)
    print("zpre", z_pre[1:])
    # epsilon_v = tf.scan(update_epsilon_v, z_pre[1:], initializer=epsilon_v_zero, )
    # _, epsilon_v = jax.lax.scan(update_epsilon_v, xs=z_pre[1:], init=epsilon_v_zero)
    # epsilon_v = jnp.concat([[epsilon_v_zero], epsilon_v], axis=0)
    print(epsilon_v_zero.shape, z_pre[1:].shape)
    epsilon_v = iterate(update_epsilon_v, z_pre[1:], epsilon_v_zero)
    print("ev", epsilon_v)

    update_epsilon_a = lambda epsilon_a, elems:\
            (rho - beta * elems['psi'][:, None, :]) * epsilon_a + elems['psi'][:, None, :] * elems['epsi']

    epsilon_a_zero = jnp.zeros_like(epsilon_v[0])
    # epsilon_a = tf.scan(fn=update_epsilon_a,
    #                     elems={'psi': psi[:-1], 'epsi': epsilon_v[:-1], 'previous_epsi':shift_by_one_time_step(epsilon_v[:-1])},
    #                     initializer=epsilon_a_zero)
    # _, epsilon_a = jax.lax.scan(update_epsilon_a,
    #                             xs=[{'psi': psi[:-1], 'epsi': epsilon_v[:-1], 'previous_epsi':shift_by_one_time_step(epsilon_v[:-1])}],
    #                             init=epsilon_a_zero)

    # epsilon_a = jnp.concat([[epsilon_a_zero], epsilon_a], axis=0)
    # epsilon_a = iterate(update_epsilon_a, [{'psi': psi[:-1], 'epsi': epsilon_v[:-1], 'previous_epsi':shift_by_one_time_step(epsilon_v[:-1])}], epsilon_a_zero)
    previous_epsi = shift_by_one_time_step(epsilon_v[:-1])
    elems = [{'psi': psie, 'epsi': epsie, 'previous_epsi': pe} for psie, epsie, pe in zip(psi[:-1], epsilon_v[:-1], previous_epsi)]
    epsilon_a = iterate(update_epsilon_a, elems, epsilon_a_zero)
    print("ea", epsilon_a)

    e_trace = psi[:, :, None, :] * (epsilon_v - beta * epsilon_a)

    # everything should be time major
    # e_trace = tf.transpose(e_trace, perm=[1, 0, 2, 3])
    # epsilon_v = tf.transpose(epsilon_v, perm=[1, 0, 2, 3])
    # epsilon_a = tf.transpose(epsilon_a, perm=[1, 0, 2, 3])
    # psi = tf.transpose(psi, perm=[1, 0, 2])

    e_trace = jax.lax.transpose(e_trace, permutation=[1, 0, 2, 3])
    epsilon_v = jax.lax.transpose(epsilon_v, permutation=[1, 0, 2, 3])
    epsilon_a = jax.lax.transpose(epsilon_a, permutation=[1, 0, 2, 3])
    psi = jax.lax.transpose(psi, permutation=[1, 0, 2])

    if is_rec:
        identity_diag = jnp.eye(n_neurons)[None, None, :, :]
        e_trace -= identity_diag * e_trace
        epsilon_v -= identity_diag * epsilon_v
        epsilon_a -= identity_diag * epsilon_a

    return e_trace, epsilon_v, epsilon_a, psi

In [219]:
def compute_loss_gradient(learning_signal, z_pre, z_post, v_post, b_post,
                              decay_out=None,zero_on_diagonal=None):
        thr_post = thr + beta * b_post
        v_scaled = (v_post - thr_post) / thr
        print(v_scaled)

        e_trace, epsilon_v, epsilon_a, _ = compute_eligibility_traces(v_scaled, z_pre, z_post, zero_on_diagonal)
        print("evbv", epsilon_v)
        print("eabv", epsilon_a)

        if decay_out is not None:
            e_trace_time_major = jax.lax.transpose(e_trace, permutation=[1, 0, 2, 3])
            filtered_e_zero = jnp.zeros_like(e_trace_time_major[0])
            filtering = lambda filtered_e, e: filtered_e * decay_out + e * (1 - decay_out)
            # filtered_e = tf.scan(filtering, e_trace_time_major, initializer=filtered_e_zero)
            # _, filtered_e = jax.lax.scan(filtering, xs=e_trace_time_major, init=filtered_e_zero)
            filtered_e = iterate(filtering, e_trace_time_major, filtered_e_zero, remove_first=True)
            filtered_e = jax.lax.transpose(filtered_e, permutation=[1, 0, 2, 3])
            e_trace = filtered_e
        print("e_trace", e_trace)

        if len(learning_signal.shape) == 2:
            learning_signal = jnp.expand_dims(learning_signal, axis=0)

        print(e_trace.shape, learning_signal.shape)

        gradient = jnp.einsum('btj,btij->ij', learning_signal, e_trace)
        return gradient, e_trace, epsilon_v, epsilon_a

In [235]:
# loss, y_out, y_target, w_out, spikes, V, variations = eval2(inputs, 1)
loss, y_out, y_target, w_out, spikes, V, variations = eval3(params['RecurrentLIF']['w_rec'], 1)

print(y_out.shape, y_target.shape, w_out.shape, spikes.shape, V.shape, variations.shape)

[[0. 0. 1.]] -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x000001731131DDE0>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x000001731237FB00; to 'JaxprTracer' at 0x000001731237FAC0>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shardings': (UnspecifiedValue

In [236]:
learning_signals = jnp.einsum("btk,jk->tj", y_out - y_target, w_out)
print(learning_signals.shape)

(10, 4)


In [237]:
# learning_signals, spikes, V, variations = jnp.load("data.npz").values()


# print(spikes)
pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)
# print(pre_synpatic_spike_one_step_before)
# print(learning_signals)
# print(V)
print(pre_synpatic_spike_one_step_before)


gradients_eprop, eligibility_traces, _, _ = \
    compute_loss_gradient(learning_signals, pre_synpatic_spike_one_step_before, spikes, V,
                               variations, decay_out, True)

(1, 4) (10, 0, 4)
Traced<ConcreteArray([[[0. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 1. 0.]]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[[0., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.]]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,10,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,10,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x000001731131FE60>, in_tracers=(Traced<ShapedArray(float32[10,1,4]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x000001731238D030; to 'JaxprTracer' at 0x000001731238CFF0>], out_avals=[ShapedArray(float32[1,10,4])], primitive=transpose, params={'permutation': (1, 0, 2)}

In [239]:
# (gradients_eprop / jnp.max(jnp.abs(gradients_eprop))).asarray()

AttributeError: JVPTracer has no attribute asarray

In [240]:
def eval3(weights, batch_size=1):
    state = None
    spikes = []
    V = []
    variations = []
    params['RecurrentLIF']['w_rec'] = weights
    for t in range(T):
        it = inputs[:, t]
        it = jnp.expand_dims(it, axis=0)
        outs, state = lsnn2_hk.apply(params, it, state, batch_size)
        # print(state[0].s.shape)
        print(inputs[:,t], "->", outs)
        spikes.append(outs)
        V.append(state[0].s[...,0])
        variations.append(state[0].s[...,1])
        # print(V[-1].shape)

    spikes = jnp.stack([s[0] for s in spikes], axis=0)
    V = jnp.stack(V, axis=1)
    variations = jnp.stack(variations, axis=1)
    print(spikes.shape)
    w_out = jax.random.normal(key=key, shape=[n_rec, 1])
    decay_out = jnp.exp(-dt / tau_v)
    print(spikes.shape)
    z_filtered = exp_convolve(spikes, decay_out)
    print(z_filtered.shape)
    y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
    y_target = jax.random.normal(key=key, shape=[T, 1])
    print(y_out.shape, y_target.shape)
    loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
    y_out = jnp.expand_dims(y_out, axis=0)
    y_target = jnp.expand_dims(y_target, axis=0)
    spikes = jnp.expand_dims(spikes, axis=0)
    return loss, y_out, y_target, w_out, spikes, V, variations

In [241]:
loss, y_out, y_target, w_out, spikes, V, variations = eval3(params['RecurrentLIF']['w_rec'], 1)
print(y_out.shape, y_target.shape, w_out.shape, spikes.shape, V.shape, variations.shape)

[[0. 0. 1.]] -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x00000173124F08B0>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x0000017312589260; to 'JaxprTracer' at 0x0000017312589220>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shardings': (UnspecifiedValue

In [242]:
surrogate_grad3 = jax.value_and_grad(lambda x: eval3(x, 1)[0])(params['RecurrentLIF']['w_rec'])

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type [[ 0.2713923   0.86784893  0.07354291 -0.43423995]
 [ 0.9018898  -0.1874919   0.18929863  0.15613334]
 [ 0.2636667   0.32536188  0.5184128   0.30088994]
 [ 0.22115956 -0.26662493  0.49525732  0.11215727]], dtype=float32 wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Tracer from a higher level: Traced<ConcreteArray([[ 0.2713923   0.86784893  0.07354291 -0.43423995]
 [ 0.9018898  -0.1874919   0.18929863  0.15613334]
 [ 0.2636667   0.32536188  0.5184128   0.30088994]
 [ 0.22115956 -0.26662493  0.49525732  0.11215727]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
       [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
       [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
       [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)
  tangent = Traced<ShapedArray(float32[4,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[4,4]), None)
    recipe = LambdaBinding() in trace JVPTrace(level=2/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [228]:
surrogate_grad3[1] / jnp.max(jnp.abs(surrogate_grad3[1]))

Array([[ 0.        ,  0.        ,  0.59414923,  0.57911915],
       [-0.08269119,  0.        ,  1.        ,  0.9049598 ],
       [-0.36506352,  0.        ,  0.        ,  0.42960662],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32)

In [229]:
surrogate_grad3

(Array(3.8879519, dtype=float32),
 Array([[ 0.        ,  0.        ,  0.06295326,  0.06136075],
        [-0.00876157,  0.        ,  0.10595531,  0.0958853 ],
        [-0.03868042,  0.        ,  0.        ,  0.0455191 ],
        [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32))

In [230]:
res3 = eval3(params['RecurrentLIF']['w_rec'], 1)
loss3, spikes3 = res3[0], res3[4]
res2 = eval2(inputs, 1)
loss2, spikes2 = res2[0], res2[4]
print(loss2, loss3)
print(spikes2 - spikes3)

[[0. 0. 1.]] -> Traced<ConcreteArray([[0. 1. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 1., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,4]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000017310D7E0E0>, in_tracers=(Traced<ShapedArray(float32[1,4]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[False False False False]], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[0. 0. 0. 0.]], dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x0000017310F553A0; to 'JaxprTracer' at 0x0000017310F55860>], out_avals=[ShapedArray(float32[1,4])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,4][39m b[35m:bool[1,4][39m c[35m:f32[1,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1,4][39m = select_n b a c
  [34m[22m[1min [39m[22m[22m(d,) }, 'in_shardings': (UnspecifiedValue