In [1]:
import torch as t
from torch import nn
from torch.utils.data import Dataset, DataLoader
import plotly.express as px
from IPython.display import display
import pandas as pd
import numpy as np
import transformers
from fancy_einsum import einsum
from dataclasses import dataclass
from tqdm.notebook import tqdm_notebook
import matplotlib

from einops import rearrange, reduce, repeat

In [2]:
import sys 
sys.path.append('../common')

import gpt_modules as gpt
import utils

In [3]:
from transformer_modules import Dropout, LayerNorm, MLP, TransformerConfig, Embedding, GELU
from general_modules import Linear

In [4]:
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [6]:
text = "turn down for what"

In [7]:
tokenizer.encode_plus(text, add_special_tokens=True, max_length=64)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


{'input_ids': [101, 1885, 1205, 1111, 1184, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [9]:
from typing import Optional


class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.query_size = int(hidden_size / num_heads)
        self.W_Q = nn.Linear(hidden_size, hidden_size)
        self.W_K = nn.Linear(hidden_size, hidden_size)
        self.W_V = nn.Linear(hidden_size, hidden_size)
        self.ff = Linear(hidden_size, hidden_size)

    def multihead_masked_attention(self, Q: t.Tensor, K: t.Tensor, V: t.Tensor, additive_attention_mask: Optional[t.Tensor], num_heads: int):
        """
        Implements multihead masked attention on the matrices Q, K and V.

        Q: shape (batch, seq, nheads*headsize)
        K: shape (batch, seq, nheads*headsize)
        V: shape (batch, seq, nheads*headsize)

        returns: shape (batch, seq, nheads*headsize)
        """
        Q = rearrange(Q, "B S (nheads headsize) -> B S nheads headsize", nheads=num_heads)
        K = rearrange(K, "B S (nheads headsize) -> B S nheads headsize", nheads=num_heads)
        V = rearrange(V, "B S (nheads headsize) -> B S nheads headsize", nheads=num_heads)

        batch_size, seq_len, nheads, headsize = Q.shape
        scores = einsum("B Qseq nheads headsize, B Kseq nheads headsize -> B nheads Qseq Kseq", Q, K)
        scores /= Q.shape[-1] ** 0.5

        if additive_attention_mask is not None:
            attention_scores = attention_scores + additive_attention_mask

        scores = t.softmax(scores, dim=-1)
        Z = einsum("B nheads Qseq Kseq, B Kseq nheads headsize -> B Qseq nheads headsize", scores, V)
        Z = rearrange(Z, "B Qseq nheads headsize -> B Qseq (nheads headsize)")
        return Z

    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor]) -> t.Tensor:
        """
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        """
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Z = self.multihead_masked_attention(Q, K, V, additive_attention_mask, self.num_heads)
        out = self.ff(Z)
        return out 


class BERTBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = MultiheadMaskedAttention(config.hidden_size, config.num_heads)
        self.lnorm1 = LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(config.hidden_size, config.dropout)
        self.lnorm2 = LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        additive_attention_mask: shape (batch, nheads=1, seqQ=1, seqK)
        '''
        attn = self.attn(x, additive_attention_mask)
        out = self.lnorm1(attn + x)
        mlp = self.mlp(out)
        out = self.lnorm2(mlp + out)
        return out


def make_additive_attention_mask(one_zero_attention_mask: t.Tensor, big_negative_number: float = -10000) -> t.Tensor:
    '''
    one_zero_attention_mask: 
        shape (batch, seq)
        Contains 1 if this is a valid token and 0 if it is a padding token.

    big_negative_number:
        Any negative number large enough in magnitude that exp(big_negative_number) is 0.0 for the floating point precision used.

    Out: 
        shape (batch, nheads=1, seqQ=1, seqK)
        Contains 0 if attention is allowed, big_negative_number if not.
    '''
    mask = 1 - one_zero_attention_mask
    mask = big_negative_number * mask
    return repeat(mask, 'B S -> B 1 1 S')

In [10]:
make_additive_attention_mask(t.tensor([[1,1,1,0,0,0]]))

tensor([[[[     0,      0,      0, -10000, -10000, -10000]]]])

In [11]:
class BertCommon(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.emb = Embedding(config.vocab_size, config.hidden_size)
        self.pos_emb = Embedding(config.max_seq_len, config.hidden_size)
        self.tkn_emb = Embedding(2, config.hidden_size)

        self.lnorm = LayerNorm(config.hidden_size)
        self.dropout = Dropout(p=config.dropout)

        decoders = [BERTBlock(config) for l in range(config.num_layers)]
        self.blocks = nn.ModuleList(decoders)
        

    def forward(
        self,
        x: t.Tensor,
        one_zero_attention_mask: Optional[t.Tensor] = None,
        token_type_ids: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        '''
        input_ids: (batch, seq) - the token ids
        one_zero_attention_mask: (batch, seq) - only used in training, passed to `make_additive_attention_mask` and used in the attention blocks.
        token_type_ids: (batch, seq) - only used for NSP, passed to token type embedding.
        '''
        # Embeddings
        pos = t.arange(x.shape[1], device=x.device)
        if not token_type_ids:
            token_type_ids = t.zeros_like(x)

        embedding = self.emb(x) + self.pos_emb(pos) + self.tkn_emb(token_type_ids)

        # Norm & Dropout
        out = self.lnorm(embedding)
        out = self.dropout(out)
        
        # Mask
        if one_zero_attention_mask:
            mask = make_additive_attention_mask(one_zero_attention_mask)
        else:
            mask = None

        for b in self.blocks:
            out = b(out, mask)

        return out

In [12]:
class BERTLanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.common = BertCommon(config)
        self.linear = Linear(config.hidden_size, config.hidden_size)
        self.gelu = GELU()
        self.lnorm = LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.tied_embed_bias = nn.Parameter(t.zeros(config.vocab_size))
    
    def forward(self, x):
        out = self.common(x)
        out = self.gelu(self.linear(out))
        out = self.lnorm(out)
        out = einsum("B S E, V E -> B S V", out, self.common.emb.weight)

        return out

In [13]:
config = TransformerConfig(
    num_layers = 12,
    num_heads = 12,
    vocab_size = 28996,
    hidden_size = 768,
    max_seq_len = 512,
    dropout = 0.1,
    layer_norm_epsilon = 1e-12
)

In [15]:
# Note: The offset between weights in my version of BERT and the original is expected.
# The original BERT uses a different embedding layer. The weights are the same, but the bias is different.
# The function below for copying weights compensates for the difference.
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
my_bert = BERTLanguageModel(config).train()
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")

utils.print_param_count(my_bert, bert)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model 1, total params = 108340804


Unnamed: 0,name_1,shape_1,num_params_1
0,tied_embed_bias,"(28996,)",28996
1,common.emb.weight,"(28996, 768)",22268928
2,common.pos_emb.weight,"(512, 768)",393216
3,common.tkn_emb.weight,"(2, 768)",1536
4,common.lnorm.weight,"(768,)",768
...,...,...,...
197,common.blocks.11.lnorm2.bias,"(768,)",768
198,linear.weight,"(768, 768)",589824
199,linear.bias,"(768,)",768
200,lnorm.weight,"(768,)",768


Model 2, total params = 108340804


Unnamed: 0,num_params_2,shape_2,name_2
0,22268928,"(28996, 768)",bert.embeddings.word_embeddings.weight
1,393216,"(512, 768)",bert.embeddings.position_embeddings.weight
2,1536,"(2, 768)",bert.embeddings.token_type_embeddings.weight
3,768,"(768,)",bert.embeddings.LayerNorm.weight
4,768,"(768,)",bert.embeddings.LayerNorm.bias
...,...,...,...
197,28996,"(28996,)",cls.predictions.bias
198,589824,"(768, 768)",cls.predictions.transform.dense.weight
199,768,"(768,)",cls.predictions.transform.dense.bias
200,768,"(768,)",cls.predictions.transform.LayerNorm.weight


Parameter counts don't match up exactly.


Unnamed: 0,name_1,shape_1,num_params_1,num_params_2,shape_2,name_2
0,tied_embed_bias,"(28996,)",28996,22268928,"(28996, 768)",bert.embeddings.word_embeddings.weight
1,common.emb.weight,"(28996, 768)",22268928,393216,"(512, 768)",bert.embeddings.position_embeddings.weight
2,common.pos_emb.weight,"(512, 768)",393216,1536,"(2, 768)",bert.embeddings.token_type_embeddings.weight
3,common.tkn_emb.weight,"(2, 768)",1536,768,"(768,)",bert.embeddings.LayerNorm.weight
4,common.lnorm.weight,"(768,)",768,768,"(768,)",bert.embeddings.LayerNorm.bias
5,common.lnorm.bias,"(768,)",768,589824,"(768, 768)",bert.encoder.layer.0.attention.self.query.weight
6,common.blocks.0.attn.W_Q.weight,"(768, 768)",589824,768,"(768,)",bert.encoder.layer.0.attention.self.query.bias
7,common.blocks.0.attn.W_Q.bias,"(768,)",768,589824,"(768, 768)",bert.encoder.layer.0.attention.self.key.weight
8,common.blocks.0.attn.W_K.weight,"(768, 768)",589824,768,"(768,)",bert.encoder.layer.0.attention.self.key.bias
9,common.blocks.0.attn.W_K.bias,"(768,)",768,589824,"(768, 768)",bert.encoder.layer.0.attention.self.value.weight


In [16]:
def copy_weights_from_bert(my_bert: BERTLanguageModel, bert: transformers.models.bert.modeling_bert.BertForMaskedLM) -> BERTLanguageModel:
    '''
    Copy over the weights from BERT to my BERT.
    '''
    my_list = list(my_bert.named_parameters())
    my_list_rearranged = my_list[1:-4] + [my_list[0]] + my_list[-4:]
    pretrained_dict = dict(bert.named_parameters())

    # Initialise an empty dictionary to store the correct key-value pairs
    state_dict = {}

    for (my_name, my_param), (pt_name, pt_param) in zip(my_list_rearranged, pretrained_dict.items()):
        print(f"my name: {my_name} my size: {my_param.shape}")
        print(f"bert name: {pt_name} bert size: {pt_param.shape}")
        state_dict[my_name] = pt_param
        
    my_bert.load_state_dict(state_dict)
    return my_bert

my_bert = copy_weights_from_bert(my_bert, bert)

my name: common.emb.weight my size: torch.Size([28996, 768])
bert name: bert.embeddings.word_embeddings.weight bert size: torch.Size([28996, 768])
my name: common.pos_emb.weight my size: torch.Size([512, 768])
bert name: bert.embeddings.position_embeddings.weight bert size: torch.Size([512, 768])
my name: common.tkn_emb.weight my size: torch.Size([2, 768])
bert name: bert.embeddings.token_type_embeddings.weight bert size: torch.Size([2, 768])
my name: common.lnorm.weight my size: torch.Size([768])
bert name: bert.embeddings.LayerNorm.weight bert size: torch.Size([768])
my name: common.lnorm.bias my size: torch.Size([768])
bert name: bert.embeddings.LayerNorm.bias bert size: torch.Size([768])
my name: common.blocks.0.attn.W_Q.weight my size: torch.Size([768, 768])
bert name: bert.encoder.layer.0.attention.self.query.weight bert size: torch.Size([768, 768])
my name: common.blocks.0.attn.W_Q.bias my size: torch.Size([768])
bert name: bert.encoder.layer.0.attention.self.query.bias bert siz

In [18]:
from typing import List


def predict(model, tokenizer, text: str, k=15) -> List[List[str]]:
    """
    Return a list of k strings for each [MASK] in the input.
    """
    model.eval()
    tokens = tokenizer.encode(text=text, return_tensors="pt")
    res = model(tokens)
    
    mask_predictions = []
    for n, input_id in enumerate(tokens.squeeze()):
        if input_id == tokenizer.mask_token_id:
            logits = res[0, n]
            top_logits_indices = t.topk(logits, k).indices
            predictions = tokenizer.decode(top_logits_indices)
            mask_predictions.append(predictions)
    
    return mask_predictions

def test_bert_prediction(predict, model, tokenizer):

    text = "Former President of the United States of America, George[MASK][MASK]"
    predictions = predict(model, tokenizer, text)
    print(f"Prompt: {text}")
    print("Model predicted: \n", "\n".join(map(str, predictions)))
    assert "Washington" in predictions[0]
    assert "Bush" in predictions[0]

test_bert_prediction(predict, my_bert, tokenizer)

Prompt: Former President of the United States of America, George[MASK][MASK]
Model predicted: 
 W Washington Bush Wallace Dewey Polk Patton H Marshall Buchanan Clinton C G Carter E
;.!? |... Johnson Press Brown Carter Anderson III Smith Jones Thompson
