In [2]:
import os
import polars as pl

## Load all of the label files for configuration 00
mo_labels_00 = pl.read_parquet("/Users/marcos/MEGA/gitHub/erdos_paware/data_labeling/labeled_results/labeled_00_mo.parquet").sort(by=["query_text","reddit_name"])
kk_labels_00 = pl.read_parquet("/Users/marcos/MEGA/gitHub/erdos_paware/data_labeling/labeled_results/labeled_00_kk.parquet").sort(by=["query_text","reddit_name"])
kp_labels_00 = pl.read_parquet("/Users/marcos/MEGA/gitHub/erdos_paware/data_labeling/labeled_results/labeled_00_kp.parquet").sort(by=["query_text","reddit_name"])
dr_labels_00 = pl.read_parquet("/Users/marcos/MEGA/gitHub/erdos_paware/data_labeling/labeled_results/labeled_00_dr.parquet").sort(by=["query_text","reddit_name"])
sr_labels_00 = pl.read_parquet("/Users/marcos/MEGA/gitHub/erdos_paware/data_labeling/labeled_results/labeled_00_sr.parquet").sort(by=["query_text","reddit_name"])

## Gather all the labels into one dataframe
df = mo_labels_00.clone()
df = df.with_columns(
    kk_labels_00["kk_label"].alias("kk_label"),
    kp_labels_00["kp_label"].alias("kp_label"),
    dr_labels_00["dr_label"].alias("dr_label"),
    sr_labels_00["sr_label"].alias("sr_label")
)

## Keep track of the columns with votes
vote_cols = ["mo_label", "kk_label", "kp_label", "dr_label", "sr_label"]

## Add a column that has a list of all votes cast
df = df.with_columns(
    pl.struct([pl.col(column_name) for column_name in vote_cols])
    .map_elements(lambda s: [value for value in s.values() 
                             if value is not None])
    .alias("votes")
)

## Set up the voting machine
def voting_machine(row: pl.Series) -> int:
    votes = row.to_list()
    votes_1 = votes.count(1)
    votes_2 = votes.count(2)
    votes_3 = votes.count(3)
    
    ## Clear winners
    if (votes_1 > votes_2) & (votes_1 > votes_3):
        return 1
    elif (votes_2 > votes_1) & (votes_2 > votes_3):
        return 2
    elif (votes_3 > votes_2) & (votes_3 > votes_1):
        return 3

    ## Dealing with ties
    if votes_1 == votes_2 > votes_3:
        if votes[0] in [1, 2]:
            return votes[0]
        else:
            return 2

    if votes_1 == votes_3 > votes_2:
        if votes[0] in [1, 3]:
            return votes[0]
        else:
            return 3

    if votes_2 == votes_3 > votes_1:
        if votes[0] in [2, 3]:
            return votes[0]
        else:
            return 2

    if votes_1 == votes_2 == votes_3:
        return votes[0]

# Apply the voting machine function
df = df.with_columns(
    df['votes'].map_elements(
        voting_machine, return_dtype=int).alias("relevance_rating")
)

In [3]:
## Join the unlabeled results with the labeled results
labeled_00 = df.clone()
query_results_02 = pl.read_parquet(
    "../data_labeling/raw_results/query_results_config_02_full.parquet")
joined_results = query_results_02.join(
    labeled_00, on=["reddit_name", "query_text"], how="left")

## Only keep the resuls that have yet to be labeled
joined_results = joined_results.filter(
    joined_results["relevance_rating"].is_null())

## Only keep the columns that are needed so that they match the coluumns
## expected in the labeling script
columns_to_keep = query_results_02.columns
unlabled_02 = joined_results[columns_to_keep].clone()

## Save the results
unlabled_02.write_parquet("../data_labeling/raw_results/query_results_config_02.parquet")