# Chapter 3 - attention mechanism
In this part I will implement a self attention multi head class, as in ch3 of the book.

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import json

In [None]:
torch.manual_seed(123)

<torch._C.Generator at 0x10db24ef0>

In [14]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Success: Using Apple M2 GPU Acceleration")
else:
    device = torch.device("cpu")
    print("Using CPU (Slow)")

Success: Using Apple M2 GPU Acceleration


In [4]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method. 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec



In [5]:

# stand alone multi-head attention, with weight split
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in,
        d_out,
        context_length,
        dropout,
        num_heads,
        qkv_bias=False,
        hidden_dim=None,
    ):
        super(MultiHeadAttention, self).__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.d_head = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.context_length = context_length
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool(),
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # Linear projections
        queries = (
            self.W_query(x)
            .view(b, num_tokens, self.num_heads, self.d_head)
            .transpose(1, 2)
        )
        keys = (
            self.W_key(x)
            .view(b, num_tokens, self.num_heads, self.d_head)
            .transpose(1, 2)
        )
        values = (
            self.W_value(x)
            .view(b, num_tokens, self.num_heads, self.d_head)
            .transpose(1, 2)
        )

        # Compute attention scores
        scores = queries @ keys.transpose(-2, -1)
        scores.masked_fill_(self.mask[:num_tokens, :num_tokens], float("-inf"))
        weights = nn.functional.softmax(scores / self.d_head**0.5, dim=-1)
        weights = self.dropout(weights)

        # Compute context vector as weighted sum of values
        context_vec = weights @ values
        context_vec = context_vec.transpose(1, 2).contiguous().view(b, num_tokens, -1)
        return context_vec

Check with sample of the data

In [13]:
text = json.load(open("output/text_stanza.json", "r", encoding="utf-8"))
vocab = sorted(set(text))
sample_sentence = text[760:770] 
print(sample_sentence)
word2idx = {token: idx for idx, token in enumerate(vocab)}
idx2word = {idx: token for idx, token in enumerate(vocab)}

['ירדף', 'אויב', 'נפשי', 'ו', 'ישג', 'ו', 'ירמס', 'ל', 'ארץ', 'חיי']


Load embeddings

In [15]:
state = torch.load('output/word2vec_model_stanza.pth', map_location=device)
VOCAB_SIZE = len(vocab)
emb_tensor = None

if isinstance(state, dict):
    # prefer common embedding keys
    for key in (
        "in_embed.weight",
        "in_embed.weight",
        "out_embed.weight",
        "embeddings.weight",
        "embedding.weight",
        "encoder.weight",
    ):
        if (
            key in state
            and isinstance(state[key], torch.Tensor)
            and state[key].dim() == 2
        ):
            if state[key].size(0) == VOCAB_SIZE:
                emb_tensor = state[key]
                break

    # fallback: first 2D tensor whose first dim matches vocab size
    if emb_tensor is None:
        for k, v in state.items():
            if isinstance(v, torch.Tensor) and v.dim() == 2 and v.size(0) == VOCAB_SIZE:
                emb_tensor = v
                break

if emb_tensor is None:
    raise RuntimeError(
        "Couldn't find a 2D embedding tensor matching VOCAB_SIZE in the saved state_dict."
    )
emb_tensor = emb_tensor.to(device)


Sample embeddings

In [16]:
sample_embeddings = None
emb_list = []
for t in sample_sentence:
    if t in word2idx:
        emb = emb_tensor[word2idx[t]]
        emb_list.append(emb)
sample_embeddings = torch.stack(emb_list) if emb_list else None
if sample_embeddings is None:
    raise RuntimeError("No embeddings found for sample_sentence tokens.")
sample_embeddings = sample_embeddings.to(device)
print(f"word embeddings for sample sentence: {sample_embeddings.shape}")

word embeddings for sample sentence: torch.Size([10, 100])
