# This notebook will read in experimentally determined fraction infectivity curves, plot, and then make correlations with DMS data

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
altair_config=None
nipah_config=None

neut = None
escape_file = None

nah1_validation_neut_curves = None
IC50_validation_plot = None
combined_ic50_neut_curve_plot = None

In [None]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import yaml

import neutcurve

import scipy.stats
print(f"Using `neutcurve` version {neutcurve.__version__}")

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")

### For running interactively

In [None]:
if nah1_validation_neut_curves is None:
    altair_config = 'data/custom_analyses_data/theme.py'
    nipah_config = 'nipah_config.yaml'
    escape_file = 'results/antibody_escape/averages/nAH1.3_mut_effect.csv'
    neut = 'data/custom_analyses_data/experimental_data/nAH1_3_mab_validation_neuts.csv'

    ephrin_binding_neuts_file = 'data/custom_analyses_data/experimental_data/bat_ephrin_neuts.csv'
    ephrin_validation_curves = 'data/custom_analyses_data/experimental_data/binding_single_mutant_validations.csv'
    validation_ic50s_file = 'data/custom_analyses_data/experimental_data/receptor_IC_validations.csv'
    e2_monomeric_binding_file = 'results/receptor_affinity/averages/EFNB2_monomeric_mut_effect.csv'
    e3_dimeric_binding_file = 'results/receptor_affinity/averages/EFNB3_dimeric_mut_effect.csv'

### Read in config files

In [None]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Function for getting fitting neut curve from experimental data

In [None]:
def get_neutcurve(df, replicate="average"):
    # estimate neut curves
    fits = neutcurve.curvefits.CurveFits(
        data=df,
        serum_col="antibody",
        virus_col="virus",
        replicate_col="replicate",
        conc_col="concentration",
        fracinf_col="fraction infectivity",
        fixbottom=0,
    )
    # get neutcurve parameters for different IC values
    fitParams = fits.fitParams(ics=[50, 90, 95, 97, 98, 99])

    # Get the parameters for loop
    antibody_list = list(df["antibody"].unique())
    virus_list = list(df["virus"].unique())

    curve_list = []  # set to append
    for antibody in antibody_list:
        for virus in virus_list:
            # get curves and combine
            curve = fits.getCurve(serum=antibody, virus=virus, replicate=replicate)
            neut_df = curve.dataframe()
            neut_df["antibody"] = antibody
            neut_df["virus"] = virus
            curve_list.append(neut_df)

    neut_curve_df = pd.concat(curve_list)
    neut_curve_df["upper"] = neut_curve_df["measurement"] + neut_curve_df["stderr"]
    neut_curve_df["lower"] = neut_curve_df["measurement"] - neut_curve_df["stderr"]

    return fitParams, neut_curve_df


### Make neut curve plot

In [None]:
# Sorting function to put 'WT' on top of the legend, followed by numerical order
def custom_sort_order(array):
    # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
    def extract_number(virus):
        num = re.search(r"\d+", virus)
        return int(num.group()) if num else 0

    array = sorted(array, key=extract_number)

    # Move 'WT' to the beginning of the list
    if "WT" in array:
        array.remove("WT")
        array.insert(0, "WT")
    return array


