In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import polars as pl
import seaborn as sns
from teeplot import teeplot as tp


In [None]:
df = pl.read_parquet(
    "https://osf.io/emh23/download",
    use_pyarrow=True,
)


In [None]:
print(df.columns)


In [None]:
df = df.with_columns(
    pl.col("SLIP_INSERTION_BOOL_MASK").any().over(
        ["Treatment", "Run ID", "Generation Born"],
    ).alias("SLIP_INSERTION_BOOL_MASK any"),
)


In [None]:
df = df.with_columns(
    pl.col("Is Task Coding Site").any().over(
        ["Treatment", "Run ID", "Generation Born", "Site"],
    )
    .alias("is any coding site"),
)


In [None]:
df = df.with_columns(
    pl.col("Is Task Coding Site Delta").sum().over(
        ["Treatment", "Run ID", "Generation Born"],
    )
    .alias("is task coding site delta sum"),
)


In [None]:
df = df.with_columns(
    pl.col("has task").sum().over(
        ["Treatment", "Run ID", "Generation Born", "Site"],
    )
    .alias("num tasks has"),
)


In [None]:
df = df.with_columns(
    pl.col("is any coding site").sum().over(
        ["Treatment", "Run ID", "Generation Born", "Task",],
    ).alias("num coding sites"),
)


In [None]:
dfx = df.group_by(
    ["Treatment", "Run ID", "Generation Born", "num tasks has", "num coding sites", "is task coding site delta sum", "SLIP_INSERTION_BOOL_MASK any"],
).agg(
    (
        (pl.col("delta has task") == 1).any() * 2
        + (pl.col("delta has task") == -1).any()
    ).replace_strict(
        {
            0: "No change",
            1: "Task loss",
            2: "Task gain",
            3: "Task gain and loss",
        },
    )
    .alias("task change"),
)


In [None]:
dfx


In [None]:
tp.tee(
    sns.catplot,
    hue="task change",
    y="num coding sites",
    x="Treatment",
    kind="box",
    col="num tasks has",
    data=dfx.to_pandas(),
    col_wrap=3,
    teeplot_outattrs={"mut": "poisson"},
)


In [None]:
saveit, g = tp.tee(
    sns.catplot,
    hue="task change",
    y="is task coding site delta sum",
    col="Treatment",
    kind="bar",
    x="num tasks has",
    data=dfx.to_pandas(),
    col_wrap=3,
    teeplot_callback=True,
    teeplot_outattrs={"mut": "poisson"},
)
for ax in g.axes.flat:
    ax.axhline(0, color="black")
    ax.set_yscale("symlog")

saveit()


In [None]:
saveit, g = tp.tee(
    sns.catplot,
    hue="SLIP_INSERTION_BOOL_MASK any",
    y="is task coding site delta sum",
    kind="bar",
    x="num tasks has",
    data=dfx.filter(
        pl.col("Treatment") == "Slip+",
    ).to_pandas(),
    teeplot_callback=True,
    teeplot_outattrs={"mut": "poisson"},
)
for ax in g.axes.flat:
    ax.axhline(0, color="black")
    ax.set_yscale("symlog")

saveit()


In [None]:
saveit, g = tp.tee(
    sns.catplot,
    hue="SLIP_INSERTION_BOOL_MASK any",
    y="is task coding site delta sum",
    kind="bar",
    x="num tasks has",
    estimator="median",
    data=dfx.filter(
        pl.col("Treatment") == "Slip+",
    ).to_pandas(),
    teeplot_callback=True,
    teeplot_outattrs={"mut": "poisson"},
)
for ax in g.axes.flat:
    ax.axhline(0, color="black")
    ax.set_yscale("symlog")

saveit()


In [None]:
saveit, g = tp.tee(
    sns.catplot,
    hue="SLIP_INSERTION_BOOL_MASK any",
    y="is task coding site delta sum",
    kind="strip",
    dodge=True,
    x="num tasks has",
    data=dfx.filter(
        pl.col("Treatment") == "Slip+",
    ).to_pandas(),
    teeplot_callback=True,
    teeplot_outattrs={"mut": "poisson"},
)
for ax in g.axes.flat:
    ax.axhline(0, color="black")
    ax.set_yscale("symlog")

saveit()


In [None]:
saveit, g = tp.tee(
    sns.catplot,
    hue="SLIP_INSERTION_BOOL_MASK any",
    y="is task coding site delta sum",
    kind="violin",
    dodge=True,
    x="num tasks has",
    data=dfx.filter(
        pl.col("Treatment") == "Slip+",
    ).to_pandas(),
    teeplot_callback=True,
    teeplot_outattrs={"mut": "poisson"},
)
for ax in g.axes.flat:
    ax.axhline(0, color="black")
    ax.set_yscale("symlog")

saveit()

In [None]:
saveit, g = tp.tee(
    sns.displot,
    hue="SLIP_INSERTION_BOOL_MASK any",
    kind="hist",
    x="num tasks has",
    data=dfx.filter(
        pl.col("Treatment") == "Slip+",
    ).to_pandas().astype(
        {"num tasks has": "category"},
    ),
    teeplot_callback=True,
    stat="count",
    multiple="dodge",
    shrink=0.8,
    discrete=True,
    teeplot_outattrs={"mut": "poisson"},
)
plt.yscale('log')

# add count labels
# adapted from https://stackoverflow.com/a/55319634/17332200
for ax in g.axes.flat:
    for p in ax.patches:
        ax.annotate(
            text=f"{p.get_height():1.0f}",
            xy=(p.get_x() + p.get_width() / 2., p.get_height()),
            xycoords='data',
            ha='center',
            va='center',
            fontsize=11,
            color='black',
            xytext=(0,7),
            textcoords='offset points',
            clip_on=True,                   # <---  important
        )

saveit()
