In [None]:
%pip install -r requirements.txt

In [55]:
import sys
import pandas as pd
import einops
from dataclasses import dataclass
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate

In [None]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(f"Here is the configuration for the model we will be building: {cfg}")

Here is the configuration for the model we will be building: Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        mean = t.mean(residual, dim=-1, keepdim=True)
        var = t.var(residual, unbiased=False, dim=-1, keepdim=True)
        eps = cfg.layer_norm_eps

        y = (residual - mean) / t.sqrt(var + eps)
        return (self.w * y) + self.b
    
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range) # init the weights sampling from a normal distribution

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
      return self.W_E[tokens]
  
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
      num_batches, seq_len = tokens.shape
      return self.W_pos[:seq_len].repeat(num_batches, 1, 1)

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # get keys, queries, and values

        keys = einops.einsum(
            normalized_resid_pre, self.W_K,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head") + self.b_K

        queries = einops.einsum(
            normalized_resid_pre, self.W_Q,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head") + self.b_Q

        values = einops.einsum(
            normalized_resid_pre, self.W_V,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head") + self.b_V

        attn_scores = einops.einsum(
            queries, keys,
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K")

        attn_scores = attn_scores / cfg.d_head ** 0.5

        attn_scores = self.apply_causal_mask(attn_scores)

        attn_probs = t.softmax(attn_scores, dim=-1)

        weighted = einops.einsum(
            values, attn_probs,
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
        )

        out = einops.einsum(
            weighted, self.W_O,
            "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
        ) + self.b_O

        return out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE) # leaves us with a lower triangular matrix
        return attn_scores

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        pre = einops.einsum(
            normalized_resid_mid, self.W_in,
            "batch position d_model, d_model d_mlp -> batch position d_mlp",
        ) + self.b_in

        post = gelu_new(pre)

        mlp_out = einops.einsum(
            post, self.W_out,
            "batch position d_mlp, d_mlp d_model -> batch position d_model",
        ) + self.b_out
        return mlp_out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return einops.einsum(normalized_resid_final, self.W_U, "batch position d_model, d_model d_vocab -> batch position d_vocab") + self.b_U

In [None]:
class Transformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        residual_stream = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
          residual_stream = block(residual_stream)
        residual_stream = self.ln_final(residual_stream)
        out = self.unembed(residual_stream)
        return out

In [None]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

In [None]:
demo_gpt2 = Transformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
tokens = t.tensor([reference_gpt2.tokenizer.encode("Hello, my name is")])

demo_gpt2(tokens)

In [None]:
@dataclass
class TransformerTrainingArgs():
	batch_size = 16
	epochs = 20
	max_steps_per_epoch = 200
	lr = 1e-3
	weight_decay = 1e-2
	wandb_project: Optional[str] = "jackson-transformer"
	wandb_name: Optional[str] = None

In [None]:
# 10k pile dataset
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")

tokenized_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [49]:
class TransformerTrainer:
	def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
		super().__init__()
		self.model = model
		self.args = args
		self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
		self.step = 0


	def training_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
		'''
		Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

		Remember that `batch` is a dictionary with the single key 'tokens'.
		'''
		tok = batch['tokens'].to(device)
		logits = self.model(tok) # get outputs
		# compute loss
		loss = -get_log_probs(logits, tok).mean()
		loss.backward() # compute grad
		self.optimizer.step() # update grad
		self.optimizer.zero_grad() # zero out the grad to prevent accumulation
		self.step += 1 # step up
		wandb.log({"train_loss": loss}, step=self.step) # log it
		return loss

	def validation_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]):
		'''
		Calculates & returns the accuracy on the tokens in the batch (i.e. how often the model's prediction
		is correct). Logging should happen in the `train` function (after we've computed the accuracy for
		the whole validation set).
		'''
		tok = batch['tokens'].to(device)
		logits = self.model(tok)[:, :-1] # get outputs, exclude the last prediction
		predicts = logits.argmax(dim=-1)

		correct_predicts = (predicts == tok[:, 1:]).flatten()
		return correct_predicts # calculating acc will be in training step


	def train(self):
		'''
		Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
		for each epoch at `self.args.max_steps_per_epoch` steps.
		'''
		wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
		accuracy = np.nan

		pb = tqdm(total = self.args.max_steps_per_epoch * self.args.epochs)

		for epoch in range(self.args.epochs):
			for i, batch in enumerate(self.train_loader()):
				loss = self.training_step(batch)
				pb.update()
				pb.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.2f}")
				if i >= self.args.max_steps_per_epoch:
					break

			correct_predictions = t.concat([self.validation_step(batch) for batch in self.test_loader()])
			accuracy = correct_predictions.float().mean().item()
			wandb.log({"accuracy": accuracy}, step=self.step)

		wandb.finish()

	def train_loader(self) -> DataLoader:
		'''Returns train loader (as in code above).'''
		return DataLoader(dataset_dict["train"], batch_size=self.args.batch_size, shuffle=True, num_workers=4, pin_memory=True)


	def test_loader(self) -> DataLoader:
		'''Returns test loader (as in code above).'''
		return DataLoader(dataset_dict["test"], batch_size=self.args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
model = DemoTransformer(cfg)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()