# Compare binding to different receptor ligands

In [1]:
import itertools

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

In [2]:
# this cell is tagged parameters for `papermill` parameterization

entry_293T_human_Mxra8 = None
binding_human_Mxra8 = None
binding_mouse_Mxra8 = None
corr_chart_html = None
min_entry_293T_human_Mxra8 = None
min_entry_293T_human_Mxra8_std = None
min_mouse_Mxra8_binding_std = None
min_human_Mxra8_binding_std = None
min_times_seen = None

In [3]:
# Parameters
min_entry = -4  # or should this be -4?
min_entry_std = 2.25
entry_csv = "results/func_effects/averages/293T-Mxra8_entry_func_effects.csv"
entry_name = "entry in 293T-Mxra8 cells"
min_times_seen = 2

ligands = {"mouse_Mxra8": "mouse Mxra8", "human_Mxra8": "human Mxra8"}
binding_csvs = {
    "human_Mxra8": "results/receptor_affinity/averages/human_Mxra8_mut_effect.csv",
    "mouse_Mxra8": "results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv",
}
binding_csv_col_names = {"human_Mxra8": "Mxra8", "mouse_Mxra8": "Mxra8"}
max_binding_stds = {"human_Mxra8": 2.5, "mouse_Mxra8": 2.25}

site_numbering_map = "data/site_numbering_map.csv"
addtl_site_annotations = "data/addtl_site_annotations.csv"
addtl_site_annotations_cols = {
    "domain": "domain",
    "contacts": "Mxra8 contact",
}

import os
os.chdir("../")

## Read the data

In [4]:
# read the data

print(f"Reading cell entry from {entry_csv=}")
data_df = (
    pd.read_csv(entry_csv)
    .query("times_seen >= @min_times_seen")
    .query("effect_std <= @min_entry_std")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    [["site", "wildtype", "mutant", "effect"]]
    .rename(columns={"effect": "entry"})
)

for ligand in ligands:
    print(f"Reading binding to {ligand=} from {binding_csvs[ligand]=}")
    max_std = max_binding_stds[ligand]
    col_name = binding_csv_col_names[ligand]
    bind_df = (
        pd.read_csv(binding_csvs[ligand])
        .query("times_seen >= @min_times_seen")
        .query("frac_models == 1")
        .query(f"`{col_name} binding_std` <= @max_std")
        .rename(columns={f"{col_name} binding_median": ligand})
    )
    bind_rep_cols = bind_df.columns[11: ].tolist()
    bind_df = (
        bind_df
        .assign(
            label=lambda x: x.apply(
                lambda r: f"{r[ligand]:.2f} ({', '.join(str(round(r[c], 2)) for c in bind_rep_cols)})",
                axis=1,
            )
        )
        .rename(columns={"label": f"{ligand}_label"})
        [["site", "wildtype", "mutant", ligand, f"{ligand}_label"]]
    )
    data_df = data_df.merge(
        bind_df, how="left", on=["site", "mutant", "wildtype"], validate="1:1"
    )

print(f"Adding sequential site from {site_numbering_map=}")
data_df = data_df.merge(
    pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})[
        ["site", "sequential_site", "region"]
    ],
    on="site",
    validate="many_to_one",
)

print(f"Adding site annotations from {addtl_site_annotations=}")
data_df = data_df.merge(
    (
        pd.read_csv(addtl_site_annotations)
        [["sequential_site"] + list(addtl_site_annotations_cols)]
        .rename(columns=addtl_site_annotations_cols)
    ),
    on="sequential_site",
    validate="many_to_one",
    how="left",
)

data_df = (
    data_df
    .query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        **{"Mxra8 contact": lambda x: x["Mxra8 contact"].fillna("no")},
    )
    .sort_values(["sequential_site", "mutant"])
    .reset_index(drop=True)
)

data_df

Reading cell entry from entry_csv='results/func_effects/averages/293T-Mxra8_entry_func_effects.csv'
Reading binding to ligand='mouse_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv'
Reading binding to ligand='human_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/human_Mxra8_mut_effect.csv'
Adding sequential site from site_numbering_map='data/site_numbering_map.csv'
Adding site annotations from addtl_site_annotations='data/addtl_site_annotations.csv'


