In [3]:
%matplotlib widget

In [4]:
import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import haiku as hk

In [5]:
# 1. Let's define some parameters
n_in = 3
n_LIF = 4
n_ALIF = 0 #4
n_rec = n_ALIF + n_LIF

dt = 1  # ms
tau_v = 20  # ms
tau_a = 500  # ms
T = 10  # ms
f0 = 100  # Hz

thr = 0.62
beta = 0.07 * jnp.concatenate([jnp.zeros(n_LIF), jnp.ones(n_ALIF)])
dampening_factor = 0.3
n_ref = 3

In [6]:
# jax set random seed
key = jax.random.PRNGKey(2)

inputs = (jax.random.uniform(key, shape=(1, T, n_in)) < f0 * dt / 1000).astype(float)
print(inputs.shape, inputs)

(1, 10, 3) [[[0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 1.]
  [0. 0. 0.]]]


In [7]:
def lsnn(x, state=None):
    core = hk.DeepRNN([
        hk.Linear(4),
        snn.LIF((4,), activation=spyx.axn.superspike(), threshold=0.3),
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(1)
    spikes, V = core(x, state)#core.initial_state(1))
    return spikes, V

lsnn_hk = hk.without_apply_rng(hk.transform(lsnn))

In [8]:
print(inputs[:,0].shape)
params = lsnn_hk.init(rng=key, x=inputs[:,0])#, state=jnp.zeros((1,4,3)))
print(params)

(1, 3)
{'linear': {'w': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'b': Array([0., 0., 0., 0.], dtype=float32)}, 'LIF': {'beta': Array([0.5494536, 0.6733595, 0.2541612, 0.2463395], dtype=float32)}}


In [9]:
spikes, V = lsnn_hk.apply(params, inputs[:,0])
print(spikes.shape, V[0].shape, spikes, V)

(1, 4) (1, 4) [[0. 0. 0. 0.]] (Array([[-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]], dtype=float32),)


In [10]:
state = None
spikes = []
V = []
for t in range(T):
    outs, state = lsnn_hk.apply(params, inputs[:,t], state)
    print(inputs[:,t], "->", outs[0])
    spikes.append(outs[0])
    V.append(outs[1][0])
# print(spikes)
spikes = jnp.stack(spikes, axis=0)
print(spikes.shape)

[[0. 0. 1.]] -> [0. 0. 0. 0.]


[[0. 0. 0.]] -> [0. 1. 1. 0.]
[[0. 0. 0.]] -> [0. 1. 0. 0.]
[[0. 1. 1.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [0. 1. 1. 0.]
[[0. 0. 0.]] -> [0. 1. 0. 0.]
[[1. 0. 0.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [1. 0. 0. 1.]
[[0. 0. 1.]] -> [0. 0. 0. 0.]
[[0. 0. 0.]] -> [0. 1. 1. 0.]
(10, 4)


In [11]:
def exp_convolve(tensor, decay):
    '''
    Filters a tensor with an exponential filter.
    :param tensor: a tensor of shape (trial, time, neuron)
    :param decay: a decay constant of the form exp(-dt/tau) with tau the time constant
    :return: the filtered tensor of shape (trial, time, neuron)
    '''
    r_shp = range(len(tensor.shape))
    transpose_perm = [1, 0] + list(r_shp)[2:]

    tensor_time_major = jax.lax.transpose(tensor, permutation=transpose_perm)
    initializer = jnp.zeros_like(tensor_time_major[0])
    _, filtered_tensor = jax.lax.scan(lambda a, x: (a * decay + (1 - decay) * x, a * decay + (1 - decay) * x), xs=tensor_time_major, init=initializer)
    filtered_tensor = jax.lax.transpose(filtered_tensor, permutation=transpose_perm)

    return filtered_tensor

In [12]:
w_out = jax.random.normal(key=key, shape=[n_rec, 1])
decay_out = jnp.exp(-1 / 20)
print(spikes.shape)
z_filtered = exp_convolve(spikes, decay_out)
print(z_filtered.shape)
y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
y_target = jax.random.normal(key=key, shape=[T, 1])
loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
print(loss)

(10, 4)
(10, 4)
4.169595


In [13]:
# REMARK : Z as a hidden state!

In [14]:
def eval_fct(inputs):
    state = None
    spikes = []
    V = []
    for t in range(T):
        outs, state = lsnn_hk.apply(params, inputs[:,t], state)
        print(inputs[:,t], "->", outs[0])
        spikes.append(outs[0])
        V.append(outs[1][0])
    # print(spikes)
    spikes = jnp.stack(spikes, axis=0)
    print(spikes.shape)
    w_out = jax.random.normal(key=key, shape=[n_rec, 1])
    decay_out = jnp.exp(-1 / 20)
    print(spikes.shape)
    z_filtered = exp_convolve(spikes, decay_out)
    print(z_filtered.shape)
    y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
    y_target = jax.random.normal(key=key, shape=[T, 1])
    loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
    print(loss)

    return loss

In [15]:
surrogate_grad = jax.value_and_grad(eval_fct)(inputs)

Traced<ConcreteArray([[0. 0. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 1.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000018803C7A870>, in_tracers=(Traced<ShapedArray(float32[1,1,3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x0000018805FE6390; to 'JaxprTracer' at 0x0000018805FE6350>], out_avals=[ShapedArray(float32[1,3])], primitive=squeeze, params={'dimensions': (1,)}, effects=frozenset(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x0000018804DB24F0>, name_stack=NameStack(stack=(Transform(name='jvp'),)))) -> [0. 0. 0. 0.]
Traced<ConcreteArray([[0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(flo

In [16]:
print(surrogate_grad[0].shape, surrogate_grad[1].shape)
print(surrogate_grad[0], surrogate_grad[1])

() (1, 10, 3)
4.169595 [[[-4.1758074e-04  4.3031195e-04  4.4645189e-04]
  [-4.0897168e-05 -4.2305433e-04 -4.6824905e-04]
  [ 6.1749428e-04 -2.5006174e-03 -3.5016711e-03]
  [ 5.9845735e-04 -1.1083717e-03 -1.7980392e-03]
  [ 1.2781026e-03 -1.7627638e-03 -2.8037443e-03]
  [ 2.0294576e-03 -1.9827855e-03 -3.4295523e-03]
  [ 2.0014024e-03  3.7792340e-04  3.1928593e-04]
  [ 1.2110691e-03  6.1571115e-04  5.3813226e-05]
  [-6.9205730e-06  8.0696773e-06  8.3334262e-06]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


In [17]:
learning_signals = jnp.einsum("tk,jk->tj", y_out - y_target, w_out)


In [18]:
def shift_by_one_time_step(tensor, initializer=None):
    '''
    Shift the input on the time dimension by one.
    :param tensor: a tensor of shape (trial, time, neuron)
    :param initializer: pre-prend this as the new first element on the time dimension
    :return: a shifted tensor of shape (trial, time, neuron)
    '''
    tensor = jnp.expand_dims(tensor, axis=0)
    r_shp = range(len(tensor.shape))
    transpose_perm = [1, 0] + list(r_shp)[2:]
    tensor_time_major = jax.lax.transpose(tensor, permutation=transpose_perm)

    if initializer is None:
        initializer = jnp.zeros_like(tensor_time_major[0])

    print(initializer.shape, tensor_time_major[:,:-1].shape)

    shifted_tensor = jnp.concat([initializer[None, :, :], tensor_time_major[:-1]], axis=0)
    # shifted_tensor = tensor_time_major
    # shifted_tensor = jnp.concatenate([initializer[:], tensor_time_major[:,:-1]], axis=1)

    shifted_tensor = jax.lax.transpose(shifted_tensor, permutation=transpose_perm)
    return shifted_tensor

In [19]:
print(spikes.shape)
pre_synpatic_spike_one_step_before = shift_by_one_time_step(spikes)[0]
print(pre_synpatic_spike_one_step_before.shape)

(10, 4)
(1, 4) (10, 0, 4)
(10, 4)


In [71]:
from collections import namedtuple

# https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb


CustomALIFStateTuple = namedtuple('CustomALIFStateTuple', ('s', 'z', 'r'))

    
@jax.custom_gradient
def SpikeFunction(v_scaled, dampening_factor):
    z_ = jnp.greater(v_scaled, 0.)
    z_ = z_.astype(jnp.float32)

    def grad(dy):
        dE_dz = dy
        dz_dv_scaled = jnp.maximum(1 - jnp.abs(v_scaled), 0)
        dz_dv_scaled *= dampening_factor

        dE_dv_scaled = dE_dz * dz_dv_scaled

        return [dE_dv_scaled,
                jnp.zeros_like(dampening_factor)]

    return z_, grad

class RecurrentLIF(hk.RNNCore):
    """
    Leaky Integrate and Fire neuron model inspired by the implementation in
    snnTorch:

    https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html
    
    """

    def __init__(self, 
                 n_in, n_rec, tau=20., thr=.615, dt=1., dtype=jnp.float32, dampening_factor=0.3,
                 tau_adaptation=200., beta=.16, tag='',
                 stop_gradients=False, w_in_init=None, w_rec_init=None, n_refractory=1, rec=True,
                 name="RecurrentLIF"):
        super(RecurrentLIF, self).__init__(name=name)

        self.n_refractory = n_refractory
        self.tau_adaptation = tau_adaptation
        self.beta = beta
        self.decay_b = jnp.exp(-dt / tau_adaptation)

        if jnp.isscalar(tau): tau = jnp.ones(n_rec, dtype=dtype) * jnp.mean(tau)
        if jnp.isscalar(thr): thr = jnp.ones(n_rec, dtype=dtype) * jnp.mean(thr)

        tau = jnp.array(tau, dtype=dtype)
        dt = jnp.array(dt, dtype=dtype)
        self.rec = rec

        self.dampening_factor = dampening_factor
        self.stop_gradients = stop_gradients
        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        self.tau = tau
        self._decay = jnp.exp(-dt / tau)
        self.thr = thr

        # init_w_in_var = w_in_init if w_in_init is not None else \
        #         (jax.random.uniform(key, shape=(n_in, n_rec)) / jnp.sqrt(n_in)).astype(dtype)
        init_w_in_var = w_in_init if w_in_init is not None else hk.initializers.TruncatedNormal(1./jnp.sqrt(n_in))
        self.w_in_var = hk.get_parameter("w_in" + tag, (n_in, n_rec), dtype, init_w_in_var)
        self.w_in_val = self.w_in_var

        if rec:
            # init_w_rec_var = w_rec_init if w_rec_init is not None else \
            # (jax.random.uniform(key, shape=(n_rec, n_rec)) / jnp.sqrt(n_rec)).astype(dtype)
            init_w_rec_var = w_rec_init if w_rec_init is not None else hk.initializers.TruncatedNormal(1./jnp.sqrt(n_rec))
            self.w_rec_var = hk.get_parameter("w_rec" + tag, (n_rec, n_rec), dtype, init_w_rec_var)

            self.recurrent_disconnect_mask = jnp.diag(jnp.ones(n_rec, dtype=bool))

            # Disconnect autotapse
            self.w_rec_val = jnp.where(self.recurrent_disconnect_mask, jnp.zeros_like(self.w_rec_var), self.w_rec_var)

            dw_val_dw_var_rec = jnp.ones((self._num_units,self._num_units)) - jnp.diag(jnp.ones(self._num_units))
        
        dw_val_dw_var_in = jnp.ones((n_in,self._num_units))

        self.dw_val_dw_var = [dw_val_dw_var_in, dw_val_dw_var_rec] if rec else [dw_val_dw_var_in,]

        self.variable_list = [self.w_in_var, self.w_rec_var] if rec else [self.w_in_var,]
        self.built = True

    def initial_state(self, batch_size, dtype=jnp.float32, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        s0 = jnp.zeros(shape=(batch_size, n_rec, 2), dtype=dtype)
        z0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
        r0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
        return CustomALIFStateTuple(s=s0, z=z0, r=r0)

    
    
    def compute_z(self, v, b):
        adaptive_thr = self.thr + b * self.beta
        v_scaled = (v - adaptive_thr) / self.thr
        z = SpikeFunction(v_scaled, self.dampening_factor)
        z = z * 1 / self.dt
        return z
        
    def __call__(self, inputs, state, scope=None, dtype=jnp.float32):
        decay = self._decay

        z = state.z
        s = state.s
        print("zs", z.shape, s.shape)
        # return

        if self.stop_gradients:
            z = jax.lax.stop_gradient(z)

        if len(self.w_in_val.shape) == 3:
            i_in = jnp.einsum('bi,bij->bj', inputs, self.w_in_val)
        else:
            # print(inputs.shape, self.w_in_val.shape)
            # i_in = jnp.matmul(inputs, self.w_in_val)
            i_in = jnp.einsum('abc,cd->ad', inputs, self.w_in_val)
            print(inputs.shape, self.w_in_val.shape)
            print("i_in", i_in.shape)

        if self.rec:
            if len(self.w_rec_val.shape) == 3:
                i_rec = jnp.einsum('bi,bij->bj', z, self.w_rec_val)
            else:
                print("z wrec", z.shape, self.w_rec_val.shape)
                i_rec = jnp.matmul(z, self.w_rec_val)

            i_t = i_in + i_rec
        else:
            i_t = i_in

        print("i_t", i_t.shape)


        def get_new_v_b(s, i_t):
            v, b = s[..., 0], s[..., 1]
            print("vs", v.shape, b.shape)
            old_z = self.compute_z(v, b)
            new_b = self.decay_b * b + old_z

            I_reset = z * self.thr * self.dt
            print('vii', v.shape, i_t.shape, I_reset.shape)
            new_v = decay * v + i_t  - I_reset

            return new_v, new_b
        
        new_v, new_b = get_new_v_b(s, i_t)
        print("nv nb", new_v.shape, new_b.shape)


        # Spike generation
        is_refractory = jnp.greater(state.r, .1)
        zeros_like_spikes = jnp.zeros_like(state.z)
        new_z = jnp.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, new_b))
        new_r = jnp.clip(state.r + self.n_refractory * new_z - 1,
                                 0., float(self.n_refractory))
        new_s = jnp.stack((new_v, new_b), axis=-1)



        def safe_grad(fun, x):
            # g = jax.grad(fun)(x)
            g = jax.vjp(fun, x)[1]
            if g is None:
                g = jnp.zeros_like(x)
            return g

        # dnew_v_ds = tf.gradients(new_v, s, name='dnew_v_ds')[0]
        # dnew_b_ds = tf.gradients(new_b, s, name='dnew_b_ds')[0]
        # dnew_v_ds = jax.vjp(lambda x: get_new_v_b(x, i_t)[0], s)[1]
        # dnew_b_ds = jax.vjp(lambda x: get_new_v_b(x, i_t)[1], s)[1]

        # print("dnew_v_ds", dnew_v_ds, dnew_b_ds)

        # dnew_v_ds = jax.grad(lambda x: get_new_v_b(x, i_t)[0])(s)
        # dnew_b_ds = jax.grad(lambda x: get_new_v_b(x, i_t)[1])(s)


        # dnew_s_ds = jnp.stack((dnew_v_ds, dnew_b_ds), 2)

        # dnew_z_dnew_v = jnp.where(is_refractory, zeros_like_spikes, safe_grad(new_z, new_v))
        # dnew_z_dnew_b = jnp.where(is_refractory, zeros_like_spikes, safe_grad(new_z, new_b))
        # dnew_z_dnew_s = jnp.stack((dnew_z_dnew_v, dnew_z_dnew_b), axis=-1)

        # diagonal_jacobian = [dnew_s_ds, dnew_z_dnew_s]

        # "in_weights, rec_weights"
        # ds_dW_bias: 2 x n_rec
        # dnew_v_di = safe_grad(new_v,i_t)
        # dnew_b_di = safe_grad(new_b,i_t)
        # dnew_v_di = safe_grad(lambda x: get_new_v_b(s, x)[0], i_t)
        # dnew_b_di = safe_grad(lambda x: get_new_v_b(s, x)[1], i_t)
        # dnew_v_di = safe_grad(lambda x: get_new_v_b(s, x)[0], i_t)
        # dnew_b_di = safe_grad(lambda x: get_new_v_b(s, x)[1], i_t)

        
        # dnew_s_di = jnp.stack([dnew_v_di,dnew_b_di], axis=-1)

        # partials_wrt_biases = [dnew_s_di, dnew_s_di]

        new_state = CustomALIFStateTuple(s=new_s, z=new_z, r=new_r)
        # return [new_z, new_s, diagonal_jacobian, partials_wrt_biases], new_state
        return new_z, new_state

In [72]:
def lsnn2(x, state=None):
    core = hk.DeepRNN([
        # hk.Linear(n_rec), # otherwise not need Win
        RecurrentLIF(
            n_in,
            n_rec,
            tau=tau_v,
            thr=thr,
            dt=dt,
            dtype=jnp.float32,
            dampening_factor=dampening_factor,
            tau_adaptation=tau_a,
            beta=beta,
            tag='',
            stop_gradients=False,
            w_in_init=None,
            w_rec_init=None,
            n_refractory=n_ref,
            rec=True,
        )
    ])
    # spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=T)
    # return spikes, V
    # print(x.shape[1])
    if state is None:
        state = core.initial_state(5)
    spikes, hiddens = core(x, state)#core.initial_state(1))
    return spikes, hiddens

lsnn2_hk = hk.without_apply_rng(hk.transform(lsnn2))

In [73]:
# print(inputs[:,0].shape)
# i0 = jnp.expand_dims(inputs[:,0], axis=0)
i0 = jnp.stack([inputs[:,0], inputs[:,0], inputs[:,0],inputs[:,0], inputs[:,0]], axis=0)
print(i0.shape)
params = lsnn2_hk.init(rng=key, x=i0)
print(params)

(5, 1, 3)
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
{'RecurrentLIF': {'w_in': Array([[ 0.7967948 , -0.3821632 , -0.7605332 ,  0.45293623],
       [-0.03456055,  0.65856   ,  0.58331513, -0.10983399],
       [-0.4869853 ,  1.0580422 ,  0.53946483, -0.00187313]],      dtype=float32), 'w_rec': Array([[ 0.2713923 ,  0.86784893,  0.07354291, -0.43423995],
       [ 0.9018898 , -0.1874919 ,  0.18929863,  0.15613334],
       [ 0.2636667 ,  0.32536188,  0.5184128 ,  0.30088994],
       [ 0.22115956, -0.26662493,  0.49525732,  0.11215727]],      dtype=float32)}}


zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
[[0. 0. 1.]] -> [0. 1. 0. 0.]
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
[[0. 0. 0.]] -> [0. 0. 1. 0.]
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
[[0. 0. 0.]] -> [1. 0. 0. 0.]
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
[[0. 1. 1.]] -> [0. 1. 0. 0.]
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
[[0. 0. 0.]] -> [0. 0. 1. 0.]
zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4)

In [79]:
def eval2(inputs):
    state = None
    spikes = []
    V = []
    for t in range(T):
        it = jnp.stack([inputs[:,t], inputs[:,t], inputs[:,t],inputs[:,t], inputs[:,t]], axis=0)
        outs, state = lsnn2_hk.apply(params, it, state)
        print(inputs[:,t], "->", outs[0])
        spikes.append(outs[0])
        V.append(outs[1][0])
    # print(spikes)
    spikes = jnp.stack(spikes, axis=0)
    print(spikes.shape)
    w_out = jax.random.normal(key=key, shape=[n_rec, 1])
    decay_out = jnp.exp(-1 / 20)
    print(spikes.shape)
    z_filtered = exp_convolve(spikes, decay_out)
    print(z_filtered.shape)
    y_out = jnp.einsum("tj,jk->tk", z_filtered, w_out) # no batch dim
    y_target = jax.random.normal(key=key, shape=[T, 1])
    loss = 0.5 * jnp.sum((y_out - y_target) ** 2)
    print(loss)
    return loss

In [80]:
surrogate_grad = jax.value_and_grad(eval2)(inputs)

zs (5, 4) (5, 4, 2)
(5, 1, 3) (3, 4)
i_in (5, 4)
z wrec (5, 4) (4, 4)
i_t (5, 4)
vs (5, 4) (5, 4)
vii (5, 4) (5, 4) (5, 4)
nv nb (5, 4) (5, 4)
Traced<ConcreteArray([[0. 0. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 1.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[1,3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000018811560CD0>, in_tracers=(Traced<ShapedArray(float32[1,1,3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x00000188113E0F40; to 'JaxprTracer' at 0x00000188113E18B0>], out_avals=[ShapedArray(float32[1,3])], primitive=squeeze, params={'dimensions': (1,)}, effects=frozenset(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x00000188115984F0>, name_stack=NameStack(stack=(Transform(name='jvp'),)))) -> Traced<ConcreteArray([0. 1. 0. 0.], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0., 1

TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef([*, *]) for primal input structure PyTreeDef((*, *)).