<a href="https://colab.research.google.com/github/hananather/AutoText/blob/main/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key= nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim = -1
    )
    context_vec = attn_weights @ values
    return context_vec

In [None]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(3,2) #n

In [None]:
sa_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

# Casual Attention

In [None]:
# intutions

queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [None]:
context_length = attn_scores.shape[0]

In [None]:
# use tril function to create a mask where the values above diagonal are zero
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length,context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [None]:
masked_simple = attn_weights*mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [None]:
# now we re-normalize
row_sums = masked_simple.sum(dim=-1, keepdim = True)
print(row_sums)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


## Masking Trick

In [None]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
print(mask)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [None]:
attn_weights = torch.softmax(masked/ keys.shape[-1]**0.5, dim = 1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

# Causal Attention with Dropout

In [None]:
# batch data
batch = torch.stack((inputs, inputs), dim = 0)
print(batch) # simulating two inputs with six tokens each and embedding dim 3


tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

This results in a three-dimensional tensor consisting of two input texts and six tokens each, where each token is  a three-dimensonal embedding vector.

In [None]:
dropout = nn.Dropout(0.5) # first we create a dropout obj
example = torch.ones(6,6)
dropout(example)

tensor([[0., 0., 2., 2., 0., 2.],
        [0., 0., 0., 0., 2., 0.],
        [0., 2., 0., 0., 2., 2.],
        [2., 0., 2., 2., 2., 0.],
        [0., 2., 2., 2., 2., 2.],
        [2., 0., 0., 0., 2., 2.]])

A buffer is a tensor that:
1. We want to save with out model. When we save our model to a file, the buffer tensors get saved too.
2. We want to automatically move to GPU if we move out model from CPU to GPU
3. But we DON'T want to update through trainig (unlike model weights)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    self.dropout = nn.Dropout(dropout)
    self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))


  def forward(self,x):
    b, num_tokens, d_in = x.shape
    keys = self.W_key(x)
    values = self.W_value(x)
    queries = self.W_queries(x)

    attn_scores = queries @ keys.transpose(1,2)
    attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )
    attn_weights = torch.softmax(attn_scores/ keys.shape[-1]**0.5, dim = -1)





## Efficient Multi-head attention

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads,
               qkv_bias = False):
    self.super().__init_()
    assert(d_out  % num_heads == 0), "d_out  must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads  #1 reduces projection dim
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    # use Linear layer combine output heads
    self.out_proj = nn.Linear(d_in, d_out)

    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        "mask", torch.triu(torch.ones(contex_length, context_length, diagonal =1))
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    # Tensor shape: (b, num_tokens, d_out)
    key = self.W_key(x)
    keys = self.W_key(x)
    values = self.W_value(x)
    queries = self.W_queries(x)

    # The view operation is reshaping the data to create multiple 'attention heads,'
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)

    # transposes from shape (b, num_tokens, num_heads, head_dim)
    #                       -> (b, num_heads, num_tokens, head_dim)
    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)

    # compute the dot product of each head
    attn_scores = queries @ keys.transpose(2,3)


    # context_length of tokens the model process
    # num_tokens is the actual number of tokens in the input sequence
    # masks truncated to the number of tokens
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)
    # (b, num_heads, num_tokens, num_tokens)
    attn_weights = torch.softmax(
        attn_scores/ key.shape[-1]**0.5
    )
    attn_weights = self.dropout(attn_weights)

    # (b, num_heads, num_tokens, head_dim) ->
    # (b, num_heads, head_dim, num_tokens)
    context_vec = (attn_weights @ values).transpose(1,2)
    context_vec = context_vec.contigouous().view(
        b, num_tokens, self.out
    )
    # not necessary but used in many architectures
    context_vec = self.proj(context_vec)
    return context_vec


In [7]:
# shape is (b, num_heads, num_tokens, head_dim ) = (1,2,3,4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],    #1
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

# a.transpose(2,3): (b, num_heads, head_dim, num_token )
# matrix multiplicatin
print(a @ a.transpose(2,3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [11]:
first_head = a[0,0, :, :]
print(first_head)
first_res = first_head @ first_head.T
print(first_res)

tensor([[0.2745, 0.6584, 0.2775, 0.8573],
        [0.8993, 0.0390, 0.9268, 0.7388],
        [0.7179, 0.7058, 0.9156, 0.4340]])
tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
