In [97]:
import torch
import torch.nn as nn

torch.__version__, torch.cuda.is_available()

('1.4.0', True)

In [232]:
class RNNCell(nn.Module):
    def __init__(self, in_size, out_size):
        super(RNNCell, self).__init__()
        self.linear = nn.Linear(in_size + out_size, out_size * 2)

    def forward(self, x, h):
        inp = torch.cat((x, h), dim=1)
        out = torch.tanh(self.linear(inp))
        y, new_h = torch.chunk(out, 2, dim=1)
        return y, new_h

In [233]:
cell = RNNCell(10, 20)
cell = torch.jit.script(cell)

x = torch.rand(5, 10)
h = torch.rand(5, 20)

y, h = cell(x, h)
y.shape, h.shape, x.shape

(torch.Size([5, 20]), torch.Size([5, 20]), torch.Size([5, 10]))

In [241]:
class RNNLoop(nn.Module):
    def __init__(self, bs, seq_size, in_size, out_size):
        super(RNNLoop, self).__init__()
        self.seq_size = seq_size
        self.bs = bs
        self.in_size = in_size
        self.out_size = out_size
        self.cell = torch.jit.script(RNNCell(in_size, out_size))
        
    def forward(self, xs):
        h = torch.zeros(self.bs, self.out_size)
        y = torch.zeros(self.bs, self.out_size)
        rollout = torch.zeros(self.bs, self.seq_size, self.out_size)

        for i in range(xs.size(0)):
            y, h = self.cell(xs[:, i], h)
            rollout[:, i] = y

        return rollout

In [246]:
bs = 8
seq_size = 12
in_size = 10
out_size = 20

loop = RNNLoop(bs, seq_size, in_size, out_size)
loop = torch.jit.script(loop)
print(loop.code)

xs = torch.rand(bs, seq_size, in_size)
loop(xs).shape

def forward(self,
    xs: Tensor) -> Tensor:
  h = torch.zeros([self.bs, self.out_size], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([self.bs, self.out_size], dtype=None, layout=None, device=None, pin_memory=None)
  _0 = [self.bs, self.seq_size, self.out_size]
  rollout = torch.zeros(_0, dtype=None, layout=None, device=None, pin_memory=None)
  h0 = h
  for i in range(torch.size(xs, 0)):
    _1 = torch.slice(xs, 0, 0, 9223372036854775807, 1)
    _2 = (self.cell).forward(torch.select(_1, 1, i), h0, )
    y0, h1, = _2
    _3 = torch.slice(rollout, 0, 0, 9223372036854775807, 1)
    _4 = torch.copy_(torch.select(_3, 1, i), y0, False)
    h0 = h1
  return rollout



torch.Size([8, 12, 20])