In [10]:
import jax
from jax import value_and_grad,jit,vmap
import jax.numpy as jnp
import opax

In [9]:
Prng = jax.random.PRNGKey(42)

In [124]:
from abc import ABC, abstractmethod
import numpy as np

class NN(ABC):
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def forward(self, X):
        """Perform forward propagation"""
        pass
    



In [None]:
class LSTM(NN):
    

    def __init__(self ,input_size:int, hidden_size:int,rng):
        
        """
        Constructor for the LSTM class.
        -------------------------------
        Parameters:
        
        input_size (int): The number of input features (i.e., the dimensionality of the input vector).
        hidden_size (int): The number of units in the hidden state (i.e., the dimensionality of the LSTM's hidden state and cell state).
        rng (jax.random.PRNGKey): The random number generator key used to initialize the weights and biases of the LSTM gates.
        
        Initializes the LSTM with the following parameters:
        - Weight matrices for the input-to-hidden and hidden-to-hidden connections for the four gates: input gate, forget gate, candidate gate, and output gate.
        - Bias vectors for each gate.
        - Initializes the hidden state and cell state to zero.
        """        # initializing 
        self.input_size=input_size
        self.hidden_size=hidden_size
        
        
        # Input gate weights and biases
        self.Wii = jax.random.normal(rng, (input_size, hidden_size))  # Input-to-hidden
        self.Whi = jax.random.normal(rng, (hidden_size, hidden_size))  # Hidden-to-hidden
        self.bi = jnp.zeros((hidden_size,))  # Bias for input gate
        
        # Forget gate weights and biases
        self.Wif = jax.random.normal(rng, (input_size, hidden_size))
        self.Whf = jax.random.normal(rng, (hidden_size, hidden_size))
        self.bf = jnp.zeros((hidden_size,))
        
        # Candidate (cell update) gate weights and biases
        self.Wig = jax.random.normal(rng, (input_size, hidden_size))
        self.Whg = jax.random.normal(rng, (hidden_size, hidden_size))
        self.bg = jnp.zeros((hidden_size,))
        
        # Output gate weights and biases
        self.Wio = jax.random.normal(rng, (input_size, hidden_size))
        self.Who = jax.random.normal(rng, (hidden_size, hidden_size))
        self.bo = jnp.zeros((hidden_size,))        
            
        self.h_t = jnp.zeros((hidden_size,))
        self.c_t = jnp.zeros((hidden_size,))
        
        

    def forward(self, X):
        """
        Forawrd pass of LSTM for a single time step
        """
        # Assuming X in already an input for current time step t        
        X_t =X
        i = jax.nn.sigmoid(self.Wii@X_t+ self.Whi@self.h_t +self.bi)
        f = jax.nn.sigmoid(self.Wif@X_t+self.Whf@self.h_t+self.bf)
        g= jax.nn.sigmoid(self.Wig@X_t + self.Whg@self.h_t + self.bg)
        o = jax.nn.sigmoid(self.Wio@X_t+ self.Who@self.h_t + self.bo)
        
        self.c_t = f*self.c_t + i*g
        self.h_t = o* jax.nn.sigmoid(self.c_t)        
        
        return self.c_t,self.h_t
    
    def params(self):
        return {
            'Wii': self.Wii,
            'Whi': self.Whi,
            'bi': self.bi,
            'Wif': self.Wif,
            'Whf': self.Whf,
            'bf': self.bf,
            'Wig': self.Wig,
            'Whg': self.Whg,
            'bg': self.bg,
            'Wio': self.Wio,
            'Who': self.Who,
            'bo': self.bo,
            'c_t': self.c_t
        } 
    @staticmethod
    def loss(params,x,y):
        self.
        


In [132]:
model = LSTM(1,1,Prng)

In [128]:
x= jnp.linspace(-10,10,1000).reshape(-1,1)
y=jax.nn.sigmoid(x)

In [129]:
fwd = vmap( model.forward)

In [130]:
c,t=fwd(x)