In [None]:
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from transformers import AutoTokenizer
from transformers import logging

import src
from src.bert import module
from src.bert import training
from src.bert.dataset import PBertDataset
from src.bert.dataset import strategies

Best Youden thresh: {0: 0.5013, 1: 0.0728, 2: 0.0852, 3: 0.0292}
Best BF thresh {0: 0.5, 1: 0.5, 2: 0.45, 3: 0.23}
epoch=13 train_loss=0.0351 eval_loss=0.1581 score=0.7092
{0: 0.5013018, 1: 0.5017193, 2: 0.42243505, 3: 0.38281676}

In [None]:
logging.set_verbosity_error()

# model hyper-parameters
LR = 4e-6
N_EPOCHS = 13
BATCH_SIZE = 16

TOKENIZER = "deepset/gbert-large"
BASE_MODEL = "deepset/gbert-large"

STRATEGY = strategies.MLMin1PopIdeol(output_fmt="single_task")

DEVICE = "cuda"

TRESHOLDS = {0: 0.5013018, 1: 0.5017193, 2: 0.42243505, 3: 0.38281676}

In [None]:
dataset = PBertDataset.from_disk(src.PATH / "data/bert/dataset.csv.zip", label_strategy=STRATEGY)

In [None]:
len(dataset)

8795

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
collate_fn = dataset.create_collate_fn(tokenizer)

In [None]:
splits = KFold(n_splits=5, shuffle=True, random_state=42)

results = []
for fold, (train_idx, test_idx) in enumerate(splits.split(np.arange(len(dataset))), 1):
    print("#" * 50)
    print(f"{fold=}")

    # create data
    train_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        sampler=SubsetRandomSampler(train_idx),
    )
    test_loader = DataLoader(
        dataset,
        batch_size=32,
        collate_fn=collate_fn,
        sampler=SubsetRandomSampler(test_idx),
    )

    model = module.BertSingleTaskMultiLabel(num_labels=dataset.num_labels, name=BASE_MODEL)
    model.train()
    model = model.to(DEVICE)
    model.set_seed(seed=10)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LR,
        amsgrad=False,
        weight_decay=1e-2,
    )

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=N_EPOCHS,
        eta_min=1e-9,
    )

    for epoch in range(1, N_EPOCHS + 1):
        train_loss = training.train_epoch(model, train_loader, optimizer, lr_scheduler)
        eval_loss, score, _ = training.eval_epoch(model, test_loader)
        print(f"{epoch=} {train_loss=:.4f} {eval_loss=:.4f} {score=:.4f}")

    with torch.inference_mode():
        y_true = []
        y_pred = []
        for batch in test_loader:
            encodings = batch["encodings"]
            encodings = encodings.to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            preds = model.predict_proba(encodings)
            y_true.extend(batch["labels"].numpy())
            y_pred.extend(preds)
        y_pred_05 = np.where(np.array(y_pred) > 0.5, 1, 0)
        y_pred_thresh = model.apply_thresh(np.array(y_pred), TRESHOLDS)
        y_true = np.array(y_true)

    print("THRESH .5")
    print(classification_report(y_true, y_pred_05, target_names=dataset.labels, zero_division=0))
    print()
    print("THRESH DICT")
    print(
        classification_report(y_true, y_pred_thresh, target_names=dataset.labels, zero_division=0)
    )
    classification_dict = classification_report(
        y_true, y_pred_thresh, target_names=dataset.labels, zero_division=0, output_dict=True
    )
    results.append(classification_dict)
    out = pd.DataFrame(classification_dict)
    out.to_csv(src.PATH / f"results/kfold/fold_{fold}.csv")

##################################################
fold=1


Best BF thresh {0: 0.65, 1: 0.25, 2: 0.32, 3: 0.23}
epoch=1 train_loss=0.4465 eval_loss=0.3465 score=0.3847


Best BF thresh {0: 0.45, 1: 0.4, 2: 0.34, 3: 0.18}
epoch=2 train_loss=0.2749 eval_loss=0.2499 score=0.5457


Best BF thresh {0: 0.44, 1: 0.44, 2: 0.44, 3: 0.17}
epoch=3 train_loss=0.2413 eval_loss=0.1475 score=0.6354


