In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import logging

import src
import src.bert.utils as bert_utils
from src.bert import module
from src.bert import training
from src.bert.dataset import PBertDataset
from src.bert.dataset import strategies

In [None]:
logging.set_verbosity_error()

# model hyper-parameters
LR = 4e-6
N_EPOCHS = 13
BATCH_SIZE = 16

TOKENIZER = "deepset/gbert-large"
BASE_MODEL = "deepset/gbert-large"

STRATEGY = strategies.MLMin1PopIdeol(output_fmt="single_task")

DEVICE = "cuda"

In [None]:
train = PBertDataset.from_disk(src.PATH / "data/bert/train.csv.zip", label_strategy=STRATEGY)
val = PBertDataset.from_disk(src.PATH / "data/bert/validation.csv.zip", label_strategy=STRATEGY)
test = PBertDataset.from_disk(src.PATH / "data/bert/test.csv.zip", label_strategy=STRATEGY)

In [None]:
len(train), len(test), len(val)

(5277, 1759, 1759)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

In [None]:
collate_fn = train.create_collate_fn(tokenizer)

train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(val, collate_fn=collate_fn, batch_size=32, shuffle=False)
test_loader = DataLoader(test, collate_fn=collate_fn, batch_size=32, shuffle=False)

In [None]:
model = module.BertSingleTaskMultiLabel(num_labels=train.num_labels, name=BASE_MODEL)
model.train()
model = model.to(DEVICE)
model.set_seed(seed=10)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    amsgrad=False,
    weight_decay=1e-2,
)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=15,
    eta_min=1e-9,
)

for epoch in range(1, N_EPOCHS + 1):
    train_loss = training.train_epoch(model, train_loader, optimizer, lr_scheduler)
    eval_loss, score, thresh = training.eval_epoch(model, valid_loader)
    print(f"{epoch=} {train_loss=:.4f} {eval_loss=:.4f} {score=:.4f}")
    print(thresh)
    print("-" * 30)

  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.38, 1: 0.22, 2: 0.18, 3: 0.14}
epoch=1 train_loss=0.3415 eval_loss=0.2607 score=0.2269
{0: 0.3766822, 1: 0.22738275, 2: 0.17531393, 3: 0.36036944}
------------------------------


Best BF thresh {0: 0.5700000000000001, 1: 0.3, 2: 0.33, 3: 0.31}
epoch=2 train_loss=0.2033 eval_loss=0.2142 score=0.5799
{0: 0.57342243, 1: 0.2760982, 2: 0.3322879, 3: 0.30530763}
------------------------------


Best BF thresh {0: 0.32, 1: 0.2, 2: 0.27, 3: 0.22}
epoch=3 train_loss=0.2489 eval_loss=0.2164 score=0.5468
{0: 0.32757726, 1: 0.20682135, 2: 0.26907408, 3: 0.22335309}
------------------------------


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.31, 1: 0.45, 2: 0.32, 3: 0.17}
epoch=4 train_loss=0.1999 eval_loss=0.1880 score=0.6341
{0: 0.29426068, 1: 0.37834522, 2: 0.9680072, 3: 0.2667202}
------------------------------


Best BF thresh {0: 0.55, 1: 0.3, 2: 0.37, 3: 0.31}
epoch=5 train_loss=0.2789 eval_loss=0.1974 score=0.6610
{0: 0.55211514, 1: 0.3007287, 2: 0.37783322, 3: 0.31490615}
------------------------------


Best BF thresh {0: 0.52, 1: 0.36, 2: 0.63, 3: 0.3}
epoch=6 train_loss=0.2591 eval_loss=0.1893 score=0.6916
{0: 0.5210392, 1: 0.3614637, 2: 0.6365567, 3: 0.29676586}
------------------------------


Best BF thresh {0: 0.58, 1: 0.25, 2: 0.6900000000000001, 3: 0.42}
epoch=7 train_loss=0.0936 eval_loss=0.1956 score=0.7065
{0: 0.584204, 1: 0.23510472, 2: 0.6952419, 3: 0.3797357}
------------------------------


Best BF thresh {0: 0.2, 1: 0.52, 2: 0.46, 3: 0.17}
epoch=8 train_loss=0.1707 eval_loss=0.2102 score=0.6971
{0: 0.20202683, 1: 0.5678805, 2: 0.45491958, 3: 0.16894972}
------------------------------


