In [1]:

import jax
from transformers import AutoTokenizer
from datasets import load_dataset
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset

from optimusjx.model import TransformerLM
from optimusjx.train import CollatorForCausalLM, LMTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [3]:
tokenizer.pad_token = tokenizer.eos_token

In [4]:
# load the dataset
d = load_dataset("huanggab/reddit_haiku", data_files={'test':'merged_with_keywords.csv'})  # use data_files or it will result in error

# we will use test to compute the test perplexity
train_test_ratio = 0.1
d['test'] = d['test'].train_test_split(test_size=1-train_test_ratio)

# Now you have the train and test datasets
train_dataset = d['test']['train']
test_dataset = d['test']['test']


In [5]:
train_dataset, test_dataset

(Dataset({
     features: ['Unnamed: 0', 'id', 'processed_title', 'ups', 'keywords'],
     num_rows: 1528
 }),
 Dataset({
     features: ['Unnamed: 0', 'id', 'processed_title', 'ups', 'keywords'],
     num_rows: 13753
 }))

In [6]:
from itertools import chain
# plot the distribution of the lengths of the sequences
lengths = []

for row in chain(train_dataset, test_dataset):
    lengths.append(len(tokenizer(row['processed_title'])['input_ids']))

plt.hist(lengths, bins=100)
plt.show()

In [7]:
train_dataset = train_dataset.map(
    lambda x: tokenizer(x['processed_title'], padding='max_length', truncation=True), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"],
)

test_dataset = test_dataset.map(
    lambda x: tokenizer(x['processed_title'], padding='max_length', truncation=True), 
    batched=True,
    remove_columns=['Unnamed: 0', 'processed_title', 'keywords', "ups", "id", "processed_title"]
)


class RandomIntDataset(Dataset):
    def __init__(
        self, 
        seq_len: int,
        vocab_size: int, 
        n_batches: int = 10,
        seed: int = 42
    ) -> None:
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.n_batches = n_batches
        self.rng = jax.random.PRNGKey(seed)
    
    def __getitem__(self, index) -> dict[str, list]:
        if index > self.n_batches - 1:
            raise ValueError("Index larger than length.")
        
        random_vocab = jax.random.randint(
            self.rng, 
            (self.seq_len,), 
            minval=0,
            maxval=self.vocab_size
        ).tolist()

        self.rng, _ = jax.random.split(self.rng, 2)
        
        return {"input_ids": random_vocab}

    def __len__(self) -> int:
        return self.n_batches

train_dataset = RandomIntDataset(3, 4)

In [8]:
train_dataset[0]

{'input_ids': [3, 2, 1]}

In [9]:
random_seed = 42
batch_size = 8

rng = torch.Generator()
rng.manual_seed(random_seed)

collator = CollatorForCausalLM(tokenizer)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    generator=rng,
    collate_fn=collator
)

In [10]:
for batch in train_loader:
    break

print(batch)

{'inputs': Array([[2, 2, 1],
       [1, 3, 2],
       [1, 0, 1],
       [1, 2, 3],
       [0, 3, 3],
       [1, 0, 3],
       [3, 0, 1],
       [2, 1, 1]], dtype=int32), 'labels': Array([[2, 2, 1],
       [1, 3, 2],
       [1, 0, 1],
       [1, 2, 3],
       [0, 3, 3],
       [1, 0, 3],
       [3, 0, 1],
       [2, 1, 1]], dtype=int32), 'lookahead_mask': Array([[  0., -inf, -inf],
       [  0.,   0., -inf],
       [  0.,   0.,   0.]], dtype=float32), 'padding_mask': Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)}


In [11]:
model = TransformerLM(vocab_size=train_dataset.vocab_size)

In [12]:
trainer = LMTrainer(
    model, 
    example_batch=batch, 
    max_iters=101
)

In [13]:
trainer.train(5, train_loader)

Epoch 5 / 5: 100%|██████████| 5/5 [00:29<00:00,  5.95s/it, loss=1.36]


In [14]:
jax.devices()

[CpuDevice(id=0)]