# Distribution of functional effects
Get input variables

In [None]:
# this cell tagged as `parameters` for `papermill` parameterization
fitness_csv = None
xbb15_func_effects_csv = None
ba2_func_effects_csv = None
site_numbering_map_csv = None
init_min_times_seen = None
init_min_n_libraries = None
init_expected_count = None
strain_corr_html = None
natural_corr_html = None
effects_boxplot_html = None

In [None]:
import altair as alt

import numpy

import pandas as pd

import scipy.stats

_ = alt.data_transformers.disable_max_rows()

In [None]:
xbb15_func_effects = pd.read_csv(xbb15_func_effects_csv)
ba2_func_effects = pd.read_csv(ba2_func_effects_csv)
site_numbering_map = pd.read_csv(site_numbering_map_csv)

func_effects = pd.concat(
    [
        xbb15_func_effects.assign(strain="XBB.1.5"),
        ba2_func_effects.assign(strain="BA.2"),
    ]
).merge(
    site_numbering_map.rename(columns={"reference_site": "site"})[["site", "region"]]
)

# Plot distribution of functional effects for different domains

Make box plots:

In [None]:
dist_df = (
    func_effects
    .query("wildtype != mutant")
    .query("wildtype != '*'")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        mut_type=lambda x: numpy.where(
            x["mutant"] == "*",
            "stop codon",
            numpy.where(x["mutant"] == "-", "deletion", "substitution"),
        ),
        times_seen=lambda x: x["times_seen"].astype(int),
    )
    .rename(columns={"n_selections": "n_libraries"})
    [["strain", "mutation", "effect", "mut_type", "region", "times_seen", "n_libraries"]]
)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=1,
        step=0.5,
        max=min(10, dist_df["times_seen"].max()),
    ),
)

n_libraries_slider = alt.param(
    value=init_min_n_libraries,
    bind=alt.binding_range(
        name="minimum number of libraries",
        min=1,
        step=1,
        max=dist_df["n_libraries"].max(),
    ),
)

effect_floor_slider = alt.param(
    value=dist_df["effect"].min(),
    bind=alt.binding_range(
        name="mutation effect floor (clip values < this)",
        min=dist_df["effect"].min(),
        max=0,
    ),
)

n_mutations_slider = alt.param(
    value=1,
    bind=alt.binding_range(
        name="mutation number of mutations to show category",
        min=1,
        max=50,
    ),
)

dist_boxplot = (
    alt.Chart(dist_df)
    .transform_filter(alt.datum["times_seen"] > times_seen_slider)
    .transform_filter(alt.datum["n_libraries"] >= n_libraries_slider)
    .transform_calculate(
        effect_floored=alt.expr.max(alt.datum["effect"], effect_floor_slider),
    )
    .transform_joinaggregate(
        n_mutations="count()",
        groupby=["region", "strain", "mut_type"],
    )
    .transform_filter(alt.datum["n_mutations"] >= n_mutations_slider)
    .encode(
        x=alt.X(
            "mut_type",
            title=None,
            axis=alt.Axis(labelFontSize=11, labelFontStyle="bold", labelAngle=0),
            scale=alt.Scale(domain=["substitution", "stop codon", "deletion"]),
        ),
        y=alt.Y(
            "effect_floored:Q",
            title="mutation effect on cell entry",
            scale=alt.Scale(nice=False, padding=2),
        ),
        color=alt.Color("region"),
        xOffset=alt.XOffset("region"),
        column=alt.Column(
            "strain",
            title=None,
            sort=["XBB.1.5", "BA.2"],
            header=alt.Header(labelFontSize=13, labelFontStyle="bold", labelPadding=2),
            spacing=2,
        ),
    )
    .mark_boxplot(outliers=False, extent=0.75, size=12)
    .configure_axis(grid=False)
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        effect_floor_slider,
        n_mutations_slider,
    )
    .properties(height=160, width=alt.Step(14))
)

print(f"Saving to {effects_boxplot_html}")
dist_boxplot.save(effects_boxplot_html)

dist_boxplot

## Correlation with natural evolution fitness estimates

In [None]:
fitness = (
    pd.read_csv(fitness_csv)
    .query("gene == 'S'")
    .drop(columns=["gene", "subset_of_ORF1ab", "aa_differs_among_clade_founders"])
    .rename(columns={"aa_site": "site"})
)

