In [2]:
import pandas as pd
import numpy as np
from plotnine import *
import time
neut2codon = {
    "D614G_IC50":"allowed_muts/mut1_SARSCoV2_WuhanHu1_Spike.csv",
    "BA1_IC50":"allowed_muts/mut1_BA.1_Omicron_baseline_EPI_ISL_10000028.csv",
    "BA2_IC50":"allowed_muts/mut1_BA.2_Omicron_baseline_EPI_ISL_10000005.csv",
    "BA2_75_IC50":"allowed_muts/mut1_BA.2.75_EPI_ISL_13302209.csv",
    "BA5_IC50":"allowed_muts/mut1_BA.4_BA.5_EPI_ISL_11207535.csv",
}

neut2se = {
    "D614G_IC50":"bind_expr/bind_expr_WT.csv",
    "BA1_IC50":"bind_expr/bind_expr_BA1.csv",
    "BA2_IC50":"bind_expr/bind_expr_BA2.csv",
    "BA2_75_IC50":"bind_expr/bind_expr_BA2.csv",
    "BA5_IC50":"bind_expr/bind_expr_BA2.csv",
}
mut_for_bind_expr = {
    'BA2_75_IC50': [(339, 'H'), (446, 'S'), (460, 'K'), (493, 'Q')],
    'BA5_IC50': [(452, 'R'), (486, 'V'), (493, 'Q')],
}

for strain in mut_for_bind_expr:
    data = pd.read_csv(neut2se[strain]).assign(bias_e = 0.0,bias_b=0.0)
    for site, mut in mut_for_bind_expr[strain]:
        expr = data.query('site == @site and mutation == @mut')['expr_avg'].item()
        bind = data.query('site == @site and mutation == @mut')['bind_avg'].item()
        data.loc[data['site'] == site, 'bias_e'] += expr
        data.loc[data['site'] == site, 'bias_b'] += bind
    data['expr_avg'] -= data['bias_e']
    data['bind_avg'] -= data['bias_b']
    data.drop(columns=['bias_e','bias_b']).to_csv("mut_approx_"+strain+".csv", index=None)
    neut2se[strain] = "mut_approx_"+strain+".csv"

In [5]:
# calculate average based on IC50
scores_r = pd.read_csv("use_res_clean.csv")
use_abs = np.unique(scores_r['antibody'])

data = pd.read_csv("src_neut_data.csv", index_col=0)[[
    "source",
    "D614G_IC50", "BA1_IC50","BA2_IC50","BA5_IC50","BA2_75_IC50"
]].query('antibody in @use_abs')

_srcs = data['source'].to_list()

data = data.assign(Usrc = [
    "mouse" if "mouse" in _srcs[i] else (
        "WT" if _srcs[i][0:2] == "WT" else (
            "SARS" if _srcs[i][0:4] == "SARS" else (
                "BA2" if _srcs[i][0:4] == "BA.2" else (
                    "BA1" if _srcs[i][0:4] == "BA.1" else (
                        "BA5" if _srcs[i][0:4] == "BA.5" else "???"
                    )
                )
            )
        )
    )for i in range(len(_srcs))])


for term in [
    "D614G_IC50", "BA1_IC50","BA2_IC50","BA5_IC50","BA2_75_IC50"
]:
    _x = data[term].to_list()
    data[term] = [10.0 if (y[0] == '>' or y[0:3] == 'Inf') else (np.nan if y == '--' else max(0.0005,min(10.0,float(y)))) for y in _x]

data = data.query('not (Usrc == "???" or Usrc == "mouse")')

In [6]:
import logomaker
from matplotlib import rcParams
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
rcParams['pdf.fonttype'] = 42

