<img src="img/1_attention.png" width="40%" height="40%" style="margin-right: 10px;" />
<br />
<img src="img/2_attention.png" width=700 height=500/>
<br />

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size

        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embedding = nn.Embedding(block_size, embedding_dim)
        # self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, inputs, targets=None):
        B, T = inputs.shape
        token_embd = self.token_embedding(inputs)
        pos_embd = self.pos_embedding(torch.arange(T, dtype=torch.long, device=device))
        embedding = token_embd + pos_embd
        return embedding
        # logits = self.linear(embedding)
        # B,T,C = logits.shape

        # if targets is not None:
        #     logits = logits.view(B*T, C)
        #     targets = targets.view(B*T)
        #     loss = F.cross_entropy(logits, targets)
        # else: # inference
        #     loss = None
        # return logits, loss

    def generate(self, idx, max_tokens=100):
        for i in range(max_tokens):
            idx = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [4]:
vocab_size = 129
batch_size = 32
block_size = 32
embedding_dim = 256
device = "cpu"

x = torch.randint(129, (1, 5))
y = torch.randint(129, (1, 5))

gpt = GPT(vocab_size, block_size, embedding_dim)
output = gpt(x, y)
output

tensor([[[ 0.8328,  4.1695,  1.6815,  ...,  0.5243,  0.7766, -1.1049],
         [-0.4937, -1.8306,  0.9328,  ...,  2.0211, -0.8266, -0.3443],
         [-2.8458, -1.3452,  1.0649,  ...,  0.7461,  0.0575, -0.0656],
         [ 0.5501,  0.9549,  0.2227,  ..., -0.8525, -1.1406,  0.3459],
         [ 1.7011, -2.2161, -0.5062,  ..., -0.3507,  0.2010,  1.0671]]],
       grad_fn=<AddBackward0>)

In [5]:
output.shape

torch.Size([1, 5, 256])

In [6]:
WQ = nn.Linear(embedding_dim, embedding_dim)
q = WQ(output)
q.shape

torch.Size([1, 5, 256])

In [7]:
WK = nn.Linear(embedding_dim, embedding_dim)
k = WK(output)

WV = nn.Linear(embedding_dim, embedding_dim)
v = WV(output)

In [18]:
weight = q @ k.transpose(-2, -1)
print(weight.shape)
print(weight)

weight = weight * (256 ** (-0.5))
print(weight.shape)
print(weight)

mask = torch.tril(torch.ones(5,5))
print(mask.shape)
print(mask)

weight = weight.masked_fill(mask == 0, float('-inf'))
print(weight.shape)
print(weight)

torch.Size([1, 5, 5])
tensor([[[  1.5612, -11.0765,  -7.8701,   4.0861,   8.8556],
         [-21.2073, -31.0149,  -1.0455,  -7.0928, -11.9264],
         [ 16.3344,  -5.0218,   7.0764,   1.8973,   7.3178],
         [-26.6540, -11.7751,   4.6650, -12.6990,  -9.9477],
         [-21.0176,  11.0325,  10.4081, -23.2671,  11.5787]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 5, 5])
tensor([[[ 0.0976, -0.6923, -0.4919,  0.2554,  0.5535],
         [-1.3255, -1.9384, -0.0653, -0.4433, -0.7454],
         [ 1.0209, -0.3139,  0.4423,  0.1186,  0.4574],
         [-1.6659, -0.7359,  0.2916, -0.7937, -0.6217],
         [-1.3136,  0.6895,  0.6505, -1.4542,  0.7237]]],
       grad_fn=<MulBackward0>)
torch.Size([5, 5])
tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])
torch.Size([1, 5, 5])
tensor([[[ 0.0976,    -inf,    -inf,    -inf,    -inf],
         [-1.3255, -1.9384,    -inf,    -inf,    -

In [19]:
weight = F.softmax(weight, dim=-1)
print(weight.shape)
print(weight)

torch.Size([1, 5, 5])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6486, 0.3514, 0.0000, 0.0000, 0.0000],
         [0.5483, 0.1443, 0.3074, 0.0000, 0.0000],
         [0.0769, 0.1948, 0.5444, 0.1839, 0.0000],
         [0.0415, 0.3078, 0.2960, 0.0361, 0.3185]]],
       grad_fn=<SoftmaxBackward0>)