Unnamed: 0,site,wildtype,mutant,entry,mouse_Mxra8,mouse_Mxra8_label,human_Mxra8,human_Mxra8_label,sequential_site,region,domain,Mxra8 contact,mutation
0,-1(E3),M,I,-7.5410,,,,,1,E3,,no,M-1(E3)I
1,-1(E3),M,T,-7.5630,,,,,1,E3,,no,M-1(E3)T
2,1(E3),S,A,-1.0250,-0.11910,"-0.12 (-0.06, -0.18)",0.04762,"0.05 (0.06, 0.03)",2,E3,E3,no,S1(E3)A
3,1(E3),S,C,-0.7132,-0.21170,"-0.21 (-0.44, 0.01)",-0.73310,"-0.73 (-0.61, -0.85)",2,E3,E3,no,S1(E3)C
4,1(E3),S,D,0.1852,0.02613,"0.03 (0.02, 0.04)",-0.21540,"-0.22 (-0.21, -0.22)",2,E3,E3,no,S1(E3)D
...,...,...,...,...,...,...,...,...,...,...,...,...,...
18957,439(E1),H,V,-0.4753,,,,,988,E1,E1-cytoplasmic,no,H439(E1)V
18958,439(E1),H,W,-0.2051,0.23070,"0.23 (-0.03, 0.49)",-0.28620,"-0.29 (-0.64, 0.07)",988,E1,E1-cytoplasmic,no,H439(E1)W
18959,439(E1),H,Y,-0.2293,-0.01344,"-0.01 (-0.12, 0.1)",-0.24560,"-0.25 (-0.29, -0.2)",988,E1,E1-cytoplasmic,no,H439(E1)Y
18960,440(E1),*,Q,-3.3990,0.13000,"0.13 (-0.02, 0.28)",-1.51300,"-1.51 (-2.55, -0.48)",989,E1,,no,*440(E1)Q


## Simple correlation of binding to different ligands across all mutations

In [5]:
# plot the data

site_selection = alt.selection_point(on="mouseover", empty=False, fields=["site"])

mut_selection = alt.selection_point(on="mouseover", empty=False, fields=["mutation"])

min_entry_slider = alt.param(
    name="min_entry_slider",
    bind=alt.binding_range(
        min=data_df["entry"].min(),
        max=0,
        name=f"minimum {entry_name}",
    ),
    value=min_entry,
)

mut_corr_base = alt.Chart(
    data_df[
        ["mutation", "entry", "site"]
        + list(ligands)
        + [f"{lig}_label" for lig in ligands]
    ]
)

for ligand1, ligand2 in itertools.combinations(ligands, 2):
    
    mut_corr_chart = (
        mut_corr_base
        .add_params(site_selection, mut_selection, min_entry_slider)
        .transform_filter(alt.datum["entry"] >= min_entry_slider)
        .encode(
            alt.X(
                ligand1,
                title=f"binding to {ligands[ligand1]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            alt.Y(
                ligand2,
                title=f"binding to {ligands[ligand2]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            color=alt.condition(site_selection, alt.value("red"), alt.value("gray")),
            opacity=alt.condition(site_selection, alt.value(0.9), alt.value(0.15)),
            size=alt.condition(site_selection, alt.value(55), alt.value(40)),
            strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0.6)),
            tooltip=[
                "mutation",
                alt.Tooltip("entry", format=".2f", title=entry_name),
                alt.Tooltip(f"{ligand1}_label", title=ligands[ligand1]),
                alt.Tooltip(f"{ligand2}_label", title=ligands[ligand2]),
            ],
        )
        .mark_circle(stroke="black")
        .properties(
            width=175,
            height=175,
        )
        .configure_axis(grid=False)
    )

    display(mut_corr_chart)

## Plot site effects on binding
We pre-filter on the entry cutoff, and then get the summed positive and negative effects at each site scaled by the max across all sites for the positive and negative effect for that ligand:

In [6]:
data_filtered_df = data_df.query("entry >= @min_entry")

