In [None]:
import numpy as np
import torch
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

import src
from src.bert import training
from src.bert.dataset import PBertDataset
from src.bert.dataset.strategies import MLMin1PopIdeol

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXCLUDE_CODERS = []

THRESHOLDS = {0: 0.415961, 1: 0.295400, 2: 0.429109, 3: 0.302714}

In [None]:
tokenizer = AutoTokenizer.from_pretrained("luerhard/PopBERT")
model = AutoModelForSequenceClassification.from_pretrained("luerhard/PopBERT")
model = model.to(DEVICE)

In [None]:
test = PBertDataset.from_disk(
    path=src.PATH / "data/bert/test.csv.zip",
    label_strategy=MLMin1PopIdeol(),
    exclude_coders=EXCLUDE_CODERS,
)

collate_fn = test.create_collate_fn(tokenizer)
test_loader = DataLoader(test, collate_fn=collate_fn, batch_size=64, shuffle=False)

In [None]:
def apply_thresh(y_proba, thresholds: dict):
    y_proba = y_proba.copy()
    for dim, thresh in thresholds.items():
        y_proba[:, dim] = np.where(y_proba[:, dim] > thresh, 1, 0)
    return y_proba

In [None]:
with torch.inference_mode():
    y_true = []
    y_pred = []
    for batch in test_loader:
        encodings = batch["encodings"]
        encodings = encodings.to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        out = model(**encodings)
        preds = torch.nn.functional.sigmoid(out.logits)
        y_true.extend(batch["labels"].numpy())
        y_pred.extend(preds.cpu().numpy())
    y_pred_05 = np.where(np.array(y_pred) > 0.5, 1, 0)
    y_pred_thresh = apply_thresh(np.array(y_pred), THRESHOLDS)
    y_true = np.array(y_true)

In [None]:
print(classification_report(y_true, y_pred_05, target_names=["elite", "pplcentr", "left", "right"]))

              precision    recall  f1-score   support

       elite       0.82      0.86      0.84       648
    pplcentr       0.77      0.60      0.67       322
        left       0.71      0.75      0.73       279
       right       0.77      0.57      0.65       155

   micro avg       0.78      0.75      0.76      1404
   macro avg       0.77      0.69      0.72      1404
weighted avg       0.78      0.75      0.76      1404
 samples avg       0.39      0.38      0.38      1404



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_true, y_pred_thresh, target_names=["elite", "pplcentr", "left", "right"]
    )
)

              precision    recall  f1-score   support

       elite       0.81      0.88      0.84       648
    pplcentr       0.70      0.73      0.71       322
        left       0.69      0.77      0.73       279
       right       0.68      0.66      0.67       155

   micro avg       0.75      0.80      0.77      1404
   macro avg       0.72      0.76      0.74      1404
weighted avg       0.75      0.80      0.77      1404
 samples avg       0.41      0.40      0.40      1404



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
# model.save_pretrained(src.PATH / "tmp/PopBERT_model")