In [None]:
import torch
from leap import LeapForCausalLM, LeapConfig
from transformers import (TrainingArguments, Trainer,
                          EarlyStoppingCallback, default_data_collator)

from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from torch.utils.data import Subset

# word level tokenizer as per wikitext modeling
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
from transformers import PreTrainedTokenizerFast

# gpt model imports
from transformers import GPT2Config, GPT2LMHeadModel

from itertools import chain
import logging
logging.disable(logging.INFO)

In [None]:
# globals
raw_datasets = load_dataset("wikitext", "wikitext-103-v1", split = ["train[:10%]", "validation", "test"])
raw_datasets = DatasetDict({
    "train": raw_datasets[0],
    "validation": raw_datasets[1],
    "test": raw_datasets[2]
})

total_train_tokens = 10416407 # see appendix
max_num_params = 33800704 # 448 d_model gpt2 with 14 layers
param_data_ratio = max_num_params**.74 / total_train_tokens
block_size = 1024
subset_datasets = raw_datasets

# hyperparameters
training_args = TrainingArguments(
    output_dir = "./results",
    logging_strategy = "epoch",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    report_to = "none",
    learning_rate = 1e-3,
    lr_scheduler_type = "cosine",
    warmup_ratio = .05,
    num_train_epochs = 20,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    max_grad_norm = 1,
    fp16 = True,
)


In [None]:
# make a word level tokenizer
tokenizer = Tokenizer(WordLevel(unk_token="<unk>"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id = 0, pad_token = "<pad>")
# no post processing

# WE USE A SET VOCAB SIZE OF 8,192 FOR SPEED (the oov should only be around 5%)
token_trainer = WordLevelTrainer(vocab_size = 8191, # -1 for pad token
                                 special_tokens = ["<unk>"])

def batch_iterator(batch_size=10000):
    text = raw_datasets["train"]['text']
    for i in range(0, len(text), batch_size):
        yield text[i : i + batch_size]

tokenizer.train_from_iterator(batch_iterator(),
                              trainer = token_trainer,
                              length=len(raw_datasets["train"]["text"]))
tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer, pad_token = "<pad>")

In [None]:
# tokenized the dataset
def tokenize_function(examples):
    output = tokenizer(examples["text"])
    return output

# tokenize dataset
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns = "text",
    desc=f"tokenize dataset"
)

In [None]:
def subset_data(dataset, num_parameters):
    dataset = DatasetDict(dataset.copy())
    subset_num_tokens = num_parameters**.74 / param_data_ratio
    
    # add rows until we meet the subset_num_tokens
    training_set = dataset["train"]
    total_tokens = 0
    for i, row in enumerate(training_set):
        total_tokens += len(row["input_ids"])
        
        if total_tokens >= subset_num_tokens:
            print(f'NUMBER OF TOKENS: {total_tokens:,}')
            break
            
    dataset["train"] = Dataset.from_dict(training_set[:i+1])
    return dataset

def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    # Split by chunks of max_len
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    
    # for language modeling, inputs are labels (they will be shifted inside the model)
    result["labels"] = result["input_ids"].copy()
    
    # pad last block with 0
    last_ids = result["input_ids"][-1]
    diff = block_size - len(last_ids)
    result["input_ids"][-1] = last_ids + [0 for _ in range(diff)]
    
    # set attention mask to mask out these tokens
    result["attention_mask"][-1] = result["attention_mask"][-1] + [0 for _ in range(diff)]
    
    # set pad labels to -100 so they will be ignored by CrossEntropyLoss
    result["labels"][-1] = result["labels"][-1] + [-100 for _ in range(diff)]
    return result

In [None]:
def run_training(hidden_size, n_layer, n_head, seq_len, gpt = False, rnn = False):
    # get number of parameters first
    if gpt is True:
        config = GPT2Config(
            n_embd = hidden_size, n_layer = n_layer,
            n_head = 1, vocab_size = 0, n_positions = 0
        )
        model = GPT2LMHeadModel(config)
    elif rnn is True:
        config = LeapConfig(
            hidden_size = hidden_size, n_layer = n_layer,
            n_head = 1, vocab_size = 0, n_positions = 0, rnn = True
        )
        model = LeapForCausalLM(config)
    else:
        config = LeapConfig(
            hidden_size = hidden_size, n_layer = n_layer,
            n_head = 1, vocab_size = 0, n_positions = 0
        )
        model = LeapForCausalLM(config)

    non_embedding_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'NON EMBEDDING PARAMETERS: {non_embedding_parameters:,}')

    # subset dataset
    subset_datasets = subset_data(tokenized_datasets, non_embedding_parameters)

    # set globally block size for group texts function
    global block_size
    block_size = seq_len
    lm_dataset = subset_datasets.map(
        group_texts,
        batched=True,
        batch_size=10000,
        desc=f"Grouping texts in chunks of {block_size}"
    )

    lm_dataset = lm_dataset.remove_columns(["token_type_ids"])

    if gpt is True:
        config = GPT2Config(
            n_embd = hidden_size, n_layer = n_layer, n_head = n_head,
            vocab_size = len(tokenizer) + 1, n_positions = seq_len,
            initializer_range = 1 / hidden_size**.5
        )
        model = GPT2LMHeadModel(config)
    elif rnn is True:
        config = LeapConfig(
            hidden_size = hidden_size, n_layer = n_layer, n_head = n_head,
            vocab_size = len(tokenizer) + 1, n_positions = seq_len,
            use_local_att = True, window_sizes = None, rescale = 10,
            initializer_range = 1 / hidden_size**.5, rnn = True
        )
        model = LeapForCausalLM(config)
    else:
        config = LeapConfig(
            hidden_size = hidden_size, n_layer = n_layer, n_head = n_head,
            vocab_size = len(tokenizer) + 1, n_positions = seq_len,
            use_local_att = True, window_sizes = None, rescale = 10,
            initializer_range = 1 / hidden_size**.5,
        )
        model = LeapForCausalLM(config)

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=default_data_collator,
        train_dataset=lm_dataset["train"],
        eval_dataset=lm_dataset["validation"],
        callbacks = [EarlyStoppingCallback]
    )

    trainer.train()

    print("===============TEST SET CROSS ENTROPY LOSS EVALUATION===============")
    print(trainer.evaluate(lm_dataset["test"]))

    # save gpu memory
    del trainer
    del model
    del lm_dataset
    del subset_datasets
    torch.cuda.empty_cache()