effect_vs_fitness = (
    func_effects
    .merge(
        (
            fitness
            .rename(
                columns={
                    "aa": "wildtype",
                    "fitness": "wildtype_fitness",
                    "expected_count": "wildtype_expected_count"
                }
            )
        ),
        on=["site", "wildtype"],
        validate="many_to_one",
    )
    .merge(
        (
            fitness
            .rename(
                columns={
                    "aa": "mutant",
                    "fitness": "mutant_fitness",
                    "expected_count": "mutant_expected_count"
                }
            )
        ),
        on=["site", "mutant"],
        validate="many_to_one",
    )
    .query("wildtype != mutant")
    .assign(
        fitness_effect=lambda x: x["mutant_fitness"] - x["wildtype_fitness"],
        expected_count=lambda x: x[["wildtype_expected_count", "mutant_expected_count"]].min(axis=1),
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        times_seen=lambda x: x["times_seen"].astype(int),
    )
    .rename(columns={"n_selections": "n_libraries"})
    [["strain", "mutation", "effect", "fitness_effect", "times_seen", "n_libraries", "expected_count"]]
)

effect_vs_fitness

In [None]:
expected_count_slider = alt.param(
    value=init_expected_count,
    bind=alt.binding_range(
        name="minimum expected counts (natural sequences)",
        min=1,
        step=1,
        max=min(50, effect_vs_fitness["expected_count"].max()),
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

effect_vs_fitness_base = (
    alt.Chart(effect_vs_fitness)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["n_libraries"] >= n_libraries_slider)
    .transform_filter(alt.datum["expected_count"] >= expected_count_slider)
    .transform_calculate(effect_floored=alt.expr.max(effect_floor_slider, alt.datum["effect"]))
    .encode(
        x=alt.X(
            "effect_floored:Q",
            title="effect on cell entry (DMS)",
            scale=alt.Scale(nice=False, padding=4),
        ),
        y=alt.Y(
            "fitness_effect",
            title="fitness from natural sequences",
            scale=alt.Scale(nice=False, padding=4),
        ),
        tooltip=effect_vs_fitness.columns.tolist(),
    )
)

effect_vs_fitness_scatter = (
    effect_vs_fitness_base
    .encode(
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.1)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(50), alt.value(25)),
    )
    .mark_circle(color="black", stroke="red")
)

effect_vs_fitness_r = (
    effect_vs_fitness_base
    .transform_regression("effect_floored", "fitness_effect", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)

effect_vs_fitness_chart = (
    (effect_vs_fitness_scatter + effect_vs_fitness_r)
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        expected_count_slider,
        effect_floor_slider,
        mut_selection,
    )
    .properties(width=170, height=170)
    .facet(
        column=alt.Column(
            "strain",
            title=None,
            header=alt.Header(labelFontSize=13, labelFontStyle="bold", labelPadding=2),
        ),
        spacing=5,
    )
    .configure_axis(grid=False)
)

print(f"Saving to {natural_corr_html}")
effect_vs_fitness_chart.save(natural_corr_html)

effect_vs_fitness_chart

## Correlation between BA.2 and XBB.1.5 mutation effects

In [None]:
strain_corr_df = (
    func_effects
    .query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        times_seen=lambda x: x.groupby("mutation")["times_seen"].transform("min").astype(int),
        n_libraries=lambda x: x.groupby("mutation")["n_selections"].transform("min"),
    )
    .pivot_table(
        index=["mutation", "times_seen", "n_libraries"],
        columns="strain",
        values="effect",
    )
    .dropna(axis=0)
    .reset_index()
    .rename(columns={"BA.2": "BA_2", "XBB.1.5": "XBB_1_5"})
)

assert len(strain_corr_df) == strain_corr_df["mutation"].nunique()

strain_corr_df

In [None]:
strain_corr_base = (
    alt.Chart(strain_corr_df)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["n_libraries"] >= n_libraries_slider)
    .encode(
        x=alt.X(
            "XBB_1_5",
            title="effect on cell entry in XBB.1.5",
            scale=alt.Scale(nice=False, padding=4),
        ),
        y=alt.Y(
            "BA_2",
            title="effect on cell entry in BA.2",
            scale=alt.Scale(nice=False, padding=4),
        ),
        tooltip=strain_corr_df.columns.tolist(),
    )
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        mut_selection,
    )
    .properties(width=170, height=170)
)

strain_corr_scatter = (
    strain_corr_base
    .encode(
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.1)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(50), alt.value(25)),
    )
    .mark_circle(color="black", stroke="red")
)

strain_corr_r = (
    strain_corr_base
    .transform_regression("BA_2", "XBB_1_5", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)

strain_corr_chart = (strain_corr_scatter + strain_corr_r).configure_axis(grid=False)

print(f"Saving to {strain_corr_html}")
strain_corr_chart.save(strain_corr_html)

strain_corr_chart