# BERT Pretraining

In [2]:
import hashlib
import os
import sys
import zipfile
import torch as t
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from tqdm.notebook import tqdm_notebook
from IPython.display import display
import transformers
from einops import rearrange
from torch.nn import functional as F
from tqdm import tqdm
import requests
import utils

MAIN = __name__ == "__main__"
DATA_FOLDER = "./data"
DATASET = "2"
BASE_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
DATASETS = {"103": "wikitext-103-raw-v1.zip", "2": "wikitext-2-raw-v1.zip"}
TOKENS_FILENAME = os.path.join(DATA_FOLDER, f"wikitext_tokens_{DATASET}.pt")

if not os.path.exists(DATA_FOLDER):
    os.mkdir(DATA_FOLDER)

In [3]:
def maybe_download(url: str, path: str) -> None:
    '''
    Download the file from url and save it to path. 
    If path already exists, do nothing.
    '''
    if not os.path.exists(path):
        print("Downloading from {} to {}".format(url, path))
        r = requests.get(url, stream=True)
        with open(path, 'wb') as f:
            data = requests.get(url).content
            f.write(data)
    else:
        print("File {} already exists".format(path))

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [5]:
path = os.path.join(DATA_FOLDER, DATASETS[DATASET])
maybe_download(BASE_URL + DATASETS[DATASET], path)
expected_hexdigest = {"103": "0ca3512bd7a238be4a63ce7b434f8935", "2": "f407a2d53283fc4a49bcff21bc5f3770"}
with open(path, "rb") as f:
    actual_hexdigest = hashlib.md5(f.read()).hexdigest()
    assert actual_hexdigest == expected_hexdigest[DATASET]

print(f"Using dataset WikiText-{DATASET} - options are 2 and 103")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

z = zipfile.ZipFile(path)

def decompress(*splits: str) -> str:
    return [
        z.read(f"wikitext-{DATASET}-raw/wiki.{split}.raw").decode("utf-8").splitlines()
        for split in splits
    ]

train_text, val_text, test_text = decompress("train", "valid", "test")

File ./data/wikitext-2-raw-v1.zip already exists
Using dataset WikiText-2 - options are 2 and 103


In [6]:
train_text_long = "".join(train_text)
tkns = tokenizer.encode(train_text_long, return_tensors='pt')
len(tkns)

Token indices sequence length is longer than the specified maximum sequence length for this model (2378953 > 512). Running this sequence through the model will result in indexing errors


1

In [7]:
tkns.shape

torch.Size([1, 2378953])

In [8]:
tkns = tkns.squeeze(dim=0)
discard = tkns.shape[0] % 128
tkns = tkns[:-discard]
tkns.shape

torch.Size([2378880])

In [9]:
tkns.shape[0] / 128

18585.0

In [10]:
tkns = rearrange(tkns, '(batch seq) -> batch seq', seq=128)

In [11]:
def tokenize_1d(tokenizer, lines: list[str], max_seq: int) -> t.Tensor:
    '''Tokenize text and rearrange into chunks of the maximum length.

    Return (batch, seq) and an integer dtype.
    '''
    lines = ''.join(lines)
    tkns = tokenizer.encode(lines, return_tensors='pt')
    tkns = tkns.squeeze(dim=0)
    discard = tkns.shape[0] % max_seq
    tkns = rearrange(tkns[:-discard], '(batch seq) -> batch seq', seq=max_seq)

    return tkns

if MAIN:
    max_seq = 128
    print("Tokenizing training text...")
    train_data = tokenize_1d(tokenizer, train_text, max_seq)
    print("Training data shape is: ", train_data.shape)
    print("Tokenizing validation text...")
    val_data = tokenize_1d(tokenizer, val_text, max_seq)
    print("Tokenizing test text...")
    test_data = tokenize_1d(tokenizer, test_text, max_seq)
    print("Saving tokens to: ", TOKENS_FILENAME)
    t.save((train_data, val_data, test_data), TOKENS_FILENAME)

Tokenizing training text...
Training data shape is:  torch.Size([18585, 128])
Tokenizing validation text...
Tokenizing test text...
Saving tokens to:  ./data/wikitext_tokens_2.pt


In [12]:
seq_count, sequence_len = tkns.shape
mask_list = []
for _ in range(seq_count):
    permuted_indices = t.randperm(sequence_len)
    mask_list.append(permuted_indices < 0.15 * sequence_len)

