In [None]:
import itertools as it

from matplotlib import pyplot as plt
from nbmetalog import nbmetalog as nbm
import numpy as np
import pandas as pd
import seaborn as sns
from teeplot import teeplot as tp

import pylib


In [None]:
nbm.print_metadata()


In [None]:
df = pylib.util.concat_dataframes_from_glob(
    "a=prevalence-annotation-by-generation+*+ext=.csv"
)


In [None]:
df


In [None]:
def lineplot_twiny(x, y1, y2, color=None, y2_lims=None, **kwargs):
    ax1 = plt.gca()
    ax2 = ax1.twinx()
    ax1.twin = ax2

    sns.lineplot(x=x, y=y1, color="blue", ax=ax1, errorbar="pi")
    sns.lineplot(x=x, y=y2, color="red", ax=ax2, errorbar="pi")
    ax2.set_ylabel("")

    if y2_lims is not None:
        ax2.set_ylim(*y2_lims)


def facet_lineplot_twiny(data, x, y1, y2, col):
    y2_min = data[y2].min()
    y2_max = data[y2].max()

    g = sns.FacetGrid(data, col=col)
    g.map(lineplot_twiny, x, y1, y2, y2_lims=(y2_min, y2_max), sharey=True)
    g.set_axis_labels("Generation", "")
    # Use set_titles() to format the titles of subplots
    g.set_titles("Fitness Advantage {col_name}")
    g.axes.flat[0].set_ylabel("Gene Copy Count", color="blue")
    last_twinx = g.axes.flat[-1].twin
    last_twinx.set_ylabel("Stratum Annotation Bit Count", color="red")
    for ax in g.axes.flat[:-1]:
        ax.twin.set_yticks([])


tp.tee(
    facet_lineplot_twiny,
    data=df,
    x="generation",
    y1="prevalence",
    y2="annotation",
    col="fitness-advantage",
)


In [None]:
rolling = (
    df.copy()
    .sort_values("generation", axis=0)
    .groupby(["replicate", "fitness-advantage"])["annotation"]
    .rolling(16)
    .sum()
    .reset_index()
)
rolling


In [None]:
records = []
for threshold, fitness_advantage in it.product(
    range(501), rolling["fitness-advantage"].unique()
):
    count_above_threshold = rolling[
        (rolling["fitness-advantage"] == fitness_advantage)
        & (rolling["annotation"] >= threshold)
    ]["replicate"].nunique()
    records.append(
        {
            "threshold": threshold,
            "replicate_count": count_above_threshold,
            "fitness-advantage": fitness_advantage,
        },
    )

above_threshold_df = pd.DataFrame.from_records(records)


def lineplot_detection(data, x, y, hue):
    sns.lineplot(
        data,
        x=x,
        y=y,
        hue=hue,
        palette=sns.color_palette("viridis", 3),
    )

    plt.xlabel("Detection Threshold")
    plt.ylabel("Number Replicates with Detected Selection")
    plt.legend(title="Fitness Advantage")


tp.tee(
    lineplot_detection,
    data=above_threshold_df,
    x="threshold",
    y="replicate_count",
    hue="fitness-advantage",
)
