In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict

import numpy as np
import torch
from sklearn.metrics import classification_report
from sqlalchemy import func
from sqlalchemy.orm import Session
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers import logging

import src
import src.db
from src.bert import module
from src.bert.dataset import PBertDataset
from src.bert.dataset import strategies
from src.db.models import bert_data as bm
from src.utils.metrics import custom_f1_score

In [None]:
engine = src.db.make_engine("DB")

In [None]:
logging.set_verbosity_error()

# model hyper-parameters
LR = 1e-5
N_EPOCHS = 15
BATCH_SIZE = 8

TOKENIZER = "deepset/gbert-large"
BASE_MODEL = "deepset/gbert-large"

STRATEGY = strategies.MLPopBinIdeol(output_fmt="multi_task")

DEVICE = "cuda"

EXCLUDE_CODERS = []

In [None]:
train = PBertDataset.from_disk(
    src.PATH / "data/bert/train.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)


test = PBertDataset.from_disk(
    src.PATH / "data/bert/test.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)
val = PBertDataset.from_disk(
    src.PATH / "data/bert/validation.csv.zip",
    label_strategy=STRATEGY,
    exclude_coders=EXCLUDE_CODERS,
)

In [None]:
len(train), len(test), len(val)

(5277, 1759, 1759)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
model = module.BertMultiTaskMultiLabel(
    num_tasks=train.num_coders, num_labels=train.num_labels, name=BASE_MODEL
)

In [None]:
collate_fn = train.create_collate_fn(tokenizer)

train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(val, collate_fn=collate_fn, batch_size=64, shuffle=False)
test_loader = DataLoader(test, collate_fn=collate_fn, batch_size=64, shuffle=False)

In [None]:
train.coders

['grabsch', 'schadt', 'richter', 'riedel', 'coudry']

In [None]:
model.train()
model = model.to(DEVICE)
model.set_seed()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    amsgrad=False,
    weight_decay=1e-2,
)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=N_EPOCHS,
    eta_min=LR / 10,
)

print("epoch" + " " * 6 + "loss" + " " * 8 + "LR" + " " * 9 + "score" + " " * 6 + "score_meta")
print("-" * 65)

for epoch in range(1, N_EPOCHS + 1):
    epoch_loss = 0.0
    current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
    for batch in tqdm(train_loader, leave=False, desc=f"Epoch {epoch}"):
        encodings = batch["encodings"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        loss, _ = model(**encodings, labels=labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    lr_scheduler.step()

    model.eval()
    preds = defaultdict(list)
    with torch.inference_mode():
        for batch in test_loader:
            encodings = batch["encodings"]
            encodings = encodings.to(DEVICE)
            predictions = model.predict_proba(encodings)
            preds["y_vote_test"].extend(batch["vote"])
            preds["y_labels_test"].extend(batch["labels"].detach().numpy())
            preds["y_probas_test"].extend(predictions)

        for batch in valid_loader:
            encodings = batch["encodings"]
            encodings = encodings.to(DEVICE)
            predictions = model.predict_proba(encodings)
            preds["y_vote_val"].extend(batch["vote"])
            preds["y_labels_val"].extend(batch["labels"].detach().numpy())
            preds["y_probas_val"].extend(predictions)

    score = model.score(preds)
    check_score = model.score(preds, threshold_method="per_label")
    print(check_score)
    epoch_loss /= int(len(train_loader) / BATCH_SIZE)
    print(
        f"{epoch:<10} {epoch_loss:<11.3f} {current_lr:<10.6f} {score['score']:<10.4f} {score['score_meta']:<10}"
    )
    model.train()

epoch      loss        LR         score      score_meta
-----------------------------------------------------------------


Epoch 1:   0%|          | 0/660 [00:00<?, ?it/s]

1          10.972      0.000010   0.4127     {0: {0: 0.3, 1: 0.3, 2: 0.15}, 1: {0: 0.3, 1: 0.3, 2: 0.2}, 2: {0: 0.5, 1: 0.4, 2: 0.15}, 3: {0: 0.55, 1: 0.45, 2: 0.25}, 4: {0: 0.45, 1: 0.4, 2: 0.2}}


Epoch 2:   0%|          | 0/660 [00:00<?, ?it/s]

2          8.407       0.000010   0.5486     {0: {0: 0.4, 1: 0.3, 2: 0.25}, 1: {0: 0.35, 1: 0.35, 2: 0.2}, 2: {0: 0.4, 1: 0.4, 2: 0.15}, 3: {0: 0.5, 1: 0.35, 2: 0.15}, 4: {0: 0.5, 1: 0.4, 2: 0.15}}


Epoch 3:   0%|          | 0/660 [00:00<?, ?it/s]

3          6.833       0.000010   0.5465     {0: {0: 0.4, 1: 0.2, 2: 0.2}, 1: {0: 0.3, 1: 0.4, 2: 0.4}, 2: {0: 0.35, 1: 0.3, 2: 0.25}, 3: {0: 0.35, 1: 0.4, 2: 0.3}, 4: {0: 0.3, 1: 0.4, 2: 0.25}}


Epoch 4:   0%|          | 0/660 [00:00<?, ?it/s]

4          5.879       0.000009   0.5750     {0: {0: 0.4, 1: 0.15, 2: 0.2}, 1: {0: 0.4, 1: 0.25, 2: 0.25}, 2: {0: 0.55, 1: 0.35, 2: 0.15}, 3: {0: 0.55, 1: 0.55, 2: 0.2}, 4: {0: 0.55, 1: 0.4, 2: 0.15}}


Epoch 5:   0%|          | 0/660 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model = model.eval()

thresh = eval(score["score_meta"])

with torch.inference_mode():
    y_true = []
    y_proba = []
    for batch in test_loader:
        encodings = batch["encodings"]
        encodings = encodings.to(DEVICE)
        predictions = model.predict_proba(encodings)
        y_true.extend(batch["vote"])
        y_proba.extend(predictions)

y_pred = [model.vote(y, threshold=thresh) for y in y_proba]

In [None]:
print(classification_report(y_true, y_pred, target_names=train.labels, zero_division=0))

              precision    recall  f1-score   support

         pop       0.60      0.66      0.63       341
        left       0.42      0.59      0.49        76
       right       0.51      0.53      0.52        34

   micro avg       0.56      0.64      0.60       451
   macro avg       0.51      0.59      0.55       451
weighted avg       0.57      0.64      0.60       451
 samples avg       0.12      0.13      0.12       451



In [None]:
print(
    "all labels: {:.3f}".format(
        custom_f1_score(y_true, y_pred, labels=[(1, 0, 0), (1, 1, 0), (1, 0, 1), (0, 0, 0)])
    )
)
print("pop general: {:.3f}".format(custom_f1_score(y_true, y_pred, labels=[(1, 0, 0)])))
print("pop left: {:.3f}".format(custom_f1_score(y_true, y_pred, labels=[(1, 1, 0)])))
print("pop right: {:.3f}".format(custom_f1_score(y_true, y_pred, labels=[(1, 0, 1)])))

all labels: 0.608
pop general: 0.537
pop left: 0.476
pop right: 0.515


In [None]:
with Session(engine) as s:
    max_batch = s.query(bm.Sample).with_entities(func.max(bm.Sample.used_in_batch)).scalar()
    print(max_batch)

9


In [None]:
# torch.save(model, src.PATH / f"tmp/model_v8.4.model")