## Casual Attention
The main difference is that self-attention allows every token in a sequence to attend to every other token, while causal self-attention restricts each token to only attend to tokens that came before it, and itself. Causal self-attention is used for tasks like language modeling, where a model predicts the next word based on only the preceding ones, preventing it from "seeing" future information.

In [None]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your
        [0.55, 0.87, 0.66], # Journey
        [0.57, 0.85, 0.64], # Starts
        [0.22, 0.58, 0.33], # with
        [0.77, 0.25, 0.10], # one
        [0.05, 0.80, 0.55] # step
    ]
)

In [None]:


import torch.nn as nn
class SelfAttention_v2(nn.Module):

  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    atten_score = query @ key.T
    attn_weights = torch.softmax(atten_score/key.shape[-1]**0.5, dim=-1)
    context_vec = attn_weights @ value
    return context_vec


In [None]:
d_in = inputs.shape[-1]
d_out = 2


In [None]:
sa_v2 = SelfAttention_v2(d_in, d_out)

In [None]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_score = queries @ keys.T
attn_weights = torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
attn_weights

tensor([[0.1642, 0.1683, 0.1679, 0.1679, 0.1599, 0.1718],
        [0.1669, 0.1667, 0.1664, 0.1678, 0.1616, 0.1706],
        [0.1669, 0.1667, 0.1664, 0.1679, 0.1611, 0.1710],
        [0.1674, 0.1663, 0.1663, 0.1669, 0.1659, 0.1673],
        [0.1660, 0.1674, 0.1666, 0.1694, 0.1541, 0.1765],
        [0.1677, 0.1659, 0.1661, 0.1660, 0.1704, 0.1639]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
## generating a mask
context_length = attn_score.shape[0]
mask = torch.tril(torch.ones(context_length, context_length))
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [None]:
## now adding this in weights
masked_attention_weights = attn_weights * mask
masked_attention_weights

tensor([[0.1642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1669, 0.1667, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1669, 0.1667, 0.1664, 0.0000, 0.0000, 0.0000],
        [0.1674, 0.1663, 0.1663, 0.1669, 0.0000, 0.0000],
        [0.1660, 0.1674, 0.1666, 0.1694, 0.1541, 0.0000],
        [0.1677, 0.1659, 0.1661, 0.1660, 0.1704, 0.1639]],
       grad_fn=<MulBackward0>)

In [None]:
## normalize the weights
masked_attention_weights = masked_attention_weights / masked_attention_weights.sum(dim=-1, keepdim=True)
masked_attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5004, 0.4996, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3338, 0.3334, 0.3328, 0.0000, 0.0000, 0.0000],
        [0.2510, 0.2494, 0.2493, 0.2503, 0.0000, 0.0000],
        [0.2016, 0.2032, 0.2023, 0.2057, 0.1872, 0.0000],
        [0.1677, 0.1659, 0.1661, 0.1660, 0.1704, 0.1639]],
       grad_fn=<DivBackward0>)

In [None]:
## then we will multilpy this with value to generate the context vector

## but this has a problem because we have already normalized the attention scores which leads to the significant involvement of scores which are masked later causing data leakage problem
## other way : attention scores -> upper triangle infinity mask -> normalize(softmax)
attn_score



tensor([[ 0.0002,  0.0352,  0.0319,  0.0316, -0.0367,  0.0643],
        [-0.0036, -0.0058, -0.0082,  0.0039, -0.0499,  0.0271],
        [-0.0039, -0.0053, -0.0080,  0.0049, -0.0539,  0.0303],
        [-0.0017, -0.0106, -0.0109, -0.0054, -0.0140, -0.0025],
        [-0.0068,  0.0047, -0.0014,  0.0216, -0.1116,  0.0803],
        [ 0.0004, -0.0150, -0.0132, -0.0146,  0.0228, -0.0321]],
       grad_fn=<MmBackward0>)

In [None]:
  ## applying mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked_attention_score = attn_score.masked_fill(mask.bool(), -torch.inf)
print(masked_attention_score)

tensor([[ 0.0002,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0036, -0.0058,    -inf,    -inf,    -inf,    -inf],
        [-0.0039, -0.0053, -0.0080,    -inf,    -inf,    -inf],
        [-0.0017, -0.0106, -0.0109, -0.0054,    -inf,    -inf],
        [-0.0068,  0.0047, -0.0014,  0.0216, -0.1116,    -inf],
        [ 0.0004, -0.0150, -0.0132, -0.0146,  0.0228, -0.0321]],
       grad_fn=<MaskedFillBackward0>)


In [None]:
## applying softmax
attn_weights= torch.softmax(masked_attention_score/ keys.shape[-1]**0.5,dim = 1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5004, 0.4996, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3338, 0.3334, 0.3328, 0.0000, 0.0000, 0.0000],
        [0.2510, 0.2494, 0.2493, 0.2503, 0.0000, 0.0000],
        [0.2016, 0.2032, 0.2023, 0.2057, 0.1872, 0.0000],
        [0.1677, 0.1659, 0.1661, 0.1660, 0.1704, 0.1639]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
## masking addition weights with dropout implemented in GPT models
torch.manual_seed(123)
dropout= torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [None]:
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6669, 0.6656, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4987, 0.5006, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3744, 0.0000],
        [0.3354, 0.3318, 0.0000, 0.3319, 0.3408, 0.3278]],
       grad_fn=<MulBackward0>)

In [None]:
## batching the input
batch = torch.stack((inputs, inputs), dim = 0)
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [None]:
batch.shape

torch.Size([2, 6, 3])

In [None]:
import torch.nn as nn
class CausalAttention_v1(nn.Module):

  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.dropout= nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length), diagonal = 1))

  def forward(self, x):
    batch_size, num_tokens, d_in = x.shape ## torch.Size([2, 6, 3])

    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    attn_score = query @ key.transpose(1,2) ## here 1 num of tokens and 2 is d_in
    attn_score.masked_fill(
        self.mask.bool()[:num_tokens, :num_tokens],
        -torch.inf
    )
    attn_weights = torch.softmax(attn_score/key.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ value
    return context_vec


In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
causal_attn = CausalAttention_v1(d_in, d_out, context_length = 6, dropout = 0.5)
context_vecs = causal_attn(batch)
context_vecs

tensor([[[-0.8158, -0.1411],
         [-0.6920, -0.0972],
         [-0.4050, -0.1201],
         [-0.6902, -0.0969],
         [-0.5199, -0.0440],
         [-0.1417, -0.0505]],

        [[-0.7938, -0.2379],
         [-0.7858, -0.1145],
         [-0.3969,  0.0037],
         [-0.7704, -0.2374],
         [-0.7801, -0.1107],
         [-0.6749, -0.0984]]], grad_fn=<UnsafeViewBackward0>)