In [None]:
import pandas as pd
from sqlalchemy import func
from sqlalchemy import update
from sqlalchemy.orm import Query
from sqlalchemy.orm import Session

import src
import src.db.models.bert_data as bm
import src.db.models.open_discourse as m
from src.db.connect import make_engine

In [None]:
engine = make_engine("DB")
session = Session(engine)

In [None]:
# takes 16s to download all data
# select only rows that have not been used bevor (used_in_batch is NULL)
query = (
    Query(bm.Sample)
    .join(bm.Sample.faction)
    .join(bm.Sample.speech)
    .filter(bm.Sample.used_in_batch == None)
    .with_entities(
        bm.Sample.id,
        m.Faction.abbreviation,
        m.Speech.electoral_term,
        bm.Sample.text,
        bm.Sample.pop_dict_score,
    )
)

sample_df = pd.read_sql(query.statement, engine)

## Sample Sentences


In [None]:
# number of examples to be drawn
MAX_POS_SAMPLES_PER_GROUP = 100
MAX_NEG_SAMPLES_PER_GROUP = 80

# we
subgroup_group_samples = []
for _, term in sample_df[sample_df.pop_dict_score == 1].groupby("electoral_term"):
    for _, subdf in term.groupby("abbreviation"):
        if len(subdf) > MAX_POS_SAMPLES_PER_GROUP:
            subdf = subdf.sample(MAX_POS_SAMPLES_PER_GROUP)
        subgroup_group_samples.append(subdf)

for _, term in sample_df[sample_df.pop_dict_score == 0].groupby("electoral_term"):
    for _, subdf in term.groupby("abbreviation"):
        if len(subdf) > MAX_NEG_SAMPLES_PER_GROUP:
            subdf = subdf.sample(MAX_NEG_SAMPLES_PER_GROUP)
        subgroup_group_samples.append(subdf)


final_df = pd.concat(subgroup_group_samples)

In [None]:
# get max value of batch number
current_max_batch = session.query(func.max(bm.Sample.used_in_batch)).scalar()
if current_max_batch is None:
    current_max_batch = 0

# increment for new batch
new_batch_no = current_max_batch + 1

In [None]:
# update rows in DB to reflect batch number
query = (
    update(bm.Sample.__table__)
    .where(bm.Sample.__table__.c.id.in_(set(final_df["id"])))
    .values(used_in_batch=new_batch_no)
)

_ = session.execute(query)

session.commit()

<sqlalchemy.engine.cursor.CursorResult at 0x7f98702d3fd0>

# Summary


In [None]:
# over all counts
sample_df.groupby(["electoral_term", "abbreviation", "pop_dict_score"])["id"].count()

electoral_term  abbreviation  pop_dict_score
18              CDU/CSU       False             225628
                              True                1991
                DIE LINKE.    False             100319
                              True                1529
                Fraktionslos  False                 42
                Grüne         False             108277
                              True                1057
                SPD           False             152581
                              True                1425
19              AfD           False              93448
                              True                2838
                CDU/CSU       False             200950
                              True                1958
                DIE LINKE.    False              75392
                              True                1350
                FDP           False              81677
                              True                 860
                Frak

In [None]:
# sample counts
final_df.groupby(["electoral_term", "abbreviation", "pop_dict_score"])["id"].count()

electoral_term  abbreviation  pop_dict_score
18              CDU/CSU       False              80
                              True              100
                DIE LINKE.    False              80
                              True              100
                Fraktionslos  False              42
                Grüne         False              80
                              True              100
                SPD           False              80
                              True              100
19              AfD           False              80
                              True              100
                CDU/CSU       False              80
                              True              100
                DIE LINKE.    False              80
                              True              100
                FDP           False              80
                              True              100
                Fraktionslos  False              80
                   

In [None]:
final_df.to_csv(src.PATH / "data/sentence_sample.csv", index=False)