In [7]:
import torch
from torch import nn

In [8]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, X, c, h, batch_first=False):
        if batch_first:
            X = X.transpose(0, 1)
        outputs = []
        for x in X:
            xh = torch.cat((x, h), dim=1)
            ig = self.sigmoid(self.input_gate(xh))
            fg = self.sigmoid(self.forget_gate(xh))
            og = self.sigmoid(self.output_gate(xh))
            cell_candidate = self.tanh(self.cell_gate(xh))
            c = c * fg + cell_candidate * ig
            h = self.tanh(c * og)
            outputs.append(h)
        outputs = torch.stack(outputs)
        if batch_first:
            outputs.transpose_(0, 1)
        return outputs

In [9]:
input_size = 50
hidden_size = 100
lstm = LSTM(input_size, hidden_size)

X = torch.randn((8, 10, input_size))
h = torch.zeros((8, hidden_size))
c = torch.zeros((8, hidden_size))

y = lstm(X, c, h, True)
print(y.shape)

torch.Size([8, 10, 100])


In [12]:
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.upgrade_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, X, h, batch_first=False):
        if batch_first:
            X = X.transpose(0, 1)
        outputs = []
        for x in X:
            xh = torch.cat((x, h), dim=1)
            rg = self.sigmoid(self.reset_gate(xh))
            ug = self.sigmoid(self.upgrade_gate(xh))
            h_candidate = self.tanh(self.output_layer(torch.cat((x, h * rg), dim=1)))
            h = ug * h + (1 - ug) * h_candidate
            outputs.append(h)
        outputs = torch.stack(outputs)
        if batch_first:
            outputs.transpose_(0, 1)
        return outputs

In [13]:
input_size = 50
hidden_size = 100
lstm = GRU(input_size, hidden_size)

X = torch.randn((8, 10, input_size))
h = torch.zeros((8, hidden_size))

y = lstm(X, h, True)
print(y.shape)

torch.Size([8, 10, 100])


: 