# Validation neutralization assays versus `polyclonal` fits
Compare actual measured neutralization values for specific mutants to the `polyclonal` fits.

Import Python modules:

In [1]:
import os
import pickle

import altair as alt

import pandas as pd
import numpy as np

import yaml

from scipy import stats

import warnings
warnings.simplefilter("ignore")

palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7"] 

extended_palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7", '#9F0162'] 

long_palette = ['#9F0162', '#009F81', '#FF5AAF', '#8400CD', '#008DF9', '#00C2F9', '#FFB2FD', '#A40122', '#E20134', '#FF6E3A', '#FFC33B', '#00FCCF']

figure_palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7", '#9F0162','#8400CD']

Read configuration and validation assay measurements:

In [2]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
    
validation_ic50s = pd.read_csv(config["validation_ics"], na_filter=None)

validation_ic50s

Unnamed: 0,antibody,virus,aa_substitutions,measured IC50,measured IC80,date
0,10-1074,BF520,,0.064166,0.355822,6-24
1,10-1074,BF520,N332L,20.0,20.0,6-24
2,10-1074,BF520,S140D,6.412878,20.0,6-24
3,10-1074,TRO11,,0.03121,0.129595,6-24
4,10-1074,TRO11,H330E,20.0,20.0,6-24
5,10-1074,TRO11,D325R,20.0,20.0,6-24
6,10-1074,TRO11,T415Q,0.060501,0.251278,6-24
7,3BNC117,TRO11,,0.256718,1.031087,6-24
8,3BNC117,TRO11,R304G,0.671586,2.192752,6-24
9,3BNC117,TRO11,Y318E,0.841764,2.646716,6-24


Now get the predictions by the averaged `polyclonal` model fits:

In [3]:
validation_vs_prediction = []
for virus, virus_df in validation_ic50s.groupby("virus"):
    if virus == 'TRO11':
        virus_data_path = 'results/antibody_escape/averages/'
    elif virus == 'BF520':
        virus_data_path = '../HIV_Envelope_BF520_DMS/results/antibody_escape/averages/'
    for antibody, antibody_df in virus_df.groupby("antibody"):
        with open(os.path.join(virus_data_path, f"{antibody}_polyclonal_model.pickle"), "rb") as f:
            model = pickle.load(f)
        df = model.icXX(antibody_df)
        #df = model.icXX(df, x=0.80, col="IC80")
        # if antibody == "10-1074":
        #     df['mean_IC50'] = df['mean_IC50'].clip(upper=20)
        #     df['median_IC50'] = df['median_IC50'].clip(upper=20)
        #     df['mean_IC80'] = df['mean_IC80'].clip(upper=20)
        #     df['median_IC80'] = df['median_IC80'].clip(upper=20)
        # elif antibody == "3BNC117":
        #     df['mean_IC50'] = df['mean_IC50'].clip(upper=4)
        #     df['median_IC50'] = df['median_IC50'].clip(upper=4)
        #     df['mean_IC80'] = df['mean_IC80'].clip(upper=4)
        #     df['median_IC80'] = df['median_IC80'].clip(upper=4)
        df = df.merge((model.mut_escape_df
                       .rename(columns={'mutation': 'aa_substitutions'})
                       [['aa_substitutions', 'times_seen']]
                      ), how='left', on='aa_substitutions')
        validation_vs_prediction.append(df)
    
validation_vs_prediction = pd.concat(validation_vs_prediction, ignore_index=True)

validation_vs_prediction = validation_vs_prediction.assign(standard_deviation=lambda x: x['std_IC50'] / x['mean_IC50'])

validation_vs_prediction

