In [1]:
import torch
import torch.nn as nn


In [2]:
# Simplified Attention example
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        # hidden: (batch, hidden_dim)
        # encoder_outputs: (batch, seq_len, hidden_dim)
        seq_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        return torch.softmax(attention, dim=1)


In [3]:
# Example run
batch_size, seq_len, hidden_dim = 2, 5, 8
encoder_outputs = torch.randn(batch_size, seq_len, hidden_dim)
hidden = torch.randn(batch_size, hidden_dim)


In [4]:
attn = Attention(hidden_dim)
weights = attn(hidden, encoder_outputs)
print("Attention Weights:", weights)

Attention Weights: tensor([[0.1273, 0.1194, 0.3624, 0.2655, 0.1255],
        [0.1056, 0.2921, 0.2828, 0.1090, 0.2105]], grad_fn=<SoftmaxBackward0>)
