Attempt to finetune DistilBERT-base-uncased for Arxiv papers multiclassification:
- Prepare dataset
- Tokenize
- Load model
- Define metrics
- Define Trainer object and TrainingArguments
- Evaluate predictions
- Error analysis


In [None]:
import pandas as pd
import numpy as np
from datasets import load_from_disk

all_stream_data = load_from_disk("data/processed/all_stream_data")


In [None]:
print(all_stream_data)


In [None]:
from datasets import concatenate_datasets

full_dataset = concatenate_datasets(
    [
        all_stream_data["train"],
        all_stream_data["validation"],
        all_stream_data["test"],
    ]
)


In [None]:
from transformers import AutoTokenizer
import torch

model_id = "allenai/scibert_scivocab_uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, max_length=256)


In [None]:
full_dataset = full_dataset.map(tokenize, batched=True, batch_size=None)


In [None]:
labels = full_dataset.features["label"].names
labels


In [None]:
from transformers import AutoModelForSequenceClassification

num_labels = len(labels)
model = (AutoModelForSequenceClassification
        .from_pretrained(model_id, num_labels=num_labels)
        .to(device))


In [None]:
from huggingface_hub import notebook_login
notebook_login()


In [None]:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from sklearn.metrics import balanced_accuracy_score, f1_score

batch_size = 32
num_train_epochs = 2
learning_rate = 1e-5
logging_steps = len(full_dataset) // batch_size
model_name = f"./models/scibert-finetuned-arxiv-final"
training_args = TrainingArguments(output_dir=model_name,
                                  overwrite_output_dir=True,
                                  num_train_epochs=num_train_epochs,
                                  learning_rate=learning_rate,
                                  per_device_train_batch_size=batch_size,
                                  weight_decay=0.01,
                                  warmup_steps=300,
                                  fp16=True,
                                  eval_strategy="no",
                                  save_steps=300,
                                  save_strategy="steps",
                                  save_total_limit=5,
                                  disable_tqdm=False,
                                  logging_steps=logging_steps,
                                  load_best_model_at_end=False,
                                  metric_for_best_model="macro_f1",
                                  greater_is_better=True,
                                  push_to_hub=True,
                                  label_smoothing_factor=0.1,
                                  log_level="error")

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    macro_f1 = f1_score(labels, preds, average="macro")
    balanced_accuracy = balanced_accuracy_score(labels, preds)
    return {"macro_f1": macro_f1, "balanced_accuracy": balanced_accuracy}


In [None]:
from sklearn.utils.class_weight import compute_class_weight

# Calculate class weights
train_labels = np.array(full_dataset["label"])
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(train_labels),
    y=train_labels
)

# Convert to PyTorch tensor and move to the right device
class_weights = torch.tensor(class_weights, dtype=torch.float).to(model.device)

class CustomTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Apply class weights to the loss
        loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss


In [None]:
trainer = CustomTrainer(
    class_weights=class_weights,
    model=model, 
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=full_dataset,
    eval_dataset=None,
    processing_class=tokenizer,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

trainer.train()


In [None]:
preds_output = trainer.predict(eval_dataset)


In [None]:
preds_output.metrics


In [None]:
y_preds = np.argmax(preds_output.predictions, axis=1)


In [None]:
from src.utils import plot_confusion_matrix

plot_confusion_matrix(y_preds, y_valid, labels)


In [None]:
from sklearn.metrics import classification_report

print(classification_report(y_valid, y_preds, target_names=labels))
