In [1]:
import numpy as np
import pandas as pd
import os

from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments, Trainer, ModernBertForSequenceClassification
from datasets import Dataset

import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # convert the logits to their predicted class
    predictions = np.argmax(logits, axis=-1)
    
    print(f"Avg pred: {sum(predictions)/len(predictions)}, Avg Labels: {sum(labels)/len(labels)}")
    
    return metric.compute(predictions=predictions, references=labels)

df = pd.read_csv( "../data/pairwise-model-data/pairedwise-tweets.tsv", sep = '\t')
df = pd.DataFrame(df)

dataset = Dataset.from_pandas(df)
dataset.shuffle()
split = dataset.train_test_split(test_size=0.2)
train_dataset = split["train"].remove_columns('Unnamed: 0')
eval_dataset = split["test"].remove_columns('Unnamed: 0')

In [2]:
def preprocess_function(examples):
    full_str = "[CLS] "
    full_str += examples["query"]
    
    for key in ["paper1", "paper2"]:
        if examples[key] != None:
            full_str += " [SEP] " + examples[key]
    
    return tokenizer(full_str, truncation=True).to('cuda')


id2label = {0: "0", 1: "1"}
label2id = {"0": 0, "1": 1}

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
model = ModernBertForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=2, id2label=id2label, label2id=label2id)

Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
train_dataset_tokenized = train_dataset.map(preprocess_function)

Map:   0%|          | 0/9881 [00:00<?, ? examples/s]

In [4]:
test_dataset_tokenized = eval_dataset.map(preprocess_function)

Map:   0%|          | 0/2471 [00:00<?, ? examples/s]

In [5]:

training_args = TrainingArguments(
    output_dir="../models/pairwise-classifier",
    learning_rate=2e-6,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="steps",
    eval_steps=1200,
    save_strategy="steps",
    save_steps=1200,
    load_best_model_at_end=True,
    push_to_hub=False,
    fp16=True,
    optim="adamw_torch",
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_tokenized,
    eval_dataset=test_dataset_tokenized,
    compute_metrics=compute_metrics
)

# trainer.evaluate()
trainer.train(resume_from_checkpoint=True)

W0504 21:42:14.893000 1736 torch/_inductor/utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode


Step,Training Loss,Validation Loss,Accuracy
9600,1.4889,1.27751,0.478754
10800,1.3121,1.006303,0.478754
12000,0.824,0.878566,0.478754
13200,0.8806,0.750115,0.521246
14400,1.1376,1.571779,0.477944
15600,1.2259,0.950029,0.518414
16800,0.9369,0.754169,0.478754
18000,0.747,0.752941,0.521246
19200,1.13,0.704891,0.477539
20400,1.0312,2.015591,0.544314


Avg pred: 0.0, Avg Labels: 0.5212464589235127
Avg pred: 0.0, Avg Labels: 0.5212464589235127
Avg pred: 0.0, Avg Labels: 0.5212464589235127
Avg pred: 1.0, Avg Labels: 0.5212464589235127
Avg pred: 0.0056657223796034, Avg Labels: 0.5212464589235127
Avg pred: 0.8312424119789559, Avg Labels: 0.5212464589235127
Avg pred: 0.0, Avg Labels: 0.5212464589235127
Avg pred: 1.0, Avg Labels: 0.5212464589235127
Avg pred: 0.0012140833670578712, Avg Labels: 0.5212464589235127
Avg pred: 0.9672197490894374, Avg Labels: 0.5212464589235127
Avg pred: 0.6143261837312829, Avg Labels: 0.5212464589235127
Avg pred: 0.5548360987454471, Avg Labels: 0.5212464589235127
Avg pred: 0.49979765277215704, Avg Labels: 0.5212464589235127
Avg pred: 0.47956292998785915, Avg Labels: 0.5212464589235127
Avg pred: 0.5236746256576285, Avg Labels: 0.5212464589235127
Avg pred: 0.5394577094293809, Avg Labels: 0.5212464589235127
Avg pred: 0.5050586806960745, Avg Labels: 0.5212464589235127
Avg pred: 0.49736948603804126, Avg Labels: 0.521

TrainOutput(global_step=49405, training_loss=0.5742320753811173, metrics={'train_runtime': 11170.0766, 'train_samples_per_second': 4.423, 'train_steps_per_second': 4.423, 'total_flos': 3.190023543444822e+16, 'train_loss': 0.5742320753811173, 'epoch': 5.0})