### This notebook will implement a causal mask in the MultiHeadAttention class

In [3]:
import torch
from torch import nn

In [71]:
class MultiHeadAttention(nn.Module):
    # TODO: Add causal and padding mask functionality
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, causal_mask = False, padding_mask = False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = (
            d_out // num_heads
        )  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.causal_mask = causal_mask
        self.padding_mask = padding_mask
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        if self.causal_mask:
            # Original mask truncated to the number of tokens and converted to boolean
            mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

            # Use the mask to fill attention scores
            attn_scores.masked_fill_(mask_bool, -torch.inf)
            print(attn_scores)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

In [83]:
d_in = 768
d_out = 768 # change this later
context_length = 10
dropout = 0.0
num_heads = 6
qkv_bias = False
causal_mask = True



In [110]:
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
lang = "en"
tokenizer_path = Path(f"tokenizers/tokenizer_{lang}.json")
tokenizer = Tokenizer.from_file(str(tokenizer_path))


In [111]:
tokenizer.encode("[PAD]").ids

[1]

In [105]:
tokenizer

Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"[UNK]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"[PAD]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"[SOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"[EOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=Whitespace(), post_processor=None, decoder=None, model=WordLevel(vocab={"[UNK]":0, "[PAD]":1, "[SOS]":2, "[EOS]":3, ",":4, ".":5, "e":6, "di":7, "che":8, "—":9, "’":10, "la":11, "non":12, "a":13, "il":14, "un":15, "in":16, "per":17, "si":18, ";":19, "con":20, "una":21, "era":22, "le":23, "l":24, "mi":25, "ma":26, "è":27, "da":28, "'":29, "?":30, "del":31, "i":32, "come":33, "più":34, "della":35, "lo":36, "disse":37, "gli":

In [94]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
pad_token = tokenizer.encode("<pad>")
pad_token

[27, 15636, 29]

In [95]:
print(pad_token)

[27, 15636, 29]


In [88]:
mod = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, causal_mask=causal_mask)

In [89]:
x = torch.rand(3,10,d_in) # I think d_in is the dimension of the embedding? 

In [90]:
test = mod(x)

tensor([[[[-0.2853,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
          [ 0.1148,  0.8548,    -inf,  ...,    -inf,    -inf,    -inf],
          [-0.3284,  0.1110,  0.5591,  ...,    -inf,    -inf,    -inf],
          ...,
          [ 0.2573,  0.4953,  0.6885,  ..., -0.6523,    -inf,    -inf],
          [-0.1536,  0.2995,  0.2587,  ..., -1.3150,  0.1155,    -inf],
          [-1.2021, -0.6269, -0.3970,  ..., -1.7746, -0.3770, -1.1426]],

         [[-1.4818,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
          [-1.2557, -1.7559,    -inf,  ...,    -inf,    -inf,    -inf],
          [-1.1673, -0.1226, -1.0296,  ...,    -inf,    -inf,    -inf],
          ...,
          [-1.0485, -1.3166, -1.3198,  ..., -1.8844,    -inf,    -inf],
          [-1.1410, -0.0911, -0.8937,  ..., -2.1074, -0.1804,    -inf],
          [-2.0097, -1.2315, -1.4400,  ..., -2.9487, -1.2072, -1.0436]],

         [[-0.4530,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
          [-1.9940, -0.0998,  

In [51]:
keys = mod.W_key(x)
queries = mod.W_query(x)
values = mod.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, mod.num_heads, mod.head_dim)
values = values.view(b, num_tokens, mod.num_heads, mod.head_dim)
queries = queries.view(b, num_tokens, mod.num_heads, mod.head_dim)

keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

In [57]:
mod(x)[0]

tensor([[ 0.1161, -0.0640,  0.1574,  ...,  0.2085,  0.1012, -0.1956],
        [ 0.1173, -0.0626,  0.1569,  ...,  0.2088,  0.1014, -0.1943],
        [ 0.1169, -0.0623,  0.1561,  ...,  0.2087,  0.1001, -0.1940],
        ...,
        [ 0.1164, -0.0620,  0.1552,  ...,  0.2089,  0.1013, -0.1949],
        [ 0.1185, -0.0630,  0.1575,  ...,  0.2050,  0.1004, -0.1913],
        [ 0.1157, -0.0629,  0.1573,  ...,  0.2075,  0.1015, -0.1941]],
       grad_fn=<SelectBackward0>)

In [53]:
keys

tensor([[[[ 3.9036e-01, -6.1331e-01,  5.6162e-01,  ...,  5.3590e-01,
           -1.5826e-01, -3.4378e-01],
          [ 3.8729e-01, -4.0058e-01,  2.9023e-01,  ...,  4.3854e-01,
           -1.0273e-01, -7.6853e-01],
          [ 4.0646e-01, -5.2343e-01,  3.4935e-01,  ...,  3.8613e-01,
           -2.0635e-01, -6.1847e-01],
          ...,
          [ 4.2612e-01, -3.8096e-01,  1.6191e-01,  ...,  7.0502e-01,
           -3.6284e-01, -5.4460e-01],
          [ 7.0977e-01, -7.2567e-01,  4.4282e-01,  ...,  3.9433e-01,
           -4.5750e-01, -3.1949e-01],
          [ 7.6468e-01, -6.6319e-01,  6.3531e-01,  ...,  3.7257e-01,
            1.2986e-02, -4.6432e-01]],

         [[-1.0813e-01, -1.7516e-01,  3.0420e-01,  ...,  4.0926e-01,
           -1.2339e-01, -2.7586e-01],
          [ 4.2742e-01, -2.7741e-01,  7.5614e-02,  ...,  5.8028e-01,
           -1.4355e-01,  1.4073e-01],
          [ 3.8670e-01,  9.4920e-02,  4.1481e-01,  ...,  2.0570e-01,
            5.8262e-03,  6.7882e-02],
          ...,
     