In [1]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}

dc

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [2]:
import torch

sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', ' ').split()]
)
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [3]:
vocab_size = 50000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])


In [4]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

In [5]:
W_key.shape, d

(torch.Size([2, 3]), 3)

In [28]:
W_query.shape, x_2.shape

(torch.Size([2, 3]), torch.Size([3]))

In [7]:
x_2 = embedded_sentence[1]
query_2 = W_query @ x_2
key_2 = W_key @ x_2
value_2 = W_value @ x_2

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([2])
torch.Size([2])
torch.Size([4])


In [8]:
query_2, key_2, value_2

(tensor([1.1568, 0.6930], grad_fn=<MvBackward0>),
 tensor([0.3099, 1.0682], grad_fn=<MvBackward0>),
 tensor([0.5430, 0.7067, 1.7432, 1.7999], grad_fn=<MvBackward0>))

In [9]:
embedded_sentence.shape, W_key.t().shape

(torch.Size([6, 3]), torch.Size([3, 2]))

In [10]:
queries = embedded_sentence @ W_query.t()
keys = embedded_sentence @ W_key.t()
values = embedded_sentence @ W_value.t()

In [11]:
print("queries.shape:", queries.shape)
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

queries.shape: torch.Size([6, 2])
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])


In [12]:
values, values.shape

(tensor([[-0.1055, -0.1367, -0.2476, -0.2113],
         [ 0.5430,  0.7067,  1.7432,  1.7999],
         [-0.3177, -0.7453, -0.8191, -0.8921],
         [ 0.1516,  0.5015,  0.3204,  0.3821],
         [ 0.2330,  0.3515,  0.5692,  0.5221],
         [-0.3989, -1.5627, -1.1977, -1.6479]], grad_fn=<MmBackward0>),
 torch.Size([6, 4]))

In [13]:
omega24 = query_2.dot(keys[4])
omega24

tensor(0.1836, grad_fn=<DotBackward0>)

In [14]:
#Unnormalized attention weights

omega2 = query_2 @ keys.T
omega2, d_v

(tensor([-0.0459,  1.0988, -0.5511,  0.2555,  0.1836, -1.5248],
        grad_fn=<SqueezeBackward4>),
 4)

In [15]:
import math
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega2 / math.sqrt(d_v), dim=0)
print(attention_weights_2)

tensor([0.1585, 0.2809, 0.1231, 0.1842, 0.1777, 0.0756],
       grad_fn=<SoftmaxBackward0>)


In [16]:
context_vector_2 = attention_weights_2 @ values
context_vector_2

tensor([0.1358, 0.1217, 0.4191, 0.4007], grad_fn=<SqueezeBackward4>)

computing normalized attention for all tokens at once - start

In [17]:
queries.shape, keys.shape

(torch.Size([6, 2]), torch.Size([6, 2]))

In [18]:
omega = queries @ keys.T
omega, d_v

(tensor([[ 2.7697e-03, -6.8030e-02,  3.4104e-02, -1.5821e-02, -1.1244e-02,
           9.4798e-02],
         [-4.5924e-02,  1.0988e+00, -5.5110e-01,  2.5554e-01,  1.8359e-01,
          -1.5248e+00],
         [ 1.9828e-02, -8.1986e-01,  4.0784e-01, -1.9043e-01, -1.1284e-01,
           1.2138e+00],
         [-7.5753e-03,  4.9751e-01, -2.4645e-01,  1.1548e-01,  6.1024e-02,
          -7.6006e-01],
         [-8.5906e-03,  2.5962e-01, -1.2969e-01,  6.0342e-02,  3.9600e-02,
          -3.7221e-01],
         [ 4.8586e-02, -2.2343e+00,  1.1102e+00, -5.1887e-01, -2.9842e-01,
           3.3368e+00]], grad_fn=<MmBackward0>),
 4)

In [19]:
omega.shape

torch.Size([6, 6])

In [20]:
import math
import torch.nn.functional as F

attention_weights = F.softmax(omega / math.sqrt(d_v), dim=0)
print(attention_weights)

tensor([[0.1668, 0.1582, 0.1549, 0.1682, 0.1672, 0.1031],
        [0.1627, 0.2835, 0.1156, 0.1926, 0.1843, 0.0459],
        [0.1682, 0.1086, 0.1868, 0.1541, 0.1589, 0.1805],
        [0.1659, 0.2099, 0.1346, 0.1796, 0.1733, 0.0673],
        [0.1658, 0.1863, 0.1427, 0.1747, 0.1715, 0.0817],
        [0.1706, 0.0535, 0.2653, 0.1308, 0.1448, 0.5216]],
       grad_fn=<SoftmaxBackward0>)


In [21]:
attention_weights.shape

