In [None]:
import torch
from sqlalchemy.orm import Session
from sqlalchemy.orm import close_all_sessions
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import logging

import src.db.models.bert_data as bm
from src.db.connect import make_engine

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cuda"

In [None]:
# set transformers logging v to error
logging.set_verbosity_error()

engine = make_engine("DB")
bm.Base.metadata.create_all(engine, checkfirst=True)

In [None]:
new_run = True

if new_run:
    close_all_sessions()
    bm.Base.metadata.drop_all(engine, tables=[bm.Prediction.__table__])
    bm.Base.metadata.create_all(engine, tables=[bm.Prediction.__table__])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("luerhard/PopBERT")
model = AutoModelForSequenceClassification.from_pretrained("luerhard/PopBERT")
model = model.to(DEVICE)

In [None]:
def iter_batches(engine, yield_per=100):
    with Session(engine) as s:
        existing_preds = s.query(bm.Prediction).filter(bm.Prediction.sample_id == bm.Sample.id)

        query = (
            s.query(bm.Sample)
            .filter(~existing_preds.exists())
            .with_entities(
                bm.Sample.id,
                bm.Sample.text,
            )
        ).limit(None)

    cache = []
    for row in tqdm(query.yield_per(yield_per), total=query.count()):
        cache.append(row)
        if len(cache) >= yield_per:
            yield cache
            cache.clear()
    else:
        yield cache

In [None]:
with torch.inference_mode():
    for batch in iter_batches(engine, 500):
        ids, text = list(zip(*batch))
        encodings = tokenizer(text, return_tensors="pt").to(DEVICE)

        out = model(**encodings)
        proba_tensor = torch.nn.functional.sigmoid(out.logits)
        probas = proba_tensor.cpu().detach().numpy()

        preds = []
        for (
            id_,
            pred,
        ) in zip(ids, probas):
            row = {
                "sample_id": id_,
                "elite": pred[0].astype(float),
                "pplcentr": pred[1].astype(float),
                "left": pred[2].astype(float),
                "right": pred[3].astype(float),
            }
            preds.append(row)

        with Session(engine) as s:
            s.bulk_insert_mappings(bm.Prediction, preds)
            s.commit()