In [None]:
import jax
import numpy as np
import torch
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from transformers import AutoTokenizer
from datasets import load_dataset

from wheeljax.model import TransformerLM
from wheeljax.train import CollatorForCausalLM, LMTrainer

In [None]:
# list jax devices
print(jax.devices())

# catches XLA ptxas<11.8 error on GPU
jax.numpy.zeros((2, 2))

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# load the dataset
d = load_dataset("huanggab/reddit_haiku", data_files={'test':'merged_with_keywords.csv'})  # use data_files or it will result in error

# we will use test to compute the test perplexity
train_test_ratio = 0.8
d['test'] = d['test'].train_test_split(test_size=1-train_test_ratio)

# Now you have the train and test datasets
train_dataset = d['test']['train']
test_dataset = d['test']['test']


In [None]:
train_dataset, test_dataset

In [None]:
# plot the distribution of the lengths of the sequences
lengths = []

for row in train_dataset:
    lengths.append(len(tokenizer(row['processed_title'])['input_ids']))

plt.hist(lengths, bins=100, log=True, label=f"train (max={max(lengths)})")
plt.legend()
plt.show()

In [None]:
# get a max length that covers the vast majority of the data
quantile = 0.999
max_length = int(np.round(np.quantile(lengths, 0.999)))
print(f"{quantile * 100: .2f}% of lengths are <= {max_length}")

In [None]:
train_dataset = train_dataset.map(
    lambda x: tokenizer(
        x['processed_title'], 
        padding='max_length',
        truncation=True,
        max_length=max_length
    ), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"],
)

test_dataset = test_dataset.map(
    lambda x: tokenizer(x['processed_title'], padding='max_length', truncation=True), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"]
)

In [None]:
list(train_dataset[0].keys())
print(len(train_dataset[0]['input_ids']))

In [None]:
random_seed = 42
batch_size = 32

rng = torch.Generator()
rng.manual_seed(random_seed)

collator = CollatorForCausalLM(tokenizer)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    generator=rng,
    collate_fn=collator
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collator,
    shuffle=False
)

In [None]:
for batch in train_loader:
    break

print(list(batch.keys()))

In [None]:
print(tokenizer.vocab_size)

In [None]:
model = TransformerLM(
    dim_feedforward=32,
    model_dim=64,
    num_heads=4,
    num_encoder_layers=4,
    vocab_size=tokenizer.vocab_size
)

In [None]:
n_epochs = 2

trainer = LMTrainer(
    model, 
    example_batch=batch, 
    max_iters=len(train_loader) * n_epochs,
    report_to="wandb"
)

In [None]:
trainer.train(n_epochs, train_loader, val_loader=test_loader)

In [None]:
input_text = "<|endoftext|> an orange and an apple walk into a"
input_tokens = tokenizer.encode(input_text, return_tensors="jax")

tokens = model.generate(
    trainer.state.params,
    input_tokens=input_tokens,
    rng_key=jax.random.PRNGKey(42),
)

In [None]:
tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True)