In [None]:
# https://www.kaggle.com/datasets/bcruise/reddit-rfloridaman

In [None]:
import typing as t

from transformer.models import CausalLM
from transformer.dataloaders import CausalDataModule
from transformer.params import TransformerParams, TemperatureSamplingParams
from transformer.decoding import TemperatureSamplingDecoder

import pandas as pd
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

In [None]:
# load and preview data
titles = pd.read_csv("data/florida_man.csv").title
titles.tail()

In [None]:
# create data module
class FloridaManDataModule(CausalDataModule):
    def setup(self: t.Self, stage: str) -> None:
        # read titles with 200 or fewer characters from CSV
        self.data = titles.loc[titles.str.contains("florida", case=False) & (titles.str.len() <= 200)].to_list()
        super().setup(stage=stage)

In [None]:
# initialize pretrained tokenizer for causal language modelling
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

In [None]:
# initialize the transformer
context_length = 64
model = CausalLM(
    params=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
)

In [None]:
# tokenize & encode data and prepare train/test splits
datamodule = FloridaManDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir lightning_logs/

In [None]:
%%time
# train the model
trainer = Trainer(
    max_epochs=100,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
)
trainer.fit(model=model, datamodule=datamodule)

In [None]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

In [None]:
# view first batch of test set predictions
# note: these are still produced using teacher-forcing, so not purely generated
pred = trainer.predict(model=model, datamodule=datamodule)

In [None]:
pred[:5]

In [None]:
# initialize decoder
decoder = TemperatureSamplingDecoder(
    params=TemperatureSamplingParams(max_length=200, temperature=0.25, k=5),
    model=model,
)

In [None]:
decoder.generate()

In [None]:
decoder.generate("Florida man")