In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Basic_RNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers=1, 
                 dropout=0.1):
        # x : input_size
        # v : embedding_size
        # h : hidden_size
        super().__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.gru = nn.GRU(embedding_size, hidden_size, num_layers=num_layers, batch_first=True)

    def forward(self, x):
        # x input of shape (batch, seq_len, input_size)

        # output of shape (batch, seq_len, num_directions * hidden_size)
        # h_n of shape (num_layers * num_directions, batch, hidden_size)
        output, hn = self.gru(x)
        return hn

class RNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, 
                 dropout = 0.1):
        super().__init__()
        self.rnn_module = Basic_RNN(input_size, embedding_size, hidden_size, 
                                    dropout=dropout)


        self.o_mat = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        self.emb = Parameter(torch.Tensor(input_size, embedding_size))
        self.init_weights()
        
    def init_weights(self):
        init.kaiming_uniform_(self.emb, a=math.sqrt(5))
        #self.gru.weight.data.normal_(0, 0.01)        
        
    def forward(self, x, lengths):
        batch_size = x.size(0)
        x = x @ self.emb
        # x input of shape (batch, seq_len, input_size)

        packed_x = pack_padded_sequence(x, batch_first=True, lengths=lengths)
        # packed_x input of shape (batch*seq_len, input_size)
        packed_output = self.rnn_module(packed_x)
        #(num_layers * num_directions, batch, hidden_size)
        # output of shape (batch, seq_len, num_directions * hidden_size)
        #output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=total_length)
        # c of shape (batch, num_directions * hidden_size)
        c = packed_output.transpose(0,1).contiguous().view(batch_size, -1, 1).squeeze(-1)

        c = self.dropout(c)
        y = torch.sigmoid(self.o_mat(c))
        return y