# Make logoplots showing key sites of escape for each antibody

Get variables from `snakemake`:

In [None]:
escape_csv = snakemake.input.escape_csv
phenotypes_csv = snakemake.input.phenotypes_csv

logoplot_subdir = snakemake.output.logoplot_subdir

Import Python modules:

In [None]:
import functools
import itertools
import operator
import os

import altair as alt

import dmslogo

import matplotlib
import matplotlib.pyplot as plt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

os.makedirs(logoplot_subdir, exist_ok=True)

Some configuration options.
**Set configuration here, do not hardcode it in the notebook in later cells**.

In [None]:
positive_escape_only = True  # only show positive escape values
min_cell_entry = -5  # minimum cell-entry effects worse than this are set to this
drop_low_cell_entry = True  # drop escape for any mutations with cell entry below this?

# highlight sites that meet any of these criteria (an OR operation)
highlight_params = {
    "n_site_escape": 6,  # the top this many sites in terms of site_escape
    "n_top_mut_escape": 6,  # the top this many sites in terms of top site escape
    "min_site_escape": 12,  # site escape >= this
    "min_top_mut_escape": 2.5,  # top mutation escape >= this
}

# amino acids to plot
aas = list(dmslogo.colorschemes.AA_CHARGE)

# make logo plots with each of these antibody groups in addition to per-antibody logos
antibody_groups = {
    "region_3_antibodies": ["17C7", "CR4098", "RVA122", "RVC58"],
    "region_1_antibodies": ["RVC20", "CR57"],   
}

# file extensions on saved plots
file_extensions = [".svg", ".pdf"]

Read escape values.
Mutations with missing escape (or that have cell entry below threshold if dropping those) is set to zero.
Setting these to zero is the same as removing them for the logoplots and site-escape sum plots used here, but if you use other plots you should bae aware that missing mutations have escape of zero in this dataframe.

In [None]:
escape = (
    pd.read_csv(escape_csv)
    [["antibody", "site", "wildtype", "mutant", "escape"]]
    .merge(
        pd.read_csv(phenotypes_csv)[["site", "sequential_site", "wildtype", "mutant", "cell entry"]],
        on=["site", "wildtype", "mutant"],
        how="left",
        validate="many_to_one",
    )
    .query("mutant in @aas")
)

if drop_low_cell_entry:
    print(f"Setting to zero escape for mutations with cell entry values below {min_cell_entry=}")
    escape["escape"] = escape["escape"].where(escape["cell entry"] >= min_cell_entry, 0)
    
print(f"Flooring cell entry values to {min_cell_entry=}")
escape["cell entry"] = escape["cell entry"].clip(lower=min_cell_entry)

if positive_escape_only:
    print("Setting negative escape values to zero")
    escape["escape"] = escape["escape"].clip(lower=0)

# pad missing escape values to zero
antibodies = escape["antibody"].unique().tolist()
print(f"Read escape for the following {len(antibodies)}:\n{antibodies=}")
escape_fill_zero = (
    pd.DataFrame(
        [[*tup, 0] for tup in itertools.product(escape["sequential_site"].unique(), aas, antibodies)],
        columns=["sequential_site", "mutant", "antibody", "escape"]
    )
    .merge(
        escape[["sequential_site", "site", "wildtype"]].drop_duplicates(),
        on="sequential_site",
        validate="many_to_one",
    )
)
escape = (
    escape
    [["sequential_site", "mutant", "antibody", "cell entry", "escape"]]
    .merge(
        escape_fill_zero,
        on=["sequential_site", "mutant", "antibody"],
        how="outer",
        validate="one_to_one",
    )
    .assign(
        **{
            "cell entry": lambda x: x["cell entry"].fillna(0),
            "escape": lambda x: x["escape_x"].where(x["escape_x"].notnull(), x["escape_y"]),
        }
    )
    .drop(columns=["escape_x", "escape_y"])
)    

To choose which sites to highlight, we get the total magnitude of the site escape and largest magnitude mutation at each site for each antibody.
Then plot to indicate which sites are being shown:

In [None]:
# get total and top magnitude escape at each site
per_site_escape = (
    escape
    .groupby(["antibody", "site"], as_index=False)
    .aggregate(
        site_escape=pd.NamedAgg("escape", lambda s: s.abs().sum()),
        top_mut_escape=pd.NamedAgg("escape", lambda s: s.abs().max()),
    )
    .melt(id_vars=["antibody", "site"], var_name="escape_type", value_name="escape")
    .assign(
        rank=lambda x: (
            x.groupby(["antibody", "escape_type"])["escape"].rank(
                method="min", ascending=False
            ).astype(int)
        )
    )
)

In [None]:
# indicate the sites to highlight
per_site_escape["highlight"] = functools.reduce(
    operator.or_,
    [
        (
            (per_site_escape["escape_type"] == escape_type)
            & (
                (per_site_escape["escape"] >= highlight_params[f"min_{escape_type}"])
                | (per_site_escape["rank"] <= highlight_params[f"n_{escape_type}"])
            )
        )
        for escape_type in ["site_escape", "top_mut_escape"]
    ]
)

