In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class RNNCell(nn.Module):
    """RNN Cell implementation"""
    def __init__(self, input_size, hidden_size, output_size):
        """Initialize by passing the input_size, the hidden_size, the output_size"""
        super().__init__()
        #define linear units: output = input*weight+bias
        self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=True)
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size, bias=False)
        self.hidden_to_output = nn.Linear(hidden_size, output_size, bias=True)

    def forward(self, current_input, previous_hidden_state):
        """
        next hidden state = activation_func(W_input_to_hidden * input 
        + W_hidden_to_hidden * previous_hidden + bias)
        H_t = phi(W_xh*X+W_hh*H_(t-1)+b_xh
        
        output = W_hidden_to_output*next_hidden+bias
        y = W_hy*H_t+bias_hy
        
        Input:
        current_input [batch, input_size]
        previous_hidden_state: [batch, hidden_size]
        
        Output:
            next_hidden_state [batch, hidden_size]
            output: [batch, output_size]
        """
        #compute next hidden state
        next_hidden_state = torch.tanh(self.input_to_hidden(current_input) 
                                       + self.hidden_to_hidden(previous_hidden_state))
        #compute output
        output = self.hidden_to_output(next_hidden_state)
        return next_hidden_state, output