In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
import re
from functools import partial

REGEX_MULTI_SPACE = re.compile("\s+")


def preprocess_text(_re, _regex, s):
    return {
        "text": _re.sub(_regex, " ", s["title"])
        # + "\n\n"
        # + _re.sub(_regex, " ", s["abstract"])
    }
    
partial_preprocess_text = partial(preprocess_text, re, REGEX_MULTI_SPACE)

In [3]:
from datasets import load_dataset

dataset = load_dataset("aalksii/ml-arxiv-papers")
dataset = dataset.map(
    partial_preprocess_text,
)
dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'abstract', 'text'],
        num_rows: 105832
    })
    test: Dataset({
        features: ['title', 'abstract', 'text'],
        num_rows: 11760
    })
})

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# def add_tokens(tokenizer, new_tokens):
    # new_tokens = set(new_tokens) - set(tokenizer.vocab.keys())
    # tokenizer.add_tokens(list(new_tokens))
    # return tokenizer

In [5]:
def tokenize_function(tokenizer, examples):
    return tokenizer(examples["text"], truncation=True)



partial_tokenize_function = partial(tokenize_function, tokenizer)

tokens = dataset.map(
    partial_tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=dataset["train"].column_names,
)

In [6]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [7]:
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer

model = AutoModelForCausalLM.from_pretrained("distilgpt2")

In [8]:
from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir="./results_causal",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=3,
    save_steps=1000,
    optim="adafactor",
    fp16=True,
    # load_best_model_at_end=True,
)

# NOTE: There's a HuggingFace bug on this; will fix later 
# (validate that eval_loss doesn't worsen by manual inspection for now)
# early_stopping = EarlyStoppingCallback(
#     early_stopping_patience=3, early_stopping_threshold=0.03
# )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokens["train"],
    eval_dataset=tokens["test"],
    data_collator=data_collator,
    # callbacks=[early_stopping],
)

trainer.train()

  0%|          | 0/39687 [00:00<?, ?it/s]

{'loss': 4.9543, 'learning_rate': 1.9749540151686952e-05, 'epoch': 0.04}
{'loss': 4.5818, 'learning_rate': 1.949807241666037e-05, 'epoch': 0.08}
{'loss': 4.4567, 'learning_rate': 1.924610073827702e-05, 'epoch': 0.11}
{'loss': 4.3533, 'learning_rate': 1.899412905989367e-05, 'epoch': 0.15}
{'loss': 4.3255, 'learning_rate': 1.8742157381510318e-05, 'epoch': 0.19}
{'loss': 4.2333, 'learning_rate': 1.849018570312697e-05, 'epoch': 0.23}
{'loss': 4.2168, 'learning_rate': 1.823821402474362e-05, 'epoch': 0.26}
{'loss': 4.2227, 'learning_rate': 1.798624234636027e-05, 'epoch': 0.3}
{'loss': 4.192, 'learning_rate': 1.773427066797692e-05, 'epoch': 0.34}
{'loss': 4.1175, 'learning_rate': 1.748229898959357e-05, 'epoch': 0.38}
{'loss': 4.0573, 'learning_rate': 1.7230327311210222e-05, 'epoch': 0.42}
{'loss': 4.0865, 'learning_rate': 1.697835563282687e-05, 'epoch': 0.45}
{'loss': 4.039, 'learning_rate': 1.672638395444352e-05, 'epoch': 0.49}
{'loss': 4.028, 'learning_rate': 1.6474412276060173e-05, 'epoch'

  0%|          | 0/1470 [00:00<?, ?it/s]

{'eval_loss': 3.735496759414673, 'eval_runtime': 13.2368, 'eval_samples_per_second': 888.436, 'eval_steps_per_second': 111.054, 'epoch': 1.0}
{'loss': 3.8386, 'learning_rate': 1.3200292287146925e-05, 'epoch': 1.02}
{'loss': 3.777, 'learning_rate': 1.2948320608763576e-05, 'epoch': 1.06}
{'loss': 3.79, 'learning_rate': 1.2696348930380227e-05, 'epoch': 1.1}
{'loss': 3.8055, 'learning_rate': 1.2444881195353644e-05, 'epoch': 1.13}
{'loss': 3.7811, 'learning_rate': 1.2192909516970292e-05, 'epoch': 1.17}
{'loss': 3.7469, 'learning_rate': 1.1940937838586944e-05, 'epoch': 1.21}
{'loss': 3.7534, 'learning_rate': 1.1688966160203595e-05, 'epoch': 1.25}
{'loss': 3.7403, 'learning_rate': 1.1437498425177012e-05, 'epoch': 1.29}
{'loss': 3.7467, 'learning_rate': 1.1185526746793662e-05, 'epoch': 1.32}
{'loss': 3.732, 'learning_rate': 1.0934059011767077e-05, 'epoch': 1.36}
{'loss': 3.7481, 'learning_rate': 1.0682087333383728e-05, 'epoch': 1.4}
{'loss': 3.7348, 'learning_rate': 1.043011565500038e-05, 'epo

  0%|          | 0/1470 [00:00<?, ?it/s]

{'eval_loss': 3.626753807067871, 'eval_runtime': 13.4091, 'eval_samples_per_second': 877.019, 'eval_steps_per_second': 109.627, 'epoch': 2.0}
{'loss': 3.6955, 'learning_rate': 6.6520523093204325e-06, 'epoch': 2.0}
{'loss': 3.6218, 'learning_rate': 6.400080630937083e-06, 'epoch': 2.04}
{'loss': 3.6337, 'learning_rate': 6.148108952553733e-06, 'epoch': 2.08}
{'loss': 3.6108, 'learning_rate': 5.896137274170384e-06, 'epoch': 2.12}
{'loss': 3.6299, 'learning_rate': 5.644669539143801e-06, 'epoch': 2.15}
{'loss': 3.6417, 'learning_rate': 5.3926978607604504e-06, 'epoch': 2.19}
{'loss': 3.6116, 'learning_rate': 5.140726182377102e-06, 'epoch': 2.23}
{'loss': 3.6321, 'learning_rate': 4.888754503993752e-06, 'epoch': 2.27}
{'loss': 3.6034, 'learning_rate': 4.637286768967168e-06, 'epoch': 2.31}
{'loss': 3.6104, 'learning_rate': 4.3858190339405855e-06, 'epoch': 2.34}
{'loss': 3.6213, 'learning_rate': 4.133847355557236e-06, 'epoch': 2.38}
{'loss': 3.6167, 'learning_rate': 3.8818756771738855e-06, 'epoch

  0%|          | 0/1470 [00:00<?, ?it/s]

{'eval_loss': 3.5998127460479736, 'eval_runtime': 13.135, 'eval_samples_per_second': 895.318, 'eval_steps_per_second': 111.915, 'epoch': 3.0}
{'train_runtime': 2525.5688, 'train_samples_per_second': 125.713, 'train_steps_per_second': 15.714, 'train_loss': 3.8243711187380116, 'epoch': 3.0}


TrainOutput(global_step=39687, training_loss=3.8243711187380116, metrics={'train_runtime': 2525.5688, 'train_samples_per_second': 125.713, 'train_steps_per_second': 15.714, 'train_loss': 3.8243711187380116, 'epoch': 3.0})