In [2]:
import equinox as eqx
import jax
import jax.numpy as jnp
from dataclasses import dataclass



In [3]:
@dataclass 
class RNNConfig: 
  input_dim: int = 2
  output_dim: int = 1
  hidden_dim: int = 40 

In [25]:
key = jax.random.PRNGKey(446)

class RNN(eqx.Module): 
    config: RNNConfig
    hidden_dim: int
    Wh: eqx.nn.Linear
    Wx: eqx.nn.Linear
    Wy: eqx.nn.Linear
    act: callable
    device: str = 'cpu'
    
    # TODO: check bias 
    def __init__(self, config, device='cpu'):
        self.config = config
        self.hidden_dim = config.hidden_dim
        self.Wh = eqx.nn.Linear(config.hidden_dim, config.hidden_dim, key=key)
        self.Wx = eqx.nn.Linear(config.input_dim, config.hidden_dim, key=key)
        self.Wy = eqx.nn.Linear(config.hidden_dim, config.output_dim, key=key)
        self.act = jax.nn.sigmoid
        self.device = device

    def __call__(self, x): 
        # batch_size, seq_len, _ = x.shape
        # batch_size, seq_len, _  = 1, x.shape
        seq_len, _ = x.shape

        # hidden = torch.nn.Parameter(torch.zeros(batch_size, self.hidden_dim))
        # hidden = jnp.zeros((batch_size, self.hidden_dim))
        hidden = jnp.zeros((self.hidden_dim,))
        outs = []

        for i in range(seq_len):
            # print(self.Wh(hidden).shape)
            # print(hidden.shape)
            hidden = self.act(self.Wh(hidden) + self.Wx(x[i,:]))
            out = self.Wy(hidden)
            outs.append(out)
        
        # output shape = (batch_size, sequence length, output dimension)
        # return torch.stack(outs).permute(1,0,2)
        stacked = jnp.stack(outs, axis=1)
        return jnp.transpose(stacked, (1,0,2))


In [26]:
config = RNNConfig(input_dim=2, hidden_dim=2, output_dim=1)
model = RNN(config)
model

RNN(
  config=RNNConfig(input_dim=2, output_dim=1, hidden_dim=2),
  hidden_dim=2,
  Wh=Linear(
    weight=f32[2,2],
    bias=f32[2],
    in_features=2,
    out_features=2,
    use_bias=True
  ),
  Wx=Linear(
    weight=f32[2,2],
    bias=f32[2],
    in_features=2,
    out_features=2,
    use_bias=True
  ),
  Wy=Linear(
    weight=f32[1,2],
    bias=f32[1],
    in_features=2,
    out_features=1,
    use_bias=True
  ),
  act=<wrapped function sigmoid>,
  device='cpu'
)

In [27]:
batch_size = 128
seq_len = 10
input_dim = 2
# x = jax.random.normal(key, (batch_size, seq_len, input_dim))
x = jax.random.normal(key, (seq_len, input_dim))
model(x)

IndexError: Too many indices for array: array has ndim of 2, but was indexed with 3 non-None/Ellipsis indices.