In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torchmetrics import ConfusionMatrix, Accuracy, Precision, Recall, F1

from transformers import AutoTokenizer, AutoModelForSequenceClassification

import data_prep 

In [None]:
_, _, binary_df = data_prep.load_data()

In [None]:
model_name = 's-nlp/deberta-large-formality-ranker'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [None]:
id2formality = {0: "formal", 1: "informal"} # from model documentation on Hugging Face

batch_size = 4
predicted_labels = [] # 0 for informal, 1 for formal (consistent with dataset labels)

for i in tqdm(range(0, len(binary_df), batch_size)):
    texts = binary_df['sentence'][i:i + batch_size].tolist()

    # prepare the input
    encoding = tokenizer(
        texts,
        add_special_tokens=True,
        return_token_type_ids=True,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    # inference
    output = model(**encoding)

    batch_predicted_labels = []
    for text_scores in output.logits.softmax(dim=1):
        score_dict = {id2formality[idx]: score for idx, score in enumerate(text_scores.tolist())}
        batch_predicted_labels.append(1 if score_dict['formal'] > score_dict['informal'] else 0)

    predicted_labels.extend(batch_predicted_labels)

In [None]:
# Initialize metrics
conf_matrix_metric = ConfusionMatrix(num_classes=2)
accuracy_metric = Accuracy()
precision_metric = Precision()
recall_metric = Recall()
f1_metric = F1()

# Compute metrics
conf_matrix = conf_matrix_metric(torch.tensor(predicted_labels), torch.tensor(binary_df['formal'].values))
accuracy = accuracy_metric(torch.tensor(predicted_labels), torch.tensor(binary_df['formal'].values))
precision = precision_metric(torch.tensor(predicted_labels), torch.tensor(binary_df['formal'].values))
recall = recall_metric(torch.tensor(predicted_labels), torch.tensor(binary_df['formal'].values))
f1 = f1_metric(torch.tensor(predicted_labels), torch.tensor(binary_df['formal'].values))

print("Confusion Matrix:")
print(conf_matrix.numpy())
print("Accuracy:", accuracy.item())
print("Precision:", precision.item())
print("Recall:", recall.item())
print("F1 Score:", f1.item())
