<a href="https://colab.research.google.com/github/hoyongjungdev/pda-classification/blob/main/RETAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# RETAIN model with GRU
class RETAIN(nn.Module):
  def __init__(self, params:dict):
    super(RETAIN, self).__init__()
    self.device = params["device"]

    # 1. Embedding
    self.emb_layer = nn.Linear(in_features=params["num_embeddings"], out_features=params["embedding_dim"])

    # 2. visit-level attention
    self.visit_level_rnn = nn.GRU(params["visit_rnn_hidden_size"], params["visit_rnn_output_size"]).to(self.device)
    self.visit_hidden_size = params["visit_rnn_hidden_size"]
    self.visit_level_attention = nn.Linear(params["visit_rnn_output_size"], params["visit_attn_output_size"]) # α (scalar)

    # 3. variable-level attention
    self.variable_level_rnn = nn.GRU(params["var_rnn_hidden_size"], params["var_rnn_output_size"]).to(self.device)
    self.var_hidden_size = params["var_rnn_hidden_size"]
    self.variable_level_attention = nn.Linear(params["var_rnn_output_size"], params["var_attn_output_size"]) # β (vector)

    # etc
    self.dropout = nn.Dropout(params["dropout_p"])
    self.output_dropout = nn.Dropout(params["output_dropout_p"])
    self.output_layer = nn.Linear(params["embedding_output_size"], params["num_class"])
    self.d = 1


  def forward(self, input):
    # forwarding : get 2 attentions
    
    # 1. Embedding
    v = self.emb_layer(input)
    v = self.dropout(v)

    # 2. visit-level attention
    visit_rnn_hidden = torch.zeros(self.d, input.size()[0], self.visit_hidden_size).to(self.device) # initial
    visit_rnn_output, visit_rnn_hidden = self.visit_level_rnn(torch.flip(v, [0]), visit_rnn_hidden) # in reverse order
    alpha = self.visit_level_attention(torch.flip(visit_rnn_output, [0]))
    visit_attn_w = F.softmax(alpha, dim=0)

    # 3. variable-level attention
    var_rnn_hidden = torch.zeros(self.d, input.size()[0], self.var_hidden_size).to(self.device) # initial
    var_rnn_output, var_rnn_hidden = self.variable_level_rnn(torch.flip(v, [0]), var_rnn_hidden) # in reverse order
    beta = self.variable_level_attention(torch.flip(var_rnn_output, [0]))
    var_attn_w = torch.tanh(beta)

    # 4. generate context vector
    attn_w = visit_attn_w * var_attn_w
    c = torch.sum(attn_w * v, dim=0)
    c = self.output_dropout(c)

    # 5. prediction
    output = self.output_layer(c)
    output = F.softmax(output, dim=1)

    return output

In [None]:
# parameters
def init_params(params: dict):
    # embedding matrix
    params["num_embeddings"] = 28 # input dimension
    params["embedding_dim"] = 128

    # embedding dropout
    params["dropout_p"] = 0.5

    # Alpha (scalar)
    params["visit_rnn_hidden_size"] = 128
    params["visit_rnn_output_size"] = 128
    params["visit_attn_output_size"] = 1
    # Beta (vector)
    params["var_rnn_hidden_size"] = 128
    params["var_rnn_output_size"] = 128
    params["var_attn_output_size"] = 128

    params["embedding_output_size"] = 128
    params["num_class"] = 2 # 0 or 1
    params["output_dropout_p"] = 0.8
    params["device"] = device