In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EvalPrediction
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
import pandas as pd
import numpy as np
from Model import Model

In [None]:
model_name = "google/electra-base-generator"
model = Model(model_name=model_name)

In [None]:
train_data = pd.read_csv('data/edu_train.csv')
dev_data = pd.read_csv('data/edu_dev.csv')
test_data = pd.read_csv('data/edu_test.csv')

In [None]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    eval_steps=200,
    logging_dir='./logs',
    logging_steps=50,
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True
)

def compute_metrics(pred):
    predictions, labels = pred.predictions, pred.label_ids
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='micro')
    precision = precision_score(labels, predictions, average='macro')
    recall = recall_score(labels, predictions, average='macro')

    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
        }

In [None]:
train_input_ids, train_attention_mask, train_labels_tensor = model.data_tensor(train_data)
dev_input_ids, dev_attention_mask, dev_labels_tensor = model.data_tensor(train_data)
test_input_ids, test_attention_mask, test_labels_tensor = model.data_tensor(test_data)

In [None]:
trainer = Trainer(
    model=model.get_model(),
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=lambda data: {
        'input_ids': torch.stack([item[0] for item in data]),
        'attention_mask': torch.stack([item[1] for item in data]),
        'labels': torch.stack([item[2] for item in data])
    },
    train_dataset=torch.utils.data.TensorDataset(train_input_ids, train_attention_mask, train_labels_tensor),
    eval_dataset=torch.utils.data.TensorDataset(dev_input_ids, dev_attention_mask, dev_labels_tensor),

    )    

In [None]:
trainer.train()

In [None]:
# Evaluate the model on the test set
results = trainer.evaluate(eval_dataset=torch.utils.data.TensorDataset(test_input_ids, test_attention_mask, test_labels_tensor))

In [None]:
print(results)

In [None]:
from performance import PerformanceSaver
PerformanceSaver().save_performance(model_name=model_name, results=results)