### Employing an attention mechanism in a neural network using PyTorch involves creating a layer that ### computes attention scores and applies them to the inputs. 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [3]:
class Attention(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(Attention, self).__init__()
        self.attention_layer = nn.Linear(input_dim, attention_dim)
        self.context_vector = nn.Parameter(torch.randn(attention_dim))

    def forward(self, inputs):
        # Calculate attention scores
        attention_scores = self.attention_layer(inputs)  # Shape: (batch_size, seq_len, attention_dim)
        attention_scores = torch.tanh(attention_scores)  # Apply non-linearity
        attention_scores = torch.matmul(attention_scores, self.context_vector)  # Shape: (batch_size, seq_len)

        # Normalize attention scores
        attention_weights = F.softmax(attention_scores, dim=1)  # Shape: (batch_size, seq_len)

        # Compute weighted sum of inputs
        weighted_sum = torch.matmul(attention_weights.unsqueeze(1), inputs).squeeze(1)  # Shape: (batch_size, input_dim)

        return weighted_sum, attention_weights



In [4]:
# Example usage
batch_size = 2
seq_len = 5
input_dim = 10
attention_dim = 8

# Create a batch of input sequences
inputs = torch.randn(batch_size, seq_len, input_dim)



In [5]:
# Initialize and apply the attention layer
attention_layer = Attention(input_dim, attention_dim)
weighted_sum, attention_weights = attention_layer(inputs)

print("Weighted sum:", weighted_sum)
print("Attention weights:", attention_weights)

Weighted sum: tensor([[ 0.9345, -0.7757, -0.9824, -0.9015,  0.1070,  0.6261,  0.0676,  0.4808,
          0.6231,  0.4690],
        [ 0.2006, -0.9407,  0.0216,  0.5492, -0.9414, -0.1283, -0.0565, -0.7475,
         -0.1725,  0.0123]], grad_fn=<SqueezeBackward1>)
Attention weights: tensor([[0.4550, 0.0287, 0.1501, 0.2764, 0.0899],
        [0.1424, 0.6182, 0.0385, 0.1193, 0.0817]], grad_fn=<SoftmaxBackward0>)
