In [1]:
# https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts/data
# predict category from title/abstract

In [2]:
import typing as t
from ast import literal_eval

from transformer.models.seq2seq import Seq2SeqLM
from transformer.dataloaders.seq2seq import Seq2SeqDataModule
from transformer.params.transformer import TransformerParams

import torch
import numpy as np
import pandas as pd
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load and preview data
data = pd.read_csv("data/arxiv.csv")
data.titles = data.titles.str.replace("\n", " ")
data.abstracts = data.abstracts.str.replace("\n", " ")
data.tail()

Unnamed: 0,terms,titles,abstracts
56176,"['cs.CV', 'cs.IR']",Mining Spatio-temporal Data on Industrializati...,Despite the growing availability of big data i...
56177,"['cs.LG', 'cs.AI', 'cs.CL', 'I.2.6; I.2.7']",Wav2Letter: an End-to-End ConvNet-based Speech...,This paper presents a simple end-to-end model ...
56178,['cs.LG'],Deep Reinforcement Learning with Double Q-lear...,The popular Q-learning algorithm is known to o...
56179,"['stat.ML', 'cs.LG', 'math.OC']",Generalized Low Rank Models,Principal components analysis (PCA) is a well-...
56180,"['cs.LG', 'cs.AI', 'stat.ML']",Chi-square Tests Driven Method for Learning th...,SDYNA is a general framework designed to addre...


In [4]:
# create data module
class ArxivSummarizationDataModule(Seq2SeqDataModule):
    def setup(self: t.Self, stage: str) -> None:
        self.data = data[["abstracts", "titles"]].to_numpy()
        super().setup(stage=stage)

In [5]:
# initialize pretrained tokenizer
# - llama does not add an EOS token by default, so override this
# - llama also does not use a padding token, so this needs to be added
tokenizer = LlamaTokenizer.from_pretrained(
    "huggyllama/llama-7b", add_eos_token=True, legacy=False
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

1

In [6]:
# initialize the transformer - note that for this seq2seq task, it is appropriate to use the same tokenizer for input and output
context_length = 512
model = Seq2SeqLM(
    params=TransformerParams(context_length=context_length, model_dim=64),
    input_tokenizer=tokenizer,
    output_tokenizer=tokenizer,
)

In [7]:
# tokenize & encode data and prepare train/test splits
datamodule = ArxivSummarizationDataModule(
    input_tokenizer=tokenizer,
    output_tokenizer=tokenizer,
    context_length=context_length,
    batch_size=8,
    val_size=0.2,
    test_size=0.1,
    num_workers=9,
    persistent_workers=True,
    limit=None,
    random_state=1,
)

In [8]:
%%time
# train the model
trainer = Trainer(
    max_epochs=5,
    callbacks=EarlyStopping(monitor="val_loss", mode="min", patience=5),
    accelerator="gpu",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | ModuleDict | 7.6 M  | train
---------------------------------------------
7.6 M     Trainable params
0         Non-trainable params
7.6 M     Total params
30.268    Total estimated model params size (MB)


                                                                           

  preds = preds.flatten(end_dim=1)[output_masks]


Epoch 4: 100%|██████████| 4916/4916 [39:22<00:00,  2.08it/s, v_num=16, val_loss=9.020, train_loss=9.200]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 4916/4916 [39:23<00:00,  2.08it/s, v_num=16, val_loss=9.020, train_loss=9.200]
CPU times: user 3h 5min 59s, sys: 3h 16min 39s, total: 6h 22min 38s
Wall time: 3h 19min 37s


In [None]:
# calculate test metrics
trainer.test(model=model, datamodule=datamodule)

In [None]:
# view first batch of test set predictions
pred = trainer.predict(model=model, datamodule=datamodule)
pred[:10]

In [None]:
# calculate accuracy
torch.tensor([x[1] == x[2] for batch in pred for x in batch]).float().mean()