In [1]:
import pandas as pd
import neutcurve
from neutcurve.colorschemes import CBMARKERS, CBPALETTE
from matplotlib import pyplot as plt
import altair as alt
import numpy as np
import sys

sys.path.append('../analysis/')
import theme
alt.themes.register('main_theme', theme.main_theme)
alt.themes.enable('main_theme')

alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [2]:
pd.set_option('display.float_format', '{:.3g}'.format)
pd.set_option('display.max_columns', 20)
pd.set_option('display.width', 400)
plt.rcParams['savefig.dpi'] = 300

In [3]:
data_1 = pd.read_csv('data/250417_neutralization.csv')
data_1['fraction infectivity'] = data_1['fraction infectivity'].clip(lower=1e-6)

data_2 = pd.read_csv('data/250420_neutralization.csv')
data_2['fraction infectivity'] = data_2['fraction infectivity'].clip(lower=1e-6)

In [4]:
fits_1 = neutcurve.CurveFits(data_1)
fits_2 = neutcurve.CurveFits(data_2)

In [5]:
for serum in fits_1.sera:
    print(f"Viruses measured against {serum}:\n" +
          str(fits_1.viruses[serum]))
    
for serum in fits_2.sera:
    print(f"Viruses measured against {serum}:\n" +
          str(fits_2.viruses[serum]))

Viruses measured against SCH23-y2021-s056:
['unmutated', 'K189E', 'S145N', 'R229I', 'R220T', 'S205Y', 'N165H']
Viruses measured against SCH23-y2016-s037:
['unmutated', 'K189E', 'S145N', 'R229I', 'R220T', 'S205Y', 'N165H']
Viruses measured against SCH23-y2009-s002:
['unmutated', 'K189E', 'S145N', 'R229I', 'R220T', 'S205Y', 'N165H']
Viruses measured against SCH23-y2009-s007:
['unmutated', 'K189E', 'S145N', 'R229I', 'R220T', 'S205Y', 'N165H']
Viruses measured against SCH23-y2021-s056:
['unmutated', 'K140I']
Viruses measured against SCH23-y2016-s037:
['unmutated', 'K140I']
Viruses measured against SCH23-y2009-s002:
['unmutated', 'K140I']
Viruses measured against SCH23-y2009-s007:
['unmutated', 'K140I']


In [6]:
def make_neutcurve_df(fit):
    curves = []  # initialize an empty list to store neutralization curve data
    # Loop over each serum type and retrieve the curve
    for serum in list(fit.sera):
        for virus in list(fit.viruses[serum]):
            curve = fit.getCurve(serum=serum, virus=virus, replicate="average")
            neut_df = curve.dataframe()  # turn into a dataframe
            neut_df["serum"] = serum  # assign serum name to a column
            neut_df["virus"] = virus  # assign virus name to a column
            curves.append(neut_df)

    # Concatenate all the dataframes into one
    combined_curve = pd.concat(curves, axis=0)
    combined_curve["upper"] = combined_curve["measurement"] + combined_curve["stderr"]
    combined_curve["lower"] = combined_curve["measurement"] - combined_curve["stderr"]
    return combined_curve

fits1_df = make_neutcurve_df(fits_1)
fits1_df.head()

