<a href="https://colab.research.google.com/github/blackwithwhitegreen/Transformer/blob/main/Scaled_dot_product_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Computes the Scaled Dot-Product Attention.

    Args:
    Q: Query matrix of shape (batch_size, num_heads, seq_len, d_k)
    K: Key matrix of shape (batch_size, num_heads, seq_len, d_k)
    V: Value matrix of shape (batch_size, num_heads, seq_len, d_v)
    mask: Optional mask tensor to prevent attending to certain positions.

    Returns:
    attention_output: The output after applying attention.
    attention_weights: The computed attention weights.
    """
    d_k = Q.shape[-1]  # Get the dimension of keys

    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Compute attention weights using softmax
    attention_weights = F.softmax(scores, dim=-1)

    # Compute the final output
    attention_output = torch.matmul(attention_weights, V)

    return attention_output, attention_weights

# Example usage
batch_size = 2
num_heads = 4
seq_len = 5
d_k = d_v = 8  # Dimension of key, query, and value vectors

# Create random Q, K, V matrices
Q = torch.rand(batch_size, num_heads, seq_len, d_k)
K = torch.rand(batch_size, num_heads, seq_len, d_k)
V = torch.rand(batch_size, num_heads, seq_len, d_v)

# Compute attention
output, weights = scaled_dot_product_attention(Q, K, V)
print("Attention Output:", output)
print("Attention Weights:", weights)


Attention Output: tensor([[[[0.6575, 0.4717, 0.6834, 0.4144, 0.4921, 0.5801, 0.4806, 0.6286],
          [0.6395, 0.5116, 0.6673, 0.3996, 0.5182, 0.5957, 0.5000, 0.6414],
          [0.6393, 0.5015, 0.6601, 0.4059, 0.5172, 0.6047, 0.5085, 0.6425],
          [0.6362, 0.5159, 0.6602, 0.4004, 0.5258, 0.5979, 0.5063, 0.6440],
          [0.6536, 0.5078, 0.6739, 0.4173, 0.5130, 0.5761, 0.5193, 0.6266]],

         [[0.1816, 0.3636, 0.4597, 0.5507, 0.5844, 0.5570, 0.6119, 0.4455],
          [0.1843, 0.3706, 0.4432, 0.5678, 0.5881, 0.5430, 0.5984, 0.4326],
          [0.1797, 0.3676, 0.4445, 0.5620, 0.5730, 0.5429, 0.5983, 0.4416],
          [0.1716, 0.3562, 0.4660, 0.5344, 0.5626, 0.5581, 0.6177, 0.4634],
          [0.1823, 0.3715, 0.4456, 0.5656, 0.5704, 0.5465, 0.5972, 0.4411]],

         [[0.5588, 0.4885, 0.3971, 0.4725, 0.3751, 0.6071, 0.6803, 0.4354],
          [0.5640, 0.4768, 0.3827, 0.4844, 0.3849, 0.6001, 0.6903, 0.4322],
          [0.5638, 0.4759, 0.3867, 0.4712, 0.3845, 0.5974, 0.6837,