## Casual Attention
The main difference is that self-attention allows every token in a sequence to attend to every other token, while causal self-attention restricts each token to only attend to tokens that came before it, and itself. Causal self-attention is used for tasks like language modeling, where a model predicts the next word based on only the preceding ones, preventing it from "seeing" future information.

In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your
        [0.55, 0.87, 0.66], # Journey
        [0.57, 0.85, 0.64], # Starts
        [0.22, 0.58, 0.33], # with
        [0.77, 0.25, 0.10], # one
        [0.05, 0.80, 0.55] # step
    ]
)

In [2]:


import torch.nn as nn
class SelfAttention_v2(nn.Module):

  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    atten_score = query @ key.T
    attn_weights = torch.softmax(atten_score/key.shape[-1]**0.5, dim=-1)
    context_vec = attn_weights @ value
    return context_vec


In [3]:
d_in = inputs.shape[-1]
d_out = 2


In [4]:
sa_v2 = SelfAttention_v2(d_in, d_out)

In [5]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_score = queries @ keys.T
attn_weights = torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
attn_weights

tensor([[0.1504, 0.1662, 0.1663, 0.1741, 0.1702, 0.1729],
        [0.1419, 0.1670, 0.1668, 0.1783, 0.1667, 0.1794],
        [0.1424, 0.1669, 0.1668, 0.1780, 0.1669, 0.1790],
        [0.1516, 0.1672, 0.1670, 0.1736, 0.1658, 0.1748],
        [0.1592, 0.1661, 0.1662, 0.1699, 0.1701, 0.1684],
        [0.1454, 0.1676, 0.1673, 0.1766, 0.1641, 0.1791]],
       grad_fn=<SoftmaxBackward0>)

In [6]:
## generating a mask
context_length = attn_score.shape[0]
mask = torch.tril(torch.ones(context_length, context_length))
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [7]:
## now adding this in weights
masked_attention_weights = attn_weights * mask
masked_attention_weights

tensor([[0.1504, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1419, 0.1670, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1424, 0.1669, 0.1668, 0.0000, 0.0000, 0.0000],
        [0.1516, 0.1672, 0.1670, 0.1736, 0.0000, 0.0000],
        [0.1592, 0.1661, 0.1662, 0.1699, 0.1701, 0.0000],
        [0.1454, 0.1676, 0.1673, 0.1766, 0.1641, 0.1791]],
       grad_fn=<MulBackward0>)

In [8]:
## normalize the weights
masked_attention_weights = masked_attention_weights / masked_attention_weights.sum(dim=-1, keepdim=True)
masked_attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4594, 0.5406, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2992, 0.3506, 0.3502, 0.0000, 0.0000, 0.0000],
        [0.2299, 0.2535, 0.2533, 0.2632, 0.0000, 0.0000],
        [0.1915, 0.1997, 0.1999, 0.2043, 0.2046, 0.0000],
        [0.1454, 0.1676, 0.1673, 0.1766, 0.1641, 0.1791]],
       grad_fn=<DivBackward0>)

In [9]:
## then we will multilpy this with value to generate the context vector

## but this has a problem because we have already normalized the attention scores which leads to the significant involvement of scores which are masked later causing data leakage problem
## other way : attention scores -> upper triangle infinity mask -> normalize(softmax)
attn_score



tensor([[-0.2249, -0.0836, -0.0830, -0.0182, -0.0501, -0.0280],
        [-0.3269, -0.0969, -0.0985, -0.0044, -0.0990,  0.0045],
        [-0.3202, -0.0958, -0.0973, -0.0051, -0.0960,  0.0027],
        [-0.1889, -0.0507, -0.0521,  0.0022, -0.0628,  0.0123],
        [-0.1086, -0.0491, -0.0480, -0.0166, -0.0150, -0.0296],
        [-0.2657, -0.0650, -0.0676,  0.0088, -0.0950,  0.0289]],
       grad_fn=<MmBackward0>)

In [10]:
  ## applying mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked_attention_score = attn_score.masked_fill(mask.bool(), -torch.inf)
print(masked_attention_score)

tensor([[-0.2249,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3269, -0.0969,    -inf,    -inf,    -inf,    -inf],
        [-0.3202, -0.0958, -0.0973,    -inf,    -inf,    -inf],
        [-0.1889, -0.0507, -0.0521,  0.0022,    -inf,    -inf],
        [-0.1086, -0.0491, -0.0480, -0.0166, -0.0150,    -inf],
        [-0.2657, -0.0650, -0.0676,  0.0088, -0.0950,  0.0289]],
       grad_fn=<MaskedFillBackward0>)


In [11]:
## applying softmax
attn_weights= torch.softmax(masked_attention_score/ keys.shape[-1]**0.5,dim = 1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4594, 0.5406, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2992, 0.3506, 0.3502, 0.0000, 0.0000, 0.0000],
        [0.2299, 0.2535, 0.2533, 0.2632, 0.0000, 0.0000],
        [0.1915, 0.1997, 0.1999, 0.2043, 0.2046, 0.0000],
        [0.1454, 0.1676, 0.1673, 0.1766, 0.1641, 0.1791]],
       grad_fn=<SoftmaxBackward0>)

In [12]:
## masking addition weights with dropout implemented in GPT models
torch.manual_seed(123)
dropout= torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [13]:
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7012, 0.7005, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5066, 0.5264, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.4092, 0.0000],
        [0.2908, 0.3352, 0.0000, 0.3531, 0.3281, 0.3582]],
       grad_fn=<MulBackward0>)

