In [1]:

import jax
import torch
from torch.utils.data import DataLoader, Dataset

from optimusjx.model import TransformerLM
from optimusjx.train import CollatorForCausalLM, LMTrainer

In [26]:
class RandomIntDataset(Dataset):
    def __init__(
        self, 
        seq_len: int,
        vocab_size: int, 
        n_batches: int = 10,
        seed: int = 42,
        padding_amount: int | None = None
    ) -> None:
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.n_batches = n_batches
        self.rng = jax.random.PRNGKey(seed)
        self.padding_amount = padding_amount
        if self.padding_amount is not None:
            self.pad_token_id = self.vocab_size + 1
            self.vocab_size += 1
        
    def __getitem__(self, index) -> dict[str, list]:
        if index > self.n_batches - 1:
            raise ValueError("Index larger than length.")
        
        random_vocab = jax.random.randint(
            self.rng, 
            (self.seq_len,), 
            minval=0,
            maxval=self.vocab_size
        )

        if self.padding_amount is not None and self.pad_token_id is not None:
            padding = jax.numpy.full((self.padding_amount), self.pad_token_id)
            random_vocab = jax.numpy.concatenate([random_vocab, padding], axis=-1)
        
        self.rng, _ = jax.random.split(self.rng, 2)
        
        return {"input_ids": random_vocab.tolist()}

    def __len__(self) -> int:
        return self.n_batches


class TokenizerStandin:
    def __init__(self, pad_token_id: int = 0) -> None:
        self.pad_token_id = pad_token_id
    

train_dataset = RandomIntDataset(3, 4, padding_amount=3)

In [27]:
train_dataset[0]

{'input_ids': [0, 3, 2, 5, 5, 5]}

In [28]:
random_seed = 42
batch_size = 8

rng = torch.Generator()
rng.manual_seed(random_seed)

collator = CollatorForCausalLM(TokenizerStandin(pad_token_id=train_dataset.pad_token_id))

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    generator=rng,
    collate_fn=collator
)

In [29]:
for batch in train_loader:
    break

print(batch)

{'inputs': Array([[0, 3, 3, 5, 5, 5],
       [1, 0, 2, 5, 5, 5],
       [0, 4, 2, 5, 5, 5],
       [2, 2, 0, 5, 5, 5],
       [1, 1, 0, 5, 5, 5],
       [3, 0, 0, 5, 5, 5],
       [2, 3, 0, 5, 5, 5],
       [1, 4, 3, 5, 5, 5]], dtype=int32), 'labels': Array([[   0,    3,    3, -100, -100, -100],
       [   1,    0,    2, -100, -100, -100],
       [   0,    4,    2, -100, -100, -100],
       [   2,    2,    0, -100, -100, -100],
       [   1,    1,    0, -100, -100, -100],
       [   3,    0,    0, -100, -100, -100],
       [   2,    3,    0, -100, -100, -100],
       [   1,    4,    3, -100, -100, -100]], dtype=int32), 'lookahead_mask': Array([[  0., -inf, -inf, -inf, -inf, -inf],
       [  0.,   0., -inf, -inf, -inf, -inf],
       [  0.,   0.,   0., -inf, -inf, -inf],
       [  0.,   0.,   0.,   0., -inf, -inf],
       [  0.,   0.,   0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.,   0.,   0.]], dtype=float32), 'padding_mask': Array([[  0.,   0.,   0., -inf, -inf, -inf],
       

In [30]:
model = TransformerLM(vocab_size=train_dataset.vocab_size)

In [31]:
trainer = LMTrainer(
    model, 
    example_batch=batch, 
    max_iters=101,
    report_to="wandb"
)

In [32]:
trainer.train(5, train_loader)



0,1
train_loss,


Epoch 5 / 5: 100%|██████████| 5/5 [00:27<00:00,  5.58s/it, loss=nan]


In [11]:
trainer.history

{'train_loss': [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
 'val_loss': []}