In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict

import numpy as np
import torch
from sklearn.metrics import classification_report
from sqlalchemy import func
from sqlalchemy.orm import Session
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers import logging

import src
import src.bert.utils as bert_utils
import src.db
from src.bert import module
from src.bert.dataset import PBertDataset
from src.bert.dataset import strategies
from src.db.models import bert_data as bm

In [None]:
engine = src.db.make_engine("DB")

In [None]:
logging.set_verbosity_error()

# model hyper-parameters
LR = 9e-6
N_EPOCHS = 15
BATCH_SIZE = 16

TOKENIZER = "deepset/gbert-large"
BASE_MODEL = "deepset/gbert-large"

STRATEGY = strategies.MLMin1PopIdeol(output_fmt="single_task")

DEVICE = "cuda"

EXCLUDE_CODERS = []

In [None]:
train = PBertDataset.from_disk(
    src.PATH / "data/bert/train.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)
test = PBertDataset.from_disk(
    src.PATH / "data/bert/test.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)
val = PBertDataset.from_disk(
    src.PATH / "data/bert/validation.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)

In [None]:
len(train), len(test), len(val)

(5277, 1759, 1759)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

In [None]:
collate_fn = train.create_collate_fn(tokenizer)

train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(val, collate_fn=collate_fn, batch_size=64, shuffle=False)
test_loader = DataLoader(test, collate_fn=collate_fn, batch_size=64, shuffle=False)

In [None]:
train.coders

['grabsch', 'schadt', 'richter', 'riedel', 'coudry']

In [None]:
train.num_labels

4

In [None]:
model = module.BertSingleTaskMultiLabel(num_labels=train.num_labels, name=BASE_MODEL)
model.train()
model = model.to(DEVICE)
model.set_seed()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    amsgrad=False,
    weight_decay=1e-2,
)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=N_EPOCHS,
    eta_min=LR / 20,
)

print("epoch" + " " * 6 + "loss" + " " * 8 + "LR" + " " * 9 + "score" + " " * 6 + "score_meta")
print("-" * 65)

for epoch in range(1, N_EPOCHS + 1):
    epoch_loss = 0.0
    val_loss = 0.0
    current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
    for batch in tqdm(train_loader, leave=False, desc=f"Epoch {epoch}"):
        encodings = batch["encodings"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        loss, _ = model(**encodings, labels=labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    lr_scheduler.step()

    model.eval()
    with torch.inference_mode():
        y_true, y_pred = [], []
        for batch in valid_loader:
            encodings = batch["encodings"]
            encodings = encodings.to(DEVICE)
            predictions = model.predict_proba(encodings)
            y_true.extend(batch["vote"])
            y_pred.extend(predictions)

        thresh_finder = bert_utils.ThresholdFinder(
            method=model.threshold_type, type=model.model_type
        )
        thresholds = thresh_finder.find_thresholds(np.array(y_true), np.array(y_pred))

        y_true, y_pred = [], []
        for batch in test_loader:
            encodings = batch["encodings"]
            encodings = encodings.to(DEVICE)
            predictions = model.predict_proba(encodings)
            y_true.extend(batch["vote"])
            y_pred.extend(predictions)

    score = model.score(np.array(y_true), np.array(y_pred), thresholds=thresholds)
    epoch_loss /= int(len(train_loader) / BATCH_SIZE)
    print(
        f"{epoch:<10} {epoch_loss:<11.3f} {current_lr:<10.6f} {score['score']:<10.4f} {str(thresholds):<10}"
    )
    model.train()

epoch      loss        LR         score      score_meta
-----------------------------------------------------------------


Epoch 1:   0%|          | 0/330 [00:00<?, ?it/s]

1          5.416       0.000009   0.6900     {0: 0.29327637, 1: 0.34079868, 2: 0.31137857, 3: 0.243601}


Epoch 2:   0%|          | 0/330 [00:00<?, ?it/s]

  fscores = (2 * precision * recall) / (precision + recall)


2          3.566       0.000009   0.5550     {0: 0.30472243, 1: 0.39312005, 2: 0.33047506, 3: 0.9591235}


Epoch 3:   0%|          | 0/330 [00:00<?, ?it/s]

  fscores = (2 * precision * recall) / (precision + recall)


3          2.641       0.000009   0.5606     {0: 0.5958861, 1: 0.380949, 2: 0.49105638, 3: 0.98364383}


Epoch 4:   0%|          | 0/330 [00:00<?, ?it/s]

4          1.740       0.000008   0.7290     {0: 0.42223817, 1: 0.45497906, 2: 0.28872433, 3: 0.053682033}


Epoch 5:   0%|          | 0/330 [00:00<?, ?it/s]

5          1.222       0.000008   0.7319     {0: 0.25937536, 1: 0.6706303, 2: 0.5638416, 3: 0.164134}


Epoch 6:   0%|          | 0/330 [00:00<?, ?it/s]

6          0.829       0.000007   0.7436     {0: 0.15712969, 1: 0.5346219, 2: 0.19867392, 3: 0.09746829}


Epoch 7:   0%|          | 0/330 [00:00<?, ?it/s]

7          0.601       0.000006   0.7413     {0: 0.45009318, 1: 0.3139537, 2: 0.24526045, 3: 0.2079964}


Epoch 8:   0%|          | 0/330 [00:00<?, ?it/s]

8          0.403       0.000005   0.7375     {0: 0.23274799, 1: 0.28421813, 2: 0.20090549, 3: 0.28468984}


Epoch 9:   0%|          | 0/330 [00:00<?, ?it/s]

9          0.294       0.000004   0.7337     {0: 0.67199904, 1: 0.33226287, 2: 0.26435283, 3: 0.33514825}


Epoch 10:   0%|          | 0/330 [00:00<?, ?it/s]

10         0.215       0.000003   0.7342     {0: 0.48007447, 1: 0.037389167, 2: 0.06950781, 3: 0.088775896}


Epoch 11:   0%|          | 0/330 [00:00<?, ?it/s]

11         0.170       0.000003   0.7328     {0: 0.2947924, 1: 0.47270703, 2: 0.04146539, 3: 0.055548627}


Epoch 12:   0%|          | 0/330 [00:00<?, ?it/s]

12         0.165       0.000002   0.7368     {0: 0.5641429, 1: 0.21469122, 2: 0.06700835, 3: 0.044898875}


Epoch 13:   0%|          | 0/330 [00:00<?, ?it/s]

13         0.132       0.000001   0.7294     {0: 0.26046485, 1: 0.39254162, 2: 0.045209806, 3: 0.04642019}


Epoch 14:   0%|          | 0/330 [00:00<?, ?it/s]

14         0.111       0.000001   0.7370     {0: 0.40245926, 1: 0.12584877, 2: 0.036415488, 3: 0.08278702}


Epoch 15:   0%|          | 0/330 [00:00<?, ?it/s]

15         0.105       0.000001   0.7370     {0: 0.2621143, 1: 0.20370445, 2: 0.49065122, 3: 0.04641365}


In [None]:
model = model.eval()

with torch.inference_mode():
    y_true = []
    y_proba = []
    for batch in test_loader:
        encodings = batch["encodings"]
        encodings = encodings.to(DEVICE)
        predictions = model.predict_proba(encodings)
        y_true.extend(batch["vote"])
        y_proba.extend(predictions)

y_pred = [model.vote(y, threshold=thresholds) for y in y_proba]

In [None]:
print(classification_report(y_true, y_pred, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

       elite       0.80      0.89      0.84       656
       centr       0.64      0.75      0.69       335
        left       0.74      0.70      0.72       276
       right       0.63      0.76      0.69       153

   micro avg       0.73      0.81      0.77      1420
   macro avg       0.70      0.78      0.74      1420
weighted avg       0.73      0.81      0.77      1420
 samples avg       0.41      0.41      0.40      1420



In [None]:
with Session(engine) as s:
    max_batch = s.query(bm.Sample).with_entities(func.max(bm.Sample.used_in_batch)).scalar()
    print(max_batch)

In [None]:
torch.save(model, src.PATH / f"tmp/model_min1_popideol_v9.1.model")