Best BF thresh {0: 0.27, 1: 0.23, 2: 0.46, 3: 0.16}
epoch=9 train_loss=0.0665 eval_loss=0.2412 score=0.6921
{0: 0.285582, 1: 0.23008464, 2: 0.4613408, 3: 0.16180824}
------------------------------


Best BF thresh {0: 0.58, 1: 0.64, 2: 0.73, 3: 0.49}
epoch=10 train_loss=0.0929 eval_loss=0.2764 score=0.7208
{0: 0.5865942, 1: 0.64210063, 2: 0.71926636, 3: 0.49389726}
------------------------------


Best BF thresh {0: 0.43, 1: 0.2, 2: 0.51, 3: 0.22}
epoch=11 train_loss=0.1282 eval_loss=0.2263 score=0.6981
{0: 0.43724817, 1: 0.19381644, 2: 0.5150236, 3: 0.22553217}
------------------------------


Best BF thresh {0: 0.18, 1: 0.3, 2: 0.67, 3: 0.16}
epoch=12 train_loss=0.1414 eval_loss=0.2495 score=0.7021
{0: 0.18372594, 1: 0.2898114, 2: 0.6724017, 3: 0.1627382}
------------------------------


Best BF thresh {0: 0.45, 1: 0.62, 2: 0.72, 3: 0.39}
epoch=13 train_loss=0.0728 eval_loss=0.2739 score=0.7218
{0: 0.44911858, 1: 0.62757695, 2: 0.7143382, 3: 0.3941237}
------------------------------


In [None]:
y_true = []
y_pred = []
with torch.inference_mode():
    for batch in valid_loader:
        encodings = batch["encodings"]
        encodings = encodings.to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        preds = model.predict_proba(encodings)
        y_true.extend(batch["labels"].numpy())
        y_pred.extend(preds)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

thresh_finder = bert_utils.ThresholdFinder(type=model.model_type)
thresholds = thresh_finder.find_thresholds(y_true, y_pred)

Best BF thresh {0: 0.45, 1: 0.62, 2: 0.72, 3: 0.39}


In [None]:
print(thresholds)

{0: 0.44911858, 1: 0.62757695, 2: 0.7143382, 3: 0.3941237}


In [None]:
y_pred_05 = np.where(y_pred > 0.5, 1, 0)
y_pred_thresh = model.apply_thresh(y_pred, thresholds)

In [None]:
print(classification_report(y_true, y_pred_05, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

       elite       0.80      0.92      0.85       630
       centr       0.61      0.80      0.69       307
        left       0.63      0.80      0.71       280
       right       0.69      0.59      0.64       155

   micro avg       0.71      0.83      0.76      1372
   macro avg       0.68      0.78      0.72      1372
weighted avg       0.71      0.83      0.76      1372
 samples avg       0.40      0.41      0.40      1372



In [None]:
print(classification_report(y_true, y_pred_thresh, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

       elite       0.80      0.92      0.85       630
       centr       0.65      0.76      0.70       307
        left       0.68      0.76      0.72       280
       right       0.68      0.62      0.65       155

   micro avg       0.73      0.82      0.77      1372
   macro avg       0.70      0.77      0.73      1372
weighted avg       0.73      0.82      0.77      1372
 samples avg       0.40      0.40      0.40      1372



In [None]:
y_true = []
y_pred = []
with torch.inference_mode():
    for batch in test_loader:
        encodings = batch["encodings"]
        encodings = encodings.to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        preds = model.predict_proba(encodings)
        y_true.extend(batch["labels"].numpy())
        y_pred.extend(preds)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

In [None]:
y_pred_05 = np.where(y_pred > 0.5, 1, 0)
y_pred_thresh = model.apply_thresh(y_pred, thresholds)

In [None]:
print(classification_report(y_true, y_pred_05, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

       elite       0.79      0.89      0.84       625
       centr       0.57      0.81      0.67       302
        left       0.62      0.84      0.72       279
       right       0.69      0.67      0.68       130

   micro avg       0.69      0.84      0.76      1336
   macro avg       0.67      0.80      0.73      1336
weighted avg       0.70      0.84      0.76      1336
 samples avg       0.39      0.40      0.39      1336



In [None]:
print(classification_report(y_true, y_pred_thresh, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

       elite       0.79      0.89      0.83       625
       centr       0.59      0.77      0.67       302
        left       0.66      0.77      0.72       279
       right       0.64      0.69      0.67       130

   micro avg       0.70      0.82      0.75      1336
   macro avg       0.67      0.78      0.72      1336
weighted avg       0.70      0.82      0.76      1336
 samples avg       0.39      0.40      0.38      1336

