In [1]:
import os
from typing import Annotated, Any, Mapping, Optional

import matplotlib.pyplot as plt
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from beartype.typing import Callable, Iterable
from beartype.vale import Is
from datasets import load_dataset
from dvclive.huggingface import DVCLiveCallback
from jaxtyping import Bool, Float, Int
from torch import Tensor as TT
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from gpt import DenseGPTConfig, DenseGPTForCausalLM
from utils import explore_batch, fetch_or_ask

%load_ext autoreload
%autoreload 2

In [2]:
config = DenseGPTConfig(
    vocab_size=50257,
    hidden_size=256,
    num_layers=8,
    attention_types=[[["global"], 8]],
    num_heads=16,
    use_dense=True,
)
# model = DenseGPTForCausalLM(config)
# model = DenseGPTForCausalLM.from_pretrained("roneneldan/TinyStories-8M", config=config)
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-8M")

In [5]:
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-8M")
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

Repo card metadata block was not found. Setting CardData to empty.
Using pad_token, but it is not set yet.


In [47]:
@typed
def tokenize_function(example: Mapping[str, str | int]) -> Mapping[str, list[int]]:
    result = tokenizer(
        example["text"],
        max_length=256,
        padding="max_length",
        truncation=True,
    )
    result["labels"] = result["input_ids"]
    return result


subset_size = 1000
subset = dataset["train"].shuffle().select(range(subset_size))
tokenized = subset.map(tokenize_function, batched=True).remove_columns(["text"])

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
explore_batch(model, tokenizer, tokenized, n_samples=100)

In [64]:
from utils import get_loss
from tqdm.auto import tqdm

losses = []
for row in tqdm(dataset["train"].shuffle().select(range(100))):
    losses.append(get_loss(model, tokenizer, row["text"]))

  0%|          | 0/100 [00:00<?, ?it/s]

In [65]:
print(len(losses), sum(losses) / len(losses))

100 1.3923834615945816
