In [None]:
import src
from src.db.connect import make_engine
import src.db.models.bert_data as bm
import src.db.models.open_discourse as od
import pandas as pd

from sqlalchemy.orm import Query


engine = make_engine("DB")


query = (
    Query(od.Speech)
    .join(bm.Sample)
    .join(bm.Prediction)
    .join(od.Faction)
    .filter(
      od.Speech.electoral_term.in_([19]), 
      od.Speech.politician_id != -1,
      od.Faction.abbreviation != "Fraktionslos",
    )
    .with_entities(
        od.Speech.session,
        od.Speech.electoral_term,
        od.Faction.abbreviation.label("faction"),
        od.Speech.politician_id,
        od.Speech.id.label("speech_id"),
        bm.Sample.id.label("sample_id"),
        bm.Prediction.left,
        bm.Prediction.right,
        bm.Sample.text
    )
    .order_by(od.Speech.id.asc(), bm.Sample.id.asc())
    .limit(None)
    )

with engine.connect() as conn:
    raw = pd.read_sql(query.statement, conn)

speeches = raw.groupby(["electoral_term", "session", "politician_id"])

cache = []
group_ix = 0
new_block = True


def group_counter(gap):
    global group_ix
    global new_block
    if gap > 7 or new_block:
        new_block = False
        group_ix += 1
        return group_ix
    else:
        return group_ix


for i, speech in speeches:
    new_block = True
    if len(speech) == 1:
        speech["group"] = group_counter(10)
        cache.append(speech)
        continue
    speech.reset_index(inplace=True, drop=True)
    speech = speech.sort_values("speech_id", ascending=True)
    speech["prev_id"] = speech["speech_id"].shift(1)
    speech["gap"] = speech["speech_id"] - speech["prev_id"]
    speech["group"] = speech.gap.apply(group_counter)

    for _, speech_group in speech.groupby("group"):
        cache.append(speech_group)

df = pd.concat(cache).reset_index(drop=True)

df.to_csv(src.PATH / "wordfish.csv")