site_df = (
    data_filtered_df
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        value_vars=ligands,
        var_name="ligand",
        value_name="effect",
    )
    .groupby(
        ["ligand", "site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        as_index=False,
        dropna=False,
    )
    .aggregate(
        positive_effect=pd.NamedAgg("effect", lambda s: s.clip(lower=0).sum()),
        negative_effect=pd.NamedAgg("effect", lambda s: s.clip(upper=0).sum()),
        absolute_effect=pd.NamedAgg("effect", lambda s: s.abs().sum()),
    )
    # scale by min / max
    .assign(
        norm=lambda x: pd.concat(
            [
                x.groupby("ligand")["positive_effect"].transform("max"),
                -x.groupby("ligand")["negative_effect"].transform("min"),
            ],
            axis=1
        ).max(axis=1),
        positive_effect=lambda x: x["positive_effect"] / x["norm"],
        negative_effect=lambda x: x["negative_effect"] / x["norm"],
        absolute_effect=lambda x: x["absolute_effect"] / x.groupby("ligand")["absolute_effect"].transform("max"),
    )
    .drop(columns="norm")
)

In [7]:
chart_width = 600

site_binding_chart = (
    alt.Chart(
        site_df.assign(ligand_name=lambda x: "binding to " + x["ligand"].map(ligands))
    )
    .encode(
        alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(
                values=site_df[["sequential_site", "site"]].sort_values("sequential_site")["site"].iloc[50::130],
                labelAngle=0,
                titleFontSize=11,
            ),
        ),
        alt.Y("positive_effect", title=None, scale=alt.Scale(nice=False, padding=4)),
        alt.Y2("negative_effect", title=None),
        alt.Color(
            "Mxra8 contact",
            scale=alt.Scale(
                domain=["no", "wrapped", "intraspike", "interspike"],
                range=["gray", "red", "purple", "orange"],
            ),
        ),
        alt.Row(
            "ligand_name",
            title=None,
            header=alt.Header(labelFontSize=11, labelFontStyle="bold", labelPadding=2),
            spacing=5,
        ),
        tooltip=[
            "site",
            "wildtype",
            alt.Tooltip("positive_effect", format=".2f"),
            alt.Tooltip("negative_effect", format=".2f"),
            "Mxra8 contact",
        ],
    )
    .mark_bar(opacity=1, width=1)
    .properties(width=chart_width, height=128)
)

Make overlay bar with regions:

In [8]:
region_chart = (
    alt.Chart(site_df[["sequential_site", "region"]].drop_duplicates())
    .encode(
        alt.X("sequential_site:O", axis=None),
        alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
        ),
    )
    .mark_rect(opacity=0.75, strokeWidth=0)
    .properties(width=chart_width)
)

text_df = site_df.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))

text_chart = (
    alt.Chart(text_df)
    .encode(
        alt.X(
            "x:Q",
            title=None,
            scale=alt.Scale(domain=(site_df["sequential_site"].min(), site_df["sequential_site"].max())),
            axis=None,
        ),
        alt.Text("region"),
    )
    .mark_text(fontWeight="bold", fontSize=11)
    .properties(width=chart_width, height=13)
)

overlay_chart = region_chart + text_chart

Combine overlay and site chart:

In [9]:
site_chart = (
    alt.vconcat(overlay_chart, site_binding_chart, spacing=1)
    .resolve_scale(color="independent")
    .configure_axis(grid=False)
    .configure_view(stroke="black", strokeOpacity=1, strokeWidth=1)
)

site_chart

## Plot correlations in site effects

In [10]:
site_corr_df = (
    site_df
    .melt(
        id_vars=["ligand", "site", "wildtype", "region", "Mxra8 contact"],
        value_vars=["positive_effect", "negative_effect", "absolute_effect"],
        var_name="metric",
        value_name="effect",
    )
    .pivot_table(
        index=["site", "wildtype", "region", "Mxra8 contact", "metric"],
        values="effect",
        columns="ligand",
    )
    .reset_index()
)

site_corr_df

ligand,site,wildtype,region,Mxra8 contact,metric,human_Mxra8,mouse_Mxra8
0,1(6K),A,6K,no,absolute_effect,0.105032,0.074673
1,1(6K),A,6K,no,negative_effect,-0.048771,-0.002014
2,1(6K),A,6K,no,positive_effect,0.056261,0.073745
3,1(E1),Y,E1,no,absolute_effect,0.139541,0.099551
4,1(E1),Y,E1,no,negative_effect,-0.050442,-0.076356
...,...,...,...,...,...,...,...
2797,99(E1),E,E1,no,negative_effect,-0.026468,-0.046023
2798,99(E1),E,E1,no,positive_effect,0.028378,0.036621
2799,99(E2),H,E2,no,absolute_effect,0.000000,0.012704
2800,99(E2),H,E2,no,negative_effect,0.000000,-0.012889