def plot_neut_curves(df, neut_curve_name, color_name):
    # Define the category10 colors manually
    category10_colors = [
        "#4E79A5",
        "#F18F3B",
        "#E0585B",
        "#77B7B2",
        "#5AA155",
        "#EDC958",
        "#AF7AA0",
        "#FE9EA8",
        "#9C7561",
        "#BAB0AC",
    ]

    # Adjust colors based on the unique mutations
    colors = ["black"] + category10_colors[: len(df["virus"].unique()) - 1]

    chart = (
        alt.Chart(df)
        .mark_line(size=1.5, opacity=1)
        .encode(
            x=alt.X(
                "concentration:Q",
                scale=alt.Scale(type="log"),
                axis=alt.Axis(format=".0e", tickCount=3),
                title=neut_curve_name,
            ),
            y=alt.Y("fit:Q", title="Fraction Infectivity", axis=alt.Axis(tickCount=3)),
            color=alt.Color(
                "virus",
                title=color_name,
                scale=alt.Scale(
                    domain=custom_sort_order(df["virus"].unique()), range=colors
                ),
            ),
        )
        .properties(
            height=config["neut_curve_height"],
            width=config["neut_curve_width"],
        )
    )
    circle = (
        alt.Chart(df)
        .mark_circle(size=40, opacity=1)
        .encode(
            x=alt.X(
                "concentration",
                scale=alt.Scale(type="log"),
                axis=alt.Axis(format=".0e", tickCount=3),
                title=neut_curve_name,
            ),
            y=alt.Y(
                "measurement:Q",
                title="Fraction Infectivity",
                axis=alt.Axis(tickCount=3),
            ),
            color=alt.Color(
                "virus",
                title=color_name,
                scale=alt.Scale(
                    domain=custom_sort_order(df["virus"].unique()), range=colors
                ),
            ),
        )
        .properties(
            height=config["neut_curve_height"],
            width=config["neut_curve_width"],
        )
    )
    error = (
        alt.Chart(df)
        .mark_errorbar(opacity=1)
        .encode(
            x="concentration",
            y=alt.Y("lower", title="Fraction Infectivity"),
            y2="upper",
            color="virus",
        )
    )
    plot = chart + circle + error
    return plot


### Now calculate correlations with DMS data and plot

