In [None]:
%load_ext watermark


In [None]:
import itertools as it

from IPython.display import display
from matplotlib import ticker as mpl_ticker
import pandas as pd
import polars as pl
import seaborn as sns
from teeplot import teeplot as tp

from pylib.munge._calc_fixprobs_from_traits import calc_fixprobs_from_traits


In [None]:
%watermark -diwmuv -iv


In [None]:
teeplot_subdir = "cupy-traits"
teeplot_subdir


## Prep Data


In [None]:
dfxs = pl.concat(
    [
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/s67d2/download"),
            ),
        ).with_columns(
            pl.lit("50/50").alias("initial conditions"),
            pl.lit("1D demes").alias("population structure"),
        ),
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/8q5v6/download"),
            ),
        ).with_columns(
            pl.lit("50/50").alias("initial conditions"),
            pl.lit("2D demes").alias("population structure"),
        ),
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/duam2/download"),
            ),
        ).with_columns(
            pl.lit("50/50").alias("initial conditions"),
            pl.lit("well-mixed").alias("population structure"),
        ),
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/seuyf/download"),
            ),
        ).with_columns(
            pl.lit("de novo").alias("initial conditions"),
            pl.lit("1D demes").alias("population structure"),
        ),
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/ag4ur/download"),
            ),
        ).with_columns(
            pl.lit("de novo").alias("initial conditions"),
            pl.lit("2D demes").alias("population structure"),
        ),
        pl.from_pandas(
            calc_fixprobs_from_traits(
                pd.read_parquet("https://osf.io/edwbu/download"),
            ),
        ).with_columns(
            pl.lit("de novo").alias("initial conditions"),
            pl.lit("well-mixed").alias("population structure"),
        ),
    ],
)

display(dfxs.describe()), display(dfxs.head()), display(dfxs.tail());


## Size Fixation Cliffplot


In [None]:
data = dfxs.filter(
    pl.col("genotype") == "hypermutator",
    ~(
        (pl.col("population size") > 1679616)
        & (pl.col("population structure") == "1D demes")
    ),
    ~(
        (pl.col("population size") <= 256)
        & (pl.col("available beneficial mutations") > 12)
        & (pl.col("population structure") == "1D demes")
    ),
).with_columns(
    pl.col("fixation probability").alias("fix\nprob"),
    pl.col("population structure").alias("population\nstructure"),
    pl.col("initial conditions").alias("initial\nconditions"),
)

row = "population size"
for errorbar, exclude in it.product(
    ["sd", "se", "ci", None],
    ["1D demes", None],
):
    with tp.teed(
        sns.relplot,
        data=data.filter(
            pl.col("population structure") != pl.lit(exclude)
            if exclude is not None
            else True,
        ),
        x="available beneficial mutations",
        y="fix\nprob",
        row=row,
        row_order=sorted(
            data[row].to_pandas().unique(), reverse=True
        ),
        hue="population\nstructure",
        hue_order=[
            *filter(
                lambda x: x != exclude,
                ("well-mixed", "2D demes", "1D demes"),
            ),
        ],
        style="initial\nconditions",
        style_order=["de novo", "50/50"],
        aspect=10,
        errorbar=errorbar,
        # facet_kws=dict(margin_titles=True),
        height=0.8,
        kind="line",
        markers=True,
        palette="Dark2",
        seed=1,
        teeplot_subdir=teeplot_subdir,
        teeplot_outattrs=(
            {"exclude": exclude.replace(" ", "-")}
            if exclude is not None
            else {}
        ),
    ) as teed:
        teed.set_titles(
            col_template="",
            row_template="",
        )
        teed.set(ylim=(0, 1), xlim=(1, 40))
        sns.move_legend(
            teed,
            "lower center",
            bbox_to_anchor=(0.4, 0.97),
            frameon=False,
            ncol=7,
            title=None,
            columnspacing=0.7,
        )
        for ax in teed.axes.flat:
            ax.axhline(0.5, ls=":", c="gray", lw=1)
            ax.set_yticks([0, 0.5, 1])
            ax.set_yticklabels(["", "0.5", ""])
            ax.set_ylabel("")

        teed.figure.subplots_adjust(top=0.85)
        delta = 0.001  # Small height for the dummy axis
        pos = teed.axes.flat[0].get_position()
        dummy_ax = ax.figure.add_axes(
            [
                pos.x0 + pos.width,
                pos.y0 - pos.height * data[row].to_pandas().nunique() * 1.9,
                delta,
                pos.height * data[row].to_pandas().nunique() * 2.7,
            ],
        )
        yvals = data[row]
        dummy_ax.set_ylim(yvals.min(), yvals.max())
        dummy_ax.yaxis.set_label_position("right")
        dummy_ax.set_ylabel(row)

        dummy_ax.set_yscale("log")
        dummy_ax.yaxis.set_ticks_position("right")
        dummy_ax.yaxis.set_major_locator(
            mpl_ticker.LogLocator(base=10),
        )
        formatter = mpl_ticker.LogFormatterMathtext(base=10)
        dummy_ax.yaxis.set_major_formatter(formatter)

        # Hide unwanted spines
        dummy_ax.get_xaxis().set_visible(False)

        teed.tight_layout()
