# Notes Module
We are going to implement a attention mechanism similar to RETAIN.

Alpha will be a scalar attention which will measure the importance of the Notes of particular day in the final output.
Betha will be a vector of attention weights which will measure the importance of each notes embeding value (feature) in the final output.

# Alpha Attention


In [None]:
class NotesAlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """"
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
        """
        
        weights = self.a_att(g)
        alpha = torch.softmax(weights,1)
        return alpha

# Beta Attention

In [None]:
class NotesBetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        """
        
        weights = self.b_att(h)
        beta = torch.tanh(weights)
        return beta

# NotesRnn

In [None]:
class NotesRNN(nn.Module):
    
    def attention_sum(self, alpha, beta, x, masks):
        """
        Arguments:
            alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
            beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
            rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            c: the context vector of shape (batch_size, hidden_dim)
        """
        #masks = (torch.sum(masks, 2) > 0).type(torch.float).unsqueeze(2)
        
        return torch.sum( x * alpha * beta * masks , dim=1 )
    
    def __init__(self, hidden_dim=128, notes_emb_size=200):
        super().__init__()
        
        self.rnn_a = nn.GRU(notes_emb_size, notes_emb_size, batch_first=True)
        self.rnn_b = nn.GRU(notes_emb_size, notes_emb_size, batch_first=True)
        self.att_a = NotesAlphaAttention(notes_emb_size)
        self.att_b = NotesBetaAttention(notes_emb_size)
        self.fc = nn.Linear(notes_emb_size, hidden_dim)
        self.sigmoid = nn.Sigmoid()
        #self.emb_size = notes_emb_size
        #self.RNN = nn.GRU(input_size = input_notes_emb_size, hidden_size = notes_emb_size, batch_first = True)
        #self.fc1 = nn.Linear(notes_emb_size, notes_emb_size)
        #self.relu = nn.ReLU()
        #self.dropout = nn.Dropout()
        #self.fc2 = nn.Linear(notes_emb_size,128)
        #self.sig = nn.Sigmoid()
        
    def forward(self, x, masks):
        g, _ = self.rnn_a(x)
        h, _ = self.rnn_b(x)
        alpha = self.att_a(g)
        beta = self.att_b(h)
        #print(alpha.shape)
        #print(masks.shape)
        c = self.attention_sum(alpha, beta, x, masks)
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        #rnn_out = self.RNN(x)
        #last_note_date_hs = get_last_note_date(rnn_out[0],masks)
        #fc1_out = self.fc1(last_note_date_hs)
        #fc1_out = self.relu(fc1_out)
        #dp_out = self.dropout(fc1_out)
        #fc2_out = self.fc2(dp_out)
        #out = self.sig(fc2_out).flatten()

        return probs.squeeze()