In [None]:
# https://www.kaggle.com/datasets/bcruise/reddit-rfloridaman

In [1]:
import typing as t

from transformer.models import CausalLM
from transformer.dataloaders import CausalDataModule
from transformer.params import TransformerParams, TemperatureSamplingParams
from transformer.decoding import TemperatureSamplingDecoder

import pandas as pd
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
# load and preview data
titles = pd.read_csv("data/florida_man.csv").title
titles.tail()

42768    Florida woman assaults boyfriend after he refu...
42769    Florida Woman Arrested After Dispute Over Moth...
42770    Law firm demands Florida man remove racist ‘co...
42771    Florida Man arrested for assaulting wife with ...
42772    Half of the articles linked in /r/FloridaMan d...
Name: title, dtype: object

In [3]:
# create data module
class FloridaManDataModule(CausalDataModule):
    def setup(self: t.Self, stage: str) -> None:
        # select titles containing the word florida, with 200 or fewer characters
        self.data = titles.loc[
            titles.str.contains("florida", case=False) & 
            (titles.str.len() <= 200)
        ].to_list()
        super().setup(stage=stage)

In [4]:
# initialize pretrained tokenizer for causal language modelling
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

1

In [5]:
# initialize the transformer
context_length = 64
model = CausalLM(
    params=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
)

In [6]:
# tokenize & encode data and prepare train/test splits
datamodule = FloridaManDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir lightning_logs/

In [91]:
%%time
# train the model
trainer = Trainer(
    max_epochs=500,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
    default_root_dir="models/florida_man_generation",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params | Mode
--------------------------------------------
0 | model | ModuleDict | 35.3 M | eval
--------------------------------------------
35.3 M    Trainable params
0         Non-trainable params
35.3 M    Total params
141.158   Total estimated model params size (MB)


Epoch 107: 100%|██████████| 861/861 [03:17<00:00,  4.37it/s, v_num=0, val_loss=6.820, train_loss=6.290]
CPU times: user 5h 31min 29s, sys: 11h 37min 13s, total: 17h 8min 42s
Wall time: 5h 57min 39s


In [24]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

Testing DataLoader 0: 100%|██████████| 123/123 [00:10<00:00, 11.38it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           7.4453864097595215
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 7.4453864097595215}]

In [92]:
# view first batch of test set predictions
# note: these are still produced using teacher-forcing, so not purely generated
pred = trainer.predict(model=model, datamodule=datamodule)

Predicting DataLoader 0: 100%|██████████| 123/123 [00:14<00:00,  8.73it/s]


In [95]:
pred[:5]

[[('Florida Woman Womans c of the of the her and police drives her her from the c her',
   'Florida Woman pulls out in front of motorcycle, drives off leaving couple on the street'),
  ('   Florida Man ste  UI  881 8 oting',
   'Authorities: Florida Man Recently Died From 1958 Shooting'),
  ('Florida man with with c with c ch on ch ch on c on on',
   'Florida man cuts neighbor with chainsaw during argument over shrubs'),
  ('Florida unk Florida3--year-year Florida Woman at with UI at',
   'Very drunk 83-year-old Florida Woman charged with DUI'),
  ('Florida Man arrested to to ste1 to to to  to to sex to to J toUI',
   'Florida Man Sentenced to 10 Days in Jail for Missing Jury Duty'),
  ('Florida Man arrested hised for arrested andineinet for sex..',
   "Florida Man with last name 'Cocaine' arrested for drug possession"),
  ('Florida man arrested afterlydouble afterending beer man after aftert late a night to c holiday',
   "Florida man reportedly 'pretending to be a firework' late at n

In [96]:
# initialize decoder
decoder = TemperatureSamplingDecoder(
    params=TemperatureSamplingParams(max_length=200, temperature=0.25, k=5),
    model=model,
)

In [100]:
decoder.generate()

'Florida Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman Woman W

In [98]:
decoder.generate("Florida man buys")

'Florida man buys c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c c...'