# Compare simple difference in functional effects across two conditions

Import Python modules.
We use `polyclonal` for the plotting:

In [None]:
import itertools
import math

import altair as alt

import dms_variants.utils

import pandas as pd

import polyclonal
import polyclonal.plot

This notebook is parameterized by `papermill`.
The next cell is tagged as `parameters` to get the passed parameters.

In [None]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
mutation_annotations_csv = None
diffs_csv = None
chart_html = None
corr_chart_html = None
params = None

Read the input data:

In [None]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
assert site_numbering_map[["site", "sequential_site"]].notnull().all().all()
addtl_site_cols = [
    c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]

if mutation_annotations_csv:
    mutation_annotations = pd.read_csv(mutation_annotations_csv)

condition_1 = params["condition_1"]["name"]
condition_2 = params["condition_2"]["name"]
assert condition_1 != condition_2, f"{condition_1=}, {condition_2=}"
condition_1_selections = params["condition_1"]["selections"]
condition_2_selections = params["condition_2"]["selections"]
assert len(condition_1_selections) == len(set(condition_1_selections))
assert len(condition_2_selections) == len(set(condition_2_selections))
assert len(condition_1_selections), params["condition_1"]
assert len(condition_2_selections), params["condition_2"]
if set(condition_1_selections).intersection(condition_2_selections):
    raise ValueError(
        f"shared selections in {condition_1_selections=} and {condition_2_selections=}"
    )

dfs = []
for c, sels in [
    (condition_1, condition_1_selections),
    (condition_2, condition_2_selections),
]:
    for s in sels:
        dfs.append(
            pd.read_csv(
                f"results/func_effects/by_selection/{s}_func_effects.csv"
            ).assign(
                selection=s,
                condition=c,
                times_seen=lambda x: x["times_seen"].astype("Int64"),
                mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
            )
        )
func_effects = pd.concat(dfs, ignore_index=True)

## Correlations among all selections
Compute the correlations in the mutation effects across all selections:

In [None]:
# We compute for several times seen values, get those:
try:
    init_times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
    print("No times seen in params, using a value of 3")
    init_times_seen = 3

