In [49]:
import hashlib
import os
import sys
import zipfile
import torch as t
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import transformers
from einops import rearrange
from torch.nn import functional as F
from tqdm import tqdm
import requests
import utils

MAIN = __name__ == "__main__"
DATA_FOLDER = "./data"
DATASET = "103"
BASE_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
DATASETS = {"103": "wikitext-103-raw-v1.zip", "2": "wikitext-2-raw-v1.zip"}
TOKENS_FILENAME = os.path.join(DATA_FOLDER, f"wikitext_tokens_{DATASET}.pt")

if not os.path.exists(DATA_FOLDER):
    os.mkdir(DATA_FOLDER)

In [50]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [51]:
def maybe_download(url: str, path: str) -> None:
    """Download the file from url and save it to path. If path already exists, do nothing."""
    if not os.path.exists(path):
        with open(path, "wb") as file:
            data = requests.get(url).content
            file.write(data)

In [52]:
path = os.path.join(DATA_FOLDER, DATASETS[DATASET])
maybe_download(BASE_URL + DATASETS[DATASET], path)
expected_hexdigest = {"103": "0ca3512bd7a238be4a63ce7b434f8935", "2": "f407a2d53283fc4a49bcff21bc5f3770"}
with open(path, "rb") as f:
    actual_hexdigest = hashlib.md5(f.read()).hexdigest()
    assert actual_hexdigest == expected_hexdigest[DATASET]

print(f"Using dataset WikiText-{DATASET} - options are 2 and 103")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

z = zipfile.ZipFile(path)

def decompress(*splits: str) -> str:
    return [
        z.read(f"wikitext-{DATASET}-raw/wiki.{split}.raw").decode("utf-8").splitlines()
        for split in splits
    ]

train_text, val_text, test_text = decompress("train", "valid", "test")

Using dataset WikiText-103 - options are 2 and 103


In [5]:
train_text[100:110]

[' 96 ammunition packing boxes ',
 ' Repaired : ',
 ' 2 @,@ 236 shotguns and rifles ( repaired mostly for troops in service ) ',
 ' 23 pistols ( repaired mostly for troops in service ) ',
 ' Received & Issued : ',
 ' 752 packages of ordnance and ordnance stores received and mostly issued to troops in service . ',
 ' Repaired and painted : ',
 ' 4 gun carriages ',
 ' Performed : ',
 ' Guard , office , and police duties . ']

In [6]:
test_text[100:110]

