In [None]:
import os

import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import polars as pl
import scipy
import seaborn as sns
from teeplot import teeplot as tp


In [None]:
if "CI" not in os.environ:
    df = pl.read_parquet(
        "https://osf.io/gk2ty/download",
        use_pyarrow=True,
    )
    print(df.columns)


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("SLIP_INSERTION_BOOL_MASK").any().over(
            ["Treatment", "Run ID", "Generation Born"],
        ).alias("SLIP_INSERTION_BOOL_MASK any"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("Is Task Coding Site").any().over(
            ["Treatment", "Run ID", "Generation Born", "Site"],
        )
        .alias("is any coding site"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
            pl.col("Is Task Coding Site Delta").sum().over(
                ["Treatment", "Run ID", "Generation Born"],
            )
            .alias("is task coding site delta sum"),
        )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("has task").sum().over(
            ["Treatment", "Run ID", "Generation Born", "Site"],
        )
        .alias("num tasks has"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("is any coding site").sum().over(
            ["Treatment", "Run ID", "Generation Born", "Task",],
        ).alias("num coding sites"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("Is Task Coding Site Cumulative Count").sum().over(
            ["Site", "Lineage Generation Index", "Treatment", "Run ID"]
        ).sign().alias("coded for tasks")
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        pl.col("coded for tasks").sum().over(
            ["Treatment", "Run ID", "Generation Born", "Task",],
        ).alias("num coded sites"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        Components=pl.col("Task").replace_strict(
            {
                "AND": 2,
                "ANDNOT": 3,
                "NAND": 1,
                "NOR": 4,
                "NOT": 1,
                "OR": 3,
                "ORNOT": 2,
                "XOR": 4,
                "EQUALS": 5,
            },
        ),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        MaxComponents=(
            pl.col("Components")
            * pl.col("has task")
        ).max().over(
            ["Treatment", "Run ID", "Generation Born"],
        ).alias("max components"),
    )


In [None]:
if "CI" not in os.environ:
    df = df.with_columns(
        (
            pl.col("MaxComponents") == 5
        ).alias("has equal"),
    )


In [None]:
if "CI" not in os.environ:
    dfx = df.group_by(
        ["Treatment", "Run ID", "Generation Born"],
    ).first()
    dfx.write_parquet("/tmp/num-coding-sites.pqt")
else:
    dfx = pl.read_parquet("https://osf.io/etsfy/download", use_pyarrow=True)


In [None]:
dfx = dfx.with_columns(
    Treatment=pl.col("Treatment").map_elements(
        {
            "Baseline-Treatment": "Baseline",
            "Long-Ancestor-Control-Treatment": "Long-genome",
            "Slip-duplicate": "Slip-duplicate",
        }.__getitem__,
        return_dtype=str,
    ),
)


In [None]:
dfx = dfx.with_columns(
    (
        pl.col("num coding sites") / pl.col("Genome Length")
    ).alias("frac coding sites")
)


In [None]:
dfx = dfx.with_columns(
    (dfx["num coded sites"] - dfx["num coding sites"]).alias("num free sites"),
)


In [None]:
def ensure_combinations(df, group_columns, generation_column):
    df = df.with_columns(
        pl.col(generation_column).cast(pl.Int64),
    )
    # Step 1: Create full range of generations
    min_gen = 0
    max_gen = 600
    full_generations = pl.DataFrame({generation_column: np.arange(min_gen, max_gen + 1)})

    # Step 2: Get unique values of group columns
    unique_groups = df.select(group_columns).unique()

    # Step 3: Generate all combinations
    all_combinations = unique_groups.join(full_generations, how="cross")

    # Step 4: Join the original DataFrame with the complete combinations
    complete_df = all_combinations.join(df, on=group_columns + [generation_column], how="left")

    return complete_df


In [None]:
big_df = ensure_combinations(dfx, ["Treatment", "Run ID"], "Generation Born")
big_df = big_df.fill_null(strategy="forward")
big_df


In [None]:
dfx = big_df


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="num coding sites",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
with tp.teed(
    sns.relplot,
    col="Treatment",
    hue="has equal",
    y="num coding sites",
    x="num coded sites",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    kind="scatter",
) as g:
    pass


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="num coded sites",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
    # teeplot_postprocess="plt.xscale('log')",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="num free sites",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).with_columns(
        (pl.col("num coded sites") - pl.col("num coding sites")).alias("num free sites"),
    ).to_pandas(),
    # teeplot_postprocess="plt.xscale('log')",
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="num tasks has",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    # teeplot_postprocess="plt.xscale('log')",
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="MaxComponents",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    # teeplot_postprocess="plt.xscale('log')",
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
with tp.teed(
    sns.lineplot,
    hue="Treatment",
    style="Treatment",
    y="frac coding sites",
    x="Generation Born",
    data=dfx.filter(
        pl.col("Generation Born") % 16 == 0
    ).to_pandas(),
    # teeplot_postprocess="plt.xscale('log')",
    teeplot_outexclude="style",
    teeplot_postprocess="plt.xlim(0, 600)",
) as ax:
    sns.move_legend(
        ax, "lower center",
        bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,
    )
    ax.spines[["right", "top"]].set_visible(False)


In [None]:
fil = dfx.filter(pl.col("Generation Born") == 599)
for what in ["num coding sites", "num coded sites"]:
    result = scipy.stats.mannwhitneyu(
        fil.filter(pl.col("Treatment") == "Slip-duplicate")[what].to_numpy(),
        fil.filter(pl.col("Treatment") == "Long-Ancestor-Control-Treatment")[what].to_numpy(),
    )
    print(what, result)


In [None]:
fil = dfx.filter(pl.col("Generation Born") == 599)
for what in ["num coding sites", "num coded sites"]:
    result = scipy.stats.mannwhitneyu(
        fil.filter(pl.col("Treatment") == "Slip-duplicate")[what].to_numpy(),
        fil.filter(pl.col("Treatment") == "Baseline-Treatment")[what].to_numpy(),
    )
    print(what, result)


In [None]:
fil = dfx.filter(pl.col("Generation Born") == 599)
for what in ["num coding sites", "num coded sites"]:
    for treatment in ["Slip-duplicate", "Long-Ancestor-Control-Treatment", "Baseline-Treatment"]:
        arr = fil.filter(pl.col("Treatment") == treatment)[what].to_numpy()
        print(what, treatment, np.mean(arr), np.std(arr))
