In [1]:
import numpy as np
import pandas as pd
from datasets import load_from_disk
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback
from sklearn.metrics import roc_auc_score
import torch

RANDOM_SEED  = 31415

# Create training and validation subsets

In [2]:
def get_fold_datasets(ds, fold):
    ds_train = ds.filter(lambda x: x["fold"] != fold)
    ds_val = ds.filter(lambda x: x["fold"] == fold)
    return ds_train, ds_val

In [3]:
ds_tokenized = load_from_disk("data/processed_data/ds_tokenized")
ds_tokenized

Dataset({
    features: ['id', 'fold', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 159571
})

# Model

In [6]:
# model_name = "microsoft/deberta-v3-large"
model_name = "microsoft/deberta-v3-base"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6, problem_type="multi_label_classification").to("cuda")

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# model.switch_to_sdp(SdpaConfig("flash_attention_2"))

In [8]:
# model = torch.compile(model)

In [9]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = 1 / (1 + np.exp(-logits))
    auc = roc_auc_score(labels, probs, average="macro")
    return {"roc_auc_macro": auc}

In [10]:
# args = TrainingArguments(
#     output_dir="checkpoints/deberta_fold0",
#     num_train_epochs=3,
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=4,
#     learning_rate=1e-5,
#     weight_decay=0.01,
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     metric_for_best_model="roc_auc_macro",
#     load_best_model_at_end=True,
#     fp16=True,
#     gradient_checkpointing=False,
#     logging_steps=50,
#     seed=RANDOM_SEED
# )

In [11]:
args = TrainingArguments(
    output_dir="checkpoints/deberta_fold0",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,      # New line
    per_device_eval_batch_size=8,       # Adjusted
    learning_rate=1e-5,
    weight_decay=0.01,
    evaluation_strategy="steps",        # Adjusted
    eval_steps=2000,                    # New line   
    save_strategy="steps",              # Adjusted
    save_steps=2000,                    # New line
    load_best_model_at_end=True,
    metric_for_best_model="roc_auc_macro",
    fp16=True,
    gradient_checkpointing=False,
    dataloader_num_workers=2,           # New line
    dataloader_pin_memory=True,         # New line
    logging_steps=50,
    seed=RANDOM_SEED,
    optim="adamw_torch_fused"           # New line
)



In [None]:
for fold in range(1, 5):
    train, val = get_fold_datasets(ds_tokenized, fold)
    print(f"Fold {fold}: Train size: {len(train)} ({len(train)/len(ds_tokenized):.2%}), Val size: {len(val)} ({len(val)/len(ds_tokenized):.2%})")
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train,
        eval_dataset=val,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    )
    trainer.train()
# train_0, val_0 = get_fold_datasets(ds_tokenized, 0)

In [13]:
trainer.train()

Step,Training Loss,Validation Loss,Roc Auc Macro
2000,0.042,0.042506,0.979005
4000,0.0424,0.043684,0.982231
6000,0.0358,0.039424,0.982926
8000,0.039,0.040464,0.988242
10000,0.0345,0.03966,0.988469
12000,0.0346,0.040457,0.989537
14000,0.0385,0.040989,0.990535
16000,0.0276,0.038374,0.990595
18000,0.0305,0.03947,0.990515


TrainOutput(global_step=18000, training_loss=0.04024555553992589, metrics={'train_runtime': 9782.4889, 'train_samples_per_second': 39.148, 'train_steps_per_second': 2.447, 'total_flos': 3.7890032173056e+16, 'train_loss': 0.04024555553992589, 'epoch': 2.2560631697687534})

In [14]:
trainer.save_model("checkpoints/deberta_fold0/model_final")