In [0]:
import torch
from torch import nn

In [0]:
class RNN_Cell(nn.Module):
  def __init__(self, input_depth, hidden_depth,
               hidden_act = nn.Tanh, output_act = nn.Sigmoid):
    super().__init__()

    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    self.hidden_act = hidden_act
    self.output_act = output_act

    concat_depth = input_depth + hidden_depth
    self.inp_to_hidden = nn.Sequential(nn.Linear(concat_depth, hidden_depth), 
                                       hidden_act())
    
    self.hidden_to_out = nn.Sequential(nn.Linear(hidden_depth, input_depth), 
                                       output_act())

    self.hidden = None

  def reset(self): 
    self.hidden = None
  
  def forward(self, inp):

    if self.hidden is None:
      self.hidden = torch.zeros(inp.shape[0], self.hidden_depth)

    
    concat = torch.cat([inp, self.hidden], dim=1)

    self.hidden = self.inp_to_hidden(concat)
    return self.hidden_to_out(self.hidden)


In [0]:
class RNN(nn.Module):
  def __init__(self, input_depth, hidden_depth, num_layers): 
    super().__init__()

    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    self.num_layers = num_layers

    self.layers = [RNN_Cell(input_depth, hidden_depth) for i in range(num_layers)]

  def reset(self):
    for l in self.layers: l.reset()

  # Our input sequence has shape (sequence_length, batch_size, input_depth)
  def forward(self, inp):
    
    self.reset()

    outs = []

    for token in inp:
      carry = token
      for l in self.layers:
        hidden = l(carry)
        carry = hidden
      
      outs.append(carry)

    return outs

In [0]:
rnn = RNN(32, 64, 3)

In [14]:
rnn(torch.randn(3, 4, 32))

concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])
concat shape: torch.Size([4, 96])


[tensor([[0.5104, 0.5181, 0.5090, 0.4615, 0.4776, 0.5083, 0.4941, 0.4791, 0.4758,
          0.4989, 0.5186, 0.5183, 0.4915, 0.4548, 0.4809, 0.4851, 0.4903, 0.5486,
          0.5178, 0.5197, 0.5123, 0.4836, 0.5497, 0.5222, 0.4808, 0.4269, 0.5072,
          0.5192, 0.5272, 0.4310, 0.5393, 0.4524],
         [0.5102, 0.5184, 0.5088, 0.4616, 0.4775, 0.5084, 0.4941, 0.4790, 0.4761,
          0.4988, 0.5188, 0.5183, 0.4916, 0.4549, 0.4808, 0.4848, 0.4904, 0.5485,
          0.5180, 0.5198, 0.5126, 0.4836, 0.5498, 0.5222, 0.4810, 0.4270, 0.5072,
          0.5192, 0.5270, 0.4310, 0.5393, 0.4523],
         [0.5103, 0.5180, 0.5089, 0.4613, 0.4776, 0.5084, 0.4942, 0.4789, 0.4758,
          0.4988, 0.5186, 0.5183, 0.4916, 0.4547, 0.4808, 0.4850, 0.4902, 0.5487,
          0.5180, 0.5200, 0.5126, 0.4835, 0.5498, 0.5223, 0.4810, 0.4269, 0.5072,
          0.5191, 0.5273, 0.4308, 0.5393, 0.4526],
         [0.5103, 0.5183, 0.5089, 0.4617, 0.4773, 0.5083, 0.4941, 0.4790, 0.4760,
          0.4987, 0.5187, 0