In [15]:
attention = weight @ v
print(attention.shape)
print(attention)

torch.Size([1, 5, 256])
tensor([[[ 0.5879, -0.0186,  0.2304,  ...,  0.2544,  0.4254, -0.2685],
         [ 0.2464, -0.0304,  0.6489,  ...,  0.3445,  0.2032, -0.2074],
         [ 0.4637, -0.0886,  0.1875,  ...,  0.1926,  0.0564, -0.3371],
         [ 0.0595, -0.0926,  0.5815,  ...,  0.1686, -0.0158, -0.1472],
         [ 0.1444, -0.1478,  0.1892,  ..., -0.0977,  0.0121, -0.1017]]],
       grad_fn=<UnsafeViewBackward0>)


In [22]:
class SingleHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim):
        super().__init__()
        self.block_size = block_size
        self.embedding_dim = embedding_dim

        self.WQ = nn.Linear(embedding_dim, embedding_dim)
        self.WK = nn.Linear(embedding_dim, embedding_dim)
        self.WV = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        B, T, C = x.shape
        q = self.WQ(x)
        k = self.WK(x)
        v = self.WV(x)

        weight = q @ k.transpose(-2, -1)
        weight = weight * (self.embedding_dim ** (-0.5))
        mask = torch.tril(torch.ones(T,T))
        weight = weight.masked_fill(mask == 0, float('-inf'))
        weight = F.softmax(weight, dim=-1)
        attention = weight @ v

        return attention

In [23]:
single_head = SingleHeadAttention(block_size, embedding_dim)

attention = single_head(output)
print(attention.shape)
print(attention)

torch.Size([1, 5, 256])
tensor([[[-0.6034,  1.2071, -0.1523,  ..., -0.3828,  0.3348,  1.1338],
         [-0.8492,  0.8486, -0.0768,  ...,  0.0357,  0.2744,  0.2737],
         [-0.5541,  0.2570, -0.4950,  ..., -0.1103, -0.1412, -0.2519],
         [-0.4894,  0.0768, -0.4241,  ...,  0.1036, -0.3441, -0.6043],
         [-0.0925, -0.0803, -0.4653,  ..., -0.0508, -0.6838, -0.4758]]],
       grad_fn=<UnsafeViewBackward0>)


In [28]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, n_heads):
        super().__init__()
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        self.head_size = embedding_dim // n_heads

        self.WQKV = nn.Linear(embedding_dim, embedding_dim * 3)

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.WQKV(x).split(embedding_dim, 2)

        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        weight = q @ k.transpose(-2, -1)
        weight = weight * (self.embedding_dim ** (-0.5))
        mask = torch.tril(torch.ones(T,T))
        weight = weight.masked_fill(mask == 0, float('-inf'))
        weight = F.softmax(weight, dim=-1)
        attention = weight @ v

        # Gộp lại thành ma trận ban đầu
        attention = attention.transpose(1, 2).contiguous().view(B, T, C)

        return attention

In [29]:
multi_head = MultiHeadAttention(block_size, embedding_dim, 8)

attention = multi_head(output)
print(attention.shape)
print(attention)

torch.Size([1, 5, 256])
tensor([[[ 0.9689,  1.1806,  0.0598,  ..., -1.2010, -0.4726,  0.8283],
         [ 0.8755,  0.6932, -0.0650,  ..., -0.4005, -0.5021,  0.4413],
         [ 0.3998, -0.0140,  0.2414,  ..., -1.0763, -0.7348,  0.7071],
         [ 0.1268,  0.2359,  0.2916,  ..., -0.8905, -0.8345,  0.5315],
         [ 0.1537,  0.2146,  0.2296,  ..., -0.6426, -0.5329,  0.3252]]],
       grad_fn=<ViewBackward0>)