def plot_res_logo(res, prefix, shownames={}, rownames=None, site_thres=0.0, force_plot_sites = None, force_ylim = None, width=None):
    flat_res = res.pivot(index=['antibody', 'site'], columns='mutation', values='mut_escape').fillna(0)
    sites_mean_score = flat_res.mean(axis=1)
    sites_total_score = flat_res.sum(axis=1)
    _ = sites_total_score[sites_total_score>site_thres].index
    strong_sites = np.unique(np.array(sorted([i[1] for i in _])))
    print(strong_sites)

    plot_sites = strong_sites
    plot_sites = plot_sites[plot_sites < 520].astype(int)
    print(plot_sites)
    
    if force_plot_sites is not None:
        plot_sites = force_plot_sites
    
    flat_res = flat_res[flat_res.index.isin(plot_sites, level=1)]

    _ = pd.DataFrame(sites_total_score)
    _.columns = ['value']
    _['site'] = [i[1] for i in _.index]
    _['antibody'] = [i[0] for i in _.index]

    if rownames is not None:
        Abs = rownames
    else:
        Abs = np.unique([i[0] for i in flat_res.index])
    print(Abs)
    Npages = len(Abs)//10 + 1
    if width is None:
        width=30
    with PdfPages(prefix+'_aa_logo.pdf') as pdf:
        for p in range(Npages):
            Abs_p = Abs[p*10:min(len(Abs),(p+1)*10)]
            fig = plt.figure(figsize=(width,len(Abs_p)*4.6)).subplots_adjust(wspace=0.2,hspace=0.5)
            site2pos = {}
            for i in range(len(plot_sites)):
                site2pos[plot_sites[i]] = i

            for i in range(len(Abs_p)):
                ab = Abs_p[i]
                _ = flat_res.query('antibody == @ab').droplevel(0)
                add_sites = np.setdiff1d(plot_sites, _.index)
                for _site in add_sites:
                    _.loc[_site,:] = 0.0
                _ = _.sort_index()
                _.index = range(len(_))
                ax = plt.subplot(len(Abs_p), 1, i+1)
                logo = logomaker.Logo(_,
                               ax=ax, 
                               color_scheme='dmslogo_funcgroup', 
                               vpad=.1, 
                               width=.8)
                logo.style_xticks(anchor=1, spacing=1, rotation=90, fontsize=16)
                _max = np.sum(_.to_numpy(), axis=1).max()
                # ax.set_xticklabels(plot_sites[1::2])
                ax.set_xticklabels(plot_sites)
                
                ax.set_yticks([])
                ax.tick_params(axis='both', which='both', length=0)
                ax.yaxis.set_tick_params(labelsize=20)
                if ab in shownames:
                    ax.set_title(shownames[ab], fontsize=8, fontweight="bold")
                else:
                    ax.set_title(ab, fontsize=8, fontweight="bold")
            pdf.savefig()
            plt.close()

In [7]:
def do_calc(use_ab_src, use_neut, A_adv = True, A_codon = True, E=1.0, B=1.0, use_log=False, use_max=False, use_norm=False, logo=False, return_df=False):
    use_codon = pd.read_csv(neut2codon[use_neut])
    neut_data = data[use_neut].to_dict()

    single_mut_effects = pd.read_csv(neut2se[use_neut]).assign(coef=lambda x: [y for y in np.exp(x['expr_avg']*E+x['bind_avg']*B)])
    # single_mut_effects = pd.read_csv(neut2se[use_neut]).assign(coef=lambda x: [min(1.0, y) for y in np.exp(x['expr_avg']*E+x['bind_avg']*B)])
    single_mut_effects.index = single_mut_effects['site'].astype('str') + single_mut_effects['mutation']
    single_mut_effects = single_mut_effects['coef'].to_dict()

    _umuts = set()
    for i in range(len(use_codon)):
        _ms = use_codon['mut1'][i]
        for x in _ms:
            _umuts.add(str(use_codon['pos'][i])+x)

    _uabs = set(data.query('Usrc in @use_ab_src').index.to_list())

    scores = scores_r.assign(site_mut = lambda x: x['site'].astype(str)+x['mutation']).query('antibody in @_uabs').assign(
        adv_weight = (lambda x: [single_mut_effects[y] for y in x['site_mut']]) if A_adv else 1.0,
        codon_weight = (lambda x: [(1.0 if y in _umuts else 0.0) for y in x['site_mut'].to_list()]) if A_codon else 1.0
    )
    
    if use_norm:
        scores = scores.assign(escape_max = lambda x: x.groupby('antibody')['mut_escape'].transform('max')).assign(
            mut_escape = lambda x: x['mut_escape']/x['escape_max']).drop(columns=['escape_max'])
    
    if use_log:
        scores = scores.assign(neut_weight = lambda x: [(0.0 if np.isnan(neut_data[y]) else max(0.0,np.log10(1.0/min(1.0,neut_data[y])))) for y in x['antibody']])
    else:
        scores = scores.assign(neut_weight = lambda x: [(0.0 if np.isnan(neut_data[y]) else 1.0/neut_data[y]) for y in x['antibody']])
    
    scores = scores.assign(
        mut_escape_adj = lambda x: x['mut_escape'] * x['neut_weight'] * x['adv_weight'] * x['codon_weight']
    )
    _title = ("src: "+'+'.join(use_ab_src)+
              ' weight: '+use_neut+' expr_bind:'+str(A_adv)+
              ' codon:'+str(A_codon)+' log:'+str(use_log)+
              ' norm:'+str(use_norm)+' max:'+str(use_max)+
              ' Expr:'+str(E)+' Bind:'+str(B))
    
    if logo:
        scores = scores.groupby(['site','mutation']).sum()['mut_escape_adj'].reset_index().assign(antibody=_title)
        scores['mut_escape_adj'] = scores['mut_escape_adj']/scores['mut_escape_adj'].max()
        return scores
    
    if use_max:
        site_avg = scores.groupby(['site', 'antibody']).max()['mut_escape_adj'].reset_index().groupby('site').sum().reset_index()
    else:
        site_avg = scores.groupby(['site', 'mutation']).sum()['mut_escape_adj'].reset_index().groupby('site').sum().reset_index()
    site_avg['mut_escape_adj'] = site_avg['mut_escape_adj']/site_avg['mut_escape_adj'].max()
    
    if return_df:
        return site_avg.assign(
            absrc = '+'.join(use_ab_src), weight = use_neut, is_expr_bind = A_adv, is_codon = A_codon, 
            is_neut_log = use_log, is_norm = use_norm, is_max = use_max, expr_coef = E, bind_coef = B
        )
    
    p = (
        ggplot(site_avg, aes('site', 'mut_escape_adj')) + 
        geom_line() + geom_point()+ theme_classic() + theme(
            axis_text_y=element_blank(),
            axis_ticks_major_y=element_blank(),figure_size=(12,3),
            axis_text_x=element_text(angle=90)
        )+scale_x_continuous(breaks=range(331,531,2))+
        ylab('weighted escape score')+xlab('RBD residues')+ggtitle(_title)+
        geom_text(site_avg.query('mut_escape_adj > 0.2'), aes(label='site'), #nudge_y=0.05, 
                                adjust_text={'expand_points': (2, 2), 'arrowprops': {'arrowstyle': '-'}})
    )
    
    return p

