# GRUs
- They're the next innovation after LSTMs but before Transformers.
- They are an adaptation of RNNs, similar to LSTMs
- They have less tensor operators so they are speedier to train. 
- You can think of them as LSTMs without the short term memory, but they do more operations to the long term memory.
- Researchers try both LSTMs and GRUs to see what works better for their usecase.

In [None]:
import torch
import torch.nn as nn
import numpy as np 

In [None]:
class Cell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(Cell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        self.input = nn.Linear(input_size, 3 * hidden_size, bias=bias)
        self.hidden = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std,std)

    def forward(self, input, hx=None):
        if hx = None:
            hx = Variable(input.new_zeros(input.size(0), self.hidden_size))
        
        xt = self.input(input)
        ht = self.hidden(hx)

        x_reset, x_upd, x_new = x_t.chunk(3, 1)
        h_reset, h_upd, h_new = h_t.chunk(3, 1)

        reset_gate = torch.sigmoid(x_reset + h_reset)
        update_gate = torch.sigmoid(x_upd + h_upd)
        new_gate = torch.tanh(x_new + (reset_gate * h_new))

        hy = update_gate * hx + (1 - update_gate) * new_gate

        return hy

In [None]:
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bias, output_size):
        super(GRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.output_size = output_size

        self.rnn_cell_list = nn.ModuleList()
        self.rnn_cell_list.append(Cell(self.input_size, self.hidden_size, self.bias))

        for l in range(1,self.num_layers):
            self.rnn_cell_list.append(Cell(self.hidden_size, self.hidden_size, self.bias))
        
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hx=None):
        if hx is None: 
            h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size).to("mps"))
        else: 
            h0 = hx

        outs = []
        hidden = list()

        for layer in range(self.num_layers):
            hidden.append(h0[layer, :, :])

        for t in range(input.size(1)):
            for layer in range(self.num_layers):
                if layer == 0:
                    hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer])
                else:
                    hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer])
                hidden[layer] = hidden_l
                hidden[layer] = hidden_l
            outs.append(hidden_l)

        out = outs[-1].squeeze()
        out = self.fc(out)

        return out