# Events Module
We are going to implement a attention mechanism similar to RETAIN.

alpha will be a scalar attention which will measure the importance of the events of particular day in the final output.
beta will be a vector of attention weights which will measure the importance of each type of event (feature) in the final output.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Alpha Attention

In [4]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """        
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
            
        HINT: consider `torch.softmax`
        """
        
        return F.softmax(self.a_att(g), dim=1)

## Beta attention

In [5]:
class BetaAttention(torch.nn.Module):

    def __init__(self, input_dim, emb_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(input_dim, emb_dim)


    def forward(self, h):
        """
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        HINT: consider `torch.tanh`
        """
        
        # your code here
        #print ("h", h.shape)
        return torch.tanh(self.b_att(h))

## Events RNN with Attention (similar to RETAIN)

In [9]:
class EventsRNN(nn.Module):
    
    def attention_sum(self, alpha, beta, x, masks):
        """
        TODO: mask select the hidden states for true visits (not padding visits) and then
            sum the them up.

        Arguments:
            alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
            beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
            rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            c: the context vector of shape (batch_size, hidden_dim)
            
        NOTE: Do NOT use for loop.
        """
        
        # your code here
        
        #mymask = (torch.sum(masks, dim=-1) > 0).type(torch.float).unsqueeze(-1)
        #print ("alpha", alpha.shape)
        #print ("beta", beta.shape)
        #print ("x", x.shape)
        #print ("mask", masks.unsqueeze(-1).shape)

        #print ("mymask", mymask.shape)
        #x1 = beta * x
        #print ("x1", x1.shape)
        #x2 = x1 * alpha
        #print ("x2", x2.shape)
        #x3 = x2 * masks.unsqueeze(-1)
        #print ("x3", x3.shape)
        #s = torch.sum( x3 , dim=1 )
        #print ("s", s.shape)
        masks = masks.unsqueeze(-1)
        #print (masks[1])
        return torch.sum( beta * x * alpha * masks , dim=1 )


    def __init__(self, num_codes, emb_size=128):
        super().__init__()
        
       # self.embedding = nn.Embedding(num_codes, emb_size)
        #self.rnn = nn.GRU(num_codes, hidden_size=emb_size, batch_first=True)
        #self.fc1 = nn.Linear(emb_size, 128)
        #self.sig = nn.Sigmoid()

        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        #self.embedding = nn.Embedding(num_codes, 128)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(num_codes, 128, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(num_codes, 128, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(128)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(128, num_codes)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(num_codes, emb_size)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()

    
    def forward(self, events, masks):
        
        #rnn_hidden_states, _ = self.rnn(events)        
        #real_hidden_states = rnn_hidden_states * masks.unsqueeze(-1).expand(rnn_hidden_states.shape)
        #sum_hidden_states = real_hidden_states.sum(dim=1)
        
        #fc1 = self.fc1(sum_hidden_states)
        #output = self.sig(fc1).flatten()

        # 1. Pass the reversed sequence through the embedding layer;
#        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
#        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(events)
        h, _ = self.rnn_b(events)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = self.attention_sum(alpha, beta, events, masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
        
        
        #return sum_hidden_states