# Events Module
We are going to implement an attention mechanism similar to RETAIN (https://arxiv.org/abs/1608.05745).

alpha will be a scalar attention which will measure the importance of the events of particular day in the final output.
beta will be a vector of attention weights which will measure the importance of each type of event (feature) in the final output.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Alpha Attention

In [4]:
class AlphaAttention(torch.nn.Module):
    """
    Alpha attention mechanism to compute the attention weights corresponding to each date with events data.
    """

    def __init__(self, hidden_dim):
        super().__init__()
        """        
        Arguments:
            hidden_dim: the hidden layer dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """        
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
            
        """
        
        return F.softmax(self.a_att(g), dim=1)

## Beta attention

In [5]:
class BetaAttention(torch.nn.Module):
    """
    Beta attention mechanism to compute the attention weights corresponding to each event code.
    """

    def __init__(self, input_dim, emb_dim):
        super().__init__()
        """
        Arguments:
            input_dim: the hidden layer dimension
            emb_dim: the number of events codes
        """
        
        self.b_att = nn.Linear(input_dim, emb_dim)


    def forward(self, h):
        """
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, input_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, # of events codes)
        """
        
        return torch.tanh(self.b_att(h))

## Events RNN with Attention (similar to RETAIN)

In [9]:
class EventsRNN(nn.Module):
    
    def attention_sum(self, alpha, beta, x, masks):
        """
            Performs the weighted sum of the events data using alpha and beta attention weights. 
            It also sets to 0 the positions corresponding to dates without events data using the masks information.

        Arguments:
            alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
            beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
            x: the events data for each date with shape (batch_size, # of dates, # of events codes)
            masks: the padding masks in time of shape (batch_size, # of dates, # of events codes)

        Outputs:
            c: the context vector of shape (batch_size, hidden_dim)
        """
        
        masks = masks.unsqueeze(-1)
        return torch.sum( beta * x * alpha * masks , dim=1 )


    def __init__(self, num_codes, emb_size=128):
        super().__init__()

        # Define the RNN-alpha using `nn.GRU()`
        self.rnn_a = nn.GRU(num_codes, 128, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`
        self.rnn_b = nn.GRU(num_codes, 128, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`
        self.att_a = AlphaAttention(128)
        # Define the beta-attention using `BetaAttention()`
        self.att_b = BetaAttention(128, num_codes)
        # Define the linear layers using `nn.Linear()`
        self.fc = nn.Linear(num_codes, emb_size)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()

    
    def forward(self, events, masks):

        # Pass the events data through RNN-alpha
        g, _ = self.rnn_a(events)
        # Pass the events data through RNN-beta
        h, _ = self.rnn_b(events)
        # Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # Perform the weighted sum of the events data using the attention weights for the dates with events data
        c = self.attention_sum(alpha, beta, events, masks)
        # Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()