[' Du Fu \'s popularity grew to such an extent that it is as hard to measure his influence as that of Shakespeare in England : it was hard for any Chinese poet not to be influenced by him . While there was never another Du Fu , individual poets followed in the traditions of specific aspects of his work : Bai Juyi \'s concern for the poor , Lu You \'s patriotism , and Mei Yaochen \'s reflections on the quotidian are a few examples . More broadly , Du Fu \'s work in transforming the lǜshi from mere word play into " a vehicle for serious poetic utterance " set the stage for every subsequent writer in the genre . ',
 ' In the 20th century , he was the favourite poet of Kenneth Rexroth , who has described him as " the greatest non @-@ epic , non @-@ dramatic poet who has survived in any language " , and commented that , " he has made me a better man , as a moral agent and as a perceiving organism " . ',
 ' ',
 ' = = = Influence on Japanese literature = = = ',
 ' ',
 " Du Fu 's poetry has ma

In [7]:
def tokenize_1d(tokenizer, lines: list[str], max_seq: int) -> t.Tensor:
    '''Tokenize text and rearrange into chunks of the maximum length.

    Return (batch, seq) and an integer dtype.
    '''
    def flatten(l):
        return [item for sublist in l for item in sublist]
    
    input_ids = flatten(tokenizer(lines, 
                truncation = False, 
                padding=False, 
                return_attention_mask=False,
                return_token_type_ids=False)["input_ids"])

    truncation_length = len(input_ids) % max_seq
    input_ids = t.tensor(input_ids[:-truncation_length]).to(t.int)

    input_ids = rearrange(input_ids,'(b s) -> b s', s= max_seq)

    return input_ids

if MAIN:
    max_seq = 128
    print("Tokenizing training text...")
    train_data = tokenize_1d(tokenizer, train_text, max_seq)
    print("Training data shape is: ", train_data.shape)
    print("Tokenizing validation text...")
    val_data = tokenize_1d(tokenizer, val_text, max_seq)
    print("Tokenizing test text...")
    test_data = tokenize_1d(tokenizer, test_text, max_seq)
    print("Saving tokens to: ", TOKENS_FILENAME)
    t.save((train_data, val_data, test_data), TOKENS_FILENAME)

Tokenizing training text...


Token indices sequence length is longer than the specified maximum sequence length for this model (686 > 512). Running this sequence through the model will result in indexing errors


Training data shape is:  torch.Size([19159, 128])
Tokenizing validation text...
Tokenizing test text...
Saving tokens to:  ./data/wikitext_tokens_2.pt


In [66]:
import random

def random_mask(
    input_ids: t.Tensor, mask_token_id: int, vocab_size: int, select_frac=0.15, mask_frac=0.8, random_frac=0.1
) -> tuple[t.Tensor, t.Tensor]:
    '''Given a batch of tokens, return a copy with tokens replaced according to Section 3.1 of the paper.

    input_ids: (batch, seq)

    Return: (model_input, was_selected) where:

    model_input: (batch, seq) - a new Tensor with the replacements made, suitable for passing to the BertLanguageModel. Don't modify the original tensor!

    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise
    '''
    input_ids = input_ids.clone()
    seq_len= input_ids.shape[-1]
    input_ids = rearrange(input_ids, 'b s -> (b s)')
    n = len(input_ids)

    # choose which positions to affect
    masked_positions = random.sample(range(n), k = int(select_frac*n))
    mask_token_positions = masked_positions[:int(mask_frac*len(masked_positions))]
    random_token_positions = masked_positions[int(mask_frac*len(masked_positions)):int((mask_frac+random_frac)*len(masked_positions))]
    leave_token_positions = masked_positions[int((mask_frac+random_frac)*len(masked_positions)):]

    # mask each
    input_ids[mask_token_positions] = mask_token_id
    input_ids[random_token_positions] = t.tensor(random.sample(range(vocab_size), k = len(random_token_positions))).to(t.long)

    # get was selected 
    mask = t.zeros(n)
    mask[masked_positions] = 1

    # rearrange before returning
    input_ids = rearrange(input_ids, '(b s) -> b s', s = seq_len)
    was_selected = rearrange(mask, '(b s) -> b s', s = seq_len)

    return input_ids, was_selected
    
if MAIN:
    utils.test_random_mask(random_mask, input_size=1000, max_seq=max_seq)

Testing empirical frequencies
Checking fraction of tokens selected...
Checking fraction of tokens masked...
Checking fraction of tokens masked OR randomized...
0.7225000262260437 0.7224999999999999


In [14]:
import math 
math.log(28996)

10.274913168420769

In [32]:
def flatten(l):
    return [item for sublist in l for item in sublist]

if MAIN:
    "TODO: YOUR CODE HERE, TO CALCULATE CROSS ENTROPY OF UNIGRAM FREQUENCIES"
    word_frequencies = t.bincount(train_data.flatten())
    word_frequencies = t.bincount(train_data.flatten())
    word_frequencies = word_frequencies[word_frequencies>0]
    word_probabilities= word_frequencies / word_frequencies.sum()
    cross_entropy = (- word_probabilities* word_probabilities.log()).sum()
    print(cross_entropy)

tensor(7.2800)


In [46]:
def cross_entropy_selected(pred: t.Tensor, target: t.Tensor, was_selected: t.Tensor) -> t.Tensor:
    '''
    pred: (batch, seq, vocab_size) - predictions from the model
    target: (batch, seq, ) - the original (not masked) input ids
    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise

    Out: the mean loss per predicted token
    '''
    target = t.where(was_selected.to(t.bool), target, -100)
    entropy = nn.functional.cross_entropy(
        rearrange(pred, "b s ... -> (b s) ..."), 
        rearrange(target, "b s ... -> (b s) ...")
    )
    return entropy

if MAIN:
    utils.test_cross_entropy_selected(cross_entropy_selected)

    batch_size = 8
    seq_length = 512
    batch = t.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
    pred = t.rand((batch_size, seq_length, tokenizer.vocab_size))
    (masked, was_selected) = random_mask(batch, tokenizer.mask_token_id, tokenizer.vocab_size)
    loss = cross_entropy_selected(pred, batch, was_selected).item()
    print(f"Random MLM loss on random tokens - does this make sense? {loss:.2f}")

Random MLM loss on random tokens - does this make sense? 10.33


In [54]:
from src.transformers import TransformerConfig

tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

hidden_size = 512
bert_config_tiny = TransformerConfig(
    num_layers = 8,
    num_heads = hidden_size // 64,
    vocab_size = 28996,
    hidden_size = hidden_size,
    max_seq_len = 128,
    dropout = 0.1,
    layer_norm_epsilon = 1e-12
)

config_dict = dict(
    lr=0.0002,
    epochs=40,
    batch_size=128,
    weight_decay=0.01,
    mask_token_id=tokenizer.mask_token_id,
    warmup_step_frac=0.01,
    eps=1e-06,
    max_grad_norm=None,
)

(train_data, val_data, test_data) = t.load("./data/wikitext_tokens_2.pt")
print("Training data size: ", train_data.shape)

train_loader = DataLoader(
    TensorDataset(train_data), shuffle=True, batch_size=config_dict["batch_size"], drop_last=True
)

Training data size:  torch.Size([19159, 128])


In [64]:
import plotly.express as px 

def lr_for_step(step: int, max_step: int, max_lr: float, warmup_step_frac: float):
    '''Return the learning rate for use at this step of training.'''
    delta =  max_step*warmup_step_frac
    if step < delta:
        return max_lr*(step/delta) # when step == delta, reach max_lr
    else: 
        return max_lr - max_lr*(step/(max_step-delta)) # when step = max_step-delta, lr = 0


if MAIN:
    max_step = int(len(train_loader) * config_dict["epochs"])
    lrs = [
        lr_for_step(step, max_step, max_lr=config_dict["lr"], warmup_step_frac=config_dict["warmup_step_frac"])
        for step in range(max_step)
    ]
    # TODO: YOUR CODE HERE, PLOT `lrs` AND CHECK IT RESEMBLES THE GRAPH ABOVE
    fig = px.line(lrs)
    fig.add_hline(y= config_dict["lr"], annotation_text = "max lr")
    fig.add_vline(x = max_step, annotation_text="max step")
    fig.show()

In [69]:
from build_bert import BERTLanguageMODEL as BertLanguageModel

def make_optimizer(model: BertLanguageModel, config_dict: dict) -> t.optim.AdamW:
    '''
    Loop over model parameters and form two parameter groups:

    - The first group includes the weights of each Linear layer and uses the weight decay in config_dict
    - The second has all other parameters and uses weight decay of 0
    '''
    weights = {k:v for k,v in model.named_parameters if "weight" in k}
    biases = {k:v for k,v in model.named_parameters if "weight" not in k}
    parameter_groups = [
        {'params': weights, 'weight_decay': config_dict[""]},
        {'params': biases, 'weight_decay': 0.00}
    ]
    

if MAIN:
    test_config = TransformerConfig(
        num_layers = 3,
        num_heads = 1,
        vocab_size = 28996,
        hidden_size = 1,
        max_seq_len = 4,
        dropout = 0.1,
        layer_norm_epsilon = 1e-12,
    )

    optimizer_test_model = BertLanguageModel(test_config)
    opt = make_optimizer(
        optimizer_test_model, 
        dict(weight_decay=0.1, lr=0.0001, eps=1e-06)
    )
    expected_num_with_weight_decay = test_config.num_layers * 6 + 1
    wd_group = opt.param_groups[0]
    actual = len(wd_group["params"])
    assert (
        actual == expected_num_with_weight_decay
    ), f"Expected 6 linear weights per layer (4 attn, 2 MLP) plus the final lm_linear weight to have weight decay, got {actual}"
    all_params = set()
    for group in opt.param_groups:
        all_params.update(group["params"])
    assert all_params == set(optimizer_test_model.parameters()), "Not all parameters were passed to optimizer!"

AttributeError: 'NoneType' object has no attribute 'param_groups'

In [79]:
from build_bert import BERTLanguageMODEL
from torchinfo import summary 

tiny_bert = BERTLanguageMODEL(bert_config_tiny)
#summary(tiny_bert, depth = 10)
for i in dict(tiny_bert.named_parameters()).keys():
    if "weight" in i:
        print(i)


bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.encoder.layer.0.attention.query.weight
bert.encoder.layer.0.attention.key.weight
bert.encoder.layer.0.attention.value.weight
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.1.attention.query.weight
bert.encoder.layer.1.attention.key.weight
bert.encoder.layer.1.attention.value.weight
bert.encoder.layer.1.attention.output.dense.weight
bert.encoder.layer.1.attention.output.LayerNorm.weight
bert.encoder.layer.1.intermediate.dense.weight
bert.encoder.layer.1.output.dense.weight
bert.encoder.layer.1.output.LayerNorm.weight
bert.encoder.layer.2.attention.query.weight
bert.encoder.layer.2.attention.key.weight
bert.encoder.lay