In [1]:
# calculates neut curves from validation data, gets IC50 values, compares with DMS escape data

In [None]:
import pandas as pd
import neutcurve
import altair as alt
import re
import httpimport
import numpy as np
import scipy.stats

In [None]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import main_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return main_theme.main_theme()


import_theme_new()

In [None]:
neut_df = pd.read_csv(snakemake.input.neut_data)

In [None]:
def fit_neutcurve(df):
    return neutcurve.CurveFits(
        data=df,
        serum_col="serum",
        virus_col="virus",
        replicate_col="replicate",
        conc_col="concentration",
        fracinf_col="fraction infectivity",
    )

fit = fit_neutcurve(neut_df)

In [None]:
fit_params = fit.fitParams(ics=[50, 90, 99])
ic50_fits = fit_params[['serum', 'virus','ic50','ic50_bound']]
display(ic50_fits)
ic50_fits.to_csv(snakemake.output.ic50_data, index=False)

In [None]:
# get the curves for each serum/virus combination
curves = []
for group, group_data in neut_df.groupby(['serum','virus']):
    curve = fit.getCurve(serum=group[0], virus=group[1], replicate="average")
    tmp_df = curve.dataframe()  
    tmp_df["serum"] = group[0]  
    tmp_df["virus"] = group[1]
    curves.append(tmp_df)

neutcurve_df = pd.concat(curves)
neutcurve_df["upper"] = neutcurve_df["measurement"] + neutcurve_df["stderr"]
neutcurve_df["lower"] = neutcurve_df["measurement"] - neutcurve_df["stderr"]

In [None]:
def custom_sort_order(array):
    # Helper function to extract numerical part from mutation strings.
    def extract_number(virus):
        num = re.search(r"\d+", virus)
        return (
            int(num.group()) if num else 0
        )  # Convert digits to integer, or 0 if none found.

    array = sorted(
        array, key=extract_number
    )  # Sort array by the numerical value extracted.
    
    # Ensure 'WT' (wild type) is the first element in the list if it exists.
    if "Unmutated" in array:
        array.remove("Unmutated")  # Remove 'Unmutated' from its current position.
        array.insert(0, "Unmutated")  # Insert 'Unmutated' at the beginning of the list.
    print(array)
    return array

colors_for_plot = [
    "#d1603d",
    "#e3a857",
    "#7ca982",
    "#49759c",
    "#8b5e83",
    "#e27d60",
]


def custom_sort_key(item):
    # Extract the time value from the item string
    match = re.search(r"(\d+)min", item)
    if match:
        return int(match.group(1))
    return 0  # Return 0 for items without a time value

# Adjust colors based on the unique mutations
def adjust_colors(merged):
    colors = ["black"] + colors_for_plot[: len(merged["virus"].unique()) - 1]
    return colors


In [None]:
def plot_neut_curve(df):
    color_variable = 'virus'
    # Define the axis and scale for the plot
    x_axis = alt.Axis(format=".0e", tickCount=3)
    scale = alt.Scale(
        type="log",
    )

    ### Define color scale
    # If 'Unmutated' is present, set the color scale to include 'Unmutated' as the first color and make it black
    if "Unmutated" in df["virus"].unique():
        print("Unmutated is present")
        colors = ["black"] + colors_for_plot[: len(df["virus"].unique()) - 1]

        color_scale = alt.Color(
            color_variable,
            scale=alt.Scale(
                domain=custom_sort_order(df["virus"].unique()), range=colors
            ),
        )
    

    # make the line component of the plot which represents the inferred fraction infectivity
    line = (
        alt.Chart(df)
        .mark_line(size=2)
        .encode(
            x=(alt.X("concentration:Q").scale(scale).axis(x_axis).title('Concentration (µg/mL)')),
            y=(
                alt.Y("fit:Q")
                .title('Fraction infectivity')
                .scale(alt.Scale(domain=[0, 1]))
                .axis(alt.Axis(values=[0, 0.5, 1]))
            ),
            color=color_scale,
        )
    )

    # make the circle component of the plot which represent the measured fraction infectivity
    circle = (
        alt.Chart(df)
        .mark_circle(size=40, opacity=1)
        .encode(
            x=(alt.X("concentration").scale(scale).axis(x_axis).title('Concentration (µg/mL)')),
            y=(alt.Y("measurement:Q").title('Fraction infectivity')),
            color=color_scale,
        )
    )

    # make the error bar component of the plot which represents the error in the measured fraction infectivity
    error = (
        alt.Chart(df)
        .mark_errorbar(opacity=1)
        .encode(
            x="concentration",
            y=alt.Y("lower").title('Fraction infectivity'),
            y2="upper",
            color=color_scale,
        )
    )

    # combine the line, circle, and error components into one plot and set dimensions
    plot = error + line + circle
    plot = plot.properties(width=150, height=100)
    return plot