print("Number of sites to highlight for each antibody:")
per_site_escape.query("highlight").groupby("antibody").aggregate(
    n_sites_to_highlight=pd.NamedAgg("site", "nunique")
)

In [None]:
# make plot
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False) 

per_site_escape_chart = (
    alt.Chart(per_site_escape)
    .add_params(site_selection)
    .transform_calculate(jitter="random() - 0.5")
    .encode(
        alt.X("escape", title=None),
        alt.Y("antibody"),
        alt.YOffset("jitter:Q", scale=alt.Scale(domain=[-1.2, 1.2])),
        alt.Column(
            "escape_type",
            title=None,
            header=alt.Header(orient="bottom", labelFontSize=11, labelFontStyle="bold"),
        ),
        alt.Color("highlight", title="site to highlight"),
        tooltip=["site", "rank", "antibody"],
        strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(0.25)),
    )
    .mark_circle(size=40, strokeOpacity=1, fillOpacity=0.4, stroke="black")
    .resolve_scale(x="independent")
    .properties(width=275, height=alt.Step(30))
    .configure_axis(grid=False)
)

per_site_escape_chart

Add the sites to highlight for each antibody to the escape data frame, and get the total positive and negative escape at each site:

In [None]:
escape = (
    escape
    .assign(
        positive_site_escape=(
            lambda x: (
                x.groupby(["antibody", "site"])["escape"].transform(lambda s: s[s > 0].sum())
    
            )
        ),
        negative_site_escape=(
            lambda x: (
                x.groupby(["antibody", "site"])["escape"].transform(lambda s: s[s < 0].sum())
    
            )
        ),
    )
    .drop(columns="highlight", errors="ignore")
    .merge(
        per_site_escape.groupby(["antibody", "site"], as_index=False).aggregate(
            {"highlight": "any"}
        ),
        on=["antibody", "site"],
        validate="many_to_one",
    )
)

Now make line and logo plots for each antibody.

First make color scale to color logos by cell entry:

In [None]:
colormap = dmslogo.colorschemes.ValueToColorMap(
    minvalue=min_cell_entry,
    maxvalue=0,
    cmap=matplotlib.colors.LinearSegmentedColormap.from_list(
        "white_to_green",
        [(1,0.985,0.737), (0.0545, 0.4313, 0.054)],
    ),
)

for orientation in ["vertical", "horizontal"]:
    assert colormap.minvalue == int(colormap.minvalue), "code requires integer minvalue for color scale"
    assert colormap.maxvalue == int(colormap.maxvalue), "code requires integer maxvalue for color scale"
    scale_fig, scale_ax = colormap.scale_bar(
        orientation=orientation,
        label="effect on cell entry",
    )
    ticks = list(range(int(colormap.minvalue), int(colormap.maxvalue) + 1))
    ticklabels = [f"≤{ticks[0]}"] + [str(t) for t in ticks[1: -1]] + [f"≥{ticks[-1]}"]
    ax = scale_ax.xaxis if orientation == "horizontal" else scale_ax.yaxis
    ax.set_ticks(ticks)
    ax.set_ticklabels(ticklabels)
    display(scale_fig)

    for ext in file_extensions:
        fname = os.path.join(logoplot_subdir, f"scalebar_{orientation}{ext}")
        print(f"Saving to {fname}")
        scale_fig.savefig(fname)
    
    plt.close(scale_fig)

escape["letter_color"] = colormap.val_to_color(
    escape["cell entry"].clip(lower=colormap.minvalue, upper=colormap.maxvalue)
)

Draw plots for each individual antibody:

In [None]:
for antibody_group, antibodies_to_plot in (
    antibody_groups | {a: [a] for a in antibodies}
).items():
    print(f"\nMaking plot for {antibody_group=} with {antibodies_to_plot=}")
    df = (
        escape
        .query("antibody in @antibodies_to_plot")
        .assign(highlight=lambda x: x.groupby("site")["highlight"].transform("any"))
    )
    fig, _ = dmslogo.facet_plot(
        df,
        x_col="sequential_site",
        show_col="highlight",
        gridrow_col="antibody",
        draw_line_kwargs={
            "height_col": "positive_site_escape",
            "height_col2": None if positive_escape_only else "negative_site_escape",
            "xtick_col": "site",
            "ylabel": "escape",
            "widthscale": 0.4,
        },
        draw_logo_kwargs={
            "letter_col": "mutant",
            "letter_height_col": "escape",
            "xtick_col": "site",
            "color_col": "letter_color",
            "widthscale": 0.75,
        },
        share_xlabel=True,
        share_ylabel=True,
        height_per_ax=1.9,
        wspace=0.7,
        share_ylim_across_rows=False,
    )

    for ext in file_extensions:
        fname = os.path.join(logoplot_subdir, f"{antibody_group}{ext}")
        print(f"Saving to {fname}")
        fig.savefig(fname)

    display(fig)
    plt.close(fig)