# LEAP TRAINING
Each run is done seperately in it's own cell just for easy viewing of logs and in case something goes wrong (OOM errors or training issues)

In [None]:
run_training(hidden_size = 64, n_layer = 2, n_head = 2, seq_len = 256, gpt = False)

In [None]:
run_training(hidden_size = 96, n_layer = 3, n_head = 3, seq_len = 384, gpt = False)

In [None]:
run_training(hidden_size = 128, n_layer = 4, n_head = 4, seq_len = 512, gpt = False)

In [None]:
run_training(hidden_size = 160, n_layer = 5, n_head = 5, seq_len = 640, gpt = False)

In [None]:
run_training(hidden_size = 192, n_layer = 6, n_head = 6, seq_len = 768, gpt = False)

In [None]:
run_training(hidden_size = 256, n_layer = 8, n_head = 8, seq_len = 1024, gpt = False)

In [None]:
run_training(hidden_size = 320, n_layer = 10, n_head = 10, seq_len = 1280, gpt = False)

In [None]:
run_training(hidden_size = 384, n_layer = 12, n_head = 12, seq_len = 1536, gpt = False)

In [None]:
run_training(hidden_size = 448, n_layer = 14, n_head = 14, seq_len = 1792, gpt = False)

In [None]:
run_training(hidden_size = 512, n_layer = 16, n_head = 16, seq_len = 2048, gpt = False)

# LSTM TRAINING

In [None]:
run_training(hidden_size = 64, n_layer = 2, n_head = 2, seq_len = 256, rnn = True)

In [None]:
run_training(hidden_size = 96, n_layer = 3, n_head = 3, seq_len = 384, rnn = True)

In [None]:
run_training(hidden_size = 128, n_layer = 4, n_head = 4, seq_len = 512, rnn = True)

In [None]:
run_training(hidden_size = 160, n_layer = 5, n_head = 5, seq_len = 640, rnn = True)

In [None]:
run_training(hidden_size = 192, n_layer = 6, n_head = 6, seq_len = 768, rnn = True)

In [None]:
run_training(hidden_size = 256, n_layer = 8, n_head = 8, seq_len = 1024, rnn = True)

In [None]:
run_training(hidden_size = 320, n_layer = 10, n_head = 10, seq_len = 1280, rnn = True)

In [None]:
run_training(hidden_size = 384, n_layer = 12, n_head = 12, seq_len = 1536, rnn = True)

# GPT2 TRAINING
Each run is done seperately in it's own cell just for easy viewing of logs and in case something goes wrong (OOM errors or training issues)

In [None]:
run_training(hidden_size = 64, n_layer = 2, n_head = 2, seq_len = 256, gpt = True)

In [None]:
run_training(hidden_size = 96, n_layer = 3, n_head = 3, seq_len = 384, gpt = True)

In [None]:
run_training(hidden_size = 128, n_layer = 4, n_head = 4, seq_len = 512, gpt = True)

In [None]:
run_training(hidden_size = 160, n_layer = 5, n_head = 5, seq_len = 640, gpt = True)

In [None]:
run_training(hidden_size = 192, n_layer = 6, n_head = 6, seq_len = 768, gpt = True)

In [None]:
run_training(hidden_size = 256, n_layer = 8, n_head = 8, seq_len = 1024, gpt = True)

In [None]:
run_training(hidden_size = 320, n_layer = 10, n_head = 10, seq_len = 1280, gpt = True)

In [None]:
# NOTE: THIS WAS RUN ON A GOOGLE COLAB P100 WITH THE CODE SHOWN
run_training(hidden_size = 384, n_layer = 12, n_head = 12, seq_len = 1536, gpt = True)

In [None]:
# NOTE: NOT RUN DUE TO MEMORY AND TIME CONTRAINTS
# run_training(hidden_size = 448, n_layer = 14, n_head = 14, seq_len = 1792, gpt = True)

# Appendix

In [None]:
import re

# to count tokens, comes from https://huggingface.co/docs/tokenizers/components
whitespace_regex = re.compile("\w+|[^\w\s]+")

# get number of tokens
total_tokens = 0
for row in raw_datasets["train"]["text"]:
    total_tokens += len((whitespace_regex.split(row)))
total_tokens