fits2_df = make_neutcurve_df(fits_2)
fits2_df.head()

  (popt, pcov) = scipy.optimize.curve_fit(


Unnamed: 0,concentration,measurement,fit,stderr,serum,virus,upper,lower
0,2.62e-06,,0.993,,SCH23-y2021-s056,unmutated,,
1,2.72e-06,,0.993,,SCH23-y2021-s056,unmutated,,
2,2.82e-06,,0.993,,SCH23-y2021-s056,unmutated,,
3,2.93e-06,,0.992,,SCH23-y2021-s056,unmutated,,
4,3.04e-06,,0.992,,SCH23-y2021-s056,unmutated,,


In [7]:
def plot_neutcurve(df, colormap):
    import altair as alt

    LINE_WIDTH = 1
    CIRCLE_SIZE = 40
    ERROR_BAR_OPACITY = 1

    y_fit = alt.Y(
        "fit:Q",
        title="Fraction Infectivity",
        scale=alt.Scale(domain=[0, 1.3]),
        axis=alt.Axis(values=[0, 0.5, 1])
    )

    serums = sorted(df['serum'].unique())
    charts = []

    for serum in serums:
        serum_df = df[df['serum'] == serum]

        # Get min and max concentration for that serum for domain
        min_conc = serum_df['concentration'].min()
        max_conc = serum_df['concentration'].max()

        # x encoding
        x_enc = alt.X(
            "concentration:Q",
            title="serum dilution",
            scale=alt.Scale(type="log", domain=[min_conc, max_conc]),
            axis=alt.Axis(format=".0e", tickCount=3)
        )

        # color encoding
        color_enc = alt.Color(
                "virus",
                scale=alt.Scale(domain=list(colormap.keys()), range=list(colormap.values())),
                title='Mutant'
        )

        base = alt.Chart(serum_df)

        line = base.mark_line(size=LINE_WIDTH).encode(
            x=x_enc,
            y=y_fit,
            color=color_enc
        )

        circle = base.mark_circle(size=CIRCLE_SIZE, opacity=1).encode(
            x=x_enc,
            y=alt.Y("measurement:Q", title="Fraction Infectivity"),
            color=color_enc
        )

        error = base.mark_errorbar(opacity=ERROR_BAR_OPACITY).encode(
            x=x_enc,
            y=alt.Y("lower", title="Fraction Infectivity"),
            y2="upper",
            color=color_enc
        )

        chart = (error + line + circle).properties(
            width=175,
            height=125,
            title=alt.TitleParams(
                text=serum, 
                fontSize=16, 
                fontWeight='bold', 
                anchor='middle'
            )
        )

        charts.append(chart)

    # Facet with 2 columns
    n_cols = 2
    rows = [
        alt.hconcat(*charts[i:i+n_cols])
        for i in range(0, len(charts), n_cols)
    ]
    
    final_plot = alt.vconcat(*rows)
    return final_plot

In [8]:
# all curves for first run
colors = {
    'unmutated' : '#BAB0AC',
    'S205Y' : '#4E79A7',
    'N165H' : '#F28E2B',
    'R220T' : '#E15759',
    'R229I' : '#76B7B2',
    'S145N' : '#59A14F',
    'K189E' : '#B07AA1'
}

plot_neutcurve(
    fits1_df,
    colors
)

In [9]:
# just destabilizing mutants and unmutated
colors = {
    'unmutated' : '#BAB0AC',
    'S205Y' : '#4E79A7',
    'N165H' : '#F28E2B',
    'R220T' : '#E15759',
    'R229I' : '#76B7B2',
}

plot_neutcurve(
    fits1_df.query('virus in ["unmutated", "R229I", "N165H", "R220T", "S205Y"]'),
    colors
)

In [10]:
# K140I run
colors = {
    'unmutated':'#BAB0AC',
    'K140I': '#EDC948'
}

plot_neutcurve(fits2_df, colors)

In [11]:
params_1 = fits_1.fitParams(ics=[50])

params_1 = params_1.assign(
    log2_fold_change = params_1.groupby("serum")["ic50"].transform(
        lambda x: np.log2(x / x[params_1["virus"] == "unmutated"].values[0])
    )
)

params_2 = fits_2.fitParams(ics=[50])

params_2 = params_2.assign(
    log2_fold_change = params_2.groupby("serum")["ic50"].transform(
        lambda x: np.log2(x / x[params_2["virus"] == "unmutated"].values[0])
    )
)

params = pd.concat([params_1, params_2], ignore_index=True)
mutations_measured = params.query('virus != "unmutated"')['virus'].drop_duplicates().tolist()

escape_data = pd.read_csv(
    '../results/summaries/Phenotypes_per_antibody_escape.csv'
).drop(columns=['antibody_set']).rename(
    columns={'antibody' : 'serum'}
).assign(
    mutation=lambda x: x['wildtype'] + x['site'].astype(str) + x['mutant'],
)[['serum', 'site', 'wildtype', 'mutant', 'mutation', 'escape']].query(
    'mutation in @mutations_measured'
)

escape_and_params = pd.merge(
    params,
    escape_data,
    left_on=['serum', 'virus'],
    right_on=['serum', 'mutation']
)

escape_and_params.head()

Unnamed: 0,serum,virus,replicate,nreplicates,ic50,ic50_bound,ic50_str,midpoint,midpoint_bound,midpoint_bound_type,slope,top,bottom,r2,log2_fold_change,site,wildtype,mutant,mutation,escape
0,SCH23-y2021-s056,K189E,average,2,0.000523,interpolated,0.000523,0.000523,0.000523,interpolated,1.38,1,0,0.985,2.36,189,K,E,K189E,0.534
1,SCH23-y2021-s056,S145N,average,2,0.000139,interpolated,0.000139,0.000139,0.000139,interpolated,1.89,1,0,0.992,0.447,145,S,N,S145N,0.0838
2,SCH23-y2021-s056,R229I,average,2,0.000108,interpolated,0.000108,0.000108,0.000108,interpolated,1.85,1,0,0.997,0.0816,229,R,I,R229I,-0.153
3,SCH23-y2021-s056,R220T,average,2,6.22e-05,interpolated,6.22e-05,6.22e-05,6.22e-05,interpolated,1.46,1,0,0.99,-0.713,220,R,T,R220T,-0.395
4,SCH23-y2021-s056,S205Y,average,2,9.67e-05,interpolated,9.67e-05,9.67e-05,9.67e-05,interpolated,1.46,1,0,0.968,-0.0774,205,S,Y,S205Y,-0.174


In [12]:
colors = {
    'S205Y' : '#4E79A7',
    'N165H' : '#F28E2B',
    'R220T' : '#E15759',
    'R229I' : '#76B7B2',
    'S145N' : '#59A14F',
    'K189E' : '#B07AA1',
    'K140I': '#EDC948'
}

r_value = escape_and_params['log2_fold_change'].corr(escape_and_params['escape'])
r_text = f"r = {r_value:.2f}"

hline = alt.Chart().mark_rule(
        color='gray',
        size=1.25,
        opacity=1,
        strokeDash=[5,5]
).encode(y=alt.Y(datum=0))

vline = alt.Chart().mark_rule(
        color='gray',
        size=1.25,
        opacity=1,
        strokeDash=[5,5]
).encode(x=alt.X(datum=0))


chart = alt.Chart(escape_and_params).mark_point(size=60, opacity=0.8).encode(
    x=alt.X(
        'log2_fold_change:Q', 
        title='log₂ IC50 fold change',
        axis=alt.Axis(grid=False)
    ),
    y=alt.Y(
        'escape:Q', 
        title='DMS escape effect',
        axis=alt.Axis(grid=False)
    ),
    fill=alt.Fill(
        "mutation:N",
        scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        title='Mutant',
    ),
    color=alt.Color(
        "mutation:N",
        scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        title='Mutant'
    ),
    shape='serum:N',
)

r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
    align='left',
    baseline='top',
    fontSize=16,
    fontWeight='normal',
    color='black'
).encode(
    text='text:N',
    x=alt.value(5), 
    y=alt.value(5)
)

scatter_plot = alt.layer(vline, hline, r_label, chart)
scatter_plot.properties(
    width=220,
    height=220
)