In [1]:
import torch
from datasets import load_dataset
from transformers import DataCollatorWithPadding
from transformers import Trainer, TrainingArguments

import asag_system.constants as c
from asag_system.models import (
    DistilBertTripletTokenizer,
    SentenceTripletClassifier,
    MostFrequentBaseline,
    compute_metrics,
)
from asag_system.datasets import TripletClassificationDataset

In [2]:
dataset = load_dataset("Atomi/semeval_2013_task_7_beetle_5way")
test = dataset["test"]
unseen_answers = test.filter(lambda example: example['test_set'] == 'unseen-answers')
unseen_questions = test.filter(lambda example: example['test_set'] == 'unseen-questions')
assert len(unseen_answers) + len(unseen_questions) == len(test)

In [3]:
tokenizer = DistilBertTripletTokenizer()
unseen_answers_dataset = TripletClassificationDataset(unseen_answers, tokenizer)
unseen_questions_dataset = TripletClassificationDataset(unseen_questions, tokenizer)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer.tokenizer)
baseline = MostFrequentBaseline()
model = SentenceTripletClassifier()
model_path = c.DATA_DIR / "models" / "sentence-triplet-classifier.pt"
model.load_state_dict(torch.load(model_path))

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

### DistilBert ASAG Model

In [4]:
training_args = TrainingArguments(output_dir=c.DATA_DIR)

In [19]:
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [20]:
# Unsween Answers
trainer.evaluate(unseen_answers_dataset)

  0%|          | 0/70 [00:00<?, ?it/s]

{'eval_loss': 1.0527443885803223,
 'eval_macro_f1': 0.6683717803011867,
 'eval_accuracy': 0.7285714285714285,
 'eval_runtime': 2.1523,
 'eval_samples_per_second': 260.188,
 'eval_steps_per_second': 32.524}

In [21]:
# Unseen Questions
trainer.evaluate(unseen_questions_dataset)

  0%|          | 0/117 [00:00<?, ?it/s]

{'eval_loss': 1.3006949424743652,
 'eval_macro_f1': 0.621826324995632,
 'eval_accuracy': 0.6512378902045209,
 'eval_runtime': 2.8966,
 'eval_samples_per_second': 320.719,
 'eval_steps_per_second': 40.392}

### Most Frequent Baseline Model

In [16]:
# Unseen Answers
trainer.model = baseline
trainer.evaluate(unseen_answers_dataset)

  0%|          | 0/70 [00:00<?, ?it/s]

{'eval_loss': 1.5012609958648682,
 'eval_macro_f1': 0.11501272264631043,
 'eval_accuracy': 0.4035714285714286,
 'eval_runtime': 0.4474,
 'eval_samples_per_second': 1251.733,
 'eval_steps_per_second': 156.467}

In [17]:
# Unseen Questions
trainer.evaluate(unseen_answers_dataset)

  0%|          | 0/70 [00:00<?, ?it/s]

{'eval_loss': 1.5012609958648682,
 'eval_macro_f1': 0.11501272264631043,
 'eval_accuracy': 0.4035714285714286,
 'eval_runtime': 0.4269,
 'eval_samples_per_second': 1311.715,
 'eval_steps_per_second': 163.964}