In [None]:
from matplotlib import pyplot as plt
import numpy as np
import outset as otst
from outset import patched as otst_patched
import pandas as pd
import seaborn as sns
from tqdm import tqdm

from pylib.auxlib._jitter import jitter
from pylib.analyze_epistasis import (
    assay_epistasis_naive,
    describe_skeletons,
    skeletonize_naive,
)
from pylib.modelsys_explicit import GenomeExplicit
from pylib.modelsys_explicit import (
    GenomeExplicit,
    CalcKnockoutEffectsAdditive,
    CalcKnockoutEffectsEpistasis,
    create_additive_array,
    create_epistasis_matrix_disjoint,
    describe_additive_array,
    describe_epistasis_matrix,
)


In [None]:
np.random.seed(1234)


## Create Sample Genome


In [None]:
num_sites = 4000
distn = lambda x: np.random.rand(x) * 0.7  # mean effect size of 0.7 / 2
additive_array = create_additive_array(num_sites, 0.05, distn)  # 50 sites
epistasis_matrix = create_epistasis_matrix_disjoint(num_sites, 20, 8)
genome = GenomeExplicit(
    [
        CalcKnockoutEffectsAdditive(additive_array),
        CalcKnockoutEffectsEpistasis(epistasis_matrix, effect_size=(0.7, 1.6)),
    ],
)


## Describe and Inspect Genome


In [None]:
dfa = describe_additive_array(additive_array)
dfb = describe_epistasis_matrix(epistasis_matrix)
df_genome = pd.DataFrame.merge(dfa, dfb, on="site")
df_genome["site type"] = (
    df_genome["additive site"].astype(int)
    + df_genome["epistasis site"].astype(int) * 2
).map(
    {
        0: "neutral",
        1: "additive",
        2: "epistasis",
        3: "both",
    }
)

df_genome


How many of each kind of site are in the genome?


In [None]:
sns.displot(df_genome["site type"])
plt.yscale("log")
print(df_genome["site type"].value_counts())
print("non-neutral", (df_genome["site type"] != "neutral").sum())


## Perform Skeletonizations


In [None]:
num_skeletonizations = 20
skeletons = np.vstack(
    [
        skeletonize_naive(num_sites, genome.test_knockout)
        for _ in tqdm(range(num_skeletonizations))
    ],
)


Example skeleton.


In [None]:
# convert from knockout true to retained true
retained_sites = ~skeletons[0].astype(bool)
sns.rugplot(
    np.flatnonzero(retained_sites),
    height=0.5,
)
retained_sites


## Describe Skeletons


Without neutral sites.


In [None]:
otst_patched.scatterplot(
    pd.DataFrame(
        {
            "skeleton order": np.mean(skeletons, axis=0),
            "skeleton frequency": jitter(
                np.mean(skeletons.astype(bool), axis=0),
                amount=0.01,
            ),
            "site type": df_genome["site type"],
        },
    ),
    x="skeleton order",
    y="skeleton frequency",
    hue="site type",
    style="site type",
    hue_order=["additive", "epistasis", "both"],
    alpha=0.5,
)


Including neutral sites.


In [None]:
assert (np.diff(df_genome["site"]) == 1).all()  # is sorted?
og = otst.OutsetGrid(
    data=pd.DataFrame(
        {
            "skeleton order": np.mean(skeletons, axis=0),
            "skeleton frequency": jitter(
                np.mean(skeletons.astype(bool), axis=0),
                amount=0.01,
            ),
            "site type": df_genome["site type"],
        },
    ),
    x="skeleton order",
    y="skeleton frequency",
    hue="site type",
    col="site type",
    col_wrap=3,
)
og.map_dataframe(
    sns.scatterplot,
    x="skeleton order",
    y="skeleton frequency",
    alpha=0.5,
    legend=False,
)
og.add_legend(loc="lower right", bbox_to_anchor=(0.9, 0.2))
og.marqueeplot()

plt.show()

df_skeletons = describe_skeletons(skeletons, genome.test_knockout)

df_skeletons


How many unique sites are in any skeleton?


In [None]:
np.any(
    (~skeletons.astype(bool)),
    axis=0,
).sum()


## Use Skeleton Jackknifes to Differentiate Epistasis & Small-effect Sites


In [None]:
est = assay_epistasis_naive(
    df_skeletons,
    exclusion_frequency_thresh=0.3,
    jackknife_severity_thresh=0.2,
)
est


In [None]:
df_joint = pd.DataFrame.merge(
    df_genome,
    df_skeletons,
    on="site",
)
ax = sns.scatterplot(
    data={
        "skeleton exclusion rate": jitter(
            df_joint["skeleton outcome frequency, excluded"],
            amount=0.03,
        ),
        "jackknife severity": df_joint["jackknife result"],
        "site type": df_joint["site type"],
    },
    x="skeleton exclusion rate",
    y="jackknife severity",
    hue="site type",
    style="site type",
    alpha=0.5,
)
sns.move_legend(
    ax,
    "upper left",
    bbox_to_anchor=(1, 1),
)
plt.axvline(
    est["exclusion frequency cutoff"],
    ls="--",
)
plt.axhline(
    est["jackknife severity cutoff"],
    ls=":",
)
with plt.rc_context({"hatch.color": "lightblue"}):
    plt.gca().add_patch(
        plt.Rectangle(
            (
                est["exclusion frequency cutoff"],
                est["jackknife severity cutoff"],
            ),
            plt.xlim()[1] - est["exclusion frequency cutoff"],
            plt.ylim()[1] - est["jackknife severity cutoff"],
            alpha=0.05,
            fill=True,
            hatch="\\",
            zorder=-1,
        )
    )
