In [1]:

import jax
import torch
from torch.utils.data import DataLoader, Dataset

from optimusjx.model import TransformerLM
from optimusjx.train import CollatorForCausalLM, LMTrainer

In [2]:
class RandomIntDataset(Dataset):
    def __init__(
        self, 
        seq_len: int,
        vocab_size: int, 
        n_samples: int = 10,
        seed: int = 42,
        padding_amount: int | None = None
    ) -> None:
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.n_samples = n_samples
        self.rng = jax.random.PRNGKey(seed)

        # could create samples at __getitem__ call instead
        self._data = jax.random.randint(
            self.rng, (n_samples, seq_len), minval=0, maxval=self.vocab_size
        )

        self.padding_amount = padding_amount            
        if self.padding_amount is not None:
            self.pad_token_id = self.vocab_size
            self.vocab_size += 1
            
            padding = jax.numpy.full((n_samples, self.padding_amount), self.pad_token_id)
            self._data = jax.numpy.concatenate([self._data, padding], axis=-1)
    
    def __getitem__(self, index) -> dict[str, list]:
        if index > self.n_samples - 1:
            raise ValueError("Index larger than length.")
        
        return {"input_ids": self._data[index, :].tolist()}

    def __len__(self) -> int:
        return self.n_samples


class TokenizerStandin:
    def __init__(self, pad_token_id: int = 0) -> None:
        self.pad_token_id = pad_token_id

In [3]:
train_dataset = RandomIntDataset(3, 4, padding_amount=3)
train_dataset[0]

2023-12-19 23:03:24.727405: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{'input_ids': [1, 1, 0, 4, 4, 4]}

In [4]:
random_seed = 42
batch_size = 2

rng = torch.Generator()
rng.manual_seed(random_seed)

collator = CollatorForCausalLM(TokenizerStandin(pad_token_id=train_dataset.pad_token_id))

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    generator=rng,
    collate_fn=collator
)

In [5]:
batch = next(iter(train_loader))
print(list(batch.keys()))
batch['inputs'], batch['labels'], batch["special_tokens_mask"]

['inputs', 'labels', 'lookahead_mask', 'padding_mask', 'special_tokens_mask']


(Array([[1, 1, 0, 4, 4, 4],
        [2, 1, 0, 4, 4, 4]], dtype=int32),
 Array([[1, 1, 0, 4, 4, 4],
        [2, 1, 0, 4, 4, 4]], dtype=int32),
 Array([[False, False, False,  True,  True,  True],
        [False, False, False,  True,  True,  True]], dtype=bool))

In [6]:
model = TransformerLM(vocab_size=train_dataset.vocab_size)

In [7]:
trainer = LMTrainer(
    model, 
    example_batch=batch.copy(), 
    max_iters=101,
    report_to=None
)

  warn(f"Transformer recieved unknown keyword argument {key} - ignoring")
  warn(f"Transformer recieved unknown keyword argument {key} - ignoring")


In [13]:
trainer.train(5, train_loader)

Epoch 1 / 5:   0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
  0%|          | 0/5 [00:23<?, ?it/s, loss=1.6438621][A
 20%|██        | 1/5 [00:23<01:32, 23.18s/it, loss=1.6438621][A
 20%|██        | 1/5 [00:23<01:32, 23.18s/it, loss=1.5999411][A
 40%|████      | 2/5 [00:23<01:09, 23.18s/it, loss=1.7082672][A
 60%|██████    | 3/5 [00:23<00:46, 23.18s/it, loss=1.5150365][A
 80%|████████  | 4/5 [00:23<00:23, 23.18s/it, loss=1.3049347][A
Epoch 2 / 5:  20%|██        | 1/5 [00:23<01:33, 23.28s/it, loss=1.55, rng=[2144832284 1361675005]]
  0%|          | 0/5 [00:00<?, ?it/s][A
  0%|          | 0/5 [00:00<?, ?it/s, loss=1.6655241][A
 20%|██        | 1/5 [00:00<00:00, 33.31it/s, loss=1.7073815][A
 40%|████      | 2/5 [00:00<00:00, 43.09it/s, loss=1.3548726][A
 60%|██████    | 3/5 [00:00<00:00, 48.26it/s, loss=1.4652253][A
 80%|████████  | 4/5 [00:00<00:00, 50.91it/s, loss=1.5186874][A
Epoch 3 / 5:  40%|████      | 2/5 [00:23<01:09, 23.28s/it, loss=1.54,