Best BF thresh {0: 0.52, 1: 0.37, 2: 0.24, 3: 0.24}
epoch=4 train_loss=0.1540 eval_loss=0.2180 score=0.6863


Best BF thresh {0: 0.3, 1: 0.27, 2: 0.21, 3: 0.22}
epoch=5 train_loss=0.1760 eval_loss=0.2602 score=0.6832


Best BF thresh {0: 0.48, 1: 0.36, 2: 0.55, 3: 0.17}
epoch=6 train_loss=0.2396 eval_loss=0.2021 score=0.7115


Best BF thresh {0: 0.48, 1: 0.42, 2: 0.65, 3: 0.24}
epoch=7 train_loss=0.0629 eval_loss=0.2082 score=0.7247


Best BF thresh {0: 0.8200000000000001, 1: 0.14, 2: 0.31, 3: 0.47000000000000003}
epoch=8 train_loss=0.3206 eval_loss=0.2071 score=0.7038


Best BF thresh {0: 0.41000000000000003, 1: 0.28, 2: 0.27, 3: 0.27}
epoch=9 train_loss=0.0946 eval_loss=0.2505 score=0.7128


Best BF thresh {0: 0.76, 1: 0.35000000000000003, 2: 0.54, 3: 0.48}
epoch=10 train_loss=0.1928 eval_loss=0.1601 score=0.7455


Best BF thresh {0: 0.39, 1: 0.24, 2: 0.33, 3: 0.19}
epoch=11 train_loss=0.1317 eval_loss=0.2162 score=0.7276


Best BF thresh {0: 0.43, 1: 0.2, 2: 0.44, 3: 0.21}
epoch=12 train_loss=0.1144 eval_loss=0.1189 score=0.7203


Best BF thresh {0: 0.73, 1: 0.47000000000000003, 2: 0.52, 3: 0.22}
epoch=13 train_loss=0.0250 eval_loss=0.4163 score=0.7418


THRESH .5
              precision    recall  f1-score   support

       elite       0.79      0.88      0.84       641
       centr       0.66      0.77      0.71       316
        left       0.72      0.77      0.74       281
       right       0.70      0.66      0.68       177

   micro avg       0.74      0.81      0.77      1415
   macro avg       0.72      0.77      0.74      1415
weighted avg       0.74      0.81      0.77      1415
 samples avg       0.40      0.40      0.39      1415


THRESH DICT
              precision    recall  f1-score   support

       elite       0.79      0.88      0.84       641
       centr       0.66      0.77      0.71       316
        left       0.69      0.77      0.73       281
       right       0.67      0.69      0.68       177

   micro avg       0.73      0.81      0.77      1415
   macro avg       0.71      0.78      0.74      1415
weighted avg       0.73      0.81      0.77      1415
 samples avg       0.40      0.40      0.39      1415


##################################################
fold=2


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.42, 1: 0.26, 2: 0.28, 3: 0.17}
epoch=1 train_loss=0.3418 eval_loss=0.2842 score=0.4122


Best BF thresh {0: 0.53, 1: 0.33, 2: 0.37, 3: 0.25}
epoch=2 train_loss=0.2830 eval_loss=0.2622 score=0.6430


Best BF thresh {0: 0.52, 1: 0.44, 2: 0.5700000000000001, 3: 0.23}
epoch=3 train_loss=0.1373 eval_loss=0.1758 score=0.6850


Best BF thresh {0: 0.5700000000000001, 1: 0.28, 2: 0.62, 3: 0.19}
epoch=4 train_loss=0.1721 eval_loss=0.2074 score=0.7145


Best BF thresh {0: 0.5700000000000001, 1: 0.56, 2: 0.67, 3: 0.17}
epoch=5 train_loss=0.2818 eval_loss=0.3019 score=0.7087


Best BF thresh {0: 0.64, 1: 0.26, 2: 0.55, 3: 0.25}
epoch=6 train_loss=0.2129 eval_loss=0.1352 score=0.7179


Best BF thresh {0: 0.59, 1: 0.52, 2: 0.63, 3: 0.22}
epoch=7 train_loss=0.1432 eval_loss=0.1014 score=0.7177


