## Imports & Installs

In [None]:
!pip install transformers torchtyping 

In [8]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir)))

In [69]:
from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch as t
from torch import nn
import transformers
from torchtyping import TensorType
from fancy_einsum import einsum
import einops

#
from arena.w2d2 import utils

## Load weights

In [49]:
GPT2_N_LAYERS = 12
GPT2_N_HEADS = 8
GPT2_VOCAB_SIZE = 50257
GPT2_HIDDEN_SIZE = 768
GPT2_MAX_SEQ_LEN = 1024
GPT2_DROPOUT = 0.1
GPT2_LN_EPS = 1e-05

@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''
    num_layers: int = GPT2_N_LAYERS
    num_heads: int = GPT2_N_HEADS
    vocab_size: int = GPT2_VOCAB_SIZE
    hidden_size: int = GPT2_HIDDEN_SIZE
    max_seq_len: int = GPT2_MAX_SEQ_LEN
    dropout: float = GPT2_DROPOUT
    layer_norm_epsilon: float = GPT2_LN_EPS

config = TransformerConfig()

In [50]:
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2").train()
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

In [63]:

def mask(A: TensorType[..., "seq_len", "seq_len"]) -> TensorType[..., "seq_len", "seq_len"]:
    seq_len = A.shape[-1]

    mask = t.triu(t.ones(seq_len, seq_len), diagonal=1).bool()
    return A.masked_fill(mask, -np.inf)

def multihead_masked_attention(
    Q: TensorType["b", "s", "n*h"], 
    K: TensorType["b", "s", "n*h"], 
    num_heads: int
) -> TensorType["b", "n", "s_q", "s_k"]:
    '''
    Should return the results of multihead self-attention (after softmax, before multiplying with V)
    '''
    _Q = einops.rearrange(Q, "b s (n h) -> b n s h", n=num_heads)    
    _K = einops.rearrange(K, "b s (n h) -> b n s h", n=num_heads)    

    d_head = _Q.shape[-1]

    A_pre = mask(
        einsum("b n s_q h, b n s_k h -> b n s_q s_k", _Q, _K)
    ) / np.sqrt(d_head)

    return t.softmax(A_pre, dim=-1)


def multihead_masked_attention_head(
    A: TensorType["b", "n", "s_q", "s_k"], 
    V: TensorType["b", "s", "n*h"],
    num_heads: int
) -> TensorType["batch", "seq", "n_heads*headsize"]:
    _V = einops.rearrange(V, "b s (n h) -> b n s h", n=num_heads)
    AV: TensorType["b", "n", "s_q", "h"] = einsum("b n s_q s_k, b n s_k h -> b n s_q h", A, _V)
    return einops.rearrange(AV, "b n s h -> b s (n h)") 


class GPT2Attention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear
    dropout: float

    def __init__(self, hidden_size: int, num_heads: int, dropout: float):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads

        super().__init__()

        self.c_attn = nn.Linear(hidden_size, hidden_size * 3)
        self.c_proj = nn.Linear(hidden_size, hidden_size)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x: TensorType["batch", "seq", "hidden_size"]) -> TensorType["batch", "seq", "hidden_size"]:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        Q, K, V = self.c_attn(x).chunk(3, dim=-1)        
        A = multihead_masked_attention(Q, K, self.num_heads)
        A = self.attn_dropout(A)
        h = multihead_masked_attention_head(A, V, self.num_heads)
        x = self.c_proj(h)
        return self.resid_dropout(x)

class GPT2MLP(nn.Module):

    def __init__(self, hidden_size: int, dropout: float):
        self.hidden_size = hidden_size

        super().__init__()

        self.c_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.c_fc = nn.Linear(hidden_size * 4, hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: t.Tensor) -> t.Tensor:
        x = self.c_proj(x)
        x = self.gelu(x)
        x = self.c_fc(x)
        x = self.dropout(x)

        return x

class GPT2Block(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, layer_norm_epsilon: float, dropout: float):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.layer_norm_epsilon = layer_norm_epsilon
        self.dropout = dropout

        super().__init__()

        self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = GPT2Attention(hidden_size, num_heads, dropout=dropout)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.mlp = GPT2MLP(hidden_size, dropout)

    def forward(self, x: t.Tensor) -> t.Tensor:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))

        return x


