In [1]:
# https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts/data
# predict category from title/abstract

In [2]:
import typing as t
from ast import literal_eval

from transformer.models.classifier import ClassifierLM
from transformer.dataloaders.inference import InferenceDataModule
from transformer.params import TransformerParams

import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load and preview data
data = pd.read_csv("data/arxiv.csv")
data.tail()

Unnamed: 0,terms,titles,abstracts
56176,"['cs.CV', 'cs.IR']",Mining Spatio-temporal Data on Industrializati...,Despite the growing availability of big data i...
56177,"['cs.LG', 'cs.AI', 'cs.CL', 'I.2.6; I.2.7']",Wav2Letter: an End-to-End ConvNet-based Speech...,This paper presents a simple end-to-end model ...
56178,['cs.LG'],Deep Reinforcement Learning with Double Q-lear...,The popular Q-learning algorithm is known to o...
56179,"['stat.ML', 'cs.LG', 'math.OC']",Generalized Low Rank Models,Principal components analysis (PCA) is a well-...
56180,"['cs.LG', 'cs.AI', 'stat.ML']",Chi-square Tests Driven Method for Learning th...,SDYNA is a general framework designed to addre...


In [4]:
# get titles and primary category
X = data.titles.to_list()
y = data.terms.apply(literal_eval).str[0].to_numpy()

In [5]:
# encode categories
label_encoder = LabelEncoder()
y = torch.from_numpy(label_encoder.fit_transform(y))

In [6]:
# create data module
class ArxivDataModule(InferenceDataModule):
    def setup(self: t.Self, stage: str) -> None:
        self.X, self.y = X, y
        super().setup(stage=stage)

In [7]:
# initialize pretrained tokenizer
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

1

In [8]:
# initialize the transformer
context_length = 64
model = ClassifierLM(
    config=TransformerParams(context_length=context_length),
    tokenizer=tokenizer,
    num_classes=len(y.unique())
)

In [14]:
# tokenize & encode data and prepare train/test splits
datamodule = ArxivDataModule(
    tokenizer=tokenizer,
    context_length=context_length,
    batch_size=32,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=1000,
    random_state=1,
)

In [15]:
# train the model
trainer = Trainer(
    max_epochs=500,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | ModuleDict | 35.3 M | train
---------------------------------------------
35.3 M    Trainable params
0         Non-trainable params
35.3 M    Total params
141.165   Total estimated model params size (MB)


                                                                           

/Users/edwinonuonga/env/llm-arm64/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 3:  73%|███████▎  | 16/22 [00:02<00:00,  7.91it/s, v_num=8, val_loss=0.867, train_loss=0.868]

/Users/edwinonuonga/env/llm-arm64/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

In [None]:
# view first batch of test set predictions
pred = trainer.predict(model=model, datamodule=datamodule)