In [None]:
# https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts/data
# predict category from title/abstract

In [None]:
import typing as t
from ast import literal_eval

from transformer.models.classifier import ClassifierLM
from transformer.dataloaders.inference import InferenceDataModule
from transformer.params import TransformerParams

import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

In [None]:
# load and preview data
data = pd.read_csv("data/arxiv.csv")
data.titles = data.titles.str.replace("\n", " ")
data.abstracts = data.abstracts.str.replace("\n", " ")
data.tail()

In [None]:
# get titles and primary category
X = data.titles.to_list()
y = data.terms.apply(literal_eval).str[0]
print(y.value_counts())
y = y.to_numpy()

In [None]:
# encode categories
label_encoder = LabelEncoder()
y = torch.from_numpy(label_encoder.fit_transform(y))

In [None]:
# create data module
class ArxivCategorizationDataModule(InferenceDataModule):
    def setup(self: t.Self, stage: str) -> None:
        self.X, self.y = X, y
        super().setup(stage=stage)

In [None]:
# initialize pretrained tokenizer
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

In [None]:
# initialize the transformer
context_length = 64
model = ClassifierLM(
    config=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
    num_classes=len(y.unique())
)

In [None]:
# tokenize & encode data and prepare train/test splits
datamodule = ArxivCategorizationDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [None]:
%%time
# train the model
trainer = Trainer(
    max_epochs=50,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="cpu",
)
trainer.fit(model=model, datamodule=datamodule)

In [None]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

In [None]:
# view first batch of test set predictions
pred = trainer.predict(model=model, datamodule=datamodule)
pred[:10]

In [None]:
# calculate accuracy
torch.tensor([x[1] == x[2] for batch in pred for x in batch]).float().mean()