Best BF thresh {0: 0.49, 1: 0.44, 2: 0.75, 3: 0.19}
epoch=8 train_loss=0.3097 eval_loss=0.2905 score=0.7207


Best BF thresh {0: 0.74, 1: 0.46, 2: 0.76, 3: 0.48}
epoch=9 train_loss=0.0923 eval_loss=0.2257 score=0.7313


Best BF thresh {0: 0.14, 1: 0.24, 2: 0.59, 3: 0.09}
epoch=10 train_loss=0.0396 eval_loss=0.2689 score=0.6878


Best BF thresh {0: 0.78, 1: 0.54, 2: 0.81, 3: 0.26}
epoch=11 train_loss=0.0233 eval_loss=0.3828 score=0.7274


Best BF thresh {0: 0.8200000000000001, 1: 0.73, 2: 0.52, 3: 0.6900000000000001}
epoch=12 train_loss=0.0827 eval_loss=0.1668 score=0.7327


Best BF thresh {0: 0.56, 1: 0.52, 2: 0.77, 3: 0.27}
epoch=13 train_loss=0.1143 eval_loss=0.4110 score=0.7272


THRESH .5
              precision    recall  f1-score   support

       elite       0.81      0.88      0.85       627
       centr       0.67      0.75      0.71       333
        left       0.64      0.77      0.70       265
       right       0.69      0.62      0.65       152

   micro avg       0.73      0.80      0.76      1377
   macro avg       0.71      0.75      0.73      1377
weighted avg       0.73      0.80      0.76      1377
 samples avg       0.40      0.40      0.39      1377


THRESH DICT
              precision    recall  f1-score   support

       elite       0.81      0.88      0.85       627
       centr       0.67      0.75      0.71       333
        left       0.63      0.78      0.70       265
       right       0.67      0.64      0.66       152

   micro avg       0.73      0.81      0.76      1377
   macro avg       0.70      0.77      0.73      1377
weighted avg       0.73      0.81      0.76      1377
 samples avg       0.40      0.41      0.39      1377


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.21, 1: 0.27, 2: 0.19, 3: 0.14}
epoch=1 train_loss=0.3091 eval_loss=0.3590 score=0.3601


Best BF thresh {0: 0.43, 1: 0.24, 2: 0.47000000000000003, 3: 0.26}
epoch=2 train_loss=0.1529 eval_loss=0.2642 score=0.6242


Best BF thresh {0: 0.58, 1: 0.31, 2: 0.36, 3: 0.25}
epoch=3 train_loss=0.1705 eval_loss=0.2076 score=0.6444


Best BF thresh {0: 0.3, 1: 0.4, 2: 0.36, 3: 0.13}
epoch=4 train_loss=0.1839 eval_loss=0.2952 score=0.6645


Best BF thresh {0: 0.29, 1: 0.39, 2: 0.28, 3: 0.29}
epoch=5 train_loss=0.4829 eval_loss=0.2606 score=0.7000


Best BF thresh {0: 0.32, 1: 0.26, 2: 0.19, 3: 0.26}
epoch=6 train_loss=0.1320 eval_loss=0.3012 score=0.6948


Best BF thresh {0: 0.27, 1: 0.36, 2: 0.26, 3: 0.32}
epoch=7 train_loss=0.1286 eval_loss=0.2255 score=0.7084


Best BF thresh {0: 0.36, 1: 0.43, 2: 0.43, 3: 0.19}
epoch=8 train_loss=0.0790 eval_loss=0.3275 score=0.6964


Best BF thresh {0: 0.39, 1: 0.41000000000000003, 2: 0.3, 3: 0.28}
epoch=9 train_loss=0.0806 eval_loss=0.0991 score=0.7159


Best BF thresh {0: 0.23, 1: 0.38, 2: 0.37, 3: 0.32}
epoch=10 train_loss=0.0495 eval_loss=0.2168 score=0.7118


Best BF thresh {0: 0.74, 1: 0.52, 2: 0.46, 3: 0.52}
epoch=11 train_loss=0.1217 eval_loss=0.2544 score=0.7333


Best BF thresh {0: 0.44, 1: 0.39, 2: 0.35000000000000003, 3: 0.13}
epoch=12 train_loss=0.0884 eval_loss=0.4090 score=0.7178


