In [337]:
import numpy as np
#!pip install transformers
# !pip install transformer_lens
import torch as t
import torch.nn as nn
import einops
import math
from transformers import GPT2Tokenizer
from tqdm import tqdm
import datasets
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate
import wandb
import json
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from torch import Tensor

In [338]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device

device(type='cpu')

In [339]:
cfg = {
    "d_model" : 768,
    "n_heads" : 12,
    "d_vocab" : 50257,
    "context" : 2000,
    "epsilon" : 1e-5,
    "d_mlp" : 3072,
    "n_layers" : 12,
    "d_head" : 64
}

In [340]:
class Embedding(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.w = nn.Parameter(t.empty((cfg["d_vocab"], cfg["d_model"])))
        nn.init.normal_(self.w)
        
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.w[tokens]

In [341]:
class PositionEmbedding(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.w = nn.Parameter(t.empty((cfg["context"], cfg["d_model"])))
        nn.init.normal_(self.w)
        
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        batch_size, seq_len = tokens.shape
        return self.w[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1)

In [342]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg["d_model"]))
        self.b = nn.Parameter(t.zeros(cfg["d_model"]))
        
        
    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:   
        mean = residual.mean(-1, keepdim=True)
        var = residual.var(-1, keepdim=True)
        x_hat = (residual - mean) / (var + self.cfg["epsilon"]).sqrt()
        return x_hat * self.w + self.b

