In [None]:
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import httpimport

# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()


In [None]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import main_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return main_theme.main_theme()


import_theme_new()


In [None]:
antibody_order = ["12B2", "2D3", "4H3", "1A9", "1F2", "2B12"]

escape_df = pd.read_csv(
    "../../results/filtered_data/antibody_escape/combined/escape_minimum_mutation_distance.csv"
)

display(escape_df)

In [None]:
test = escape_df.groupby('antibody')['escape_mean'].max().reset_index(name='max_escape')
escape_df = escape_df.assign(max_escape=escape_df['antibody'].map(dict(zip(test['antibody'], test['max_escape']))))

#escape_df = escape_df.assign(above_half_max=lambda x: x['escape_mean'] >= (x['max_escape'] / 4))
escape_df = escape_df.assign(
    above_half_max=lambda x: x["escape_mean"] >= 0.5
)



escape_df = escape_df.assign(above_min_mut_and_effect=lambda x: (x['min_mutations'] == 1) & x['above_half_max'])
display(escape_df.query('above_half_max == True'))
escape_threshold = escape_df.query('above_half_max == True').groupby('antibody').size().reset_index(name='number')
escape_and_effect_threshold = escape_df.query('above_min_mut_and_effect == True').groupby('antibody').size().reset_index(name='number')
display(escape_and_effect_threshold)

In [None]:
chart1 = (
    alt.Chart(escape_threshold)
    .mark_point(filled=True, opacity=1, size=100, color="indianred")
    .encode(x=alt.X("antibody:N", sort=antibody_order), y="number:Q")
)

chart3 = (
    alt.Chart(escape_and_effect_threshold)
    .mark_point(filled=True, opacity=1, size=100, color="steelblue")
    .encode(x=alt.X("antibody:N", sort=antibody_order), y="number:Q")
)
#chart3.display()
combined = alt.layer(chart1, chart3).resolve_scale(y='shared')
combined.display()
combined.save('../../num_escape_mutations.svg')
combined.save('../../num_escape_mutations.png', ppi=300)