Best BF thresh {0: 0.63, 1: 0.6900000000000001, 2: 0.34, 3: 0.29}
epoch=13 train_loss=0.0309 eval_loss=0.2056 score=0.7277


THRESH .5
              precision    recall  f1-score   support

       elite       0.83      0.88      0.85       638
       centr       0.66      0.75      0.70       310
        left       0.69      0.74      0.71       284
       right       0.62      0.66      0.64       137

   micro avg       0.74      0.80      0.77      1369
   macro avg       0.70      0.75      0.73      1369
weighted avg       0.74      0.80      0.77      1369
 samples avg       0.40      0.40      0.39      1369


THRESH DICT
              precision    recall  f1-score   support

       elite       0.83      0.88      0.85       638
       centr       0.67      0.75      0.71       310
        left       0.68      0.76      0.72       284
       right       0.60      0.72      0.66       137

   micro avg       0.74      0.81      0.77      1369
   macro avg       0.70      0.78      0.73      1369
weighted avg       0.74      0.81      0.77      1369
 samples avg       0.40      0.40      0.39      1369


Best BF thresh {0: 0.21, 1: 0.19, 2: 0.16, 3: 0.06}
epoch=1 train_loss=0.2022 eval_loss=0.3479 score=0.3131


Best BF thresh {0: 0.24, 1: 0.33, 2: 0.25, 3: 0.2}
epoch=2 train_loss=0.1898 eval_loss=0.3101 score=0.5428


Best BF thresh {0: 0.33, 1: 0.45, 2: 0.32, 3: 0.34}
epoch=3 train_loss=0.1873 eval_loss=0.2563 score=0.6725


Best BF thresh {0: 0.39, 1: 0.15, 2: 0.24, 3: 0.15}
epoch=4 train_loss=0.1997 eval_loss=0.2545 score=0.6225


Best BF thresh {0: 0.44, 1: 0.36, 2: 0.32, 3: 0.32}
epoch=5 train_loss=0.2723 eval_loss=0.1902 score=0.7231


Best BF thresh {0: 0.54, 1: 0.28, 2: 0.58, 3: 0.28}
epoch=6 train_loss=0.1590 eval_loss=0.2100 score=0.7123


Best BF thresh {0: 0.35000000000000003, 1: 0.45, 2: 0.46, 3: 0.24}
epoch=7 train_loss=0.1393 eval_loss=0.2593 score=0.7346


Best BF thresh {0: 0.3, 1: 0.14, 2: 0.46, 3: 0.26}
epoch=8 train_loss=0.1012 eval_loss=0.2927 score=0.7383


Best BF thresh {0: 0.15, 1: 0.24, 2: 0.26, 3: 0.43}
epoch=9 train_loss=0.0448 eval_loss=0.2559 score=0.7286


Best BF thresh {0: 0.29, 1: 0.12, 2: 0.29, 3: 0.14}
epoch=10 train_loss=0.0660 eval_loss=0.3918 score=0.7061


Best BF thresh {0: 0.46, 1: 0.36, 2: 0.3, 3: 0.24}
epoch=11 train_loss=0.0625 eval_loss=0.3216 score=0.7407


Best BF thresh {0: 0.18, 1: 0.2, 2: 0.29, 3: 0.11}
epoch=12 train_loss=0.1151 eval_loss=0.1175 score=0.7009


Best BF thresh {0: 0.6, 1: 0.27, 2: 0.71, 3: 0.38}
epoch=13 train_loss=0.0260 eval_loss=0.3800 score=0.7409


THRESH .5
              precision    recall  f1-score   support

       elite       0.81      0.90      0.85       690
       centr       0.69      0.70      0.69       333
        left       0.69      0.72      0.71       302
       right       0.70      0.72      0.71       167

   micro avg       0.75      0.80      0.77      1492
   macro avg       0.72      0.76      0.74      1492
weighted avg       0.75      0.80      0.77      1492
 samples avg       0.42      0.42      0.41      1492


THRESH DICT
              precision    recall  f1-score   support

       elite       0.82      0.90      0.85       690
       centr       0.69      0.70      0.69       333
        left       0.68      0.74      0.70       302
       right       0.69      0.78      0.73       167

   micro avg       0.75      0.81      0.78      1492
   macro avg       0.72      0.78      0.75      1492
