In [1]:
import torch
from torch import nn

In [2]:
from d2l_common import RNNScratch, Module


class BiRNNScratch(Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_hiddens = num_hiddens
        self.sigma = sigma
        self.f_rnn = RNNScratch(num_inputs, num_hiddens, sigma)
        self.b_rnn = RNNScratch(num_inputs, num_hiddens, sigma)

    def forward(self, inputs, Hs=None):
        f_H, b_H = (None, None) if Hs is None else Hs
        f_outputs, f_H = self.f_rnn(inputs, f_H)
        b_outputs, b_H = self.b_rnn(reversed(inputs), b_H)
        outputs = [torch.cat((f, b), dim=-1)
                   for f, b in zip(f_outputs, reversed(b_outputs))]
        return outputs, (f_H, b_H)

In [7]:
model = BiRNNScratch(10, 32)
output,_ = model(torch.randn(128,10))
len(output),output[0].shape

(128, torch.Size([10, 64]))