In [3]:
import hashlib
import os
import sys
import zipfile
import torch as t
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import transformers
from einops import rearrange
from torch.nn import functional as F
from tqdm import tqdm
import requests
import utils

MAIN = __name__ == "__main__"
DATA_FOLDER = "./data"
DATASET = "2"
BASE_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
DATASETS = {"103": "wikitext-103-raw-v1.zip", "2": "wikitext-2-raw-v1.zip"}
TOKENS_FILENAME = os.path.join(DATA_FOLDER, f"wikitext_tokens_{DATASET}.pt")

if not os.path.exists(DATA_FOLDER):
    os.mkdir(DATA_FOLDER)


In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [5]:
def maybe_download(url: str, path: str) -> None:
    '''
    Download the file from url and save it to path. 
    If path already exists, do nothing.
    '''
    if not os.path.exists(path):
        print(f"Downloading {url} to {path}")
        with open(path, "wb") as f:
            f.write(requests.get(url).content)


In [6]:
path = os.path.join(DATA_FOLDER, DATASETS[DATASET])
maybe_download(BASE_URL + DATASETS[DATASET], path)
expected_hexdigest = {"103": "0ca3512bd7a238be4a63ce7b434f8935", "2": "f407a2d53283fc4a49bcff21bc5f3770"}
with open(path, "rb") as f:
    actual_hexdigest = hashlib.md5(f.read()).hexdigest()
    assert actual_hexdigest == expected_hexdigest[DATASET]

print(f"Using dataset WikiText-{DATASET} - options are 2 and 103")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

z = zipfile.ZipFile(path)

def decompress(*splits: str) -> str:
    return [
        z.read(f"wikitext-{DATASET}-raw/wiki.{split}.raw").decode("utf-8").splitlines()
        for split in splits
    ]

train_text, val_text, test_text = decompress("train", "valid", "test")

Downloading https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip to ./data/wikitext-2-raw-v1.zip
Using dataset WikiText-2 - options are 2 and 103


In [7]:
print(len(train_text))
train_text[:5]

36718