# do analysis for each "times_seen"
func_effects_for_corr = pd.concat(
    [
        func_effects.query("times_seen >= @t", engine="python").assign(min_times_seen=t)
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)

corrs = (
    dms_variants.utils.tidy_to_corr(
        df=func_effects_for_corr,
        sample_col="selection",
        label_col="mutation",
        value_col="functional_effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
)

corr_chart = (
    alt.Chart(corrs)
    .encode(
        alt.X("selection_1", title=None),
        alt.Y("selection_2", title=None),
        column=alt.Column("min_times_seen", title=None),
        color=alt.Color("r2", scale=alt.Scale(zero=True)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if c in {"r2", "r"} else c
            for c in ["selection_1", "selection_2", "r2", "r"]
        ],
    )
    .mark_rect(stroke="black")
    .properties(
        width=alt.Step(15),
        height=alt.Step(15),
        title="Per-selection correlation in mutation functional effects",
    )
    .configure_axis(labelLimit=500)
)

display(corr_chart)

print(
    f"\nSelections for {condition_1}: {condition_1_selections}\n"
    f"Selections for {condition_2}: {condition_2_selections}\n"
)

## Average functional effects for each condition
Average the functional effects for each condition using the specified averaging method, then print the correlation between these average functional effects at several times seen:

In [None]:
avg_method = params["avg_method"]
assert avg_method in {"mean", "median"}, avg_method

avg_func_effects = (
    func_effects.groupby(
        ["condition", "site", "wildtype", "mutant", "mutation"], as_index=False
    )
    .aggregate(
        effect=pd.NamedAgg("functional_effect", avg_method),
        times_seen=pd.NamedAgg("times_seen", "sum"),
        n_selections=pd.NamedAgg("site", "count"),
    )
    .assign(
        times_seen=lambda x: (x["times_seen"] / x["n_selections"]).where(
            x["mutant"] != x["wildtype"],
            pd.NA,
        )
    )
)

avg_func_effects_for_corr = pd.concat(
    [
        avg_func_effects.query("times_seen >= @t", engine="python").assign(
            min_times_seen=t
        )
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)
print("Correlation between average functional effects across conditions:")
display(
    dms_variants.utils.tidy_to_corr(
        df=avg_func_effects_for_corr,
        sample_col="condition",
        label_col="mutation",
        value_col="effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
    .query("condition_1 != condition_2")
    .reset_index(drop=True)
    .groupby("min_times_seen")
    .first()
    .round(3)
)

## Compute pairwise differences
Compute pairwise differences in effects between all pairs of condition 1 selections versus condition 2 selections.
For each comparison, we compute the times seen as the mean between the two selections being compared.

We then compute the average (using the specified average method) difference across comparisons, the mean times seen, and the fraction of comparisons in which a difference can be computed:

In [None]:
# compute differences for all individual pairs
diffs_all = []
for sel1, sel2 in itertools.product(condition_1_selections, condition_2_selections):
    df1 = func_effects.query("selection == @sel1")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    df2 = func_effects.query("selection == @sel2")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    diffs_all.append(
        df1.merge(df2, on=["wildtype", "site", "mutant"], validate="1:1")
        .assign(
            times_seen=lambda x: (x["times_seen_x"] + x["times_seen_y"]) / 2,
            difference=lambda x: x["functional_effect_x"] - x["functional_effect_y"],
        )[["wildtype", "site", "mutant", "times_seen", "difference"]]
        .assign(comparison=f"{sel1} vs {sel2}")
    )

# compute average differences across pairs
diffs = (
    pd.concat(diffs_all, ignore_index=True)
    .groupby(["wildtype", "site", "mutant"], as_index=False)
    .aggregate(
        difference=pd.NamedAgg("difference", avg_method),
        difference_std=pd.NamedAgg("difference", "std"),
        times_seen=pd.NamedAgg("times_seen", "mean"),
        fraction_pairs_w_mutation=pd.NamedAgg(
            "difference",
            lambda s: len(s)
            / (len(condition_1_selections) * len(condition_2_selections)),
        ),
    )
)

# add other relevant stuff to data frame of differences
diffs = (
    diffs
    # add average effects in each condition
    .merge(
        avg_func_effects.pivot_table(
            index=["site", "wildtype", "mutant"],
            values="effect",
            columns="condition",
        )
        .reset_index()
        .assign(best_effect=lambda x: x[[condition_1, condition_2]].max(axis=1))
        .rename(columns={c: f"{c} effect" for c in [condition_1, condition_2]}),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # add per-selection effects (times seen)
    .merge(
        func_effects.assign(
            effect_times_seen=lambda x: (
                x["functional_effect"].map(lambda e: f"{e:.2f}")
                + (" (" + x["times_seen"].astype(str) + ")").where(
                    x["mutant"] != x["wildtype"],
                    "",
                )
            )
        )
        .pivot_table(
            index=[
                "site",
                "wildtype",
                "mutant",
            ],
            values="effect_times_seen",
            columns="selection",
            aggfunc=lambda s: ",".join(s),
        )[condition_1_selections + condition_2_selections]
        .reset_index(),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # sort values
    .sort_values(["site", "mutant"]).reset_index(drop=True)
)

print(f"Writing differences to {diffs_csv}")
diffs.to_csv(diffs_csv, index=False, float_format="%.4g")

Make scatter plot of comparisons, applying times seen filter:

In [None]:
print(f"Correlating differences for times_seen of {init_times_seen}")

diffs_all_df = (
    pd.concat(diffs_all)
    .query("times_seen >= @init_times_seen")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .rename(columns={"comparison": "selection"})
)

corr_panels = []
for sel1, sel2 in itertools.combinations(sorted(diffs_all_df["selection"].unique()), 2):
    corr_df = (
        diffs_all_df.query("selection == @sel1")[["mutation", "difference"]]
        .rename(columns={"difference": sel1})
        .merge(
            diffs_all_df.query("selection == @sel2")[["mutation", "difference"]].rename(
                columns={"difference": sel2}
            ),
            validate="one_to_one",
        )
    )
    n = len(corr_df)
    r = corr_df[[sel1, sel2]].corr().values[1, 0]
    corr_panels.append(
        alt.Chart(corr_df)
        .encode(
            alt.X(sel1, scale=alt.Scale(nice=False, padding=4)),
            alt.Y(sel2, scale=alt.Scale(nice=False, padding=4)),
            tooltip=[
                "mutation",
                alt.Tooltip(sel1, format=".3f"),
                alt.Tooltip(sel2, format=".3f"),
            ],
        )
        .mark_circle(color="black", size=30, opacity=0.25)
        .properties(
            width=160,
            height=160,
            title=alt.TitleParams(
                f"R = {r:.2f}, N = {n}", fontSize=11, fontWeight="normal", dy=2
            ),
        )
    )

ncols = 4
corr_rows = []
for irow in range(int(math.ceil(len(corr_panels) / ncols))):
    corr_rows.append(
        alt.hconcat(
            *[
                corr_panels[irow * ncols + icol]
                for icol in range(min(ncols, len(corr_panels[irow * ncols :])))
            ]
        )
    )
alt.vconcat(*corr_rows).configure_axis(grid=False)

## Make a scatter plot comparing the conditions
Make a correlation plot between the two conditions with informative tooltips and slider bars:

In [None]:
mutation_selection = alt.selection_point(
    on="mouseover", fields=["mutation"], empty=False
)

if mutation_annotations_csv:
    if not {"site", "mutant"}.issubset(mutation_annotations.columns):
        raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
    if set(mutation_annotations.columns).intersection(diffs.columns) != {
        "site",
        "mutant",
    }:
        raise ValueError(
            f"{mutation_annotations.columns=} shares columns with {diffs.columns=}"
        )
    diffs = diffs.merge(
        mutation_annotations,
        on=["site", "mutant"],
        how="left",
        validate="many_to_one",
    )
    for col in mutation_annotations.columns:
        if col not in {"site", "mutant"}:
            diffs[col] = diffs[col].where(diffs["wildtype"] != diffs["mutant"], pd.NA)

corr_diffs = (
    diffs.query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
    )
    .drop(columns=["wildtype", "site", "mutant"])
)
corr_diffs = corr_diffs[
    ["mutation"] + [c for c in corr_diffs.columns if c != "mutation"]
]

plot_kwargs = params["plot_kwargs"].copy()
if "slider_binding_range_kwargs" not in plot_kwargs:
    plot_kwargs["slider_binding_range_kwargs"] = {}
if "addtl_slider_stats_as_max" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats_as_max"] = []

sliders = {
    stat: alt.param(
        value=(
            plot_kwargs["addtl_slider_stats"][stat]
            if (
                "addtl_slider_stats" in plot_kwargs
                and stat in plot_kwargs["addtl_slider_stats"]
            )
            else (
                corr_diffs[stat].max()
                if stat == "difference_std"
                else corr_diffs[stat].min()
            )
        ),
        bind=alt.binding_range(
            **(
                {
                    "name": (
                        f"maximum {stat}"
                        if stat in plot_kwargs["addtl_slider_stats_as_max"]
                        else f"minimum {stat}"
                    ),
                    "min": corr_diffs[stat].min(),
                    "max": corr_diffs[stat].max(),
                }
                | (
                    plot_kwargs["slider_binding_range_kwargs"][stat]
                    if stat in plot_kwargs["slider_binding_range_kwargs"]
                    else {}
                )
            )
        ),
    )
    for stat in plot_kwargs["addtl_slider_stats"]
}

corr_chart = (
    alt.Chart(corr_diffs)
    .add_params(mutation_selection)
    .encode(
        alt.X(
            f"{condition_1} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
        ),
        alt.Y(
            f"{condition_2} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
        ),
        strokeWidth=alt.condition(mutation_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mutation_selection, alt.value(70), alt.value(45)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if corr_diffs[c].dtype == float else c
            for c in corr_diffs.columns
        ],
    )
    .mark_circle(fill="black", fillOpacity=0.35, stroke="red")
    .properties(width=275, height=275)
    .configure_axis(grid=False)
)

for stat, slider in sliders.items():
    if stat in plot_kwargs["addtl_slider_stats_as_max"]:
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] <= slider
        )
    else:
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] >= slider
        )

print(f"Saving to {corr_chart_html=}")
corr_chart.save(corr_chart_html)

corr_chart

## Make interactive chart
Set up keyword arguments to [https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap](https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap) if they are not already specified:

In [None]:
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 3

if "difference_std" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["difference_std"] = diffs["difference_std"].max()
    if "addtl_slider_stats_as_max" not in plot_kwargs:
        plot_kwargs["addtl_slider_stats_as_max"] = ["difference_std"]
    else:
        plot_kwargs["addtl_slider_stats_as_max"].append("difference_std")
elif "addtl_slider_stats_as_max" not in plot_kwargs:
    raise ValueError(
        "You specified `difference_std` in `addtl_slider_stats` but did not add it to "
        "`addtl_slider_stats_as_max`. If you really do not want `difference_std` in "
        "`addtl_slider_stats_as_max`, then specify that list without it."
    )

if "fraction_pairs_w_mutation" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["fraction_pairs_w_mutation"] = 0.5

if "site_zoom_bar_color_col" in plot_kwargs:
    if plot_kwargs["site_zoom_bar_color_col"] in diffs.columns:
        pass
    elif plot_kwargs["site_zoom_bar_color_col"] in site_numbering_map.columns:
        diffs = diffs.merge(
            site_numbering_map[["site", plot_kwargs["site_zoom_bar_color_col"]]],
            on="site",
            validate="many_to_one",
            how="left",
        )

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []
for c in ["difference_std"] + addtl_site_cols:
    if c not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append(c)

if "sequential_site" not in diffs.columns:
    diffs = diffs.merge(
        site_numbering_map[["site", *addtl_site_cols]],
        on="site",
        validate="many_to_one",
        how="left",
    )
if any(diffs["site"] != diffs["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

if params["per_selection_tooltips"]:
    assert set(condition_1_selections + condition_2_selections).issubset(diffs.columns)
    plot_kwargs["addtl_tooltip_stats"] += [
        s
        for s in condition_1_selections + condition_2_selections
        if s not in plot_kwargs["addtl_tooltip_stats"]
    ]

if "alphabet" not in plot_kwargs:
    plot_kwargs["alphabet"] = [
        a
        for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
        if a in set(diffs["mutant"])
    ]

if "sites" not in plot_kwargs:
    plot_kwargs["sites"] = site_numbering_map.sort_values("sequential_site")[
        "site"
    ].tolist()

Now make the interactive heatmap:

In [None]:
assert "_dummy" not in diffs.columns

chart = polyclonal.plot.lineplot_and_heatmap(
    data_df=diffs.assign(_dummy="dummy"),
    stat_col="difference",
    category_col="_dummy",
    **plot_kwargs,
)

display(chart)

print(f"\nSaving to {chart_html}")
chart.save(chart_html)