In [None]:
def plot_ic50_correlations(fitparams, name, escape):
    # Merge IC50 and DMS escape dataframes and append WT so it has escape score of 0
    fitparams["lower_bound"] = fitparams["ic50_bound"].apply(lambda x: x == "lower")
    fitparams["mutation"] = fitparams["virus"]
    # Merge with DMS escape data
    merged = fitparams.merge(escape, on=["mutation"])
    wt_rows = fitparams[fitparams["mutation"] == "WT"].copy()
    wt_rows["escape_median"] = 0
    merged = pd.concat([merged, wt_rows], ignore_index=True)

    # calculate R value:
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        merged["escape_median"], merged["ic50"]
    )

    # Define the category10 colors manually
    category10_colors = [
        "#4E79A5",
        "#F18F3B",
        "#E0585B",
        "#77B7B2",
        "#5AA155",
        "#EDC958",
        "#AF7AA0",
        "#FE9EA8",
        "#9C7561",
        "#BAB0AC",
    ]

    # Adjust colors based on the unique mutations
    colors = ["black"] + category10_colors[: len(merged["mutation"].unique()) - 1]

    corr_chart = (
        alt.Chart(merged)
        .mark_point(size=125)
        .encode(
            x=alt.X(
                "escape_median", title="DMS escape score", axis=alt.Axis(grid=True)
            ),
            y=alt.Y(
                "ic50",
                title=f"{name}",
                scale=alt.Scale(type="log"),
                axis=alt.Axis(grid=True),
            ),
            color=alt.Color(
                "mutation",
                title="Mutant",
                scale=alt.Scale(
                    domain=custom_sort_order(merged["mutation"].unique()), range=colors
                ),
            ),
            shape=alt.Shape("lower_bound", title="Lower Bound"),
        )
    )
    text = (
        alt.Chart(
            {
                "values": [
                    {
                        "x": merged["escape_median"].min(),
                        "y": merged["ic50"].max(),
                        "text": f"r = {r_value:.2f}",
                    }
                ]
            }
        )
        .mark_text(align="left", baseline="top", dx=5)  # Adjust this for position
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    final_chart = corr_chart + text
    return final_chart


### Combine all the data from above in single function

In [None]:
def get_neut_curve_ic50_correlations(
    raw_data, escape, neut_curve_name, name, color_name
):
    fit_df, neut_df = get_neutcurve(raw_data)
    fit_df["ic50_ng"] = fit_df["ic50"] * 1000
    print("Here are the ic50 values:")
    display(fit_df[["serum", "virus", "ic50_ng"]].round(2))

    # make neut curve plot
    neut_curves_plot = plot_neut_curves(neut_df, neut_curve_name, color_name)

    # make IC50 correlation plot
    ic50_correlations = plot_ic50_correlations(fit_df, name, escape)

    combined_neut_curve = neut_curves_plot | ic50_correlations

    return neut_curves_plot, ic50_correlations, combined_neut_curve


### Do for antibody validations

In [None]:
escape = pd.read_csv(escape_file)
neuts = pd.read_csv(neut)

# Get rid of Y455M because its not present in escape data
neuts = neuts[neuts["virus"] != "Y455M"]


def fix_neut_df(df):
    df = df[["serum", "virus", "replicate", "concentration", "fraction infectivity"]]
    df = df.rename(columns={"serum": "antibody"})
    return df


# Fix naming in dataframe
nAH1_validation_neuts = fix_neut_df(neuts)

# Calculate neut curve, ic50 correlations with DMS data, and combine figures
nAH1_curve_plot, nAH1_ic50_correlations, nAH1_combined = (
    get_neut_curve_ic50_correlations(
        nAH1_validation_neuts,
        escape,
        "nAH1.3 Concentration (μg/ml)",
        "nAH1.3 IC₅₀ (μg/ml)",
        "Mutant",
    )
)

nAH1_curve_plot.display()
if combined_ic50_neut_curve_plot is not None:
    nAH1_curve_plot.save(nah1_validation_neut_curves)

nAH1_ic50_correlations.display()
if combined_ic50_neut_curve_plot is not None:
    nAH1_ic50_correlations.save(IC50_validation_plot)

nAH1_combined.display()
if combined_ic50_neut_curve_plot is not None:
    nAH1_combined.save(combined_ic50_neut_curve_plot)


### Now do for ephrin neutralization

In [None]:
#df = pd.read_csv(ephrin_binding_neuts_file)

#df = df[df['serum'] == 'CHO-EFNB3'] # only plot CHO-EFNB3 data
#df = df.rename(columns={'serum':'antibody'}) #rename 
#df['virus'] = df['virus'].replace({'E2-dimeric': 'EFNB2-dimeric', 'E2-monomeric': 'EFNB2-monomeric','E3-dimeric': 'EFNB3-dimeric', 'E3-monomeric': 'EFNB3-monomeric'}) #fix names

# Get neutcurve
#fit_df, neut_df = get_neutcurve(df)
#fit_df['ic50_nM'] = fit_df['ic50'] * 1000
#print('Here are the ic50 values:')
#display(fit_df[['serum','virus','ic50_nM']].round(2))
#
#ephrin_neut_curve = plot_neut_curves(neut_df,'Concentration (μM)','Receptor')
#ephrin_neut_curve.display()

In [None]:
#ephrin_validation = pd.read_csv(ephrin_validation_curves)
## Fix naming in dataframe
#ephrin_validation_neuts = fix_neut_df(ephrin_validation)
#
## Calculate neut curve, ic50 correlations with DMS data, and combine figures
#ephrin_curve_plot, ephrin_ic50_correlations, ephrin_combined = get_neut_curve_ic50_correlations(ephrin_validation_neuts, escape, 'nAH1.3 Concentration (μg/ml)','nAH1.3 IC₅₀ (μg/ml)','Mutant')
#
#nAH1_curve_plot.display()
#if combined_ic50_neut_curve_plot is not None:
#    nAH1_curve_plot.save(nah1_validation_neut_curves)
#
#nAH1_ic50_correlations.display()
#if combined_ic50_neut_curve_plot is not None:
#    nAH1_ic50_correlations.save(IC50_validation_plot)
#
#nAH1_combined.display()
#if combined_ic50_neut_curve_plot is not None:
#    nAH1_combined.save(combined_ic50_neut_curve_plot)

In [None]:
#e2_monomeric_binding = pd.read_csv(e2_monomeric_binding_file)
#e3_dimeric_binding = pd.read_csv(e3_dimeric_binding_file)
#
#def make_df(df,name):
#    merged = validation_ic50s.merge(df,on=['mutation'])
#    wt_rows = validation_ic50s[validation_ic50s['mutation'] == 'WT'].copy()
#    wt_rows['Ephrin binding_median'] = 0.00000
#    merged = pd.concat([merged, wt_rows], ignore_index=True)
#    df_tmp = merged[merged['antibody'] == name]
#    return df_tmp
#
#e2_df_out = make_df(e2_monomeric_binding,'EFNB2-monomeric')
#e3_df_out = make_df(e3_dimeric_binding,'EFNB3-dimeric')