weighted avg       0.75      0.81      0.77      1492
 samples avg       0.42      0.43      0.41      1492


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.5, 1: 0.27, 2: 0.31, 3: 0.3}
epoch=1 train_loss=0.2916 eval_loss=0.3140 score=0.4148


Best BF thresh {0: 0.44, 1: 0.26, 2: 0.37, 3: 0.2}
epoch=2 train_loss=0.3054 eval_loss=0.2030 score=0.5361


Best BF thresh {0: 0.27, 1: 0.45, 2: 0.43, 3: 0.25}
epoch=3 train_loss=0.2335 eval_loss=0.1689 score=0.6459


Best BF thresh {0: 0.35000000000000003, 1: 0.44, 2: 0.32, 3: 0.44}
epoch=4 train_loss=0.2216 eval_loss=0.1686 score=0.6828


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.41000000000000003, 1: 0.29, 2: 0.56, 3: 0.18}
epoch=5 train_loss=0.1227 eval_loss=0.2036 score=0.6719


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.23, 1: 0.44, 2: 0.43, 3: 0.31}
epoch=6 train_loss=0.1258 eval_loss=0.2506 score=0.6927


Best BF thresh {0: 0.46, 1: 0.4, 2: 0.48, 3: 0.38}
epoch=7 train_loss=0.1070 eval_loss=0.1896 score=0.7081


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.23, 1: 0.35000000000000003, 2: 0.56, 3: 0.23}
epoch=8 train_loss=0.0803 eval_loss=0.2366 score=0.7045


  fscores = (2 * precision * recall) / (precision + recall)


Best BF thresh {0: 0.2, 1: 0.29, 2: 0.59, 3: 0.23}
epoch=9 train_loss=0.0990 eval_loss=0.2227 score=0.7025


Best BF thresh {0: 0.14, 1: 0.25, 2: 0.4, 3: 0.49}
epoch=10 train_loss=0.1307 eval_loss=0.1921 score=0.7153


Best BF thresh {0: 0.08, 1: 0.27, 2: 0.49, 3: 0.11}
epoch=11 train_loss=0.0519 eval_loss=0.2505 score=0.6882


Best BF thresh {0: 0.25, 1: 0.47000000000000003, 2: 0.67, 3: 0.18}
epoch=12 train_loss=0.1028 eval_loss=0.3465 score=0.7080


Best BF thresh {0: 0.19, 1: 0.28, 2: 0.45, 3: 0.5700000000000001}
epoch=13 train_loss=0.0694 eval_loss=0.3839 score=0.7167


THRESH .5
              precision    recall  f1-score   support

       elite       0.80      0.88      0.84       640
       centr       0.66      0.66      0.66       316
        left       0.65      0.79      0.71       261
       right       0.68      0.63      0.65       140

   micro avg       0.73      0.79      0.76      1357
   macro avg       0.70      0.74      0.72      1357
weighted avg       0.73      0.79      0.75      1357
 samples avg       0.38      0.39      0.38      1357


THRESH DICT
              precision    recall  f1-score   support

       elite       0.80      0.88      0.84       640
       centr       0.66      0.66      0.66       316
        left       0.64      0.81      0.71       261
       right       0.64      0.66      0.65       140

   micro avg       0.72      0.79      0.76      1357
   macro avg       0.69      0.75      0.72      1357
weighted avg       0.72      0.79      0.75      1357
 samples avg       0.38      0.39      0.38      1357


In [None]:
score_dict = defaultdict(list)
for result in results:
    for dim, scores in result.items():
        score_dict[dim].append(scores["f1-score"])

out_data = []
for dim, scores in score_dict.items():
    out = {"dim": dim, "mean": np.mean(scores), "std": np.std(scores)}
    out_data.append(out)

In [None]:
pd.DataFrame(out_data)

Unnamed: 0,dim,mean,std
0,elite,0.84678,0.007444
1,centr,0.695982,0.0189
2,left,0.713261,0.00989
3,right,0.674466,0.030877
4,micro avg,0.766746,0.006838
5,macro avg,0.732623,0.010373
6,weighted avg,0.766799,0.007123
7,samples avg,0.392949,0.011762