class GPT2(nn.Module):

    def __init__(self, config: TransformerConfig):
        self.config = config

        super().__init__()

        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.drop = nn.Dropout(config.dropout)
        self.h = nn.ModuleList([
            GPT2Block(config.hidden_size, config.num_heads, config.layer_norm_epsilon, config.dropout)
            for _ in range(config.num_layers)
        ])
        self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, elementwise_affine=True)

    def forward(self, x: t.Tensor) -> t.Tensor:
        pos = t.arange(x.shape[1], device=x.device)
        x = self.wte(x) + self.wpe(pos)

        x = self.drop(x)

        for h_i in self.h:
            x = h_i(x)
        
        x = self.ln_f(x)
        x = einsum("batch seq hidden, vocab hidden -> batch seq vocab", x, self.wte.weight)

        return x


In [64]:
gpt

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [65]:
my_gpt = GPT2(config).train()
my_gpt

GPT2(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Linear(in_features=768, out_features=2304, bias=True)
        (c_proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_proj): Linear(in_features=768, out_features=3072, bias=True)
        (c_fc): Linear(in_features=3072, out_features=768, bias=True)
        (gelu): GELU(approximate=none)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Linear(in_features=768, out_

In [66]:
for (name1, tens1), (name2, tens2)  in zip(my_gpt.named_parameters(), gpt.named_parameters()):
    s1, s2 = tens1.shape, tens2.shape
    if s1 != s2:
        print(name1, tens1.shape)
        print(name2, tens2.shape)
        print("\n")

print(len(list(my_gpt.parameters())), len(list(gpt.parameters())))

h.0.attn.c_attn.weight torch.Size([2304, 768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])


h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])


h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])


h.1.attn.c_attn.weight torch.Size([2304, 768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])


h.1.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.1.mlp.c_fc.weight torch.Size([768, 3072])


h.1.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.1.mlp.c_proj.weight torch.Size([3072, 768])


h.2.attn.c_attn.weight torch.Size([2304, 768])
transformer.h.2.attn.c_attn.weight torch.Size([768, 2304])


h.2.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.2.mlp.c_fc.weight torch.Size([768, 3072])


h.2.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.2.mlp.c_proj.weight torch.Size([3072, 768])


h.3.attn.c_attn.weight torch.Size([2304, 768])
trans

In [72]:

def print_param_count(*models, display_df=True, use_state_dict=False):
    """
    display_df: bool
        If true, displays styled dataframe
        if false, returns dataframe
    use_state_dict: bool
        If true, uses model.state_dict() to construct dataframe
            This will include buffers, not just params
        If false, uses model.named_parameters() to construct dataframe
            This misses out buffers (more useful for GPT)
    """
    df_list = []
    gmap_list = []
    for i, model in enumerate(models, start=1):
        print(f"Model {i}, total params = {sum([param.numel() for name, param in model.named_parameters()])}")
        iterator = model.state_dict().items() if use_state_dict else model.named_parameters()
        df = pd.DataFrame([
            {f"name_{i}": name, f"shape_{i}": tuple(param.shape), f"num_params_{i}": param.numel()}
            for name, param in iterator
        ]) if (i == 1) else pd.DataFrame([
            {f"num_params_{i}": param.numel(), f"shape_{i}": tuple(param.shape), f"name_{i}": name}
            for name, param in iterator
        ])
        display(df)
        df_list.append(df)
        gmap_list.append(np.log(df[f"num_params_{i}"]))
    df = df_list[0] if len(df_list) == 1 else pd.concat(df_list, axis=1).fillna(0)
    for i in range(1, len(models) + 1):
        df[f"num_params_{i}"] = df[f"num_params_{i}"].astype(int)
    if len(models) > 1:
        param_counts = [df[f"num_params_{i}"].values.tolist() for i in range(1, len(models) + 1)]
        if all([param_counts[0] == param_counts[i] for i in range(1, len(param_counts))]):
            print("All parameter counts match!")
        else:
            print("Parameter counts don't match up exactly.")
    if display_df:
        s = df.style
        for i in range(1, len(models) + 1):
            s = s.background_gradient(cmap="viridis", subset=[f"num_params_{i}"], gmap=gmap_list[i-1])
        with pd.option_context("display.max_rows", 1000):
            display(s)
    else:
        return df


