In [14]:
# https://www.kaggle.com/code/bcruise/adventures-of-florida-man/input

In [1]:
import typing as t

from transformer.models.causal import CausalLM
from transformer.dataloaders.teacher_forcing import TeacherForcingDataModule
from transformer.params import TransformerParams

import pandas as pd
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load and preview data
titles = pd.read_csv("data/florida_man.csv").title
titles.tail()

42768    Florida woman assaults boyfriend after he refu...
42769    Florida Woman Arrested After Dispute Over Moth...
42770    Law firm demands Florida man remove racist ‘co...
42771    Florida Man arrested for assaulting wife with ...
42772    Half of the articles linked in /r/FloridaMan d...
Name: title, dtype: object

In [3]:
# create data module
class FloridaManDataModule(TeacherForcingDataModule):
    def setup(self: t.Self, stage: str) -> None:
        # read titles with 200 or fewer characters from CSV
        self.data = titles.loc[titles.str.len() <= 200].to_list()
        super().setup(stage=stage)

In [4]:
# initialize pretrained tokenizer for causal language modelling
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

1

In [5]:
# initialize the transformer
context_length = 64
model = CausalLM(
    config=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
)

In [6]:
# tokenize & encode data and prepare train/test splits
datamodule = FloridaManDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [7]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [8]:
%%time
# train the model
trainer = Trainer(
    max_epochs=100,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | ModuleDict | 35.3 M | train
---------------------------------------------
35.3 M    Trainable params
0         Non-trainable params
35.3 M    Total params
141.158   Total estimated model params size (MB)


                                                                           

  preds = preds.flatten(end_dim=1)[masks]


Epoch 70: 100%|██████████| 929/929 [03:43<00:00,  4.16it/s, v_num=2, val_loss=7.890, train_loss=7.860]
CPU times: user 4h 1min 32s, sys: 7h 50min 55s, total: 11h 52min 27s
Wall time: 4h 25min 29s


In [None]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

In [9]:
# view first batch of test set predictions
# note: these are still produced using teacher-forcing, so not purely generated
pred = trainer.predict(model=model, datamodule=datamodule)

Predicting DataLoader 0: 100%|██████████| 133/133 [00:16<00:00,  8.17it/s]


In [10]:
pred[:5]

[[('Florida16 Man to,fighters fired for alleg,, toose overagues',
   "6 FloridaMan firefighters fired for allegedly placing noose over black colleague's family photo, officials say"),
  ('Florida Man Woman,117, old daughter on snapchat having sex and,118, old Man',
   'Florida Woman sees 17 year old daughter on snapchat having sex and gets 18 year old Florida Man busted for child porn'),
  ("Florida Man' and for after",
   'Florida Man and friends arrested after high speed chase'),
  ('Florida Man arrested sho Florida at,pper after',
   'Florida man shoots at stripper after she refused to have sex with him.'),
  ('Florida Man, Assaults Woman, Calls "',
   'Florida Man Assaults Woman, Calls Her "Old Snaggletooth Lady".'),
  ('Florida Man, to his friend with a with of nunch',
   'Florida Man attacks his friend with a pair of nunchucks for failing to return his DVDs'),
  ('Florida Florida Man, was for after,, her,-,,ss with her and',
   'A Florida lawyer was arrested after ramming her ex-

In [17]:
model.generate()

't'

In [26]:

model.generate("Florida man")

'man for'