In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [21]:
gru = torch.nn.GRU(input_size=10, hidden_size=8)
inpt = torch.randn(5, 2, 10)
h0 = torch.randn(1, 2, 8)

In [22]:
gru(inpt, h0)[0].shape

torch.Size([5, 2, 8])

In [None]:
class LSTM(nn.Module):
    def __init__(self, in_size, hidden_size, output_size):
        super().__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.i2f = nn.Linear(in_size + hidden_size, hidden_size)
        self.i2u = nn.Linear(in_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(in_size + hidden_size, hidden_size)
        self.i2c = nn.Linear(in_size + hidden_size, hidden_size)
        self.out_classify = nn.Linear(in_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden, cell):
        combined = torch.concat([input, hidden], dim=1)
        forget = torch.sigmoid(self.i2f(combined))
        update = torch.sigmoid(self.i2u(combined))
        output = torch.sigmoid(self.i2o(combined))
        cell_tild = torch.tanh(self.i2c(combined))
        new_cell = forget * cell + update * cell_tild
        new_hidden = output * torch.tanh(new_cell)
        out_classify = self.softmax(self.out_classify(combined))
        return out_classify, new_hidden, new_cell

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

In [40]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.emb_size = 64
        self.embedding = nn.Linear(input_size, self.emb_size)
        self.gru = nn.GRU(input_size=self.emb_size, hidden_size=hidden_size)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input)
        output, new_hidden = self.gru(embedded, hidden)
        return output, new_hidden       

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)
        

In [74]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.emb_size = 64
        self.embedding = nn.Linear(output_size, self.emb_size)
        self.gru = nn.GRU(input_size=self.emb_size, hidden_size=hidden_size)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input)
        embedded = F.relu(embedded)
        output, new_hidden = self.gru(embedded, hidden)
        output = torch.softmax(output, dim=1)
        return output, new_hidden 
      
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)


In [75]:
enc = EncoderRNN(10, 25)

In [76]:
dec = DecoderRNN(25, 11)

In [77]:
h0 = enc.initHidden()

In [78]:
in0 = torch.zeros(1, 10)
in0[:, 3] = 1

In [79]:
dec_input, _ = enc(in0, h0)

In [80]:
dec_input.shape

torch.Size([1, 25])

In [82]:
start = torch.zeros(1, 11)
in0[:, 0] = 1
out = dec(start ,dec_input)

In [87]:
out[1].shape

torch.Size([1, 25])