In [13]:
mask_list[78]

tensor([ True, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False,
         True,  True, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False,  True, False,
         True, False,  True, False, False, False,  True, False, False, False,
         True,  True, False, False, False, False,  True,  True])

In [14]:
mask_vals = t.rand(tkns.shape)
mask_vals[mask_vals < 0.10] = 2.0
mask_vals[mask_vals < 0.90] = 99999
mask_vals[mask_vals < 1.0] = t.randint(0, 10000, (1,))
mask_vals = t.where(mask_vals==2, tkns, mask_vals)

In [15]:
mask_vals[0]

tensor([  101., 99999., 99999., 99999., 99999., 17758., 99999.,  6131., 99999.,
         6131.,  1185., 99999.,  3781.,  3464.,  6131., 99999., 99999.,  6131.,
        99999.,  6131.,  6131., 99999., 99999., 99999.,   100., 99999., 99999.,
        99999.,  6131., 99999., 99999., 99999., 99999., 99999., 99999., 99999.,
        99999., 99999.,   114., 99999., 99999., 99999., 99999.,  1112., 99999.,
         3781.,  6131., 99999.,  6131., 99999., 99999.,   117., 99999., 99999.,
        99999., 99999., 99999., 99999., 99999., 99999., 99999.,  6131., 99999.,
        99999., 99999., 99999., 99999.,   119., 99999.,  1111.,  1103.,  6131.,
        99999.,  6131., 99999., 11930.,  6131., 99999., 99999., 99999., 99999.,
        99999., 99999., 99999., 99999., 99999., 99999., 99999., 99999., 99999.,
        99999.,  6131., 99999., 99999., 99999., 99999., 99999., 99999., 99999.,
        99999., 99999.,  1104., 99999., 99999., 99999., 99999.,   118.,  6131.,
        99999., 99999., 99999.,  6131., 

In [58]:

def make_mask_values(input_ids, mask_token_id, vocab_size, msk_pct=0.8, rand_pct=0.1, identical_pct=0.1):
    mask_vals = t.rand(input_ids.shape)
    mask_vals = t.where(mask_vals<rand_pct, input_ids, mask_vals)
    mask_vals[mask_vals < rand_pct + msk_pct] = mask_token_id
    mask_vals[mask_vals < rand_pct + msk_pct + identical_pct] = t.randint(0, vocab_size, (1,))
    
    return mask_vals

def random_mask(
    input_ids: t.Tensor, mask_token_id: int, vocab_size: int, select_frac=0.15, 
    mask_frac=0.8, random_frac=0.1
) -> tuple[t.Tensor, t.Tensor]:
    '''Given a batch of tokens, return a copy with tokens replaced according to Section 3.1 of the paper.

    input_ids: (batch, seq)

    Return: (model_input, was_selected) where:

    model_input: (batch, seq) - a new Tensor with the replacements made, suitable for passing to the 
    BertLanguageModel. Don't modify the original tensor!

    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 
    otherwise
    '''
    device = input_ids.device
    seq_count, sequence_len = input_ids.shape
    #flat_input = rearrange(input_ids, 'batch seq -> (batch seq)')
    permuted_indices = t.randperm(seq_count * sequence_len)
    mask = permuted_indices < select_frac * input_ids.numel()
    mask = rearrange(mask, '(batch seq) -> batch seq', seq=sequence_len)

    mask_vals = make_mask_values(input_ids, mask_token_id, vocab_size, mask_frac, random_frac)
    masked_input = t.where(mask, mask_vals, input_ids)

    #alt approach
    flat_input = rearrange(input_ids, 'batch seq -> (batch seq)')
    permuted_indices = t.randperm(input_ids.numel())
    
    mask = permuted_indices < (1 - select_frac) * input_ids.numel()
    mask_tkn = permuted_indices < (1 - select_frac * (mask_frac + random_frac)) * input_ids.numel()
    mask_rand = permuted_indices < (1 - select_frac * random_frac * input_ids.numel())
    print(input_ids.numel())
    print(input_ids.shape)
    print(mask.long().sum())
    print(mask.long().sum() / len(mask))
    print(mask_tkn.long().sum() / len(mask_tkn))
    print(mask_rand.long().sum() / len(mask_rand))

    return masked_input.to(device), mask.long().to(device)

if MAIN:
    utils.test_random_mask(random_mask, input_size=10000, max_seq=max_seq)


Testing empirical frequencies
1280000
torch.Size([10000, 128])
tensor(192000)
tensor(0.1500)
tensor(0.2775)
tensor(0.)
Checking fraction of tokens selected...


AssertionError: Scalars are not close!

Absolute difference: 0.6999999940395355 (up to 1e-05 allowed)
Relative difference: 0.8235294047523948 (up to 0 allowed)

In [59]:
def random_mask(
    input_ids: t.Tensor, mask_token_id: int, vocab_size: int, select_frac=0.15, mask_frac=0.8, random_frac=0.1
) -> tuple[t.Tensor, t.Tensor]:
    '''Given a batch of tokens, return a copy with tokens replaced according to Section 3.1 of the paper.

    input_ids: (batch, seq)

    Return: (model_input, was_selected) where:

    model_input: (batch, seq) - a new Tensor with the replacements made, suitable for passing to the BertLanguageModel. Don't modify the original tensor!

    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise
    '''

    input_ids_modified = input_ids.clone()

    # Create masks
    mask_seed = t.randperm(input_ids.numel()).reshape(input_ids.shape).to(input_ids.device)

    threshold_probabilities = t.tensor([
        0,
        select_frac * mask_frac,
        select_frac * (mask_frac + random_frac),
        select_frac
    ])
    threshold_values = input_ids.numel() * threshold_probabilities

    fill_values = [mask_token_id, input_ids.clone().random_(vocab_size)]
    for threshold_lower, threshold_higher, fill_value in zip(threshold_values[0:2], threshold_values[1:3], fill_values):
        input_ids_modified = t.where(
            (threshold_lower <= mask_seed) & (mask_seed < threshold_higher),
            fill_value,
            input_ids_modified
        )

    return input_ids_modified, mask_seed < threshold_values[-1]

if MAIN:
    utils.test_random_mask(random_mask, input_size=10000, max_seq=max_seq)

Testing empirical frequencies
Checking fraction of tokens selected...
Checking fraction of tokens masked...
Checking fraction of tokens masked OR randomized...


In [21]:
# Find the word frequencies
word_frequencies = t.bincount(train_data.flatten())
# Drop the words with occurrence zero (because these contribute zero to cross entropy)
word_frequencies = word_frequencies[word_frequencies > 0]
# Get probabilities
word_probabilities = word_frequencies / word_frequencies.sum()
# Calculate the cross entropy
cross_entropy = (- word_probabilities * word_probabilities.log()).sum()
print(cross_entropy)
# ==> 7.3446

tensor(7.3446)


In [22]:
def cross_entropy_selected(pred: t.Tensor, target: t.Tensor, was_selected: t.Tensor) -> t.Tensor:
    '''
    pred: (batch, seq, vocab_size) - predictions from the model
    target: (batch, seq, ) - the original (not masked) input ids
    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise

    Out: the mean loss per predicted token
    '''
    target = t.where(was_selected.bool(), target, -100) #-100 will be ignored
    target = rearrange(target, 'batch seq ... -> (batch seq) ...')
    pred = rearrange(pred, 'batch seq ... -> (batch seq) ...')
    loss = F.cross_entropy(pred, target.long(), ignore_index=-100)
    #print(was_selected.sum())
    #loss = loss / was_selected.sum()
    #print(loss)
    #print(loss / was_selected.sum())
    return loss

if MAIN:
    utils.test_cross_entropy_selected(cross_entropy_selected)

    batch_size = 8
    seq_length = 512
    batch = t.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
    pred = t.rand((batch_size, seq_length, tokenizer.vocab_size))
    (masked, was_selected) = random_mask(batch, tokenizer.mask_token_id, tokenizer.vocab_size)
    loss = cross_entropy_selected(pred, batch, was_selected).item()
    print(f"Random MLM loss on random tokens - does this make sense? {loss:.2f}")

Random MLM loss on random tokens - does this make sense? 10.31


In [23]:
import sys 
sys.path.append('../common_modules')

In [24]:
from transformer_modules import Dropout, LayerNorm, MLP, TransformerConfig, Embedding, GELU
from general_modules import Linear
from bert_modules import BERTLanguageModel

In [25]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

hidden_size = 512
bert_config_tiny = TransformerConfig(
    num_layers = 8,
    num_heads = hidden_size // 64,
    vocab_size = 28996,
    hidden_size = hidden_size,
    max_seq_len = 128,
    dropout = 0.1,
    layer_norm_epsilon = 1e-12
)

config_dict = dict(
    lr=0.0002,
    epochs=40,
    batch_size=128,
    weight_decay=0.01,
    mask_token_id=tokenizer.mask_token_id,
    warmup_step_frac=0.01,
    eps=1e-06,
    max_grad_norm=None,
)

(train_data, val_data, test_data) = t.load("./data/wikitext_tokens_2.pt")
print("Training data size: ", train_data.shape)

train_loader = DataLoader(
    TensorDataset(train_data), shuffle=True, batch_size=config_dict["batch_size"], drop_last=True
)

val_loader = DataLoader(
    TensorDataset(val_data), shuffle=False, batch_size=config_dict["batch_size"], drop_last=True
)

test_loader = DataLoader(
    TensorDataset(test_data), shuffle=False, batch_size=config_dict["batch_size"], drop_last=True
)

Training data size:  torch.Size([18585, 128])


In [26]:
test_dataset = TensorDataset(train_data)
test_dataset.__getitem__(2)[0]

tensor([ 1104,  5094,  8630,  1103,  5444,   119,  1109,  1342,   112,   188,
         2280,  3815,  1108,  7399,  1118,  1318,   112,   183,   119,  1135,
         1899,  1114,  3112,  3813,  1107,  1999,   117,  1105,  1108,  5185,
         1118,  1241,  1983,  1105,  2466,  4217,   119,  1258,  1836,   117,
         1122,  1460,  9133,  1895,  3438,   117,  1373,  1114,  1126,  3631,
         2596,  1107,  1379,  1104,  1115,  1214,   119,  1135,  1108,  1145,
         5546,  1154,  9675,  1105,  1126,  1560,  1888,  8794,  1326,   119,
         4187,  1106,  1822,  3813,  1104, 12226,  3781,  3464, 17758,  1563,
          117, 12226,  3781,  3464, 17758,  2684,  1108,  1136, 25813,   117,
         1133,   170,  5442,  5179, 12173,  1114,  1103,  1342,   112,   188,
         3631,  2596,  1108,  1308,  1107,  1387,   119,  3957,   119, 11724,
         1156,  1862,  1106,  1103,  5801,  1114,  1103,  1718,  1104, 12226,
         3781,  3464,   131,   138, 26395,  4543,  1111,  1103])

In [27]:
def lr_for_step(step: int, max_step: int, max_lr: float, warmup_step_frac: float):
    '''Return the learning rate for use at this step of training.'''
    initial_lr = 0.1 * max_lr
    warmup_steps = max_step * warmup_step_frac
    cooldown_steps = max_step - warmup_steps

    if step < warmup_steps:
        return initial_lr + step * ((max_lr - initial_lr) / warmup_steps)
    else:
        return max_lr - step * ((max_lr - initial_lr) / cooldown_steps)



if MAIN:
    max_step = int(len(train_loader) * config_dict["epochs"])
    lrs = [
        lr_for_step(step, max_step, max_lr=config_dict["lr"], warmup_step_frac=config_dict["warmup_step_frac"])
        for step in range(max_step)
    ]
    import plotly.express as px
    import pandas as pd

    df = pd.DataFrame(dict(
        y = lrs,
        x = [s for s, lr in enumerate(lrs)]
    ))
    fig = px.line(df, x="x", y="y", title="Learning Rates") 
    fig.show()

In [28]:
def make_optimizer(model: BERTLanguageModel, config_dict: dict) -> t.optim.AdamW:
    '''
    Loop over model parameters and form two parameter groups:

    - The first group includes the weights of each Linear layer and uses the weight decay in config_dict
    - The second has all other parameters and uses weight decay of 0
    '''
    param_list = list(model.named_parameters())

    param_groups = [
        {"params":[], "weight_decay":config_dict["weight_decay"]},
        {"params":[], "weight_decay":0.0}
        ]

    for name, param in param_list:
        name = name.split('.')
        #print(name)
        if "weight" in name and ("attn" in name or "mlp" in name or "linear" in name):
            param_groups[0]["params"].append(param)
        else:
            param_groups[1]["params"].append(param)

    return t.optim.AdamW(param_groups, lr=config_dict["lr"], eps=config_dict["eps"])

if MAIN:
    test_config = TransformerConfig(
        num_layers = 3,
        num_heads = 1,
        vocab_size = 28996,
        hidden_size = 1,
        max_seq_len = 4,
        dropout = 0.1,
        layer_norm_epsilon = 1e-12,
    )

    optimizer_test_model = BERTLanguageModel(test_config)
    opt = make_optimizer(
        optimizer_test_model, 
        dict(weight_decay=0.1, lr=0.0001, eps=1e-06)
    )
    expected_num_with_weight_decay = test_config.num_layers * 6 + 1
    wd_group = opt.param_groups[0]
    actual = len(wd_group["params"])
    assert (
        actual == expected_num_with_weight_decay
    ), f"Expected 6 linear weights per layer (4 attn, 2 MLP) plus the final lm_linear weight to have weight decay, got {actual}"
    all_params = set()
    for group in opt.param_groups:
        all_params.update(group["params"])
    assert all_params == set(optimizer_test_model.parameters()), "Not all parameters were passed to optimizer!"

In [29]:
from typing import List


def predict(model, tokenizer, text: str, k=15) -> List[List[str]]:
    '''
    Return a list of k strings for each [MASK] in the input.
    '''
    """
    Return a list of k strings for each [MASK] in the input.
    """
    model.eval()
    tokens = tokenizer.encode(text=text, return_tensors="pt")
    res = model(tokens)
    
    mask_predictions = []
    for n, input_id in enumerate(tokens.squeeze()):
        if input_id == tokenizer.mask_token_id:
            logits = res[0, n]
            top_logits_indices = t.topk(logits, k).indices
            predictions = tokenizer.decode(top_logits_indices)
            mask_predictions.append(predictions)
    
    return mask_predictions

In [30]:
item = next(iter(train_loader))
item[0].shape

torch.Size([128, 128])

In [34]:
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

def bert_mlm_pretrain(model: BERTLanguageModel, config_dict: dict, train_loader: DataLoader) -> None:
    '''Train using masked language modelling.'''

    optimizer = make_optimizer(model, config_dict)
    loss_fn = cross_entropy_selected
    loss_list = []
    model.train()
    step_count = 0

    for epoch in range(config_dict["epochs"]):
        
        progress_bar = tqdm_notebook(train_loader)
        step_count = 0
        for (x,) in progress_bar:
            step_count += 1
            #print(x.shape)
            x = x.to(device)
            y = x.detach().clone().to(device)
            x, mask = random_mask(x, config_dict["mask_token_id"], bert_config_tiny.vocab_size)
            
            optimizer.zero_grad()
            
            logits = model(x)
            # logits dimensions are (batch, seq, digits), but we care about probabilities for each digit
            # so we need to reshape into (batch * seq, digits)
            #loss = loss_fn(rearrange(logits, "b s d -> (b s) d"), y.flatten(), mask)
            loss = loss_fn(logits, y, mask)
            loss.backward()

            optimizer.step()
            step_lr = lr_for_step(step_count, config_dict["epochs"] * 145, 1e-4, 0.01)
            for g in optimizer.param_groups:
                g['lr'] = step_lr
            
            progress_bar.set_description(f"epoch = {epoch+1}, loss = {loss.item():.4f}")

            loss_list.append(loss.item())
    
    return model

if MAIN:
    model = BERTLanguageModel(bert_config_tiny)
    num_params = sum((p.nelement() for p in model.parameters()))
    print("Number of model parameters: ", num_params)
    bert_mlm_pretrain(model, config_dict, train_loader)

Number of model parameters:  40425284


  0%|          | 0/145 [00:00<?, ?it/s]

torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([128, 128])


KeyboardInterrupt: 

In [None]:
if MAIN:
    model = BERTLanguageModel(bert_config_tiny)
    model.load_state_dict(t.load(config_dict["filename"]))
    your_text = "The Answer to the Ultimate Question of Life, The Universe, and Everything is [MASK]."
    predictions = predict(model, tokenizer, your_text)
    print("Model predicted: \n", "\n".join(map(str, predictions)))