# Compare effects of mutations on mouse Mxra8 binding to the effects on cell entry

In [None]:
import itertools

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

In [None]:
# this cell is tagged as `parameters` for papermill parameterization

data_csv = None
mut_effects_floor = None
min_293T_Mxra8_entry = None
cells = None
corr_chart_html = None
paper_fig_corr_chart_html = None

In [None]:
entry_cols = [f"entry in {c} cells" for c in cells]
bind_col = "binding to mouse Mxra8"

# read the data
data = (
    pd.read_csv(data_csv)
    .sort_values("sequential_site")
    .query("`binding to mouse Mxra8`.notnull()")
    .query("wildtype != mutant")
    .query("`entry in 293T_Mxra8 cells` >= @min_293T_Mxra8_entry")
)
assert set(entry_cols).issubset(data.columns), f"{entry_cols=}, {data.columns=}"

for entry_col in entry_cols:
    data[entry_col] = data[entry_col].clip(lower=mut_effects_floor)

# compute the differences in entry between cells
diff_cols = []
for (c1, e1), (c2, e2) in itertools.combinations(zip(cells, entry_cols, strict=True), 2):
    diff_col = f"{c1} minus {c2} entry"
    diff_cols.append(diff_col)
    data[diff_col] = data[e1] - data[e2]

# compute Pearson correlations, flooriong effects on cell entry first
corrs = (
    data
    .melt(
        id_vars=["site", "mutant", bind_col],
        value_vars=entry_cols + diff_cols,
        var_name="entry type",
        value_name="effect",
    )
    .groupby("entry type")
    [[bind_col, "effect"]]
    .corr(method="pearson")
    .reset_index(level=1)
    .query("level_1 != 'effect'")
    .drop(columns=[bind_col, "level_1"])
    .reset_index()
    .rename(columns={"effect": "correlation"})
)
print(f"Correlations with {bind_col=}")
corrs

In [None]:
# plot the correlations
mut_selection = alt.selection_point(on="mouseover", fields=["site", "mutant"], empty=False)

region_selection = alt.selection_point(fields=["region"], bind="legend")

corr_base = (
    alt.Chart()
    .transform_filter(region_selection)
    .transform_filter(alt.expr.isValid(alt.datum["entry effect"]))
    .properties(width=240, height=240)
)

corr_scatter = (
    corr_base
    .encode(
        alt.X(bind_col, scale=alt.Scale(padding=10, nice=False, zero=False)),
        alt.Y("entry effect:Q", title=None, scale=alt.Scale(padding=10, nice=False, zero=False)),
        alt.Color("region", scale=alt.Scale(domain=data["region"].unique())),
        strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
        tooltip=["site", "wildtype", "mutant", "sequential_site", alt.Tooltip("entry effect:Q", format=".2f"), alt.Tooltip(bind_col, format=".2f")],
    )
    .mark_point(
        filled=True,
        fillOpacity=0.5,
        stroke="black",
        strokeOpacity=1,
    )
)

corr_text = (
    corr_base
    .transform_regression(bind_col, "entry effect", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] >= 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        label='"r = " + format(datum.r, ".2f")',
    )
    .mark_text(align="left", color="dimgray", fontWeight=500, fontSize=15, opacity=1)
    .encode(x=alt.value(5), y=alt.value(9), text=alt.Text("label:N"))
)

corr_chart = (
    alt.FacetChart(
        data,
        spec=alt.layer(corr_scatter, corr_text),
        facet=alt.Row("entry type:N", title=None),
        columns=3,
    )
    .transform_fold(entry_cols + diff_cols, ["entry type", "entry effect"])
    .resolve_scale(y="independent")
    .configure_axis(grid=False, labelFontSize=11, titleFontSize=14, labelOverlap="greedy")
    .configure_header(
        labelOrient="left",
        labelFontSize=14,
        labelFontStyle="bold",
    )
    .configure_legend(titleFontSize=14, labelFontSize=14)
    .add_params(mut_selection, region_selection)
)

corr_chart.save(corr_chart_html)

corr_chart

In [None]:
fig_data = data[
    ["site", "wildtype", "mutant", "region", "entry in 293T_Mxra8 cells", "binding to mouse Mxra8"]
]

fig_corr = (
    fig_data
    .groupby("region")
    [["entry in 293T_Mxra8 cells", "binding to mouse Mxra8"]]
    .corr()
    .reset_index()
    .query("level_1 == 'entry in 293T_Mxra8 cells'")
    [["region", "binding to mouse Mxra8"]]
    .rename(columns={"binding to mouse Mxra8": "region_corr"})
    .assign(region_corr=lambda x: x["region"] + x["region_corr"].map(lambda r: f" mutations (R = {r:.2f})"))
)

fig_data = fig_data.merge(fig_corr, on="region", validate="many_to_one")

fig_chart = (
    alt.Chart(fig_data)
    .encode(
        alt.X("binding to mouse Mxra8", scale=alt.Scale(nice=False, padding=7)),
        alt.Y("entry in 293T_Mxra8 cells", scale=alt.Scale(nice=False, padding=7), title="entry in 293T-Mxra8 cells"),
        alt.Column("region_corr", title=None, header=alt.Header(labelPadding=1, labelFontWeight="bold", labelFontSize=12)),
        tooltip=["site", "wildtype", "mutant"],
    )
    .mark_circle(color="gray", stroke="black", fillOpacity=0.25, strokeWidth=0.4, strokeOpacity=0.7)
    .configure_axis(grid=False)
    .properties(width=165, height=165)
)

print(f"Saving to {paper_fig_corr_chart_html=}")
fig_chart.save(paper_fig_corr_chart_html)

fig_chart