In [29]:
import jax.numpy as jnp
from jax import random,jit
from jax import nn
from jaxtyping import Array,Float,PyTree
import matplotlib.pyplot as plt
from jax import vmap

In [30]:
KEY = random.PRNGKey(42)

In [60]:
class RNN():
    def __init__(self,key,inp,hidden,out):
        self.inp=inp
        self.hidden=hidden
        self.out=out
        self.W_hh = random.normal(KEY,(hidden,hidden))
        self.W_hi = random.normal(KEY,(hidden,inp))
        self.b_h = random.normal(KEY,(hidden,1))
        
        self.W_oh = random.normal(KEY,(out,hidden))
        self.b_o = random.normal(KEY,(out,1))
        
        self.ycaps=None
        self.ht=None
        
    def printsize(self):
        print(self.W_hh.shape)
        print(self.W_hi.shape)
        print(self.b_h.shape)
        print(self.W_oh.shape)
        print(self.b_o.shape)
    
    def forward(self, x, h):
        ht = nn.tanh(jnp.matmul(self.W_hi, x) + jnp.matmul(self.W_hh, h) + self.b_h) 
        out = jnp.matmul(self.W_oh, ht) + self.b_o
        return ht, out

    jit_forward = jit(forward,static_argnums=[0])
        
    def forwardpass(self, x, h0=None, key=random.PRNGKey(0)):
            N = x.shape[0]

            if h0 is None:
                h0 = random.normal(key, (self.hidden, 1))
            
            self.ycaps = jnp.zeros((N, self.out, 1))  # Initialize output container
            self.ht = jnp.zeros((N + 1, self.hidden, 1))  # Initialize hidden state container

            self.ht = self.ht.at[0].set(h0)  # Set the initial hidden state
            
            for i in range(N):
                # Ensure x[i] is reshaped correctly as a column vector
                h0, out = self.jit_forward(x[i].reshape(self.inp, 1), h0)
                self.ht = self.ht.at[i + 1].set(h0)
                self.ycaps = self.ycaps.at[i].set(out.reshape(self.out, 1))  # Ensure output is reshaped correctly

            return self.ycaps, self.ht  # Return predicted outputs and hidden states
            
        

In [61]:
model = RNN(KEY,1,10,1)

In [62]:
h0 = random.normal(KEY, (10, 1))

In [63]:
x = random.normal(KEY,(10,1))

In [69]:
model.forwardpass(x)[1]

Array([[[-0.3721109 ],
        [ 0.26423115],
        [-0.18252768],
        [-0.7368197 ],
        [-0.44030377],
        [-0.1521442 ],
        [-0.67135346],
        [-0.5908641 ],
        [ 0.73168886],
        [ 0.5673026 ]],

       [[-0.39171445],
        [-0.9976392 ],
        [-0.988538  ],
        [ 0.9884642 ],
        [ 0.9555557 ],
        [-0.97827613],
        [ 0.8731274 ],
        [-0.8141124 ],
        [-0.71396935],
        [ 0.9958272 ]],

       [[-0.9999988 ],
        [-0.27952844],
        [ 0.9998438 ],
        [-0.7050039 ],
        [-0.9865794 ],
        [-0.83568287],
        [ 0.83201253],
        [-0.88802314],
        [ 0.9640635 ],
        [-0.9640202 ]],

       [[ 0.88066024],
        [-0.14389622],
        [ 0.9615624 ],
        [ 0.99947494],
        [ 0.99195856],
        [ 0.94095325],
        [ 0.9606708 ],
        [-0.03690604],
        [-0.9989105 ],
        [ 0.99216336]],

       [[-0.9470001 ],
        [ 0.9934495 ],
        [ 0.34400997],
   