In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import lru_cache

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report
from sqlalchemy import func
from sqlalchemy.orm import Query
from sqlalchemy.orm import Session
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers import logging

import src
import src.db.models.bert_data as bm
from src.bert.dataset import PBertDataset
from src.bert.dataset import strategies

In [None]:
pd.set_option("display.max_colwidth", 2048)
pd.set_option("display.max_rows", 100)

In [None]:
# set transformers logging v to error
logging.set_verbosity_error()

engine = src.db.connect.make_engine("DB")

tmpdir = src.PATH / "tmp"
tmpdir.mkdir(exist_ok=True)

DEVICE = "cuda"

LABEL_STRATEGY = strategies.MLPopBinIdeol(output_fmt="multi_task")

In [None]:
# save model
tokenizer = AutoTokenizer.from_pretrained("deepset/gbert-large")
model = torch.load(tmpdir / f"model_v8.4.model")
model = model.eval()

In [None]:
@lru_cache(maxsize=1)
def load_unlabeled_data(engine):
    query = (
        Query(bm.Sample)
        .filter(bm.Sample.used_in_batch == None)
        .with_entities(bm.Sample.id, bm.Sample.text)
    )

    with engine.connect() as conn:
        df = pd.read_sql(query.statement, conn)

    return df

In [None]:
class RawDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {"id": row.id, "text": row.text}

# find thresholds

In [None]:
test_data = PBertDataset.from_disk(
    src.PATH / "data/bert/test.csv.zip", label_strategy=LABEL_STRATEGY
)

test_loader = DataLoader(
    test_data, collate_fn=test_data.create_collate_fn(tokenizer), batch_size=128, shuffle=False
)

In [None]:
y_true = []
texts = []
probas = []

with torch.inference_mode():
    for batch in tqdm(test_loader, leave=False):
        encoding = batch["encodings"].to(DEVICE)
        text = batch["text"]
        labels = batch["vote"]

        proba = model.predict_proba(encoding)

        texts.extend(text)
        probas.extend(proba)
        y_true.extend(labels)

  0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
thresholds = model.find_thresholds(y_true, probas)
print(thresholds)

{0: 0.4, 1: 0.6, 2: 0.3}


In [None]:
y_pred = [model.vote(proba, threshold=thresholds) for proba in probas]

In [None]:
print(classification_report(y_true, y_pred, zero_division=0, target_names=LABEL_STRATEGY.labels))

              precision    recall  f1-score   support

         pop       0.58      0.70      0.64       337
        left       0.56      0.60      0.58        47
       right       0.62      0.72      0.67        36

   micro avg       0.58      0.69      0.63       420
   macro avg       0.59      0.67      0.63       420
weighted avg       0.58      0.69      0.63       420
 samples avg       0.11      0.10      0.10       420



In [None]:
y_pred_constrict = []
for pred in y_pred:
    if pred[0] == 0:
        pred[1:] = [0, 0]
    y_pred_constrict.append(pred)

y_true_constrict = []
for pred in y_true:
    pred = list(pred)
    if pred[0] == 0:
        pred[1:] = [0, 0]
    y_true_constrict.append(pred)

In [None]:
print(
    classification_report(
        y_true_constrict,
        y_pred_constrict,
        zero_division=0,
        target_names=LABEL_STRATEGY.labels,
    )
)

              precision    recall  f1-score   support

         pop       0.58      0.70      0.64       337
        left       0.53      0.59      0.56        44
       right       0.64      0.71      0.68        35

   micro avg       0.58      0.69      0.63       416
   macro avg       0.58      0.67      0.62       416
weighted avg       0.58      0.69      0.63       416
 samples avg       0.11      0.10      0.10       416



# get new samples


In [None]:
# prediction takes forever, therefore sampling a bit beforehand...
X_pool = load_unlabeled_data(engine).sample(400_000)

In [None]:
dataset = RawDataset(X_pool)


def collate_fn(batch):
    text = [d["text"] for d in batch]
    ids = [d["id"] for d in batch]
    encodings = tokenizer(text, padding=True, return_tensors="pt")

    return {"id": ids, "text": text, "encodings": encodings}


data_loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
ids = []
texts = []
probas = []

with torch.inference_mode():
    for batch in tqdm(data_loader, leave=False):
        encoding = batch["encodings"].to(DEVICE)
        text = batch["text"]
        id_ = batch["id"]

        proba = model.predict_proba(encoding)
        proba = [np.mean(p, axis=0) for p in proba]

        ids.extend(id_)
        texts.extend(text)
        probas.extend(proba)

  0%|          | 0/6250 [00:00<?, ?it/s]

In [None]:
df = pd.concat(
    [
        pd.DataFrame({"id": ids, "text": texts}),
        pd.DataFrame(np.array(probas), columns=LABEL_STRATEGY.labels),
    ],
    axis=1,
)

In [None]:
len(df[df.left > 0.15])

2944

In [None]:
df = df.sort_values("left", ascending=True)

In [None]:
df.head()

