## Differences with GPT2 and your transformer
- The order of the LayerNorms in the decoder block have changed: they now come before the attention and MLP blocks, rather than after.
- The attention block has two dropout layers: one immediately after the softmax (i.e. before multiplying by V), and one immediately after multiplying with W_O at the very end of the attention block. Note that the dropout layers won't actually affect weight-loading or performance in eval mode (and you should still be able to train your model without them), but all the same it's nice to be able to exactly match GPT's architecture!
- All your linear layers should have biases - even though some of them are called projections (which would seem to suggest not having a bias), this is often how they are implemented.
- GPT-2 uses a learned positional embedding (i.e. nn.Embedding) rather than a sinusoidal one.

In [71]:
import torch as t
import torch.nn as nn
from einops import rearrange, repeat
from fancy_einsum import einsum
from torch import optim
from torch.utils.data import DataLoader, Dataset, random_split
from dataclasses import dataclass 
from collections import OrderedDict
import math
from typing import Optional, Union
import plotly.express as px
import torchinfo
import matplotlib as plt
import os
from tqdm.notebook import tqdm_notebook
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = t.device('cuda' if t.cuda.is_available() else 'cpu')

#@title Transformer Modules
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int # also embedding dim or d_model
    masking: str # autoregressive
    max_seq_len: int = 5000 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    masking: str # autoregressive

config = TransformerConfig(
    num_layers = 12,
    num_heads = 12,
    vocab_size = 50257,
    hidden_size = 768,
    max_seq_len = 1024,
    dropout = 0.1,
    masking='autoregressive',
    layer_norm_epsilon = 1e-05
)

class MultiLayerPerceptron(nn.Module):  

    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        d_h = d_in * 4
        self.model = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(d_in, d_h)),
            ('GELU', nn.GELU()),
            ('linear2', nn.Linear(d_h, d_in)),   
            ('dropout', nn.Dropout(p=0.1))
        ]))

    def forward(self, x: t.Tensor):
        return self.model(x)

class MultiheadMaskedAttention(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.W_QKV = nn.Linear(hidden_size, hidden_size * 3)
        self.W_O = nn.Linear(hidden_size, hidden_size)
        self.num_heads = num_heads
        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)

    def forward(self, x: t.Tensor, mask=None) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        Return: shape (batch, seq, hidden_size)
        '''
        Q, K, V = self.W_QKV(x).chunk(3, dim=-1)
        att = self.multihead_masked_attention(Q, K, V, self.num_heads)
        return self.W_O(att)

    def multihead_masked_attention(self, Q: t.Tensor, K: t.Tensor, V: t.Tensor, n_heads: int):
        '''
        Q: shape (b, s1, e)
        K: shape (b, s2, e)
        V: shape (b, s2, e)

        e = nheads * h
        b = batch
        s = seq_len
        h = hidden

        Return: shape (b s e)
        '''

        assert Q.shape[-1] % n_heads == 0
        assert K.shape[-1] % n_heads == 0
        assert V.shape[-1] % n_heads == 0
        assert K.shape[-1] == V.shape[-1]

        Q = rearrange(Q, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
        K = rearrange(K, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
        V = rearrange(V, 'b s (nheads h) -> b nheads s h', nheads=n_heads)

        batch, nheads, seq_len, headsize = Q.shape

        scaled_dot_prod = einsum('b nheads s1 h, b nheads s2 h -> b nheads s2 s1', K, Q) / (headsize ** 0.5)
        mask_filter = t.triu(t.full_like(scaled_dot_prod, -t.inf), 1)
        scaled_dot_prod += mask_filter
        attention_probs = scaled_dot_prod.softmax(dim=-1)
        attention_probs = self.dropout1(attention_probs)
        attention_vals = einsum('b nheads s1 s2, b nheads s2 c -> b nheads s1 c', attention_probs, V)
        attention = rearrange(attention_vals, 'b nheads s c -> b s (nheads c)')
        return self.dropout2(attention) 

class GPT2DecoderBlock(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiheadMaskedAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_heads
        )
        self.layernorm2 = nn.LayerNorm(config.hidden_size)
        self.mlp = MultiLayerPerceptron(config.hidden_size, config.hidden_size)
    
    def forward(self, x: t.Tensor):
        x = x + self.attention(self.layernorm1(x))
        x = x + self.mlp(self.layernorm2(x))
        return x

class GPT2(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.positional_encoding = nn.Embedding(config.max_seq_len, config.hidden_size)
        decoders = [GPT2DecoderBlock(config) for i in range(config.num_layers)]
        names = ['decoder' + str(i) for i in range(config.num_layers)]
        self.decoderlayer = nn.Sequential(OrderedDict(zip(names, decoders)))
        self.dropout = nn.Dropout(p=config.dropout)
        self.layernorm = nn.LayerNorm(config.hidden_size)

    def forward(self, tokens):
        if len(tokens.shape) == 1:
            tokens = rearrange(tokens, "seq -> 1 seq")
        embedding = self.embed(tokens) # (b, seq_len) -> (b, seq_len, embedding)
        pos_enc = self.positional_encoding(t.arange(tokens.shape[1], device=tokens.device)) # (seq_len)
        a = self.dropout(embedding + pos_enc) # (b, seq_len, embedding)
        b = self.decoderlayer(a) # (b, seq_len, embedding)
        c = self.layernorm(b) @ self.embed.weight.T # (b, seq_len, embedding) @ (embedding, vocab_size) -> (b, seq_len, vocab_size)
        return c

gpt2 = GPT2(config)

In [72]:
import transformers
transformers.AutoModelForCausalLM.from_pretrained("gpt2")
print('loaded')

loaded


In [73]:
import utils

my_gpt = GPT2(config).train()
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2").train()

utils.print_param_count(my_gpt, gpt)

Model 1, total params = 124439808


Unnamed: 0,name_1,shape_1,num_params_1
0,embed.weight,"(50257, 768)",38597376
1,positional_encoding.weight,"(1024, 768)",786432
2,decoderlayer.decoder0.layernorm1.weight,"(768,)",768
3,decoderlayer.decoder0.layernorm1.bias,"(768,)",768
4,decoderlayer.decoder0.attention.W_QKV.weight,"(2304, 768)",1769472
...,...,...,...
143,decoderlayer.decoder11.mlp.model.linear1.bias,"(3072,)",3072
144,decoderlayer.decoder11.mlp.model.linear2.weight,"(768, 3072)",2359296
145,decoderlayer.decoder11.mlp.model.linear2.bias,"(768,)",768
146,layernorm.weight,"(768,)",768


Model 2, total params = 124439808


Unnamed: 0,num_params_2,shape_2,name_2
0,38597376,"(50257, 768)",transformer.wte.weight
1,786432,"(1024, 768)",transformer.wpe.weight
2,768,"(768,)",transformer.h.0.ln_1.weight
3,768,"(768,)",transformer.h.0.ln_1.bias
4,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
...,...,...,...
143,3072,"(3072,)",transformer.h.11.mlp.c_fc.bias
144,2359296,"(3072, 768)",transformer.h.11.mlp.c_proj.weight
145,768,"(768,)",transformer.h.11.mlp.c_proj.bias
146,768,"(768,)",transformer.ln_f.weight


All parameter counts match!


Unnamed: 0,name_1,shape_1,num_params_1,num_params_2,shape_2,name_2
0,embed.weight,"(50257, 768)",38597376,38597376,"(50257, 768)",transformer.wte.weight
1,positional_encoding.weight,"(1024, 768)",786432,786432,"(1024, 768)",transformer.wpe.weight
2,decoderlayer.decoder0.layernorm1.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.weight
3,decoderlayer.decoder0.layernorm1.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.bias
4,decoderlayer.decoder0.attention.W_QKV.weight,"(2304, 768)",1769472,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
5,decoderlayer.decoder0.attention.W_QKV.bias,"(2304,)",2304,2304,"(2304,)",transformer.h.0.attn.c_attn.bias
6,decoderlayer.decoder0.attention.W_O.weight,"(768, 768)",589824,589824,"(768, 768)",transformer.h.0.attn.c_proj.weight
7,decoderlayer.decoder0.attention.W_O.bias,"(768,)",768,768,"(768,)",transformer.h.0.attn.c_proj.bias
8,decoderlayer.decoder0.layernorm2.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.weight
9,decoderlayer.decoder0.layernorm2.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.bias


In [74]:
# def copy_weights(my_gpt: GPT2, pretrained_gpt) -> GPT2:
#     '''Copy over the weights of `pretrained_resnet` to your resnet.'''

#     my_parameters = list(my_gpt.named_parameters())
#     pretrained_parameters = list(pretrained_gpt.named_parameters())

#     for i, (my_name, my_param) in enumerate(my_parameters):
#         pretrained_name, pretrained_param = pretrained_parameters[i]
#         print(my_param.shape, pretrained_param.shape)
#         print(my_name, pretrained_name)
#         if my_param.shape != pretrained_param.shape:
#             pretrained_param = rearrange(pretrained_param, 'a b -> b a')
#         assert my_param.shape == pretrained_param.shape
#         # my_param.data = pretrained_param.data.clone()
# # 
#     # # Check the number of params/buffers is correct
#     # assert len(mydict) == len(pretraineddict), "Number of layers is wrong. Have you done the prev step correctly?"

#     # # Initialise an empty dictionary to store the correct key-value pairs
#     # state_dict_to_load = {}

#     # for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(mydict.items(), pretraineddict.items()):
#     #     state_dict_to_load[mykey] = pretrainedvalue

#     # myresnet.load_state_dict(state_dict_to_load)

#     # return myresnet

# my_gpt = copy_weights(my_gpt, gpt)

In [75]:
def copy_weights(my_gpt: GPT2, gpt) -> GPT2:
    '''Copy over the weights of `gpt` to your gpt implementation.'''

    # Here we use named params not state dict, because gpt doesn't have any buffers we care about
    # (I think all its buffers are attention masks)
    my_gpt_dict = dict(my_gpt.named_parameters())
    gpt_dict = dict(gpt.named_parameters())
    
    # Check the number of params/buffers is correct
    assert len(my_gpt_dict) == len(gpt_dict), "Number of layers is wrong. Have you done the prev step correctly?"
    
    # Initialise an empty dictionary to store the correct key-value pairs
    state_dict = {}
    
    for (my_param_name, my_param), (name, param) in zip(my_gpt_dict.items(), gpt_dict.items()):
        # Sometimes params are transposed
        if len(my_param.shape) == 2 and my_param.shape == param.T.shape:
            state_dict[my_param_name] = param.T
            # print(f"Copied params.T: {name} -> {my_param_name}")
        elif my_param.shape == param.shape:
            state_dict[my_param_name] = param
            # print(f"Copied params:   {name} -> {my_param_name}")
        else:
            raise Exception(f"Parameter shapes don't match: {my_param.shape} vs {param.shape}")

    if set(state_dict.keys()) != set(my_gpt.state_dict().keys()):
        raise Exception("State dicts don't match.")
    
    my_gpt.load_state_dict(state_dict)
    
    return my_gpt

my_gpt = copy_weights(my_gpt, gpt)

In [76]:
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
utils.test_load_pretrained_weights(my_gpt, tokenizer)

Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' Soros', ' S', ' Wallace']


In [84]:
def beamsearch(n_beams: int, max_len: int, model: GPT2, tokenizer, prompt: str):
    '''Beam search for `model` with `tokenizer` and `prompt`.'''

    enc = tokenizer.encode(prompt)

    # after first iteration, beams will have n_beams
    beams = [(t.tensor(enc).to(device), 0)]
    completed = []

    with t.inference_mode():
        while len(completed) < n_beams:
            # print('current beams')
            # for beam, score in beams:
            #     print(score, tokenizer.decode(beam.tolist()))
            new_beams = []
            for i, (beam, score) in enumerate(beams):
                logits = model(beam) # (1, seq_len, vocab_size)
                logits_last = logits[:,-1,:] # (1, vocab_size)
                logprobs = logits_last.log_softmax(dim=-1) # (1, vocab_size)
                scores, tokens = logprobs.topk(n_beams, dim=-1, sorted=True) # (1, n_beams)
                for new_score, token in zip(scores[0], tokens[0]):
                    # print(new_score, tokenizer.decode(token.tolist()))
                    new_beam = t.cat([beam, t.tensor([token])], dim=0)
                    if len(completed) < n_beams:
                        if token == tokenizer.eos_token_id:
                            completed.append((beam, score + new_score))
                        else:
                            new_beams.append((new_beam, score + new_score))
            # print('new beams')
            # for beam, score in new_beams:
            #     print(score, tokenizer.decode(beam.tolist()))
            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:n_beams]
            for beam, score in new_beams:
                if len(beam) > max_len:
                    completed.append((beam, score))
            beams = new_beams
            # print('selected')
            # for beam, score in beams:
            #     print(score, tokenizer.decode(beam.tolist()))
    print('completed')
    for beam, score in completed:
        print(score, tokenizer.decode(beam.tolist()))

beamsearch(n_beams=6, max_len=22, model=my_gpt, tokenizer=tokenizer, prompt="I don't want to rule the universe. I just think")



completed
tensor(-13.8475) I don't want to rule the universe. I just think there's a lot of things that need to be done
tensor(-14.6525) I don't want to rule the universe. I just think there's a lot of things that we can do to
tensor(-14.9058) I don't want to rule the universe. I just think there's a lot of things that need to change.
tensor(-15.1158) I don't want to rule the universe. I just think there's a lot of things that are going on that
tensor(-15.1325) I don't want to rule the universe. I just think there's a lot of things we can do to make
tensor(-15.1348) I don't want to rule the universe. I just think there's a lot of things we can do to help
