In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        # Linear layer that gives a score for each token
        self.attention = nn.Linear(hidden_size * 2, 1)  # *2 because BiGRU is bidirectional

    def forward(self, gru_output):
        # gru_output shape: [batch, seq_len, hidden_size*2]
        
        # 1. Compute attention scores for each token
        scores = self.attention(gru_output)   # [batch, seq_len, 1]
        
        # 2. Convert scores into probabilities with softmax
        attn_weights = F.softmax(scores, dim=1)  # [batch, seq_len, 1]
        
        # 3. Weighted sum of BiGRU outputs
        context_vector = torch.sum(attn_weights * gru_output, dim=1)  # [batch, hidden_size*2]
        
        return context_vector, attn_weights
