In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.attn = nn.Linear(hidden_dim, 1)

    def forward(self, lstm_output):
        # lstm_output: [batch_size, seq_len, hidden_dim]
        attn_weights = F.softmax(self.attn(lstm_output), dim=1)  # [batch, seq_len, 1]
        context = torch.sum(attn_weights * lstm_output, dim=1)   # [batch, hidden_dim]
        return context, attn_weights

class LSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(LSTMWithAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: [batch_size, seq_len, input_dim]
        lstm_out, _ = self.lstm(x)                  # [batch, seq_len, hidden_dim]
        context, attn_weights = self.attention(lstm_out)
        out = self.fc(context)                      # [batch, output_dim]
        return out, attn_weights

# Example usage
batch_size = 16
seq_len = 10
input_dim = 5
hidden_dim = 32
output_dim = 1

model = LSTMWithAttention(input_dim, hidden_dim, output_dim)
x = torch.randn(batch_size, seq_len, input_dim)
output, attn_weights = model(x)

print("Output shape:", output.shape)           # [16, 1]
print("Attention weights shape:", attn_weights.shape)  # [16, 10, 1]
