In [1]:
# Add this file to the Colab workspace
# https://github.com/jmsdao/arena-v1/blob/jmsdao/w2d2/utils.py

In [2]:
!pip install transformers fancy-einsum einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from dataclasses import dataclass

import torch as t
import torch.nn as nn
import transformers

from fancy_einsum import einsum
import math
import utils

In [4]:
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.
    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''
    batch = Q.shape[0]
    seq_len = Q.shape[1]
    headsize = Q.shape[2] // num_heads

    Q = Q.reshape(batch, seq_len, num_heads, headsize)
    K = K.reshape(batch, seq_len, num_heads, headsize)
    V = V.reshape(batch, seq_len, num_heads, headsize)

    scale = t.sqrt(t.tensor(K.shape[-1]).type(t.float32))
    raw_attention_filter = einsum('b sl_Q nh hs, b sl_K nh hs -> b nh sl_Q sl_K', Q, K)
    mask_filter = t.triu(t.full_like(raw_attention_filter, -t.inf), 1)
    masked_attention_filter = t.softmax((raw_attention_filter + mask_filter) / scale, dim=-1)
    attention_values = einsum('b nh sl_Q sl_K, b sl_K nh hs -> b sl_Q nh hs', masked_attention_filter, V)
    return attention_values.reshape(batch, seq_len, num_heads * headsize)


class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        assert hidden_size % num_heads == 0, "num_heads should be divisible by hidden_size"
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
        self.W_O = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        Return: shape (batch, seq, hidden_size)
        '''
        headsize = self.hidden_size // self.num_heads

        QKV = self.W_QKV(x)        
        Q = QKV[..., :self.hidden_size]
        K = QKV[..., self.hidden_size:2*self.hidden_size]
        V = QKV[..., 2*self.hidden_size:3*self.hidden_size]
        attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
        return self.W_O(attention_values)


@dataclass
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''
    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int
    max_seq_len: int
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    print_param_count: bool = False

In [5]:
class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def forward(self, input: t.Tensor) -> t.Tensor:
        return 0.5 * input * (1.0 + t.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * t.pow(input, 3.0))))


class MLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(config.hidden_size, 4*config.hidden_size),
            # nn.GELU(),
            NewGELUActivation(),
            nn.Linear(4*config.hidden_size, config.hidden_size),
            nn.Dropout(config.dropout)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        return self.model(x)


class GPTDecoderBlock(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.att = MultiheadMaskedAttention(config.hidden_size, config.num_heads)
        self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(config)

    def forward(self, x: t.Tensor) -> t.Tensor:
        x = x + self.att(self.ln1(x))
        return x + self.mlp(self.ln2(x))

In [6]:
class MyGPT(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.positional_encoding = nn.Embedding(config.max_seq_len, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        self.decoder_blocks = nn.Sequential(
            *[GPTDecoderBlock(config) for _ in range(config.num_layers)]
        )
        self.final_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x is a tensor of token ids
        x: shape (batch, seq_len)
        Return: shape (batch, seq_len, vocab_size)
        '''
        if x.dim() == 1:  # add batch dimension if missing
            x = x.unsqueeze(0)

        pos = t.arange(x.shape[1], device=x.device)
        x = self.token_embedding(x) + self.positional_encoding(pos)
        x = self.dropout(x)
        x = self.decoder_blocks(x)
        x = self.final_ln(x)
        x = einsum("batch seq hidden, vocab hidden -> batch seq vocab", x, self.token_embedding.weight)
        return x

In [7]:
def copy_weights(my_gpt: MyGPT, gpt) -> MyGPT:
    '''Copy over the weights of `gpt` to your gpt implementation.'''

    # Here we use named params not state dict, because gpt doesn't have any buffers we care about
    # (I think all its buffers are attention masks)
    my_gpt_dict = dict(my_gpt.named_parameters())
    gpt_dict = dict(gpt.named_parameters())
    
    # Check the number of params/buffers is correct
    assert len(my_gpt_dict) == len(gpt_dict), "Number of layers is wrong. Have you done the prev step correctly?"
    
    # Initialise an empty dictionary to store the correct key-value pairs
    state_dict = {}
    
    for (my_param_name, my_param), (name, param) in zip(my_gpt_dict.items(), gpt_dict.items()):
        # Sometimes params are transposed
        if len(my_param.shape) == 2 and my_param.shape == param.T.shape:
            state_dict[my_param_name] = param.T
            # print(f"Copied params.T: {name} -> {my_param_name}")
        elif my_param.shape == param.shape:
            state_dict[my_param_name] = param
            # print(f"Copied params:   {name} -> {my_param_name}")
        else:
            raise Exception(f"Parameter shapes don't match: {my_param.shape} vs {param.shape}")

    if set(state_dict.keys()) != set(my_gpt.state_dict().keys()):
        raise Exception("State dicts don't match.")
    
    my_gpt.load_state_dict(state_dict)
    
    return my_gpt

In [8]:
config = TransformerConfig(
    num_layers = 12,
    num_heads = 12,
    vocab_size = 50257,
    hidden_size = 768,
    max_seq_len = 1024,
    dropout = 0.1,
    layer_norm_epsilon = 1e-05,
    print_param_count = False
)

In [9]:
mygpt = MyGPT(config).eval()
gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2").eval()

In [10]:
# utils.print_param_count(mygpt, gpt2)

In [11]:
mygpt = copy_weights(mygpt, gpt2)

In [12]:
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

In [13]:
def encode(text: str) -> t.Tensor:
    """Return a Tensor of shape (batch=1, seq)."""
    return tokenizer(text, return_tensors="pt")["input_ids"]

def get_next_tokens(model, tokenizer, prompt):
    model.eval()
    device = next(model.parameters()).device

    input_ids = encode(prompt).to(device)
    with t.inference_mode():
        output = model(input_ids)
        logits = output[0, -1] if isinstance(output, t.Tensor) else output.logits[0, -1]
    topk = t.topk(logits, k=10).indices
    next_tokens = tokenizer.batch_decode(topk.reshape(-1, 1))
    return next_tokens

In [14]:
prompt = "Former President of the United States of America, George"
get_next_tokens(gpt2, tokenizer, prompt)

[' W',
 ' H',
 ' Bush',
 ' Washington',
 ' HW',
 ' Herbert',
 ' Pat',
 ' S',
 ' Soros',
 ' Wallace']

In [15]:
prompt = "Former President of the United States of America, George"
get_next_tokens(mygpt, tokenizer, prompt)

[' W',
 ' H',
 ' Bush',
 ' Washington',
 ' HW',
 ' Herbert',
 ' Pat',
 ' S',
 ' Soros',
 ' Wallace']

In [17]:
utils.test_load_pretrained_weights(mygpt, tokenizer)

Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' S', ' Soros', ' Wallace']