In [110]:
df = []
for use_ab_src, use_neut in [(['WT'], 'D614G_IC50'),
                             (['WT','BA1','BA2','BA5'], "BA2_75_IC50"),(['WT','BA1','BA2','BA5'], "BA5_IC50"),
                             # (['BA2'],'BA2_IC50'), 
                             # (['BA5'],'BA5_IC50'), 
                             # (['WT','BA1'], 'BA1_IC50'), (['WT','BA1'], 'BA2_IC50'),
                             # (['WT','BA1','BA2'], 'BA2_IC50'),
                             # (['WT','BA1','BA2'], 'BA5_IC50'),
                            ]:
    ts = do_calc(use_ab_src, use_neut, A_adv = True, A_codon = True, use_log=True, 
                          E=1.0, B=1.0,
                 use_norm=True, use_max=False, logo=True)
    ts.columns = ['site', 'mutation', 'mut_escape','antibody']
    df.append(ts)
plot_res_logo(pd.concat(df), "logo", site_thres=0.15, width=8)

[346 348 352 354 356 369 378 406 408 417 420 439 444 445 446 447 448 449
 450 452 455 456 460 462 468 472 473 483 484 485 486 487 490 493 494 496
 499 503]
[346 348 352 354 356 369 378 406 408 417 420 439 444 445 446 447 448 449
 450 452 455 456 460 462 468 472 473 483 484 485 486 487 490 493 494 496
 499 503]
['src: WT weight: D614G_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.0 Bind:1.0'
 'src: WT+BA1+BA2+BA5 weight: BA2_75_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.0 Bind:1.0'
 'src: WT+BA1+BA2+BA5 weight: BA5_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.0 Bind:1.0'
 'src: WT+BA1+BA2+BA5 weight: BQ1_1_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.0 Bind:1.0'
 'src: WT+BA1+BA2+BA5 weight: XBB_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.0 Bind:1.0']


In [108]:
# subject to R for plot

xx = "sum"

df = []
for use_ab_src in [["WT"], ["WT", "BA1"], ["WT", "BA1", "BA2"], ["WT", "BA1", "BA2", "BA5"], ['BA1'], ["BA2"], ["BA5"], ["BA2", "BA5"]]:
    for use_neut in ["D614G_IC50", "BA1_IC50", "BA2_IC50", "BA2_75_IC50", "BA5_IC50","BQ1_1_IC50","XBB_IC50"]:
        # plots.append(do_calc(use_ab_src, use_neut))
        df.append(do_calc(use_ab_src, use_neut, A_adv = True, A_codon = True, use_log=True,
                          E=1.0, B=1.0,
                          use_norm=True, use_max=(xx == "max"), return_df=True))

pd.concat(df).to_csv("tmp_data-"+xx+".csv", index=None)