### Escape logo plots
Read in antibody escape file, identify top escape mutations by region, and plot escape logo plotss for these regions.

In [None]:
import pandas as pd

import altair as alt

import numpy as np

import matplotlib

matplotlib.rcParams["svg.fonttype"] = "none"

import dmslogo

from dmslogo.colorschemes import CBPALETTE

from dmslogo.colorschemes import ValueToColorMap

_ = alt.data_transformers.disable_max_rows()


In [None]:
# if running manually in a jupyter notebook instead of snakemake
results_df = pd.read_csv(
    "../../results/filtered_data/antibody_escape/combined/escape_minimum_mutation_distance.csv"
)
display(results_df)
MAX_ESCAPE_FRAC = 0.5
SUM_ESCAPE_FRAC = 0.75

In [None]:
#results_df = results_df.copy()

grouped_max_site = (
    results_df.groupby(["antibody", "site"])
    .agg(
        max_escape_site=("escape_mean", "max"),
    )
    .reset_index()
)
grouped_max_total = (
    grouped_max_site.groupby(["antibody"])
    .agg(max_escape_antibody=("max_escape_site", "max"))
    .reset_index()
)
merged_df = pd.merge(results_df, grouped_max_site, on=["antibody", "site"], how="left")
merged_df = pd.merge(merged_df, grouped_max_total, on=["antibody"], how="left")
display(merged_df.head(5))

In [None]:
grouped_sum = (
    results_df.groupby(["antibody", "site"])
    .agg(sum_escape_site=("escape_mean", "sum"))
    .reset_index()
)
grouped_sum_max = (
    grouped_sum.groupby(["antibody"])
    .agg(sum_escape_antibody=("sum_escape_site", "max"))
    .reset_index()
)

# display(grouped_sum_max)
merged_df = pd.merge(merged_df, grouped_sum, on=["antibody", "site"], how="left")
merged_df = pd.merge(merged_df, grouped_sum_max, on=["antibody"], how="left")
display(merged_df.head(5))


In [None]:
merged_df = merged_df.assign(
    retain=lambda x: ((x["max_escape_site"] >= (x["max_escape_antibody"] * MAX_ESCAPE_FRAC))
    | (x["sum_escape_site"] >= (x["sum_escape_antibody"] * SUM_ESCAPE_FRAC))) & (x["escape_mean"] > (x['max_escape_antibody'] * 0.1))
)

display(merged_df.query('retain and antibody == "1F2"'))

sites_list = merged_df.query('retain').groupby('antibody')['site'].unique()
display(sites_list['12B2'].tolist())

In [None]:
merged_df["wildtype_site"] = merged_df["wildtype"].astype(str) + merged_df[
    "site"
].astype(str)
# Find colors based on effect
merged_df["clip"] = np.clip(merged_df['effect'], -2, 0)
display(merged_df.head(5))
# Create a ValueToColorMap for the effect values
map1 = ValueToColorMap(
    minvalue=-2.5, maxvalue=0, cmap="Greens"
)

merged_df["color"] = merged_df["clip"].map(map1.val_to_color)

display(merged_df.head(5))

merged_df.query('retain').to_csv('../../results/for_website/top_antibody_escape_min_mutants.csv', index=False)

In [None]:
def generate_facet_logo_plot(df, output_file_name=None):
    """Generate logo plot and save as a file."""
    draw_logo_kwargs = {
        "letter_col": "mutant",
        "color_col": "color",
        "xtick_col": "wildtype_site",
        "letter_height_col": "escape_mean",
        "xlabel": "",
        "clip_negative_heights": True,
    }
    fig, ax = dmslogo.facet_plot(
        data=df,
        x_col="site",
        gridrow_col="antibody",
        share_ylim_across_rows=False,
        show_col=None,
        draw_logo_kwargs=draw_logo_kwargs,
    )
    
    fig.savefig(output_file_name, bbox_inches="tight", format="svg")



In [None]:
for antibody in merged_df['antibody'].unique().tolist():
    antibody_df = merged_df.query('antibody == @antibody and retain')
    generate_facet_logo_plot(antibody_df, f'../../logo_{antibody}.svg')

    one_mutant_df = antibody_df.query('min_mutations == 1')
    generate_facet_logo_plot(one_mutant_df, f"../../logo_{antibody}_one_mutant.svg")