# Make logo plots of top sites of antibody escape

In [None]:
import itertools

import Bio.Seq

import dmslogo

import matplotlib.colors
import matplotlib.pyplot as plt

import pandas as pd

Get data frame of escape for each antibody, and rank by site:

In [None]:
dms_summary = pd.read_csv(snakemake.input.summary)
per_antibody_escape = pd.read_csv(snakemake.input.per_antibody_escape)

df = (
    per_antibody_escape
    [["site", "wildtype", "mutant", "antibody", "escape"]]
    .merge(
        dms_summary[["site", "sequential_site", "mutant", "spike mediated entry", "ACE2 binding"]],
        on=["site", "mutant"],
        validate="one_to_one",
    )
    .assign(site=lambda x: x["wildtype"] + x["site"].astype(str))
)

df = df.merge(
    df
    .groupby("site", as_index=False)
    .aggregate(abs_site_escape=pd.NamedAgg("escape", lambda s: s.abs().sum()))
    .assign(site_escape_rank=lambda x: x["abs_site_escape"].rank(method="min", ascending=False))
).sort_values("site_escape_rank")

Categorize as single-nucleotide adjacent. 
Right now we do this just using the overall genetic code, not paying attention to which codon is actually in the real nucleotide sequence (doing that could be better of course, if we know real nucleotide sequence).

In [None]:
codons = ["".join(nts) for nts in itertools.product("ACTG", repeat=3)]

adjacent_aas = {}
for parent_codon, mut_codon in itertools.product(codons, codons):
    if 1 == sum(x != y for (x, y) in zip(parent_codon, mut_codon)):
        parent_aa = str(Bio.Seq.Seq(parent_codon).translate())
        mut_aa = str(Bio.Seq.Seq(mut_codon).translate())
        if parent_aa not in adjacent_aas:
            adjacent_aas[parent_aa] = {mut_aa}
        else:
            adjacent_aas[parent_aa].add(mut_aa)

df["adjacent_mutation"] = df.apply(
    lambda r: r["mutant"] in adjacent_aas[r["wildtype"]], axis=1,
)

Make a column to color the logo plots by functional effect of mutations, clipping entry scores to be in a reasonable range:

In [None]:
color_col = "spike mediated entry"

df = df.assign(
    color_val=lambda x: x[color_col].clip(upper=0, lower=-3)
).query("color_val.notnull()")

# Add color column for logo plots
# Create color palette
def color_gradient_hex(start, end, n):
    """Color function from polyclonal"""
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            name="_", colors=[start, end], N=n
        )
    return [matplotlib.colors.rgb2hex(tup) for tup in cmap(list(range(0, n)))]

color_map = dmslogo.colorschemes.ValueToColorMap(
    minvalue=df["color_val"].min(),
    maxvalue=df["color_val"].max(),
    cmap=matplotlib.colors.ListedColormap(color_gradient_hex("white", "black", n=50))
)
df = df.assign(color=lambda x: x["color_val"].map(color_map.val_to_color))

_ = color_map.scale_bar(orientation="horizontal", label=color_col)

Now make logo plots for all amino-acid mutations and just adjacent ones:

In [None]:
nrows = df["antibody"].nunique() * 2
top_n = 12  # show top this many sites
fig, axes = plt.subplots(nrows, 1)
fig.subplots_adjust(hspace=0.75)
fig.set_size_inches(0.6 * (top_n + 1), 3 * nrows)

for i, (antibody, antibody_df) in enumerate(df.groupby("antibody")):
    antibody_df = antibody_df.query("site_escape_rank <= @top_n")
    _ = dmslogo.draw_logo(
        antibody_df,
        x_col="sequential_site",
        xtick_col="site",
        letter_col="mutant",
        letter_height_col="escape",
        color_col="color",
        ax=axes[i],
        title=f"{antibody} all mutations",
    )
    _ = dmslogo.draw_logo(
        antibody_df.query("adjacent_mutation"),
        x_col="sequential_site",
        xtick_col="site",
        letter_col="mutant",
        letter_height_col="escape",
        color_col="color",
        ax=axes[2 * i + 1],
        title=f"{antibody} adjacent mutations",
    )

In [None]:
fig.savefig(snakemake.output.pdf)