In [343]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w_q = nn.Parameter(t.empty((cfg["n_heads"], cfg["d_model"], cfg["d_head"])))
        self.w_k = nn.Parameter(t.empty((cfg["n_heads"], cfg["d_model"], cfg["d_head"])))
        self.w_v = nn.Parameter(t.empty((cfg["n_heads"], cfg["d_model"], cfg["d_head"])))
        self.w_o = nn.Parameter(t.empty((cfg["n_heads"], cfg["d_head"], cfg["d_model"])))
        self.b_q = nn.Parameter(t.zeros((cfg["n_heads"], cfg["d_head"])))
        self.b_k = nn.Parameter(t.zeros((cfg["n_heads"], cfg["d_head"])))
        self.b_v = nn.Parameter(t.zeros((cfg["n_heads"], cfg["d_head"])))
        self.b_o = nn.Parameter(t.zeros((cfg["d_model"])))
        nn.init.normal_(self.w_q)
        nn.init.normal_(self.w_k)
        nn.init.normal_(self.w_v)
        nn.init.normal_(self.w_o)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))
        
    def forward(self, normalised_resid_pre: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        k = einops.einsum(normalised_resid_pre, self.w_k, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_k
        q = einops.einsum(normalised_resid_pre, self.w_q, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_q
        v = einops.einsum(normalised_resid_pre, self.w_v, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_v
        
        qk = einops.einsum(q, k, "batch posn_q n_heads d_head, batch posn_k n_heads d_head -> batch n_heads posn_q posn_k")
        
        attention_probs = (self.apply_causal_mask(qk / math.sqrt(self.cfg["d_head"]))).softmax(-1)
        
        weighted_values = einops.einsum(v, attention_probs, "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head")
        
        out = einops.einsum(weighted_values, self.w_o, "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model") + self.b_o
        return out
        
    def apply_causal_mask(self, attention_scores: Float[Tensor, "batch n_heads query_pos key_pos"]) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        all_ones = t.ones(attention_scores.size(-2), attention_scores.size(-1), device=attention_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        return attention_scores.masked_fill_(mask, self.IGNORE)

In [344]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.w_in = nn.Parameter(t.empty((cfg["d_model"], cfg["d_mlp"])))
        self.w_out = nn.Parameter(t.empty((cfg["d_mlp"], cfg["d_model"])))
        self.b_in = nn.Parameter(t.zeros((cfg["d_mlp"])))
        self.b_out = nn.Parameter(t.zeros((cfg["d_model"])))
        nn.init.normal_(self.w_in)
        nn.init.normal_(self.w_out)
        
    def forward(self, normalised_resid_mid: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        proj_input = einops.einsum(normalised_resid_mid, self.w_in, "batch posn d_model, d_model d_mlp -> batch posn d_mlp") + self.b_in
        gelu = nn.GELU()
        activated_hidden = gelu(proj_input)
        return einops.einsum(activated_hidden, self.w_out, "batch posn d_mlp, d_mlp d_model -> batch posn d_model") + self.b_out        

In [345]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layernorm1 = LayerNorm(cfg)
        self.attention = Attention(cfg)
        self.layernorm2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
        
    def forward(self, resid_pre: Float[Tensor, "batch position d_model"]) -> Float[Tensor, "batch position d_model"]:
        resid_pre_mlp = self.attention(self.layernorm1(resid_pre)) + resid_pre
        return self.mlp(self.layernorm2(resid_pre_mlp)) + resid_pre_mlp

In [346]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.w = nn.Parameter(t.empty((cfg["d_model"], cfg["d_vocab"])))
        nn.init.normal_(self.w)
        self.b = nn.Parameter(t.zeros((cfg["d_vocab"]), requires_grad=False))
        
    def forward(self, normalised_resid_final: Float[Tensor, "batch position d_model"]) -> Float[Tensor, "batch position d_vocab"]:
        return einops.einsum(normalised_resid_final, self.w, "batch position d_model, d_model d_vocab -> batch position d_vocab") + self.b

In [347]:
class GPT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embedding(cfg)
        self.pos_embed = PositionEmbedding(cfg)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.layernorm = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
        
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.transformer_blocks:
            residual = block(residual)
        logits = self.unembed(self.layernorm(residual))
        return logits  

In [348]:
tokeniser = GPT2Tokenizer.from_pretrained('gpt2', padding=True)

In [349]:
tokens = tokeniser(["The best thing about England"])["input_ids"]
tokens = t.tensor(tokens).to(device)

In [350]:
def logits_to_strings(logits):
    outputs = logits.softmax(dim=-1)
    next_token_dist = t.distributions.categorical.Categorical(probs=outputs[0, -1])
    next_token = next_token_dist.sample()
    return tokeniser.decode(next_token.item())

In [351]:
def cross_entropy_loss(logits, tokens): 
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens

In [352]:
# log_probs = cross_entropy_loss(logits, tokens)
# print(f"Average cross entropy loss: {-log_probs.mean():.4f}")
# print(f"Average probability assigned to correct token: {log_probs.exp().mean():4f}")

In [353]:
train_model_cfg = {
    "d_model" : 256,
    "n_heads" : 4,
    "d_vocab" : 50257,
    "context" : 256,
    "epsilon" : 1e-5,
    "d_mlp" : 1024,
    "n_layers" : 2,
    "d_head" : 64
}

In [354]:
train_args = {
    "batch_size" : 32,
    "epochs" : 1000,
    "max_steps_per_epoch" : 200,
    "lr" : 1e-3,
    "weight_decay" : 1e-2,
    "wandb_project" : "Transformer",
    "wandb_name" : None
}

In [355]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
tokenised_dataset = tokenize_and_concatenate(dataset, tokeniser, streaming=False, max_length=train_model_cfg["context"], column_name="text", add_bos_token=True, num_proc=4)
dataset_dict = tokenised_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(dataset_dict["train"], batch_size=train_args["batch_size"], shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset_dict["test"], batch_size=train_args["batch_size"], shuffle=False, num_workers=4, pin_memory=True)

In [356]:
class TransformerTrainer:
    def __init__(self, args, model):
        super().__init__()
        self.model = model
        self.args = args
        self.optimiser = t.optim.AdamW(self.model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])
        self.step = 0

    def training_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)
        loss = -cross_entropy_loss(logits, tokens).mean()
        loss.backward()
        self.optimiser.step()
        self.optimiser.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss    
    
    def validation_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]):
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)[:, :-1]
        prob_logits = logits.softmax(dim=-1)
        dist = t.distributions.categorical.Categorical(probs=prob_logits)
        predicted_tokens = dist.sample().squeeze()
        correct_predictions = (predicted_tokens == tokens[:, 1:]).flatten()
        return correct_predictions
    
    def train(self):
        wandb.init(project=self.args["wandb_project"], name=self.args["wandb_name"], config=self.args)
        
        progress_bar = tqdm(total = self.args["max_steps_per_epoch"] * self.args["epochs"])
        accuracy = np.nan
        
        for epoch in range(self.args["epochs"]):
            for i, batch in enumerate(self.train_loader()):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args["max_steps_per_epoch"]:
                    break
                    
            correct_predictions = t.concat([self.validation_step(batch) for batch in self.test_loader()])
            accuracy = correct_predictions.float().mean().item()
            wandb.log({"accuracy": accuracy}, step=self.step)
            if epoch % 200 == 0:
                self.save_model(epoch)
            
        wandb.finish()
                    
    def train_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["train"], batch_size=train_args["batch_size"], shuffle=True, num_workers=4, pin_memory=True)
    
    def test_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["test"], batch_size=train_args["batch_size"], shuffle=False, num_workers=4, pin_memory=True)
    
    def save_model(self, i):
        t.save(self.model.state_dict(), f"gpt_model_weights{i}.pth")
        with open(f"gpt_model_config{i}.json", "w") as f:
            json.dump(train_model_cfg, f)
        with open(f"gpt_train_args{i}", "w") as f:
            json.dump(train_args, f)

