In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np

class Basic_RNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers=1, 
                 dropout=0.1):
        # x : input_size
        # v : embedding_size
        # h : hidden_size
        super().__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.gru = nn.GRU(embedding_size, hidden_size, num_layers=num_layers, batch_first=True)

    def forward(self, x):
        # x input of shape (batch, seq_len, input_size)

        # output of shape (batch, seq_len, num_directions * hidden_size)
        # h_n of shape (num_layers * num_directions, batch, hidden_size)
        output, hn = self.gru(x)
        return output

class RETAIN_BKEY(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, 
                 dropout = 0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn_module = Basic_RNN(input_size, embedding_size, hidden_size, 
                                    dropout=dropout)
        self.rnn_b_module = Basic_RNN(input_size, embedding_size, hidden_size, 
                                    dropout=dropout)
        self.a_mat = nn.Linear(hidden_size, 1)
        self.b_mat = nn.Linear(hidden_size, hidden_size)
        self.o_mat = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        self.emb = Parameter(torch.Tensor(input_size, embedding_size))
        self.init_weights()
        
    def init_weights(self):
        init.normal_(self.emb, 0, 0.01)
        #self.gru.weight.data.normal_(0, 0.01)        
        
    def forward(self, x, lengths):
        batch_size = x.size(0)
        total_length = x.size(1)
        # x input of shape (batch, seq_len, input_size)
        x = self.dropout(x @ self.emb)

        packed_x = pack_padded_sequence(x, batch_first=True, lengths=lengths)
        # packed_x input of shape (batch*seq_len, input_size)
        
        packed_output = self.rnn_module(packed_x)
        # output of shape (batch, seq_len, num_directions * hidden_size)
        output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=total_length)        
        #alpha = self.a_mat(output)     
        #query = torch.mean(x, 1).unsqueeze(1)
        query = (x[:, 0, :]).unsqueeze(1) #- torch.mean(x, 1, keepdim=True)
        alpha = output @ query.transpose(1,2)
        #avgs = torch.mean(x, 1, keepdim=True).expand(x.size())
        #alpha = torch.sum(output * (x-avgs), dim=2) / np.sqrt(self.hidden_size)

        alpha = torch.softmax(alpha, dim=1)
        packed_output = self.rnn_b_module(packed_x)
        # output of shape (batch, seq_len, num_directions * hidden_size)
        output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=total_length)   
        '''
        atten = torch.softmax(torch.sum(output * x, dim=2) / np.sqrt(self.hidden_size), dim=1).unsqueeze(1)
        beta = self.b_mat(output)        
        beta = torch.tanh(beta)        
        
        c = (atten @ (x*beta)).squeeze(1)
        '''       
        beta = self.b_mat(output)        
        beta = torch.tanh(beta)   
        c = ((x*beta).transpose(1,2) @ alpha).squeeze(-1)
        c = self.dropout(c)
        y = torch.sigmoid(self.o_mat(c))
        return y