# Scaled Dot-Product Attention in PyTorch

The basic idea is to compute the attention scores between a set of **queries (Q)** and **keys (K)** and use these scores to weight a set of **values (V)**. 


In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

### Scaled Dot-Product Attention

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$

Where:
-  *Q*  is the query matrix
-  *K*  is the key matrix
-  *V*  is the value matrix
-  *d_k* is the dimension of the key vectors (used for scaling)

The softmax function => to normalize attention scores, and the output is a weighted sum of the values *V*.

### Steps to Implement
1. **Compute the dot product of queries and keys**: This gives us raw attention scores.
2. **Scale the scores**: divide by \( \sqrt{d_k} \).
3. **Apply softmax**:  raw scores into probabilities.
4. **Compute the weighted sum of the values**: Multiply the attention weights with the values *V*.


#### Requirements

**batch_size**:
   - **Definition**: The number of sequences (or examples) that are processed together in one forward pass through the model.
   -to take advantage of vectorization (parallel processing) and to make training more efficient. 
   -The **batch_size** specifies how many sequences are included in each batch.

 **seq_len** (Sequence Length):
   - **Definition**: The length of each input sequence (i.e., the number of tokens or words in each sequence).

 **d_k** (Dimension of the Key and Query vectors):
   - **Definition**: The dimensionality of the **key** and **query** vectors in the attention mechanism.

 **d_v** (Dimension of the Value vector):
   - **Definition**: Dimension of the **value** vectors .If **d_v = 128**, each value vector would have 128 dimensions.


---



##### Input:
- **Q**: $(\text{batch\_size}, \text{seq\_len}, d_k)$  
- **K**: $(\text{batch\_size}, \text{seq\_len}, d_k)$  
- **V**: $(\text{batch\_size}, \text{seq\_len}, d_v)$  
- **Mask** (optional): $(\text{batch\_size}, \text{seq\_len}, \text{seq\_len})$
##### Output:
- **Attention Output**: $(\text{batch\_size}, \text{seq\_len}, d_v)$  
- **Attention Weights**: $(\text{batch\_size}, \text{seq\_len}, \text{seq\_len})$



### Performance
##### Time Complexity:
$$
O(\text{batch\_size} \times \text{seq\_len}^2 \times (d_k + d_v))
$$
##### Space Complexity:
$$
O(\text{batch\_size} \times \text{seq\_len}^2 + \text{batch\_size} \times \text{seq\_len} \times d_v)
$$






In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        # The dimension of the key vectors for scaling
        self.d_k = d_k 

    def forward(self, Q, K, V, mask=None):
        # Step 1: Calculate dot product of Q and K
        # Step 2: Scale the scores by sqrt(d_k)
        # Step 3: Apply mask (optional) to prevent attending to certain positions     
        # Step 4: Apply softmax to get attention weights
        # Step 5: Compute the weighted sum of values (V)
        return "Compute the weighted sum of values", "attention_weights"

# Preparing Inputs

In [4]:
batch_size = 2
seq_len = 5
d_k = 8  
d_v = 8  

# Generating Random Q, K, V matrices , this will be used to generate attention
Q = torch.rand(batch_size, seq_len, d_k)  
K = torch.rand(batch_size, seq_len, d_k) 
V = torch.rand(batch_size, seq_len, d_v) 

print(Q)
print(K)
print(V)

tensor([[[0.7883, 0.9453, 0.2390, 0.1901, 0.9819, 0.0550, 0.4194, 0.5251],
         [0.5439, 0.6290, 0.0464, 0.3899, 0.7046, 0.4846, 0.9321, 0.3274],
         [0.7153, 0.2226, 0.5617, 0.3809, 0.7121, 0.5881, 0.5029, 0.8936],
         [0.5056, 0.2150, 0.2597, 0.1637, 0.6216, 0.4101, 0.7122, 0.0269],
         [0.8801, 0.7427, 0.7813, 0.8621, 0.0157, 0.6323, 0.7689, 0.7215]],

        [[0.9600, 0.5527, 0.8281, 0.0318, 0.5924, 0.4078, 0.0745, 0.1843],
         [0.2264, 0.9531, 0.4066, 0.8108, 0.9573, 0.3994, 0.1344, 0.1221],
         [0.8385, 0.7231, 0.6895, 0.0058, 0.3018, 0.0113, 0.7199, 0.3665],
         [0.7187, 0.4544, 0.8070, 0.8227, 0.3313, 0.6501, 0.3812, 0.9185],
         [0.7793, 0.1678, 0.6792, 0.6064, 0.7477, 0.3379, 0.4711, 0.1097]]])
tensor([[[6.6788e-01, 6.0974e-01, 9.0541e-01, 5.4524e-01, 7.9361e-01,
          5.9153e-01, 7.7298e-01, 9.0083e-01],
         [1.1210e-01, 5.7324e-01, 8.6648e-02, 7.2429e-01, 8.1731e-01,
          3.0482e-01, 5.7503e-01, 4.0397e-01],
         [9.

# Optionally Masking
No masking in this example

In [5]:
mask = torch.ones(batch_size, seq_len, seq_len)  
print(mask)


tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])


# Initialize the attention layer

In [None]:

attention = ScaledDotProductAttention(d_k)


ScaledDotProductAttention()


# Compute the attention output and weights

In [8]:

output, attention_weights = attention(Q, K, V, mask)

# print("Attention Output Shape:", output.shape)
print("Attention Output Shape:", output)
# print("Attention Weights Shape:", attention_weights.shape)
print("Attention Weights Shape:", attention_weights)

Attention Output Shape: Compute the weighted sum of values
Attention Weights Shape: attention_weights
