In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras_snn
import numpy as np
import keras

In [2]:
data = np.random.randn(128, 32, 34, 34, 1)

In [3]:
class SNN(keras.Layer):
    def __init__(
        self,
        layer,
        neuron,
        return_sequences=False,
        return_state=False,
        unroll=False,
    ):
        super().__init__()
        self.layer = keras.layers.TimeDistributed(layer)
        self.neuron = keras.layers.RNN(
            neuron,
            return_sequences=return_sequences,
            return_state=return_state,
            unroll=unroll,
        )

    def build(self, input_shape):
        self.layer.build(input_shape)
        self.neuron.build(self.layer.compute_output_shape(input_shape))
        self.built = True

    def call(self, inputs):
        x = self.layer(inputs)
        x = self.neuron(x)
        return x

In [4]:
class LICell(keras.Layer):
    def __init__(self):
        super().__init__()
        self.state_size = -1

    def call(self, x, v):
        v_shape = v[0].shape
        v = v[0].reshape(x.shape)
        x = x + v
        return x, [x.reshape(v_shape)]

In [16]:
model = keras.Sequential()
model.add(
    keras.layers.TimeDistributed(
        keras.layers.Conv2D(
            32,
            (3, 3),
            strides=2,
            padding="same",
        )
    )
)
model.add(
    keras.layers.RNN(
        LICell((17 * 17 * 32)),
        return_sequences=True,
        unroll=True,
    )
)
model.add(keras.layers.TimeDistributed(keras.layers.Flatten()))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(10)))
model.add(keras.layers.RNN(LICell(10), unroll=True))


model(keras.random.normal((1, 32, 34, 34, 1)))

Array([[ 13.466108 ,   6.1763525, -21.269682 ,   4.532126 ,  11.807274 ,
         28.237759 ,  10.52721  , -85.31996  , -44.83778  , -17.812338 ]],      dtype=float32)

In [6]:
def model():
    x = keras.layers.Input(shape=(None, 34, 34, 1))
    print(x.shape)
    x = keras.layers.TimeDistributed(
        keras.layers.Conv2D(
            32,
            kernel_size=3,
            strides=2,
            padding="SAME",
        )
    )(x)
    print(x.shape)
    x = keras.layers.TimeDistributed(keras.layers.Flatten())(x)
    print(x.shape)
    x = keras.layers.TimeDistributed(keras.layers.Dense(10))(x)
    print(x.shape)
    return x


model()

(None, None, 34, 34, 1)
(None, None, 17, 17, 32)
(None, None, 9248)
(None, None, 10)


<KerasTensor shape=(None, None, 10), dtype=float32, sparse=False, name=keras_tensor_9>