In [None]:
import warnings

# import math
# from IPython.display import display, HTML, SVG
import pandas as pd
import neutcurve
import altair as alt
import re
import os

# print(f"Using `neutcurve` version {neutcurve.__version__}")
import sys

# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

# import altair theme and enable
sys.path.append("../../config/")
import theme

alt.themes.register("main_theme", theme.main_theme)
alt.themes.enable("main_theme")

In [None]:
neut_file_path = "../../data/241014_NiVsG_42_sera_combined.csv"
fitParams_output = "../../results/fitparams.csv"
output_png_path = "../../results/ephrin_receptor.png"

In [None]:
receptor_flag = False
sera_flag = False
antibody_flag = True
vary_serum_flag = True
vary_virus_flag = False

In [None]:
df = pd.read_csv(neut_file_path)

if "serum" not in df.columns:
    raise ValueError("serum is not a column")
elif "virus" not in df.columns:
    raise ValueError("virus is not a column")
elif "replicate" not in df.columns:
    raise ValueError("replicate is not a column")
elif "concentration" not in df.columns:
    raise ValueError("concentration is not a column")
elif "fraction infectivity" not in df.columns:
    raise ValueError("fraction infectivity is not a column")

display(df.head(10))

In [None]:
# Estimate neutralization curves using the `curvefits` module from `neutcurve` package.
def get_neutcurve(df, replicate="average"):
    # estimate fits
    fits = neutcurve.curvefits.CurveFits(
        data=df,
        serum_col="serum",
        virus_col="virus",
        replicate_col="replicate",
        conc_col="concentration",
        fracinf_col="fraction infectivity",
        fixbottom=0,
    )

    fitParams = fits.fitParams(ics=[50, 90, 99])

    # get list of different sera and viruses that were tested
    serum_list = list(df["serum"].unique())
    virus_list = list(df["virus"].unique())

    curves = []  # initialize an empty list to store neutralization curve data

    # Loop over each serum type and retrieve the curve
    for serum in serum_list:
        for virus in virus_list:
            curve = fits.getCurve(serum=serum, virus=virus, replicate=replicate)
            neut_df = curve.dataframe()  # turn into a dataframe
            neut_df["serum"] = serum  # assign serum name to a column
            neut_df["virus"] = virus  # assign virus name to a column
            curves.append(neut_df)

    # Concatenate all the dataframes into one
    combined_curve = pd.concat(curves, axis=0)
    combined_curve["upper"] = combined_curve["measurement"] + combined_curve["stderr"]
    combined_curve["lower"] = combined_curve["measurement"] - combined_curve["stderr"]

    return combined_curve, fitParams


neutcurve_df, fitParams = get_neutcurve(df)


fitParams = fitParams.drop(["replicate", "nreplicates"], axis=1)

fitParams = create_ng_columns(fitParams)

display(fitParams)
display(neutcurve_df.head(10))

In [None]:
def plot_neut_curve(df):
    if receptor_flag:
        scale = alt.Scale(type="log")
        axis = alt.Axis(format=".0e", tickCount=3)
        title = "Concentration (µM)"
        legend_title = "Receptor"
    if sera_flag:
        scale = alt.Scale(type="log")
        axis = alt.Axis(format=".0e", tickCount=3)
        title = "Sera Dilution"
        legend_title = "Serum"
    if antibody_flag:
        scale = alt.Scale(type="log")
        axis = alt.Axis(format=".0e", tickCount=3)
        title = "Concentration (µg/mL)"
        legend_title = "Antibody"
    if vary_serum_flag:
        color_variable = "serum"

    chart = (
        alt.Chart(df)
        .mark_line(size=1.5)
        .encode(
            x=alt.X(
                "concentration:Q",
                scale=scale,
                axis=axis,
                title=title,
            ),
            y=alt.Y(
                "fit:Q",
                title="Fraction Infectivity",
            ),
            color=alt.Color(color_variable, title=legend_title),
        )
    )
    circle = (
        alt.Chart(df)
        .mark_circle(size=40, opacity=1)
        .encode(
            x=alt.X(
                "concentration",
                scale=scale,
                axis=axis,
                title=title,
            ),
            y=alt.Y("measurement:Q", title="Fraction Infectivity"),
            color=alt.Color(color_variable, title=legend_title),
        )
    )
    error = (
        alt.Chart(df)
        .mark_errorbar(opacity=1)
        .encode(
            x="concentration",
            y=alt.Y("lower", title="Fraction Infectivity"),
            y2="upper",
            color=color_variable,
        )
    )
    plot = chart + circle + error
    plot = plot.properties(width=300, height=200)
    return plot


ephrin_curve = plot_neut_curve(neutcurve_df)
ephrin_curve.display()
ephrin_curve.save(output_png_path, ppi=300)