In [14]:
## batching the input
batch = torch.stack((inputs, inputs), dim = 0)
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [15]:
batch.shape

torch.Size([2, 6, 3])

In [16]:
import torch.nn as nn
class CausalAttention_v1(nn.Module):

  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.dropout= nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length), diagonal = 1))

  def forward(self, x):
    batch_size, num_tokens, d_in = x.shape ## torch.Size([2, 6, 3])

    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    attn_score = query @ key.transpose(1,2) ## here 1 num of tokens and 2 is d_in
    attn_score.masked_fill(
        self.mask.bool()[:num_tokens, :num_tokens],
        -torch.inf
    )
    attn_weights = torch.softmax(attn_score/key.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ value
    return context_vec


In [17]:
torch.manual_seed(123)
context_length = batch.shape[1]
causal_attn = CausalAttention_v1(d_in, d_out, context_length = 6, dropout = 0.5)
context_vecs = causal_attn(batch)
context_vecs

tensor([[[-0.8158, -0.1411],
         [-0.6920, -0.0972],
         [-0.4050, -0.1201],
         [-0.6902, -0.0969],
         [-0.5199, -0.0440],
         [-0.1417, -0.0505]],

        [[-0.7938, -0.2379],
         [-0.7858, -0.1145],
         [-0.3969,  0.0037],
         [-0.7704, -0.2374],
         [-0.7801, -0.1107],
         [-0.6749, -0.0984]]], grad_fn=<UnsafeViewBackward0>)

## Mutli Head Attention
Multi-head attention benefits include the ability to capture different types of relationships and dependencies in the data by attending to different parts of the input in parallel, leading to a more enhanced and contextualized representation. In contrast, single attention forces the model to compress all this information into a single set of weights, limiting its ability to capture complex patterns. The parallel processing in multi-head attention also improves training efficiency.

In [18]:
class MutliHeadAttentionWrapper(nn.Module):

  def __init__(self, d_in, d_out, context_length,dropout,  num_heads, qkv_bias= False):
    super().__init__()
    self.heads = nn.ModuleList(
        [CausalAttention_v1(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim = -1)

In [19]:
mha = MutliHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
context_vecs

tensor([[[-0.6581,  0.2660,  0.1293, -0.2101],
         [-0.6597,  0.2673,  0.1310, -0.2125],
         [-0.6594,  0.2670,  0.1310, -0.2125],
         [-0.6584,  0.2633,  0.1297, -0.2102],
         [-0.6540,  0.2605,  0.1304, -0.2110],
         [-0.6610,  0.2659,  0.1297, -0.2104]],

        [[-0.6581,  0.2660,  0.1293, -0.2101],
         [-0.6597,  0.2673,  0.1310, -0.2125],
         [-0.6594,  0.2670,  0.1310, -0.2125],
         [-0.6584,  0.2633,  0.1297, -0.2102],
         [-0.6540,  0.2605,  0.1304, -0.2110],
         [-0.6610,  0.2659,  0.1297, -0.2104]]], grad_fn=<CatBackward0>)

## Implementing Multi-Head Attention with Weight splits

In [23]:
class MutliHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
    super().__init__()
    assert (d_out % num_heads) == 0, \
      "d_out must be divisible by the num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads ## reducing the projection dim to match the desired outpit dim
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out) ## lear layer to combine the head output
    self.dropout = nn.Dropout(dropout)
    self.register_buffer( ## register_buffer method is used to manage and save tensors that are part of a module's state but are not trainable parameters.When you move your model to a different device (e.g., from CPU to GPU) using model.to(device), any tensors registered as buffers will automatically be moved to that same device
        "mask",
        torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )
  def forward(self, x): ## lets x = "the cat sleeps" and num of head is 2
    b,num_tokens, d_in = x.shape
    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    ## we implicitly split the matrix by adding a 'num_heads' dimension
    keys = key.view(b, num_tokens, self.num_heads, self.head_dim)
    values = value.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = query.view(b, num_tokens, self.num_heads, self.head_dim)

    ## currently the shape is [1,3,2,3] if the batch size is 1, num of tokens is 3, heads is 2 and head dim is 3
    ## so lets transpose it for index 1,2 to make it [1,2,3,3]
    keys = keys.transpose(1,2)
    values = values.transpose(1,2)
    queries = queries.transpose(1,2)

    ## computing the scaled attention score with causal mask
    attn_score = queries @ keys.transpose(2,3)

    ## original mask truncated to number of tokens and converted to boolean
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
    attn_score.masked_fill_(mask_bool, -torch.inf)
    attn_weights = torch.softmax(attn_score/self.head_dim**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    ## shape: (b, num_tokens, num_heads, head_dim)
    context_vecs = (attn_weights @ values).transpose(1,2)
    context_vecs = context_vecs.contiguous().view(b,num_tokens, self.d_out)
    context_vecs= self.out_proj(context_vecs)
    return context_vecs

In [27]:
torch.manual_seed(123)
# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor(
[[0.43, 0.15, 0.89, 0.55, 0.87, 0.66], # Row 1
[0.57, 0.85, 0.64, 0.22, 0.58, 0.33], # Row 2
[0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]
)# Row 3
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
batch_size, context_length, d_in =  batch.shape
d_out = 6
mha = MutliHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([2, 3, 6])
tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 3, 6])
