In [None]:
# this cell is tagged as `parameters` for papermill parameterization
binding_effects_csv = None
site_numbering_map_csv = None
entry_effects_csv = None
init_min_times_seen = None
init_min_n_libraries = None
init_binding_std = None

library_binding_corr = None
distance_library_binding_corr = None

In [None]:
import functools
import itertools
import tempfile
import urllib.request
import math
import operator
import os

import altair as alt

import numpy

import pandas as pd
import polyclonal.pdb_utils

_ = alt.data_transformers.disable_max_rows()

In [None]:
entry_effects = pd.read_csv(entry_effects_csv).drop('times_seen', axis=1)

In [None]:
binding_effects = pd.read_csv(binding_effects_csv)

In [None]:
site_numbering_map = pd.read_csv(site_numbering_map_csv)

In [None]:
# Merge on reference_site
merged = binding_effects.merge(site_numbering_map, left_on='site', right_on='reference_site', how="left")

merged["Lib_binding_std"] = merged[
    ["Lib1-250517-monomeric_ACE2", "Lib2-250517-monomeric_ACE2"]
].std(axis=1)

merged.head()

In [None]:
merged = merged.merge(entry_effects, on=['site','wildtype','mutant'],  how="left")

In [None]:
mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=1,
        step=0.5,
        max=min(10, merged["times_seen"].max()),
    ),
)

n_libraries_slider = alt.param(
    value=init_min_n_libraries,
    bind=alt.binding_range(
        name="minimum number of libraries",
        min=1,
        step=1,
        max=merged["n_models"].max(),
    ),
)

lib_std_slider = alt.param(
    value=init_binding_std,
    bind=alt.binding_range(
        name="maximum standard deviation between libraries",
        min=0,
        max=merged["Lib_binding_std"].max(),
        step=0.05,
    ),
)

corr_base = (
    alt.Chart(merged)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["n_models"] >= n_libraries_slider)
    .transform_filter(alt.datum["Lib_binding_std"] <= lib_std_slider)
    .encode(
        x=alt.X(
            "Lib1-250517-monomeric_ACE2",
            title="Lib-1 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        y=alt.Y(
            "Lib2-250517-monomeric_ACE2",
            title="Lib-2 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        tooltip=merged.columns.tolist(),
    )
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        lib_std_slider,
        mut_selection,
    )
    .properties(width=170, height=170)
)


corr_scatter = (
    corr_base
    .encode(
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.1)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(50), alt.value(25)),
    )
    .mark_circle(color="black", stroke="red")
)

corr_r = (
    corr_base
    .transform_regression("Lib1-250517-monomeric_ACE2", "Lib2-250517-monomeric_ACE2", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)

corr_chart = (corr_scatter + corr_r).configure_axis(grid=False)

print(f"Saving to {library_binding_corr}")
corr_chart.save(library_binding_corr)

corr_chart

In [None]:
# make some tweaks to merged binding for plotting
merged_binding = (
    merged
    .assign(
        region=lambda x: x["region"].map(lambda r: "RBD" if r == "RBD" else "not RBD")
    )
)

In [None]:
# Add distance to ACE2
ace2_proximal_cutoff = 15  # classify as ACE2 proximal if CA distance <= this

# chain A is ACE2, chain E is RBD
with tempfile.NamedTemporaryFile() as f:
    urllib.request.urlretrieve(
        "https://files.rcsb.org/download/6M0J.pdb",
        f.name,
    )
    coords_df = polyclonal.pdb_utils.extract_atom_locations(f.name, ["A", "E"], target_atom="CA")

# get closest distance for each residue in chain E (RBD) to residue in chain A (ACE2)
dist_df = (
    coords_df
    .query("chain == 'E'")
    [["site", "x", "y", "z"]]
    .merge(
        (
            coords_df
            .query("chain == 'A'")
            [["site", "x", "y", "z"]]
            .rename(columns={c: f"ACE2_{c}" for c in ["site", "x", "y", "z"]})
        ),
        how="cross",
    )
    .assign(
        distance=lambda x: x.apply(
            lambda r: math.sqrt(sum((r[c] - r[f"ACE2_{c}"])**2 for c in ["x", "y", "z"])),
            axis=1,
        )
    )
    .groupby("site", as_index=False)
    .aggregate({"distance": "min"})
)

In [None]:
merged_binding = merged_binding[merged_binding["site"].apply(lambda x: str(x).isdigit())].copy()
merged_binding["site"] = merged_binding["site"].astype(int)

In [None]:
merged_binding = (
    merged_binding
    .merge(dist_df, how="left", validate="many_to_one")
    .assign(
        ACE2_distance=lambda x: numpy.where(
            (x["region"] == "RBD") & (x["distance"] <= ace2_proximal_cutoff),
            "RBD ACE2 proximal",
            numpy.where(x["region"] == "RBD", "RBD ACE2 distal", "non-RBD"),
        ),
    )
)

In [None]:
mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=1,
        step=0.5,
        max=min(10, merged_binding["times_seen"].max()),
    ),
)