neut_curves_plot = plot_neut_curve(
    neutcurve_df
)
neut_curves_plot.display()

In [None]:
# now lets make correlation between ic50 values and escape values
# read in the escape values
escape_df = pd.read_csv(snakemake.input.escape_df_4H3)
escape_df = escape_df[['site','wildtype','mutant','escape_mean','escape_std','times_seen_ab','effect']]
escape_df = escape_df.assign(virus=lambda x: x['wildtype'] + x['site'].astype(str) + x['mutant'])

merged_df = pd.merge(
    ic50_fits,
    escape_df,
    on="virus",
    how="left"
)


# set the escape value for unmutated to zero
def set_unmutated_escape_to_zero(merged):
    merged.loc[
        merged["virus"] == "Unmutated",
        ["escape_mean"],
    ] = 0
    return merged

merged_df_4H3 = set_unmutated_escape_to_zero(merged_df)


# Calculate ic50 values for each mutation relative to mean unmutated value
def calculate_relative_ic50(merged_df):
    unmutated_ic50 = merged_df.loc[merged_df["virus"] == "Unmutated", "ic50"].values[0]
    print(f"unmutated_value: {unmutated_ic50:.3f}")
    merged_df["relative_ic50"] = (merged_df["ic50"] / unmutated_ic50).round(3)

    # Calculate log2 ic50 values for each mutation relative to mean unmutated value
    merged_df["log2_relative_ic50"] = np.log2(merged_df["relative_ic50"])
    return merged_df

merged_df_4H3 = calculate_relative_ic50(merged_df_4H3)


##### calculate R value w log2 ic50 values:
def calculate_r_value_ic50(merged):
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        merged["escape_mean"], merged["log2_relative_ic50"]
    )
    r_value = float(r_value)
    print(f"IC50_r_value: {r_value:.2f}")
    return r_value

r_value_ic50 = calculate_r_value_ic50(merged_df_4H3)

In [None]:
# Create Altair scatter plot function
def make_altair_chart(merged, r_value, log=False):
    if log:
        scale_type = alt.Scale(type="log")
    else:
        scale_type = alt.Scale(type="linear")

    ic50_chart = (
        alt.Chart(merged)
        .mark_circle(size=120, opacity=1, stroke="black", strokeWidth=1)
        .encode(
            x=alt.X("escape_mean", title=["Escape measured in", "deep mutational scanning"]),
            y=alt.Y(
                "ic50",
                title=["IC50 in validation", "neutralizations (µg/mL)"],
                scale=scale_type,
            ),
            tooltip=["virus", "escape_mean", "log2_relative_ic50"],
            color=alt.Color(
                "virus",
                scale=alt.Scale(
                    domain=custom_sort_order(merged["virus"].unique()),
                    range=adjust_colors(merged),
                ),
            ),
        )
    )
    ic50_text = (
        alt.Chart(
            {
                "values": [
                    {
                        "x": 0.1,
                        "y": 5,
                        "text": f"r = {r_value:.2f}",
                    }
                ]
            }
        )
        .mark_text(
            dx=0,
            dy=0,
        )
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    combined_ic50_chart = alt.layer(ic50_chart, ic50_text)
    return combined_ic50_chart


# Create the Altair charts for 4H3
ic50_chart_4H3 = make_altair_chart(merged_df_4H3, r_value_ic50, log=True)


display(ic50_chart_4H3)


In [None]:
combined_chart = alt.hconcat(neut_curves_plot, ic50_chart_4H3)

display(combined_chart)

In [None]:

combined_chart.save(snakemake.output.escape_validation_plot_png, ppi=300)
combined_chart.save(snakemake.output.escape_validation_plot)