print_param_count(my_gpt, gpt)

Model 1, total params = 124439808


Unnamed: 0,name_1,shape_1,num_params_1
0,wte.weight,"(50257, 768)",38597376
1,wpe.weight,"(1024, 768)",786432
2,h.0.ln_1.weight,"(768,)",768
3,h.0.ln_1.bias,"(768,)",768
4,h.0.attn.c_attn.weight,"(2304, 768)",1769472
...,...,...,...
143,h.11.mlp.c_proj.bias,"(3072,)",3072
144,h.11.mlp.c_fc.weight,"(768, 3072)",2359296
145,h.11.mlp.c_fc.bias,"(768,)",768
146,ln_f.weight,"(768,)",768


Model 2, total params = 124439808


Unnamed: 0,num_params_2,shape_2,name_2
0,38597376,"(50257, 768)",transformer.wte.weight
1,786432,"(1024, 768)",transformer.wpe.weight
2,768,"(768,)",transformer.h.0.ln_1.weight
3,768,"(768,)",transformer.h.0.ln_1.bias
4,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
...,...,...,...
143,3072,"(3072,)",transformer.h.11.mlp.c_fc.bias
144,2359296,"(3072, 768)",transformer.h.11.mlp.c_proj.weight
145,768,"(768,)",transformer.h.11.mlp.c_proj.bias
146,768,"(768,)",transformer.ln_f.weight


All parameter counts match!


Unnamed: 0,name_1,shape_1,num_params_1,num_params_2,shape_2,name_2
0,wte.weight,"(50257, 768)",38597376,38597376,"(50257, 768)",transformer.wte.weight
1,wpe.weight,"(1024, 768)",786432,786432,"(1024, 768)",transformer.wpe.weight
2,h.0.ln_1.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.weight
3,h.0.ln_1.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.bias
4,h.0.attn.c_attn.weight,"(2304, 768)",1769472,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
5,h.0.attn.c_attn.bias,"(2304,)",2304,2304,"(2304,)",transformer.h.0.attn.c_attn.bias
6,h.0.attn.c_proj.weight,"(768, 768)",589824,589824,"(768, 768)",transformer.h.0.attn.c_proj.weight
7,h.0.attn.c_proj.bias,"(768,)",768,768,"(768,)",transformer.h.0.attn.c_proj.bias
8,h.0.ln_2.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.weight
9,h.0.ln_2.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.bias


In [73]:

from typing import Type


def copy_weights(my_gpt: GPT2, gpt: nn.Module) -> GPT2:
    '''Copy over the weights of `gpt` to your gpt implementation.'''

    my_gpt_dict = dict(my_gpt.named_parameters())
    gpt_dict = dict(gpt.named_parameters())
    
    assert len(my_gpt_dict) == len(gpt_dict), "Number of layers is wrong. Have you done the prev step correctly?"
    
    state_dict_to_load = {}
    
    for (my_param_name, my_param), (name, param) in zip(my_gpt_dict.items(), gpt_dict.items()):
        if len(my_param.shape) == 2 and my_param.shape == param.T.shape:
            state_dict_to_load[my_param_name] = param.T
        elif my_param.shape == param.shape:
            state_dict_to_load[my_param_name] = param
        else:
            raise Exception(f"Parameter shapes don't match: {my_param.shape} vs {param.shape}")

    if set(state_dict_to_load.keys()) != set(my_gpt.state_dict().keys()):
        raise Exception("State dicts don't match.")
    
    my_gpt.load_state_dict(state_dict_to_load)
    
    return my_gpt

copy_weights(my_gpt, gpt)


utils.test_load_pretrained_weights(my_gpt, tokenizer)

Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' Washington', ' Bush', ' H', ' Soros', ',', ' Orwell', ' Mason', ' R', ' Zimmerman']
