In [None]:
# https://www.kaggle.com/datasets/bcruise/reddit-rfloridaman

In [1]:
import typing as t

from transformer.models import CausalLM
from transformer.dataloaders import CausalDataModule
from transformer.params import TransformerParams, TemperatureSamplingParams
from transformer.decoding import TemperatureSamplingDecoder

import pandas as pd
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
# load and preview data
titles = pd.read_csv("data/florida_man.csv").title
titles.tail()

42768    Florida woman assaults boyfriend after he refu...
42769    Florida Woman Arrested After Dispute Over Moth...
42770    Law firm demands Florida man remove racist ‘co...
42771    Florida Man arrested for assaulting wife with ...
42772    Half of the articles linked in /r/FloridaMan d...
Name: title, dtype: object

In [3]:
# create data module
class FloridaManDataModule(CausalDataModule):
    def setup(self: t.Self, stage: str) -> None:
        # read titles with 200 or fewer characters from CSV
        self.data = titles.loc[titles.str.contains("florida", case=False) & (titles.str.len() <= 200)].to_list()
        super().setup(stage=stage)

In [4]:
# initialize pretrained tokenizer for causal language modelling
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

1

In [5]:
# initialize the transformer
context_length = 64
model = CausalLM(
    params=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
)

In [6]:
# tokenize & encode data and prepare train/test splits
datamodule = FloridaManDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir lightning_logs/

In [7]:
%%time
# train the model
trainer = Trainer(
    max_epochs=100,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
    default_root_dir="models/florida_man_generation",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Missing logger folder: models/florida_man_generation/lightning_logs

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | ModuleDict | 35.3 M | train
---------------------------------------------
35.3 M    Trainable params
0         Non-trainable params
35.3 M    Total params
141.158   Total estimated model params size (MB)


                                                                           

  preds = preds.flatten(end_dim=1)[masks]


Epoch 4: 100%|██████████| 861/861 [03:29<00:00,  4.10it/s, v_num=0, val_loss=8.620, train_loss=8.830]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 861/861 [03:30<00:00,  4.10it/s, v_num=0, val_loss=8.620, train_loss=8.830]
CPU times: user 16min 3s, sys: 28min 23s, total: 44min 27s
Wall time: 17min 12s


In [8]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

Testing DataLoader 0: 100%|██████████| 123/123 [00:17<00:00,  6.92it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            8.626354217529297
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 8.626354217529297}]

In [9]:
# view first batch of test set predictions
# note: these are still produced using teacher-forcing, so not purely generated
pred = trainer.predict(model=model, datamodule=datamodule)

Predicting DataLoader 0: 100%|██████████| 123/123 [00:21<00:00,  5.60it/s]


In [10]:
pred[:5]

[[('Florida Florida in,',
   'Florida Woman pulls out in front of motorcycle, drives off leaving couple on the street'),
  ('Florida Florida Man ',
   'Authorities: Florida Man Recently Died From 1958 Shooting'),
  ('Florida Florida man',
   'Florida man cuts neighbor with chainsaw during argument over shrubs'),
  ('Florida  Florida',
   'Very drunk 83-year-old Florida Woman charged with DUI'),
  ('Florida Florida Man to 0 in',
   'Florida Man Sentenced to 10 Days in Jail for Missing Jury Duty'),
  ('Florida Florida Man',
   "Florida Man with last name 'Cocaine' arrested for drug possession"),
  ('Florida Florida man to a',
   "Florida man reportedly 'pretending to be a firework' late at night ahead of holiday weekend"),
  ('Florida Florida Man $0 of and to to a',
   'Florida Man buys $0.69 worth of gasoline and uses it to set fire to a McDonalds bathroom'),
  ('Florida the of the and',
   'Ex-cop Floridaman finds himself on the otherside of the law for stealing cows... twice... and fo

In [11]:
# initialize decoder
decoder = TemperatureSamplingDecoder(
    params=TemperatureSamplingParams(max_length=200, temperature=0.25, k=5),
    model=model,
)

In [12]:
decoder.generate()

'Florida Florida Florida Florida Florida Florida Florida Florida Florida'

In [22]:
decoder.generate("Florida man")

'Florida man man man man to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to...'