In [1]:
import polars
from polars.functions import col
from label_legends.preprocess import load_test
from label_legends.result import load_predictions
from label_legends.util import RESOURCE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model = load_predictions("DeBERTa")
base_model.head()

id,prediction
i64,i64
100,1
10005,0
10006,0
10007,0
10008,0


In [3]:
test = load_test().collect()
test.head(100)

id,rewire_id,text,annotator,label_sexist,label_category,label_vector,split,tokens
i64,str,str,i64,str,str,str,str,list[str]
100,"""sexism2022_english-10026""","""It is not insulting, it's **ex…",14,"""sexist""","""2. derogation""","""2.1 descriptive attacks""","""test""","[""it"", ""be"", … "".**""]"
10005,"""sexism2022_english-13""","""and if you buy into the Christ…",6,"""not sexist""","""none""","""none""","""test""","[""and"", ""if"", … ""..""]"
10006,"""sexism2022_english-13""","""and if you buy into the Christ…",4,"""not sexist""","""none""","""none""","""test""","[""and"", ""if"", … ""..""]"
10007,"""sexism2022_english-13""","""and if you buy into the Christ…",8,"""not sexist""","""none""","""none""","""test""","[""and"", ""if"", … ""..""]"
10008,"""sexism2022_english-130""","""Given the sub this is posted i…",10,"""not sexist""","""none""","""none""","""test""","[""give"", ""the"", … ""shit""]"
…,…,…,…,…,…,…,…,…
10410,"""sexism2022_english-1312""","""Yup. The PS Vita is filled wit…",8,"""not sexist""","""none""","""none""","""test""","[""yup"", ""."", … "".""]"
10411,"""sexism2022_english-1312""","""Yup. The PS Vita is filled wit…",9,"""not sexist""","""none""","""none""","""test""","[""yup"", ""."", … "".""]"
10412,"""sexism2022_english-1312""","""Yup. The PS Vita is filled wit…",6,"""not sexist""","""none""","""none""","""test""","[""yup"", ""."", … "".""]"
10425,"""sexism2022_english-13124""","""Did you see how black Viola Da…",5,"""not sexist""","""none""","""none""","""test""","[""do"", ""you"", … ""dress""]"


In [13]:
def assign_type(prediction: int, label: str):
    if prediction == 0:
        if label == "not sexist":
            return "tn"
        return "fn"
    if label == "sexist":
        return "tp"
    return "fp"
    
joined = base_model.join(test, on="id") \
    .with_columns(polars.struct(["prediction", "label_sexist"]) \
    .map_elements(lambda x: assign_type(x['prediction'],x['label_sexist']), return_dtype=polars.String).alias("type")) \
    .select(["id", "rewire_id", "prediction", "label_sexist", "type", "text", "tokens"]) \
    .filter(col("type") == "fn") \
    # .select(["tokens", "text"]).to_dicts()
joined.head()

{'tokens': shape: (1_015,)
 Series: 'tokens' [list[str]]
 [
 	["i", "be", … "that"]
 	["that", "poor", … "."]
 	["that", "create", … "."]
 	["if", "i", … "."]
 	["greek", "authority", … ":"]
 	…
 	["somebody", "be", … "!"]
 	["somebody", "be", … "!"]
 	["you", "misogynist", … "/s"]
 	["shudder", "..", … "."]
 	["shudder", "..", … "."]
 ],
 'text': shape: (1_015,)
 Series: 'text' [str]
 [
 	"Im actually glad they made wom…
 	"That poor bastard. Really scra…
 	"Those create your owns (As in …
 	"If I recall correctly, the Ice…
 	"Greek authorities move 400 'as…
 	…
 	"Somebody be guilt trippin' and…
 	"Somebody be guilt trippin' and…
 	"You misogynist cisgendered scu…
 	"Shudder.. if you had to have s…
 	"Shudder.. if you had to have s…
 ]}

In [5]:
(csv_path := RESOURCE / "analysis" / "deberta_base_predictions.csv").parent.mkdir(exist_ok=True, parents=True)
joined.write_csv(csv_path)

ComputeError: CSV format does not support nested data

In [None]:
joined.group_by("type").len()

In [None]:
joined.group_by("rewire_id", "label_sexist").count().rename({"count": "positive_labels"}).filter(col("label_sexist") == "sexist").group_by("positive_labels").count()

In [None]:
joined.group_by("rewire_id").agg(col("label_sexist").filter(col("label_sexist") == "sexist").len()).group_by("label_sexist").len().sort("label_sexist")

In [None]:
polars.DataFrame()