# Transformers
I'll try to implement the masked attention transformer component to better understand the architecture involved.

In [1]:
import torch
import torch.nn as nn
from torch.nn.functional import normalize

## Creating test embedding system

In [2]:
# we create a test embedding
EMBEDDING_SIZE = 10 
CONTEXT_SIZE = 3
words = [ 
    "I",
    "had",
    "a",
    "dream",
    "machines",
    "learning"
]

wte = torch.rand([len(words), EMBEDDING_SIZE])
wte

tensor([[0.7680, 0.2334, 0.1617, 0.1945, 0.5311, 0.8125, 0.3190, 0.9987, 0.8926,
         0.1501],
        [0.5814, 0.1674, 0.2023, 0.6712, 0.2358, 0.2102, 0.3119, 0.7267, 0.9835,
         0.6880],
        [0.7755, 0.5037, 0.4213, 0.9135, 0.0250, 0.0136, 0.8709, 0.1924, 0.6385,
         0.6570],
        [0.9843, 0.9465, 0.9856, 0.9258, 0.8279, 0.8540, 0.0644, 0.3804, 0.0136,
         0.3435],
        [0.1457, 0.3402, 0.9064, 0.7949, 0.8369, 0.7340, 0.2298, 0.5078, 0.1739,
         0.1443],
        [0.9441, 0.5629, 0.0017, 0.7936, 0.3271, 0.8456, 0.4251, 0.1972, 0.5576,
         0.3225]])

In [3]:
# dummy positional encoding matrix
wpe = torch.rand([CONTEXT_SIZE, EMBEDDING_SIZE])
wpe

tensor([[0.4486, 0.2670, 0.4196, 0.3424, 0.4579, 0.7168, 0.4524, 0.1003, 0.6377,
         0.9505],
        [0.1465, 0.0730, 0.5992, 0.8240, 0.4818, 0.5363, 0.9028, 0.8181, 0.0449,
         0.9999],
        [0.0170, 0.0594, 0.6006, 0.0494, 0.6601, 0.9095, 0.4197, 0.6633, 0.6130,
         0.3377]])

In [4]:
def get_position (word: str) -> int: 
    try: 
        return words.index(word)
    except ValueError as e: 
        raise Exception("Word not in vocab")

In [5]:
# context embedding
def embed(words_list: "list[str]") -> torch.tensor: 
    assert len(words_list) <= CONTEXT_SIZE, f"Vector should have at max size {CONTEXT_SIZE}"
    positions = [get_position(word) for word in words_list]
    tokens = wte[positions]
    return tokens + wpe[0: len(words_list)]

embed(["had", "machines"])

tensor([[1.0300, 0.4343, 0.6219, 1.0136, 0.6937, 0.9271, 0.7644, 0.8270, 1.6212,
         1.6385],
        [0.2922, 0.4132, 1.5056, 1.6188, 1.3187, 1.2703, 1.1325, 1.3259, 0.2188,
         1.1442]])

In [6]:
def get_prob_distribution (word: str): 
    emb = embed([word])
    logits = torch.matmul(emb, wte.transpose(0, -1))
    prob = logits.softmax(1).flatten()
    return prob 

probs  = get_prob_distribution("had")
print(words[probs.argmax()])

had


## Creating attention layer

In [7]:
Wq = torch.rand([EMBEDDING_SIZE, EMBEDDING_SIZE])
Wk = torch.rand([EMBEDDING_SIZE, EMBEDDING_SIZE])
Wv = torch.rand([EMBEDDING_SIZE, EMBEDDING_SIZE])

In [8]:
class QKV:
    def __init__ (self, token):
        self.q = Wq @ token.flatten()
        self.k = Wk @ token.flatten()
        self.v = Wv @ token.flatten()


In [9]:
def process_attention(toks: "list[torch.tensor]") -> "list[torch.Tensor]":
    qkv = [QKV(tok) for tok in toks]
    res = []
    for x in qkv: 
        query = x.q
        keys = [_.k for _ in qkv]
        values = [_.v for _ in qkv]
        scores = [query @ key for key in keys]
        scores = torch.tensor(scores).softmax(0)
        values = [value * score for value, score in zip(values, scores)]
        res.append(sum(values))
    return res

process_attention( embed(["I", "had"]))

