In [5]:
import pandas as pd
import torch
from sqlalchemy.orm import Session
from sqlalchemy.orm import close_all_sessions
from tqdm.auto import tqdm
from transformers import AutoModel
from transformers import AutoTokenizer
from transformers import logging

import src
import src.db.models.bert_data as bm
import src.db.models.open_discourse as od
from src.bert.dataset import PBertDataset
from src.db.connect import make_engine

In [6]:
DEVICE = "cuda"

In [7]:
# set transformers logging v to error
logging.set_verbosity_error()

engine = make_engine("DB")
bm.Base.metadata.create_all(engine, checkfirst=True)

In [12]:
new_run = True

if new_run:
    close_all_sessions()
    bm.Base.metadata.drop_all(engine, tables=[bm.Prediction.__table__])
    bm.Base.metadata.create_all(engine, tables=[bm.Prediction.__table__])

In [16]:
commit_hash = "0e6cb925223c62c00595cc20d37fdc7ee4c8b1e1" 
tokenizer = AutoTokenizer.from_pretrained("luerhard/PopBERT")
model = AutoModel.from_pretrained("luerhard/PopBERT", trust_remote_code=True, revision=commit_hash)
model = model.to("cuda")

In [17]:
def iter_batches(engine, yield_per=100):
    with Session(engine) as s:
        existing_preds = s.query(bm.Prediction).filter(bm.Prediction.sample_id == bm.Sample.id)

        query = (
            s.query(bm.Sample)
            .filter(~existing_preds.exists())
            .with_entities(
                bm.Sample.id,
                bm.Sample.text,
            )
        ).limit(None)

    cache = []
    for row in tqdm(query.yield_per(yield_per), total=query.count()):
        cache.append(row)
        if len(cache) >= yield_per:
            yield cache
            cache.clear()
    else:
        yield cache

In [18]:
with torch.inference_mode():
    for batch in iter_batches(engine, 500):
        ids, text = list(zip(*batch))
        encodings = tokenizer(text, padding=True, return_tensors="pt").to(DEVICE)

        _, probas = model(**encodings)
        probas = probas.detach().to("cpu").numpy()

        preds = []
        for (
            id_,
            pred,
        ) in zip(ids, probas):
            row = {
                "sample_id": id_,
                "elite": pred[0].astype(float),
                "pplcentr": pred[1].astype(float),
                "left": pred[2].astype(float),
                "right": pred[3].astype(float),
            }
            preds.append(row)

        with Session(engine) as s:
            s.bulk_insert_mappings(bm.Prediction, preds)
            s.commit()

  0%|          | 0/1266245 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 478.00 MiB (GPU 0; 7.75 GiB total capacity; 3.33 GiB already allocated; 203.19 MiB free; 3.92 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF