In [0]:
import torch
from torch import nn
import torch.nn.functional as F

In [0]:
class GRUCell(nn.Module):
  def __init__(self, input_depth, hidden_depth):
    super().__init__()
    
    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    
    concat_depth = input_depth + hidden_depth

    self.forget = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Sigmoid())
    self.update = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Sigmoid())

    self.output = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Tanh())

    self.hidden = None
  
  def reset(self):
    self.hidden = None
  
  def forward(self, token):
    if self.hidden is None:
      self.hidden = torch.zeros(token.shape[0], self.hidden_depth)

    concat = torch.cat([token, self.hidden], dim=1)

    forget = self.forget(concat)
    update = self.update(concat)

    forget_concat = torch.cat([token, self.hidden * forget], dim=1)

    output = self.output(forget_concat)

    self.hidden = update * self.hidden + (1 - update) * output 

    return self.hidden



In [0]:
class GRU(nn.Module):
  def __init__(self, input_depth, hidden_depth, num_layers): 
    super().__init__()

    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    self.num_layers = num_layers

    self.layers = [GRUCell(input_depth, hidden_depth) for _ in range(num_layers)]

  def reset(self):
    for l in self.layers: l.reset()

  # Inp takes shape (sequence_length, batch_size, input_depth)
  def forward(self, inp):
    
    self.reset()

    hiddens = []


    for token in inp:
      carry = token
      for l in self.layers:
        hidden = l(carry)
        carry = hidden
      
      hiddens.append(carry)

    return hiddens


In [0]:
l = GRU(10, 10, 2)

In [9]:
l(torch.randn(3, 2, 10))

[tensor([[-0.0041, -0.1195,  0.0676,  0.0983, -0.0706,  0.0321, -0.0816, -0.1101,
          -0.0573,  0.0662],
         [-0.0228, -0.0271,  0.0494,  0.1206, -0.0278,  0.0045, -0.0383, -0.0724,
          -0.0694,  0.1469]], grad_fn=<AddBackward0>),
 tensor([[ 0.0485, -0.1375,  0.0805,  0.1373, -0.0485,  0.0473, -0.1241, -0.1486,
          -0.0717,  0.1793],
         [-0.0096, -0.1022,  0.0942,  0.1735, -0.0580,  0.0479, -0.1317, -0.1305,
          -0.0725,  0.1653]], grad_fn=<AddBackward0>),
 tensor([[ 0.0271, -0.1816,  0.0940,  0.1594, -0.0741,  0.0546, -0.1443, -0.1748,
          -0.1018,  0.2152],
         [-0.0541, -0.0950,  0.1355,  0.1998, -0.0739,  0.0271, -0.1742, -0.1629,
          -0.1016,  0.1918]], grad_fn=<AddBackward0>)]