[tensor([4.0388, 4.5731, 4.0860, 6.4205, 5.6079, 3.5631, 5.5232, 6.3536, 5.7298,
         5.1143]),
 tensor([4.0390, 4.5730, 4.0860, 6.4208, 5.6080, 3.5630, 5.5233, 6.3535, 5.7299,
         5.1142])]

In [10]:
def process_attention_matrix(toks: torch.tensor) -> torch.tensor: 
    Q = toks @ Wq.transpose(0, -1)
    K = toks @ Wk.transpose(0, -1)
    V = toks @ Wv.transpose(0, -1)

    scores = Q @ K.transpose(0, -1)
    scores = scores.softmax(1)

    return scores @ V

process_attention_matrix(embed(["I", "had"]))

tensor([[4.0388, 4.5731, 4.0860, 6.4205, 5.6079, 3.5631, 5.5232, 6.3536, 5.7298,
         5.1143],
        [4.0390, 4.5730, 4.0860, 6.4208, 5.6080, 3.5630, 5.5233, 6.3535, 5.7299,
         5.1142]])

## Masked attention

In [11]:
def masked_attention(tokens: torch.tensor) -> torch.tensor: 
    Q = tokens @ Wq.transpose(0, -1)
    K = tokens @ Wk.transpose(0, -1)
    V = tokens @ Wv.transpose(0, -1)

    scores = Q @ K.transpose(0, -1)
    mask = torch.triu(
        torch.ones_like(scores) * float("-inf"),
        diagonal=1
    )
    masked_scores = scores + mask

    normal_masked_scores = masked_scores.softmax(1)

    return normal_masked_scores @ V

masked_attention(embed(['I','had', 'machines']))

tensor([[3.3705, 4.9609, 4.1281, 4.7428, 4.8102, 3.9217, 4.7669, 6.7278, 5.0142,
         5.5472],
        [4.0390, 4.5730, 4.0860, 6.4208, 5.6080, 3.5630, 5.5233, 6.3535, 5.7299,
         5.1142],
        [4.0368, 4.5743, 4.0862, 6.4155, 5.6055, 3.5642, 5.5209, 6.3547, 5.7276,
         5.1156]])

## Simulating a feed forward layer

In [12]:
layer1 = torch.rand([EMBEDDING_SIZE * 4, EMBEDDING_SIZE])
layer2 = torch.rand([EMBEDDING_SIZE * 4, EMBEDDING_SIZE * 4])
layer3 = torch.rand([EMBEDDING_SIZE , EMBEDDING_SIZE * 4])

In [13]:
def ff(input: torch.tensor) -> torch.tensor:
    output = layer3 @ layer2 @ layer1 @ input.transpose(0, -1)
    return output.transpose(0, -1)

ff(
    masked_attention(
        embed(["I", "dream", "learning"])
    )
)

tensor([[10260.8418, 10693.3145,  9262.3232,  9490.9883, 10262.4023,  7892.1362,
         10434.8086, 10779.0645,  9501.3799,  9467.4678],
        [12799.6396, 13344.5430, 11554.7031, 11838.2559, 12806.6699,  9844.9766,
         13015.1592, 13447.1670, 11855.3457, 11810.3477],
        [12799.6396, 13344.5430, 11554.7031, 11838.2559, 12806.6699,  9844.9766,
         13015.1592, 13447.1670, 11855.3457, 11810.3477]])

## Putting everything in a decoder block

In [14]:
def decode(tokens): 
    masked = masked_attention(tokens)
    forward = ff(masked)
    return normalize(forward)

decode(torch.rand([2, 10]))

tensor([[0.3297, 0.3437, 0.2977, 0.3050, 0.3298, 0.2537, 0.3354, 0.3465, 0.3054,
         0.3042],
        [0.3298, 0.3438, 0.2977, 0.3050, 0.3299, 0.2536, 0.3353, 0.3464, 0.3054,
         0.3043]])

## model dummy version

In [15]:
def model(words_list, decoder_heads = 12): 
    tokens = embed(words_list)
    for _ in range(decoder_heads): tokens = decode(tokens)

    logits =  tokens @ wte.transpose(0, -1)
    probs = logits.softmax(1)

    return probs.argmax(1)

model(
    ["I", "dream", "learning"]
)

tensor([3, 3, 3])