torch.Size([6, 6])

In [22]:
context_vectors = attention_weights @ values
context_vectors

tensor([[ 0.0424, -0.0445,  0.1331,  0.0928],
        [ 0.1539,  0.1816,  0.4708,  0.4669],
        [-0.0297, -0.2343, -0.0816, -0.1622],
        [ 0.0945,  0.0712,  0.2901,  0.2708],
        [ 0.0722,  0.0229,  0.2226,  0.1947],
        [-0.2277, -0.8819, -0.6666, -0.9104]], grad_fn=<MmBackward0>)

end

In [23]:
import torch.nn as nn

class SelfAttention(nn.Module):
  def __init__(self, d_in, d_out_kq, d_out_v):
    super().__init__()
    self.d_out_kq = d_out_kq
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

  def forward(self, x):
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value

    attn_scores = queries @ keys.T # unnormalized attention score
    attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
    )

    context_vec = attn_weights @ values
    return context_vec

In [24]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2626, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)


In [25]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
    super().__init__()
    self.heads = nn.ModuleList(
        [SelfAttention(d_in, d_out_kq, d_out_v)
        for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [26]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 1

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

tensor([[-0.0185],
        [ 0.4003],
        [-0.1103],
        [ 0.0668],
        [ 0.1180],
        [-0.1827]], grad_fn=<MmBackward0>)


In [27]:
torch.manual_seed(123)

block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
    d_in, d_out_kq, d_out_v, num_heads=4
)

context_vecs = mha(embedded_sentence)

print(context_vecs)
print(context_vecs.shape)

tensor([[-0.0185,  0.0170,  0.1999, -0.0860],
        [ 0.4003,  1.7137,  1.3981,  1.0497],
        [-0.1103, -0.1609,  0.0079, -0.2416],
        [ 0.0668,  0.3534,  0.2322,  0.1008],
        [ 0.1180,  0.6949,  0.3157,  0.2807],
        [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
torch.Size([6, 4])


cross-attention

In [32]:
class CrossAttention(nn.Module):
  def __init__(self, d_in, d_out_kq, d_out_v):
    super().__init__()
    self.d_out_kq = d_out_kq
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

  def forward(self, x_1, x_2): #x1-> input from decoder, x2-> input encoder
    queries_1 = x_1 @ self.W_query #comes from decoder
    keys_2 = x_2 @ self.W_key #comes from encoder
    values_2 = x_2 @ self.W_value #comes from encoder

    attn_scores = queries_1 @ keys_2.T
    attn_weights = torch.softmax(
        attn_scores/math.sqrt(self.d_out_kq), dim=-1
    )
    context_vec = attn_weights @ values_2
    return context_vec

In [34]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

crossattn = CrossAttention(d_in, d_out_kq, d_out_v)

first_input = embedded_sentence
second_input = torch.rand(8, d_in)

print("First input shape:", first_input.shape)
print("Second input shape:", second_input.shape)

First input shape: torch.Size([6, 3])
Second input shape: torch.Size([8, 3])


In [35]:
context_vectors = crossattn(first_input, second_input)

print(context_vectors)
print("Output Shape: ", context_vectors.shape)

tensor([[0.4231, 0.8665, 0.6503, 1.0042],
        [0.4874, 0.9718, 0.7359, 1.1353],
        [0.4054, 0.8359, 0.6258, 0.9667],
        [0.4357, 0.8886, 0.6678, 1.0311],
        [0.4429, 0.9006, 0.6775, 1.0460],
        [0.3860, 0.8021, 0.5985, 0.9250]], grad_fn=<MmBackward0>)
Output Shape:  torch.Size([6, 4])


causal attention (masked attention)

In [46]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
W_keq = nn.Parameter(torch.rand(d_in, d_out_kq))
W_value = nn.Parameter(torch.rand(d_in, d_out_v))

x = embedded_sentence

keys = x @ W_key
queries = x @ W_query
values = x @ W_value

attn_scores = queries @ keys.T

print(attn_scores)
print(attn_scores.shape)

tensor([[ 0.0613, -0.3491,  0.1443, -0.0437, -0.1303,  0.1076],
        [-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
        [ 0.2432, -1.3934,  0.5869, -0.1851, -0.5191,  0.4730],
        [-0.0794,  0.4487, -0.1807,  0.0518,  0.1677, -0.1197],
        [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216, -0.2787],
        [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
       grad_fn=<MmBackward0>)
torch.Size([6, 6])


In [48]:
attn_weights = torch.softmax(attn_scores/math.sqrt(d_out_kq), dim=1)
print(attn_weights)

tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
        [0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
        [0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)


In [51]:
block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [53]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<MulBackward0>)

In [55]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<DivBackward0>)
