<a href="https://colab.research.google.com/github/jchen8000/DemystifyingLLMs/blob/main/3_Transformer/Scaled_Dot_Product_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3.6 Scaled Dot-Product Attention

In [11]:
import torch
import torch.nn as nn
import numpy as np

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(K.shape[-1])
        if mask is not None:
            scores = scores.masked_fill_(mask, -float('inf'))
        weights = nn.Softmax(dim=-1)(scores)
        attention = torch.matmul(weights, V)
        return attention, weights, scores


In [17]:
seq_len = 4
d_k = 32
sdpa = ScaledDotProductAttention()

# Randomly initialize Q, K, V matrices
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)

# Call the 'scaled_dot_product_attention' function
attention, weights, scores = sdpa(Q, K, V)

# Prints output and attention weights.
torch.set_printoptions(precision=4, sci_mode=False)
print("Attention Scores:\n", scores)
print("\nAttention Weights:\n", weights)
print("\nAttention:", attention.size())
print(attention)

Attention Scores:
 tensor([[-2.0405,  0.6809,  0.2939,  1.6505],
        [ 0.3717,  0.9928,  2.1701, -0.0975],
        [ 0.2492,  0.5965,  0.3470, -0.1436],
        [ 0.1514,  2.2306, -0.3693,  0.6462]])

Attention Weights:
 tensor([[0.0150, 0.2282, 0.1550, 0.6018],
        [0.1050, 0.1954, 0.6340, 0.0657],
        [0.2385, 0.3375, 0.2630, 0.1610],
        [0.0890, 0.7120, 0.0529, 0.1460]])

Attention: torch.Size([4, 32])
tensor([[ 0.2603,  1.1516, -0.1773,  0.9438, -0.9488, -0.7728, -0.8722, -0.9581,
         -1.1464, -0.5910,  0.4530,  0.3121,  0.3881, -0.9655, -0.3270, -0.8187,
         -0.4375, -0.5775, -0.6903,  1.0332,  0.5667,  0.3945,  0.9742,  0.0393,
          0.4837, -0.4920,  0.0027,  0.0325,  0.6299, -0.5945,  0.2930, -0.4423],
        [-0.1012,  0.7945, -0.4145, -0.3559,  0.0408, -0.3640, -1.5331, -0.5489,
         -0.0262, -0.0160, -0.2952,  0.2813,  0.0801, -1.0530, -0.9572, -0.1233,
          0.1993,  0.5157, -0.3686, -1.0353, -0.0099,  0.2089,  0.9106, -0.1853,
      

In [18]:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
attention, weights, scores = sdpa(Q, K, V, mask)

print("Mask:\n", mask)
print("\nAttention Scores:\n", scores)
print("\nAttention Weights:\n", weights)
print("\nAttention:", attention.size())
print(attention)

Mask:
 tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

Attention Scores:
 tensor([[-2.0405,    -inf,    -inf,    -inf],
        [ 0.3717,  0.9928,    -inf,    -inf],
        [ 0.2492,  0.5965,  0.3470,    -inf],
        [ 0.1514,  2.2306, -0.3693,  0.6462]])

Attention Weights:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3495, 0.6505, 0.0000, 0.0000],
        [0.2842, 0.4023, 0.3135, 0.0000],
        [0.0890, 0.7120, 0.0529, 0.1460]])

Attention: torch.Size([4, 32])
tensor([[     0.5250,      0.8303,      0.3605,      1.4981,      0.6162,
             -0.9184,     -0.3429,     -1.1141,     -0.8505,     -0.5119,
             -0.5255,     -1.2378,      0.0712,     -0.1013,     -1.5522,
             -1.2822,      0.6295,      0.7240,     -0.0341,      0.4258,
             -0.4804,      1.1067,     -2.0425,     -2.2513,     -0.4764,
             -0.2635,     -0.2840,      0.6