In [None]:
from datasets import load_dataset, DatasetDict, ClassLabel, Dataset
from src.utils import map_category

aux_data = load_dataset("real-jiakai/arxiver-with-category")


In [None]:
aux_data.set_format(type="pandas")
aux_df = aux_data["train"][:]
aux_df["label"] = aux_df["primary_category"].apply(map_category)
aux_df["title"] = aux_df["title"].str.replace("\n  ", " ")
aux_df["text"] = aux_df["title"] + "\n" + aux_df["abstract"]
aux_df = aux_df[["text", "label"]]


In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(
    aux_df, 
    test_size=0.1,
    stratify=aux_df["label"],
    random_state=42
)

cpt_data = DatasetDict({
    "train": Dataset.from_pandas(train_df, preserve_index=False),
    "validation": Dataset.from_pandas(val_df, preserve_index=False)
})

labels = sorted(train_df["label"].unique())
class_label = ClassLabel(names=labels)

cpt_data = cpt_data.cast_column("label", class_label)

cpt_data.save_to_disk("data/processed/cpt_data")


In [None]:
from transformers import DistilBertForMaskedLM, DistilBertTokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

model = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", 
                    truncation=True, max_length=512)

tokenized_cpt_data = cpt_data.map(tokenize, batched=True, remove_columns=["label"])


In [None]:
# Data collator for masked language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./models/distilbert-base-uncased-cpt-arxiv",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=1,
    fp16=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=1000,
    logging_dir='logs',
    logging_steps=500,              
    dataloader_num_workers=4,       
    eval_strategy="steps",   
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss"
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_cpt_data["train"],
    eval_dataset=tokenized_cpt_data["validation"],
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

# Run continued pretraining
trainer.train()


In [None]:
# Save domain-adapted model
model.save_pretrained("distilbert-arxiv-domain-adapted")
tokenizer.save_pretrained("distilbert-arxiv-domain-adapted")