n_libraries_slider = alt.param(
    value=init_min_n_libraries,
    bind=alt.binding_range(
        name="minimum number of libraries",
        min=1,
        step=1,
        max=merged_binding["n_models"].max(),
    ),
)

lib_std_slider = alt.param(
    value=1.5,
    bind=alt.binding_range(
        name="maximum standard deviation between libraries",
        min=0,
        max=merged_binding["Lib_binding_std"].max(),
        step=0.05,
    ),
)

corr_base = (
    alt.Chart(merged_binding)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["n_models"] >= n_libraries_slider)
    .transform_filter(alt.datum["Lib_binding_std"] <= lib_std_slider)
    .encode(
        x=alt.X(
            "Lib1-250517-monomeric_ACE2",
            title="Lib-1 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        y=alt.Y(
            "Lib2-250517-monomeric_ACE2",
            title="Lib-2 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        tooltip=merged_binding.columns.tolist(),
    )
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        lib_std_slider,
        mut_selection,
    )
    .properties(width=170, height=170)
)


corr_scatter = (
    corr_base
    .encode(
        color=alt.Color("ACE2_distance:N", title="ACE2 distance"),
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.5)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(50), alt.value(25)),
    )
    .mark_circle()  # Remove fixed color so 'color' encoding takes effect
)

corr_r = (
    corr_base
    .transform_regression("Lib1-250517-monomeric_ACE2", "Lib2-250517-monomeric_ACE2", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)

corr_chart = (corr_scatter + corr_r).configure_axis(grid=False)
corr_chart

In [None]:
mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=1,
        step=0.5,
        max=min(10, merged_binding["times_seen"].max()),
    ),
)

n_libraries_slider = alt.param(
    value=init_min_n_libraries,
    bind=alt.binding_range(
        name="minimum number of libraries",
        min=1,
        step=1,
        max=merged_binding["n_models"].max(),
    ),
)

lib_std_slider = alt.param(
    value=init_binding_std,
    bind=alt.binding_range(
        name="maximum standard deviation between libraries",
        min=0,
        max=merged_binding["Lib_binding_std"].max(),
        step=0.05,
    ),
)

effect_slider = alt.param(
    value=-2,
    bind=alt.binding_range(
        name="cell entry effect",
        min=0,
        max=merged_binding["effect"].max(),
        step=0.05,
    ),
)

corr_base = (
    alt.Chart(merged_binding)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["n_models"] >= n_libraries_slider)
    .transform_filter(alt.datum["Lib_binding_std"] <= lib_std_slider)
    .transform_filter(alt.datum["effect"] >= effect_slider)
    .encode(
        x=alt.X(
            "Lib1-250517-monomeric_ACE2",
            title="Lib-1 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        y=alt.Y(
            "Lib2-250517-monomeric_ACE2",
            title="Lib-2 ACE2 binding",
            scale=alt.Scale(nice=False, padding=4),
        ),
        tooltip=merged_binding.columns.tolist(),
    )
    .add_params(
        times_seen_slider,
        n_libraries_slider,
        lib_std_slider,
        mut_selection,
        effect_slider,
    )
    .properties(width=170, height=170)
)


corr_scatter = (
    corr_base
    .encode(
        color=alt.Color("ACE2_distance:N", title="ACE2 distance"),
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.5)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mut_selection, alt.value(50), alt.value(25)),
    )
    .mark_circle()
)

corr_r = (
    corr_base
    .transform_regression("Lib1-250517-monomeric_ACE2", "Lib2-250517-monomeric_ACE2", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)

corr_chart = (
    (corr_scatter + corr_r)
    .facet(
        column=alt.Column(
            "ACE2_distance:N",
            title=None,
            header=alt.Header(labelFontSize=12, labelFontWeight="bold")
        )
    )
    .configure_axis(grid=False)
    .resolve_scale() 
)

print(f"Saving to {distance_library_binding_corr}")
corr_chart.save(distance_library_binding_corr)

corr_chart