Unnamed: 0,id,text,pop,left,right
84598,1051119,"Ein besseres Beispiel für eine kompetente, überzeugende Sacharbeit der AfD kann es gar nicht geben.",0.042575,0.002275,0.003405
70549,906527,Sie sollten sich auch beim Beschimpfen anderer Parteien zurückhalten.,0.021509,0.002277,0.002582
272722,675347,Wir müssen uns noch stärker für ein funktionierendes Gemeinwesen einsetzen.,0.057713,0.002283,0.003826
152116,1254361,"Kolleginnen und Kollegen aus dem Gesundheitsausschuss, wie Sie wissen, ist diese Entwicklung keineswegs ein Hirngespinst der Linken.",0.027731,0.002293,0.004423
125955,921458,"Ich habe Sie, werte Kollegen von der AfD, überwiegend breit grinsend sitzen sehen.",0.02478,0.002296,0.00307


In [None]:
df[df.left.between(0.30, 0.32)].sample(10)

Unnamed: 0,id,text,pop,left,right
335091,602753,"Wir kämpfen gegen die Menschen, die für viele Tausend Euros und Dollars die Ärmsten in Lastwagen sperren und über die Grenzen bringen oder Menschen auf Booten über das Mittelmeer schicken.",0.375506,0.311937,0.017949
63346,776220,Die sozialen Sicherungssysteme müssen Armut von Kindern und Jugendlichen ausschließen.,0.275237,0.31897,0.011145
390803,1097953,Gegen solche privaten Konzerngerichte sind viele Menschen in Europa bei Investitionsabkommen wie TTIP oder CETA auf die Straße gegangen.,0.323756,0.316602,0.013021
63244,256147,"Alle diese Krisen haben eine gemeinsame Ursache, nämlich ein System der ruinösen Konkurrenz auf Kosten von Mensch und Natur.",0.392722,0.307234,0.01078
23453,78369,"Es geht vor allem um die Menschen, die ausgebeutet werden, und deshalb muss ganz grundsätzlich das System, das Geschäftsmodell Fleischbranche, kritisiert werden.",0.365471,0.30354,0.006469
200415,1050323,"Die Leute sollen auch im Jobcenter die Solidarität der Gesellschaft erfahren, gerade die, die lange draußen sind.",0.316769,0.310225,0.010979
176000,634602,"Es ist verkommen, dass eine Industrie und die Politik Grenzwerte nicht ernst nehmen und so tun, als ob man sie nicht einhalten müsste.",0.491772,0.310904,0.005109
46906,1107758,"Wenn es beim Kindergeld überhaupt ein Thema gäbe, über das in diesem Hause dringend geredet und bei dem auch endlich mal gehandelt werden müsste, dann das, dass die Ärmsten und Bedürftigsten, die bisher von dieser Leistung ausgeschlossen sind, endlich in den Genuss der Leistung kommen.",0.272386,0.319198,0.012868
334577,1127065,Die Deutsche Umwelthilfe hingegen nutzt den Rechtsstaat einfach nur aus und beschmutzt nebenbei auch noch die seriöse Arbeit mancher sehr guter Umweltverbände.,0.804843,0.300303,0.036129
347874,463354,"Jeder hat eine faire Chance verdient, auch Menschen mit Behinderung.",0.256635,0.319026,0.013598


In [None]:
right_sample = df[df["pop"].between(0.3, 0.99) & df.right.between(0.05, 0.45)].sample(650)

In [None]:
left_sample = df[df["pop"].between(0.4, 0.90) & df.left.between(0.2, 0.60)].sample(650)

In [None]:
random_pop_sample = df[df["pop"].between(0.03, 0.3)].sample(200)

## select cases

In [None]:
selection = pd.concat([right_sample, left_sample, random_pop_sample])

In [None]:
selection.shape

(1500, 5)

In [None]:
selection = selection.drop_duplicates()

In [None]:
selection.shape

(1490, 5)

## load gründl cuz it's funny


In [None]:
with Session(engine) as s:
    query = (
        s.query(bm.Sample)
        .filter(bm.Sample.id.in_(selection.id.tolist()))
        .with_entities(bm.Sample.id, bm.Sample.pop_dict_score)
    )

with engine.connect() as conn:
    gruendl = pd.read_sql_query(query.statement, conn)
gruendl = pd.merge(selection, gruendl, on="id")

In [None]:
gruendl.groupby("pop_dict_score")["id"].count()

pop_dict_score
False    1403
True       87
Name: id, dtype: int64

# Export new batch

In [None]:
with Session(engine) as s:
    max_batch = s.query(bm.Sample).with_entities(func.max(bm.Sample.used_in_batch)).scalar()
    print(max_batch)

8


In [None]:
new_batch = max_batch + 1
new_batch

9

In [None]:
selection.to_parquet(tmpdir / f"active_learning_batch_{new_batch}.parquet")

In [None]:
with Session(engine) as s:
    s.query(bm.Sample).filter(bm.Sample.id.in_(selection["id"])).update(
        {"used_in_batch": new_batch}
    )
    s.commit()

In [None]:
selection = selection.sample(frac=1)

In [None]:
selection.shape

(1490, 5)

In [None]:
selection["label"] = ""
with open(tmpdir / f"active_learning_batch_{new_batch}.jsonl", "w", encoding="utf-8") as file:
    selection[["text", "label", "id"]].to_json(
        file, orient="records", lines=True, force_ascii=False
    )

In [None]:
print("...done!")

...done!