[' ',
 ' = Valkyria Chronicles III = ',
 ' ',
 ' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . ',
 " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more for

In [8]:
# Call the tokenizer on the list of lines with truncation=False to obtain lists of tokens. 
# These will be of varying length, and you'll notice some are empty due to blank lines.

# Build one large 1D tensor containing all the tokens in sequence
# Reshape the 1D tensor into (batch, sequence).

def tokenize_1d(tokenizer, lines: list[str], max_seq: int) -> t.Tensor:
    '''Tokenize text and rearrange into chunks of the maximum length.

    Return (batch, seq) and an integer dtype.
    '''
    tokens = map(tokenizer.encode, lines) # tokenize
    flattened = (item for sublist in tokens for item in sublist) # flatten
    ten = t.tensor(list(flattened), dtype=t.int) # convert to tensor
    ten = ten[: max_seq * (ten.shape[0] // max_seq)] # truncate to max_seq
    return ten.reshape(-1, max_seq) # reshape to (batch, seq)

if MAIN:
    max_seq = 128
    print("Tokenizing training text...")
    train_data = tokenize_1d(tokenizer, train_text, max_seq)
    print("Training data shape is: ", train_data.shape)
    print("Tokenizing validation text...")
    val_data = tokenize_1d(tokenizer, val_text, max_seq)
    print("Tokenizing test text...")
    test_data = tokenize_1d(tokenizer, test_text, max_seq)
    print("Saving tokens to: ", TOKENS_FILENAME)
    t.save((train_data, val_data, test_data), TOKENS_FILENAME)

Tokenizing training text...


Token indices sequence length is longer than the specified maximum sequence length for this model (686 > 512). Running this sequence through the model will result in indexing errors


Training data shape is:  torch.Size([19159, 128])
Tokenizing validation text...
Tokenizing test text...
Saving tokens to:  ./data/wikitext_tokens_2.pt


In [58]:
def random_mask(
    input_ids: t.Tensor, mask_token_id: int, vocab_size: int, select_frac=0.15, mask_frac=0.8, random_frac=0.1
) -> tuple[t.Tensor, t.Tensor]:
    '''Given a batch of tokens, return a copy with tokens replaced according to Section 3.1 of the paper.

    input_ids: (batch, seq)

    Return: (model_input, was_selected) where:

    model_input: (batch, seq) - a new Tensor with the replacements made, suitable for passing to the BertLanguageModel. Don't modify the original tensor!

    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise


    we must mask 15% of tokens in each sequence
    80% of those must be replaced with the [MASK] token
    The remaining 10% of masked tokens can be replaced with any token from the vocabulary.
    '''
    d_batch, d_seq = input_ids.shape
    n_masking = select_frac * d_seq * d_batch
    n_mask = int(n_masking * mask_frac)
    n_random = int(n_masking * random_frac)
    n_unchanged = int(n_masking * (1 - mask_frac - random_frac))
    # print(f'Out of {d_seq}, masking {n_mask} tokens with [MASK], {n_random} with random tokens, and {n_unchanged} unchanged')
    # print(f'Frequency of [MASK] is {n_mask / d_seq}')
    # print(f'Frequency of random is {n_random / d_seq}')
    # print(f'Frequency of unchanged is {n_unchanged / d_seq}')
    # choose which tokens to mask
    mask = t.zeros(d_batch*d_seq, dtype=t.int)
    mask[:n_mask] = 1
    mask[n_mask:n_mask + n_random] = 2
    mask[n_mask + n_random:n_mask + n_random + n_unchanged] = 3
    mask = mask.reshape(d_batch, d_seq)
    
    # for the input ids replace them with masked token where mask = 1
    # and with random token where mask = 2
    model_input = input_ids.clone()
    model_input[mask == 1] = mask_token_id
    rand = t.randint(0, vocab_size, (d_batch, d_seq))
    model_input = t.where(mask == 2, rand, model_input)
    mask = t.where(mask != 0, 1, 0)

    return model_input, mask
    
# test mask

# print decoded tokens
# print(tokenizer.decode(train_data[:2].flatten().tolist()))
out1, mask1 = random_mask(train_data[:2], tokenizer.mask_token_id, tokenizer.vocab_size, select_frac=0.85, mask_frac=0.1, random_frac=0.8)
# print(tokenizer.decode(out1.flatten().tolist()))


if MAIN:
    utils.test_random_mask(random_mask, input_size=10000, max_seq=max_seq)

Testing empirical frequencies
Checking fraction of tokens selected...
Checking fraction of tokens masked...
Checking fraction of tokens masked OR randomized...


In [59]:
# Find the word frequencies
word_frequencies = t.bincount(train_data.flatten())
# Drop the words with occurrence zero (because these contribute zero to cross entropy)
word_frequencies = word_frequencies[word_frequencies > 0]
# Get probabilities
word_probabilities = word_frequencies / word_frequencies.sum()
# Calculate the cross entropy
cross_entropy = (- word_probabilities * word_probabilities.log()).sum()
print(cross_entropy)
# ==> 7.3446

tensor(7.2800)


In [78]:
def flat(x: t.Tensor) -> t.Tensor:
    """Combines batch and sequence dimensions."""
    return rearrange(x, "b s ... -> (b s) ...")

def cross_entropy_selected(pred: t.Tensor, target: t.Tensor, was_selected: t.Tensor) -> t.Tensor:
    '''
    pred: (batch, seq, vocab_size) - predictions from the model
    target: (batch, seq, ) - the original (not masked) input ids
    was_selected: (batch, seq) - 1 if the token at this index will contribute to the MLM loss, 0 otherwise

    Out: the mean loss per predicted token
    '''
    # select correct predictions from the predictions
    target = t.where(was_selected == 1, target, -100)
    entropy = F.cross_entropy(flat(pred), flat(target))
    return entropy


if MAIN:
    utils.test_cross_entropy_selected(cross_entropy_selected)

    batch_size = 8
    seq_length = 512
    batch = t.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
    pred = t.rand((batch_size, seq_length, tokenizer.vocab_size))
    (masked, was_selected) = random_mask(batch, tokenizer.mask_token_id, tokenizer.vocab_size)
    loss = cross_entropy_selected(pred, batch, was_selected).item()
    print(f"Random MLM loss on random tokens - does this make sense? {loss:.2f}")

Random MLM loss on random tokens - does this make sense? 10.32
