In [1]:
import torch
from datasets import load_dataset
from transformers import DataCollatorWithPadding
from transformers import Trainer, TrainingArguments

import asag_system.constants as c
from asag_system.models import (
    DistilBertTripletTokenizer,
    SentenceTripletClassifier,
    compute_metrics,
)
from asag_system.datasets import TripletClassificationDataset

In [2]:
dataset = load_dataset("Atomi/semeval_2013_task_7_beetle_5way")
dev = dataset["train"]
split = dev.train_test_split(test_size=0.2, seed=42)
train = split["train"]
val = split["test"]
train, val

(Dataset({
     features: ['question_id', 'question', 'question_qtype', 'question_module', 'question_stype', 'reference_answer', 'reference_answer_quality', 'student_answer', 'label_5way', 'test_set'],
     num_rows: 8536
 }),
 Dataset({
     features: ['question_id', 'question', 'question_qtype', 'question_module', 'question_stype', 'reference_answer', 'reference_answer_quality', 'student_answer', 'label_5way', 'test_set'],
     num_rows: 2134
 }))

In [3]:
tokenizer = DistilBertTripletTokenizer()
train_dataset = TripletClassificationDataset(train, tokenizer)
val_dataset = TripletClassificationDataset(val, tokenizer)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer.tokenizer)
model = SentenceTripletClassifier()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
training_args = TrainingArguments(
    output_dir=c.DATA_DIR / "results",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [5]:
trainer.train()

  0%|          | 0/1602 [00:00<?, ?it/s]

{'loss': 0.8693, 'grad_norm': 8.332180976867676, 'learning_rate': 3.439450686641698e-05, 'epoch': 0.94}
{'loss': 0.4796, 'grad_norm': 13.950277328491211, 'learning_rate': 1.8789013732833958e-05, 'epoch': 1.87}
{'loss': 0.2636, 'grad_norm': 8.915434837341309, 'learning_rate': 3.1835205992509364e-06, 'epoch': 2.81}
{'train_runtime': 282.5178, 'train_samples_per_second': 90.642, 'train_steps_per_second': 5.67, 'train_loss': 0.5185116685731581, 'epoch': 3.0}


TrainOutput(global_step=1602, training_loss=0.5185116685731581, metrics={'train_runtime': 282.5178, 'train_samples_per_second': 90.642, 'train_steps_per_second': 5.67, 'total_flos': 0.0, 'train_loss': 0.5185116685731581, 'epoch': 3.0})

In [6]:
# Eval Result
trainer.evaluate()

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.3382340371608734,
 'eval_macro_f1': 0.8924175311693443,
 'eval_accuracy': 0.8983130271790065,
 'eval_runtime': 10.583,
 'eval_samples_per_second': 201.645,
 'eval_steps_per_second': 12.662,
 'epoch': 3.0}

# Retrain on entire development set

In [7]:
dev_dataset = TripletClassificationDataset(dev, tokenizer)
model = SentenceTripletClassifier()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
trainer.train()

  0%|          | 0/1602 [00:00<?, ?it/s]

{'loss': 0.8703, 'grad_norm': 8.002219200134277, 'learning_rate': 3.439450686641698e-05, 'epoch': 0.94}
{'loss': 0.4653, 'grad_norm': 7.217555999755859, 'learning_rate': 1.8789013732833958e-05, 'epoch': 1.87}
{'loss': 0.2461, 'grad_norm': 7.48265266418457, 'learning_rate': 3.1835205992509364e-06, 'epoch': 2.81}
{'train_runtime': 271.484, 'train_samples_per_second': 94.326, 'train_steps_per_second': 5.901, 'train_loss': 0.5072842156247105, 'epoch': 3.0}


TrainOutput(global_step=1602, training_loss=0.5072842156247105, metrics={'train_runtime': 271.484, 'train_samples_per_second': 94.326, 'train_steps_per_second': 5.901, 'total_flos': 0.0, 'train_loss': 0.5072842156247105, 'epoch': 3.0})

In [10]:
torch.save(trainer.model.state_dict(), c.DATA_DIR / "models" / "sentence-triplet-classifier.pt")