## SEQUENTIAL ARCHITECTURES (RNN, LSTM, GRU)    

## Vanilla RNN

In [3]:
import torch
import torch.nn as nn

class RNN_block(nn.Module):
    def __init__(self,input_size, hidden_size):
        super().__init__()
        #h_t+1 = sigma(whh * h_t-1 + whx * x_t-1), here sigma I assume tanh
        #hidden to hidden 
        self.whh = nn.Linear(hidden_size, hidden_size)
        #input to hidden
        self.whx = nn.Linear(input_size, hidden_size)
        self.tanh = nn.Tanh()
    
    def forward(self, x, h):
        h_from_prev = self.whh(h)
        h_from_input = self.whx(x)
        return self.tanh(h_from_input+h_from_prev)
    
class simpleRNN(nn.Module):
    
    def __init__(self,input_size, hidden_size, output_size, vocab_size):
        super().__init__()
        self.output = nn.Linear(hidden_size, output_size)
        self.tanh = nn.Tanh()
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.rnn_block = RNN_block(input_size, hidden_size)
        self.hidden_size = hidden_size
    
    def forward(self, x):
        embedding = self.embedding(x)
        #batch, seq_len, input_size
        seq_len = embedding.shape[1]
        #batch, seq_len, hidden_size
        h = torch.zeros(embedding.shape[0], self.hidden_size, device=embedding.device,requires_grad=True)
        for i in range(seq_len):
            h = self.rnn_block(embedding[:,i,:], h)
        out = self.output(h)
        return out

## RNN bidirectional

In [4]:
  
class bidirectionalRNN(nn.Module):
    
    def __init__(self,input_size, hidden_size, output_size, vocab_size):
        super().__init__()
        self.output = nn.Linear(2*hidden_size, output_size)
        self.tanh = nn.Tanh()
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.rnn_block_forward = RNN_block(input_size, hidden_size)
        self.rnn_block_backward = RNN_block(input_size, hidden_size)
        self.hidden_size = hidden_size
    
    def forward(self, x):
        embedding = self.embedding(x)
        #batch, seq_len, input_size
        seq_len = embedding.shape[1]
        #forward direction
        h_forward = torch.zeros(embedding.shape[0], self.hidden_size, device=embedding.device,requires_grad=True)
        for i in range(seq_len):
            h_forward = self.rnn_block_forward(embedding[:,i,:], h_forward)
        #backward direction
        h_backward = torch.zeros(embedding.shape[0], self.hidden_size, device=embedding.device,requires_grad=True)
        for i in range(seq_len):
            h_backward = self.rnn_block_backward(embedding[:,seq_len-i-1,:], h_backward)
        #now we concatenate the forward and backward hidden states at the end
        h = torch.cat((h_forward, h_backward), dim=1)
        out = self.output(h)
        return out

## LSTM 