In [None]:

import jax
from transformers import AutoTokenizer
from datasets import load_dataset
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader

from wheeljax.model import TransformerLM
from wheeljax.train import CollatorForCausalLM, LMTrainer

In [None]:
# list jax devices
print(jax.devices())

# catches XLA ptxas<11.8 error on GPU
jax.numpy.zeros((2, 2))

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# load the dataset
d = load_dataset("huanggab/reddit_haiku", data_files={'test':'merged_with_keywords.csv'})  # use data_files or it will result in error

# we will use test to compute the test perplexity
train_test_ratio = 0.1
d['test'] = d['test'].train_test_split(test_size=1-train_test_ratio)

# Now you have the train and test datasets
train_dataset = d['test']['train']
test_dataset = d['test']['test']


In [None]:
train_dataset, test_dataset

In [None]:
from itertools import chain
# plot the distribution of the lengths of the sequences
lengths = []

for row in chain(train_dataset, test_dataset):
    lengths.append(len(tokenizer(row['processed_title'])['input_ids']))

plt.hist(lengths, bins=100)
plt.show()

In [None]:
train_dataset = train_dataset.map(
    lambda x: tokenizer(x['processed_title'], padding='max_length', truncation=True), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"],
)

test_dataset = test_dataset.map(
    lambda x: tokenizer(x['processed_title'], padding='max_length', truncation=True), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"]
)

In [None]:
train_dataset[0]

In [None]:
random_seed = 42
batch_size = 8

rng = torch.Generator()
rng.manual_seed(random_seed)

collator = CollatorForCausalLM(tokenizer)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    generator=rng,
    collate_fn=collator
)

In [None]:
for batch in train_loader:
    break

print(list(batch.keys()))

In [None]:
model = TransformerLM(vocab_size=train_dataset.vocab_size)

In [None]:
trainer = LMTrainer(
    model, 
    example_batch=batch, 
    max_iters=101
)

In [None]:
trainer.train(5, train_loader)