Unnamed: 0,antibody,virus,aa_substitutions,measured IC50,measured IC80,date,mean_IC50,median_IC50,std_IC50,frac_models,mean_IC80,median_IC80,std_IC80,n_models,times_seen,standard_deviation
0,10-1074,BF520,,0.064166,0.355822,6-24,3.759731,3.00248,2.943593,1.0,6.170105,5.204661,3.922855,4,,0.782927
1,10-1074,BF520,E325R,20.0,20.0,6-24,81.674267,45.779988,85.592095,1.0,174.544888,81.047761,216.907238,4,5.375,1.047969
2,10-1074,BF520,H330Y,20.0,20.0,6-24,64.860549,51.080535,46.963819,1.0,116.639561,91.248483,86.533741,4,46.833333,0.724074
3,10-1074,BF520,N332L,20.0,20.0,6-24,141.417667,54.889487,186.259636,1.0,310.338548,93.330395,461.674989,4,5.25,1.317089
4,10-1074,BF520,Q328D,20.0,20.0,6-24,30.346077,37.666286,20.453502,1.0,50.59662,58.867887,35.328967,4,3.5,0.674008
5,10-1074,BF520,S140D,6.412878,20.0,6-24,17.912182,19.400321,4.781855,1.0,33.553728,33.37911,14.160502,4,6.25,0.266961
6,10-1074,BF520,T415Q,18.089759,20.0,6-24,12.29858,11.435929,9.920222,1.0,21.104047,17.472009,18.480391,4,10.5,0.806615
7,3BNC117,BF520,,0.035509,0.075204,12-23,2.183123,1.932619,0.617839,1.0,4.534563,3.328993,2.1386,3,,0.283007
8,3BNC117,BF520,G459D,0.039781,0.100754,12-23,3.262641,2.668634,1.675358,1.0,6.938962,4.596802,4.839695,3,22.666667,0.513497
9,3BNC117,BF520,N463S,0.05737,0.144362,12-23,3.350556,3.205052,0.661537,1.0,6.882174,5.520797,2.600513,3,6.0,0.197441


For each antibody, calculate the Pearson correlation coefficient between the predicted IC50s from our models and the IC50s measured in validation assays. We are doing this first for only single mutants:

In [4]:
print("Single mutant correlations between DMS predicted and neutralization assay measured IC50s:")
for virus, virus_df in validation_vs_prediction.groupby("virus"):
    for antibody, antibody_df in virus_df.groupby('antibody'):
        antibody_df = antibody_df.query("aa_substitutions!=''")
        antibody_df = antibody_df[~antibody_df['aa_substitutions'].str.contains(" ")]
        print(f"{virus}, {antibody}:")
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            antibody_df["median_IC50"].astype(float),
            antibody_df["measured IC50"].astype(float))
        print(round(r_value**2,3))

Single mutant correlations between DMS predicted and neutralization assay measured IC50s:
BF520, 10-1074:
0.348
BF520, 3BNC117:
0.759
TRO11, 10-1074:
0.873
TRO11, 3BNC117:
0.565


Now, plot the results. We will plot the **median** across the replicate polyclonal fits to different deep mutational scanning replicates. This is an interactive plot that you can mouse over for details:

In [23]:
for virus in ["BF520", "TRO11"]:
    for antibody in ['10-1074', '3BNC117']:
        plot_data = validation_vs_prediction.query('virus==@virus').query('antibody==@antibody')
        plot_data = plot_data[~plot_data['aa_substitutions'].str.contains(" ")]
        plot_data['measured IC50'] = plot_data['measured IC50'].astype(float)
        #plot_data['measured IC80'] = plot_data['measured IC80'].astype(float)
        corr_chart = (
            alt.Chart(plot_data)
            .encode(
                x=alt.X("measured IC50", 
                        scale=alt.Scale(type="log", 
                                        nice=False,
                                       domain=(plot_data["measured IC50"].min()*.75, 
                                           plot_data["measured IC50"].max()*1.25)),
                       ),
                y=alt.Y(
                    "median_IC50",
                    title="predicted IC50 from DMS",
                    scale=alt.Scale(type="log", 
                                    nice=False,
                                   domain=(plot_data["median_IC50"].min()*.75, 
                                           plot_data["median_IC50"].max()*1.25)),
                ),
                #facet=alt.Facet("antibody", columns=4, title=None),
                color=alt.Color("aa_substitutions", 
                                title="Amino acid substitutions", 
                                scale=alt.Scale(range=figure_palette)),
                tooltip=[
                    alt.Tooltip(c, format=".3g") if validation_vs_prediction[c].dtype == float
                    else c
                    for c in validation_vs_prediction.columns.tolist()
                ],
            )
            .mark_circle(filled=True, size=60, opacity=1)
            #.configure_axis(grid=False)
            #.resolve_scale(y="independent", x="independent")
            .properties(width=150, height=150)
        )
        if antibody == "10-1074":
            line = alt.Chart(pd.DataFrame({'measured_IC50': [40]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
        #elif antibody =="3BNC117":
        #    line = alt.Chart(pd.DataFrame({'measured_IC50': [4]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
            (corr_chart + line).configure_axis(grid=False).display()
        else: 
            corr_chart.configure_axis(grid=False).display()
            

Now also calculate the fold changes, using the **median** prediction:

In [31]:
fold_changes = (
    validation_vs_prediction
    .rename(columns={"median_IC50": "predicted IC50"})
#    .query("aa_substitutions != ''")
    [["antibody",
      "virus",
      "aa_substitutions", 
      "measured IC50",
      "predicted IC50", 
      "times_seen", 
      "n_models"]]
    .merge(
        validation_vs_prediction
        .rename(columns={"median_IC50": "predicted IC50"})
        .query("aa_substitutions == ''")
        [["antibody", "virus", "measured IC50", "predicted IC50"]],
        on=["antibody", "virus"],
        how="left",
        #validate="many_to_one",
        suffixes=[" mutant", " unmutated"],
    )
    .assign(
        measured_fold_change=lambda x: x["measured IC50 mutant"] / x["measured IC50 unmutated"],
        predicted_fold_change=lambda x: x["predicted IC50 mutant"] / x["predicted IC50 unmutated"],
    )
)

plot_data = fold_changes.copy()
plot_data = plot_data[~plot_data['aa_substitutions'].str.contains(" ")]
#display(plot_data)
for virus in ["BF520", "TRO11"]:
    for antibody in ['10-1074', '3BNC117']:
        sub_plot_data = plot_data.query('virus==@virus').query('antibody==@antibody').copy()
        sub_plot_data['aa_substitutions'] = [f'wildtype {virus}' if x is '' else x for x in sub_plot_data['aa_substitutions']]
        fold_change_chart = (
            alt.Chart(sub_plot_data.query('virus==@virus').query('antibody==@antibody'))
            .encode(
                x=alt.X(
                    "measured_fold_change",
                    title="measured fold change IC50",
                    scale=alt.Scale(type="log", 
                                        nice=False,
                                       domain=(sub_plot_data["measured_fold_change"].min()*.75, 
                                           sub_plot_data["measured_fold_change"].max()*1.25)),
                       ),
                y=alt.Y(
                    "predicted_fold_change",
                    title="predicted fold change IC50",
                    scale=alt.Scale(type="log", 
                                    nice=False,
                                   domain=(sub_plot_data["predicted_fold_change"].min()*.75, 
                                           sub_plot_data["predicted_fold_change"].max()*1.25)),
                ),
                #facet=alt.Facet("antibody", columns=4, title=None),
                color=alt.Color("aa_substitutions", 
                                title="Amino acid substitutions", 
                                scale=alt.Scale(range=figure_palette),
                                sort=[
                                    'wildtype TRO11',
                                    'wildtype BF520',
                                    'E325R',
                                    'D325R',
                                    'H330Y',
                                    'N332L',
                                    'T415Q',
                                    'S140D',
                                    'Q328D',
                                    'H330E',
                                    'N332T',
                                    'T202P',
                                    'T198D',
                                    'Q203P',
                                    'N276D',
                                    'N279D',
                                    'N295R',
                                    'R304G',
                                    'Y318E',
                                    'K440D',
                                    'G459D',
                                    'N462K',
                                    'N462T',
                                    'N463S',
                                ],
                               ),
                tooltip=[
                    alt.Tooltip(c, format=".3g") if sub_plot_data[c].dtype == float
                    else c
                    for c in sub_plot_data.columns.tolist()
                ],
            )
            .mark_circle(filled=True, size=100, opacity=1)
            #.configure_axis(grid=False)
            #.resolve_scale(y="independent", x="independent")
            .properties(width=150, height=150)
        )
        
        antibody_df = fold_changes.query("antibody==@antibody").query('virus==@virus')
        antibody_df = antibody_df[~antibody_df['aa_substitutions'].str.contains(" ")]
        print(f"{antibody}:")
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            antibody_df["predicted_fold_change"].astype(float),
            antibody_df["measured_fold_change"].astype(float))
        print(f"Predicted fold change correlation (R^2): {round(r_value**2,3)}")
        
        if antibody == "10-1074":
            line = alt.Chart(pd.DataFrame({'measured_fold_change': [sub_plot_data["measured_fold_change"].max()]})).mark_rule(strokeDash=[8,8]).encode(x='measured_fold_change')
        #elif antibody =="3BNC117":
        #    line = alt.Chart(pd.DataFrame({'measured_IC50': [4]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
            (fold_change_chart + line).configure_axis(grid=False).display()
        else: 
            fold_change_chart.configure_axis(grid=False).display()

10-1074:
Predicted fold change correlation (R^2): 0.599


3BNC117:
Predicted fold change correlation (R^2): 0.706


10-1074:
Predicted fold change correlation (R^2): 0.913


3BNC117:
Predicted fold change correlation (R^2): 0.616