In [357]:
model = GPT(train_model_cfg).to(device)
string2 = "I"
for i in range(100):
    tokens = tokeniser([string2])["input_ids"]
    tokens = t.tensor(tokens).to(device)
    logits = model(tokens)
    string2 += logits_to_strings(logits)

print(string2)    

I Breath helpslake materiallyonte anyway Micha Abbey Sn persists sensibilities softened Chaser ankle152 wi shakes Sn integrated materiallyicycle Falk ansproduct Avasus cobhericalnai Falkoine foremostonianTRUMPwife counterpartposureets cached cells compensated desireispplex abstotonin wallet knockoutullivanTRUMP Conanented grasp anecdotal Sheridan schededom Live donkeyFu materially Kah Bahrain Duffy amber challenges Falk Nassample arcane universally enclave spiders descriptions Finecorruptionprinted materially travellersshaw hauntptives towedptivesLikewise Farmingbusiness neigh lobbiedThrow communicating sects vilenie culturallySentOSEDMaria insaneBright


In [358]:
trainer = TransformerTrainer(train_args, model)
# trainer.train()

In [375]:
#load model with 600 epochs for comparison
model600 = GPT(train_model_cfg).to(device)
weights = t.load("gpt_model_weights600.pth", map_location=t.device("cpu"))
model600.load_state_dict(weights)

<All keys matched successfully>

In [376]:
#load model with 800 epochs for comparison
model800 = GPT(train_model_cfg).to(device)
weights = t.load("gpt_model_weights800.pth", map_location=t.device("cpu"))
model800.load_state_dict(weights)

<All keys matched successfully>

In [392]:
string800 = "You"
string600 = string800
for i in range(100):
    tokens = tokeniser([string800])["input_ids"]
    tokens = t.tensor(tokens).to(device)
    logits = model800(tokens)
    string800 += logits_to_strings(logits)

print(string800)    

You never before those other." " No-11 or get well done with you." "10amance-467." "It was just it thinking he was funny anyway." "I didn't get a quite agree as a sunrise himself with the Steve chandeliers." "What are there best with me?" "?" "It was that frosting nervous." "Dan's seems that it blinded when you?" He said that he came in this liking and an old Thirteenpx in the way." "


In [393]:
for i in range(100):
    tokens = tokeniser([string600])["input_ids"]
    tokens = t.tensor(tokens).to(device)
    logits = model600(tokens)
    string600 += logits_to_strings(logits)

print(string600)   

You deserve to meet you haven't all the money, then fled subject?"

<|endoftext|>The port is onboard right?Do you know you and to pay for their pace. Certainly comes in five hours before daylight. Again, but the on your senses all every day." "and tints." "The in the morning." "It is in, then ask Assumed by official sometime." "She connected with her." "[ Mickey 7800" "Ditch towards me." "It was so interesting,
