In [None]:
import torch
from leap import LeapForCausalLM, LeapConfig
from transformers import TrainingArguments, Trainer, default_data_collator
from datasets import load_dataset

from itertools import chain

In [None]:
# t5 tokenzier, warning is nothing to worry about since we will group the texts
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [None]:
raw_datasets = load_dataset("wikitext", "wikitext-2-v1")
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
block_size = 2048

def tokenize_function(examples):
    output = tokenizer(examples[text_column_name])
    return output

def group_texts(examples):
    # concatenate text
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # drop last block
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size

    # split by chunks of block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=1,
            remove_columns=column_names,
        )

lm_dataset = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=1,
    desc=f"Grouping texts in chunks of {block_size}",
)

lm_dataset.set_format('pt')

In [None]:
# hyperparameters
training_args = TrainingArguments(
    output_dir = "./results",
    logging_strategy = "epoch",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    report_to = "none",
    learning_rate = 5e-4, 
    num_train_epochs = 10,
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    max_grad_norm = 1,
    fp16 = True
)

# 1. LEAP Transformer (with windowing)

In [None]:
config = LeapConfig(
    hidden_size = 128,
    vocab_size = len(tokenizer),
    n_positions = block_size,
    n_heads = 4,
    n_layer = 6,
    use_local_att = True,
    window_sizes = None, # will be set automatically
    hidden_dropout_prob = .1,
    rescale = 10
)
window_model = LeapForCausalLM(config)

In [None]:
pytorch_total_params = sum(p.numel() for p in window_model.parameters() if p.requires_grad)
f'{(pytorch_total_params / 1e6):2.1f}M'

In [None]:
window_trainer = Trainer(
    model=window_model,
    args=training_args,
    data_collator=default_data_collator,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["validation"]
)

window_trainer.train()

In [None]:
# free gpu memory
del window_trainer
window_model.cpu()
torch.cuda.empty_cache()

## Test code to try some text generation

In [None]:
# get a test datapoint
test_number = 0

input_ids = lm_dataset["test"][test_number]["input_ids"].reshape(1, -1).cpu()
attention_mask = lm_dataset["test"][test_number]["attention_mask"].reshape(1, -1).cpu()
labels = lm_dataset["test"][test_number]["labels"].reshape(1, -1).cpu()

In [None]:
# put model on cpu
test_model = window_model
test_model = test_model.cpu().eval()

In [None]:
# show prompt
tokenizer.batch_decode(input_ids[:,:50])

In [None]:
# generate the text
generated_tokens = test_model.generate(input_ids[:,:50], attention_mask = attention_mask, do_sample=True, max_length=256, temperature = .7)
tokenizer.batch_decode(generated_tokens)