# Messing with Attention

Just trying some basic tensor operations with PyTorch's attention implementation.

https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

See also, transformers pseudocode: https://arxiv.org/abs/2207.09238

Other links:

 - [nanoGPT](https://github.com/karpathy/nanoGPT)
 - [miniGPT](https://github.com/karpathy/minGPT)

In [1]:
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
e = torch.nn.Embedding(
    num_embeddings=9,
    embedding_dim=768,
)
e

Embedding(9, 768)

In [3]:
attn = torch.nn.MultiheadAttention(
    embed_dim=768,
    num_heads=2,
)
attn

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)

In [4]:
e(torch.IntTensor([1])).size()

torch.Size([1, 768])

In [5]:
e(torch.IntTensor([1, 2])).size()

torch.Size([2, 768])

In [6]:
e(torch.IntTensor([0, 5, 8])).size()

torch.Size([3, 768])

In [7]:
input_sequence = torch.IntTensor([0, 5, 8])
x = e(input_sequence)
x.size()

torch.Size([3, 768])

In [8]:
q = x
k = x
v = x
attn_outputs, attn_output_weights = attn(q, k, v)
attn_outputs.size(), attn_output_weights.size()

(torch.Size([3, 768]), torch.Size([3, 3]))

In [9]:
attn.out_proj.weight.size()
attn

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)

In [10]:
attn.__dict__.keys()

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_pre_hooks', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_state_dict_hooks', '_state_dict_pre_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', 'embed_dim', 'kdim', 'vdim', '_qkv_same_embed_dim', 'num_heads', 'dropout', 'batch_first', 'head_dim', 'bias_k', 'bias_v', 'add_zero_attn'])

In [11]:
attn._parameters["in_proj_weight"].size()

torch.Size([2304, 768])

In [12]:
2304 / 768

3.0

In [13]:
attn.in_proj_weight.size()

torch.Size([2304, 768])

In [14]:
attn.head_dim, 768 // 2

(384, 384)

### Decoder-only transformer implementation from Andrej Karpathy

Borrowed and tweaked from here: 
https://colab.research.google.com/drive/1SiF0KZJp75rUeetKOWqpsA8clmHP6jMg?usp=sharing#scrollTo=wW1-8xqswRYg

I suspect this is from miniGPT or nanoGPT, I didn't check. (MIT license)

In [15]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Heads need to cleanly divide into n_embd"
        # this assertion produces the "embed_dim must be divisible by num_heads" in nn.MultiheadAttention

        # key, query, value projections for all heads, but in a batch
        # this is the same approach nn.MultiheadAttention uses for the in_proj_weight parameter if the key, query, and value have the same dimension
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.nonlin = nn.GELU()

    def forward(self, x):
        x = self.c_fc(x)
        x = self.nonlin(x)
        x = self.c_proj(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    # these are default GPT-2 hyperparameters
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    bias: bool = False


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),
                wpe=nn.Embedding(config.block_size, config.n_embd),
                h=nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight  # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        device = idx.device
        b, t = idx.size()
        assert (
            t <= self.config.block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)  # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(
            x[:, -1, :]
        )  # note: only returning logits at the last time step (-1), output is 2D (b, vocab_size)
        return logits

In [78]:
torch.manual_seed(619)

# vocab size is 2, so we only have two possible tokens: 0,1
vocab_size = 2
# context length is 3, so we take 3 bits to predict the next bit probability
context_length = 3
config = GPTConfig(
    block_size=16,
    vocab_size=4,
    n_layer=4,
    n_head=4,
    n_embd=16,
    bias=False,
)
gpt = GPT(config)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=1e-3, weight_decay=1e-1)

NameError: name 'config' is not defined

In [None]:
for i in range(50):
    logits = gpt(X)
    loss = F.cross_entropy(logits, Y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(i, loss.item())