In [2]:
from typing import Mapping

import torch as t
from beartype import beartype as typed
from datasets import load_dataset
from language_modeling import explore_batch
from tokenization import dependencies_tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments


In [None]:
model_name = "roneneldan/TinyStories-8M"
repo_name = open("SECRET.txt").read().strip()
dataset_name = f"{repo_name}/flat"
dataset = load_dataset(dataset_name, streaming=True)
tokenizer = (
    AutoTokenizer.from_pretrained("roneneldan/TinyStories-8M")
    if dataset_name == "TinyStories"
    else dependencies_tokenizer(vocab_size=500)
)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)


In [None]:
model.resize_token_embeddings(len(tokenizer))
for name, param in model.named_parameters():
    param.requires_grad = "wte" in name or "wpe" in name


In [None]:
tokens_sample = tokenizer(next(iter(dataset["train"]))["text"])["input_ids"]
print(len(tokens_sample))
print(tokens_sample[:10])


In [None]:
if dataset_name == "TinyStories":

    @typed
    def tokenize_function(example: Mapping[str, str | int]) -> Mapping[str, list[int]]:
        result = tokenizer(
            example["text"], max_length=128, padding="max_length", truncation=True
        )
        result["labels"] = result["input_ids"]
        return result

else:

    @typed
    def tokenize_function(example: Mapping[str, str | int]) -> Mapping[str, list[int]]:
        result = tokenizer(example["text"])
        result["labels"] = result["input_ids"]
        return result


train_size = 100000
test_size = 1000
tokenized_train = (
    dataset["train"]
    .map(tokenize_function, batched=True)
    .remove_columns(["text"])
    .take(train_size)
)
tokenized_test = (
    dataset["validation" if dataset_name == "TinyStories" else "test"]
    .map(tokenize_function, batched=True)
    .remove_columns(["text"])
    .take(test_size)
)


In [None]:
@typed
def train(batch_size: int, lr: float) -> None:
    training_args = TrainingArguments(
        output_dir="trainer",
        fp16=False,
        per_device_train_batch_size=batch_size,
        torch_compile=False,
        learning_rate=lr,
        logging_steps=10,
        num_train_epochs=1,
        max_steps=train_size // batch_size,
        save_total_limit=1,
        report_to="none",
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
    )
    trainer.train()


In [None]:
explore_batch(model, tokenizer, tokenized_test)


In [None]:
# Fine-tuning only embeddings:
train(batch_size=8, lr=1e-2)


In [None]:
explore_batch(model, tokenizer, tokenized_test)


In [None]:
# Fine-tuning only embeddings and layernorms:
for name, param in model.named_parameters():
    if "ln" in name:
        print(f"{name} unfrozen")
        param.requires_grad = True
train(batch_size=8, lr=2e-3)


In [None]:
explore_batch(model, tokenizer, tokenized_test)


In [None]:
# Fine-tuning only embeddings, layernorms and the last block:
for name, param in model.named_parameters():
    if "h.7" in name:
        param.requires_grad = True
train(batch_size=8, lr=1e-3)


In [None]:
explore_batch(model, tokenizer, tokenized_test)


In [None]:
from huggingface_hub import notebook_login

notebook_login()


In [None]:
name = input("Model name: ")
model.push_to_hub(name)
tokenizer.push_to_hub(name)
