# Titers for a serum
Analyze titers for a serum, aggregating replicates which may be across multiple plates.

In [None]:
import pickle
import sys

import altair as alt

import matplotlib.pyplot as plt

import neutcurve

import numpy

import pandas as pd

import ruamel.yaml as yaml

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [None]:
pickle_fits = snakemake.input.pickles
per_rep_titers_csv = snakemake.output.per_rep_titers
titers_csv = snakemake.output.titers
curves_pdf = snakemake.output.curves_pdf
output_pickle = snakemake.output.pickle
qc_drops_file = snakemake.output.qc_drops
viral_strain_plot_order = snakemake.params.viral_strain_plot_order
serum_titer_as = snakemake.params.serum_titer_as
qc_thresholds = snakemake.params.qc_thresholds
serum = snakemake.wildcards.serum

print(f"Processing {serum=}")

## Get all titers for this plate
Combine all the pickled `neutcurve.CurveFits` from plates for this serum into a single `neutcurve.CurveFits`:

In [None]:
print(f"Combining the curve fits for {serum=} from {pickle_fits=}")

fits_to_combine = []
for fname in pickle_fits:
    with open(fname, "rb") as f:
        fits_to_combine.append(pickle.load(f))
fits_noqc = neutcurve.CurveFits.combineCurveFits(fits_to_combine, sera=[serum])

Indicate how we are calculating the titer:

In [None]:
print(f"Calculating with {serum_titer_as=}")
assert serum_titer_as in {"nt50", "midpoint"}

Get all the per-replicate fit params with the titers.
We also convert the IC50 to NT50, and take inverse of midpoint to get it on same scale as NT50s:

In [None]:
per_rep_titers = fits_noqc.fitParams(average_only=False, no_average=True).assign(
    nt50=lambda x: 1 / x["ic50"],
    midpoint=lambda x: 1 / x["midpoint_bound"],
    titer=lambda x: x["midpoint"] if serum_titer_as == "midpoint" else x["nt50"],
    titer_bound=lambda x: (
        x["midpoint_bound_type"] if serum_titer_as == "midpoint" else x["ic50_bound"]
    ).map({"lower": "upper", "upper": "lower", "interpolated": "interpolated"}),
    titer_as=serum_titer_as,
)[
    [
        "serum",
        "virus",
        "replicate",
        "titer",
        "titer_bound",
        "titer_as",
        "nt50",
        "midpoint",
        "top",
        "bottom",
        "slope",
    ]
]
assert per_rep_titers.notnull().all().all()

if len(invalid_titer_as := per_rep_titers.query("(titer_as == 'nt50') and top <= 0.5")):
    raise ValueError(
        f"There are titers computed as nt50 when curve top <= 0.5:\n{invalid_titer_as}"
    )
assert len(per_rep_titers) == per_rep_titers["replicate"].nunique()

# get viruses in the order to plot them
viruses = sorted(per_rep_titers["virus"].unique())
if viral_strain_plot_order is not None:
    if not set(viruses).issubset(viral_strain_plot_order):
        raise ValueError(
            "`viral_strain_plot_order` lacks some viruses with titers:\n"
            + str(set(viruses) - set(viral_strain_plot_order))
        )
    viruses = [v for v in viral_strain_plot_order if v in viruses]
print(f"{serum=} has titers for a total of {len(viruses)} viruses")

## Correlate NT50s with midpoints of curves
Plot the correlation of the NT50s with the midpoint (this is an interactive plot, mouse over points for details).
This plot can help you determine if you made the correct choice of `serum_titer_as` when choosing to use the midpoint or NT50 for the titer.
For titers where they are well correlated it should not matter which you chose.
But if there are titers far from the correlation line, you should look at those measurements and curves to make sure you made the correct choice of calculating the titer as the NT50 versus midpoint:

In [None]:
virus_selection = alt.selection_point(fields=["virus"], on="mouseover", empty=False)

midpoint_vs_nt50_chart = (
    alt.Chart(per_rep_titers)
    .add_params(virus_selection)
    .encode(
        alt.X("nt50", scale=alt.Scale(type="log", nice=False, padding=8)),
        alt.Y("midpoint", scale=alt.Scale(type="log", nice=False, padding=8)),
        strokeWidth=alt.condition(virus_selection, alt.value(3), alt.value(0)),
        size=alt.condition(virus_selection, alt.value(100), alt.value(60)),
        tooltip=[
            alt.Tooltip(c, format=".2g") if per_rep_titers[c].dtype == float else c
            for c in per_rep_titers.columns
            if c not in {"serum", "titer_as"}
        ],
    )
    .mark_circle(stroke="red", fillOpacity=0.45, color="black")
    .properties(
        width=350,
        height=350,
        title=f"NT50 versus midpoint for serum {serum}",
    )
    .configure_axis(grid=False)
)

midpoint_vs_nt50_chart

Write the individual per-replicate titers to a file, this is before any QC has been applied:

In [None]:
print(f"Writing per-replicate titers (without QC filtering) to {per_rep_titers_csv=}")
per_rep_titers.to_csv(per_rep_titers_csv, index=False, float_format="%.4g")

## Plot median titers and determine if they pass QC
Get the median titers for each virus across replicates, then add these median titers to the per-replicate titers and calculate the fold-change in titer between each replicate and its median.
Finally, for each virus indicates whether it passes the QC:

In [None]:
print(f"Using the following {qc_thresholds=}")


def get_median_bound(s):
    """Get the bound on titer when taking median."""
    s = list(s)
    if len(s) % 2:
        return s[len(s) // 2]
    else:
        bounds = s[len(s) // 2 - 1 : len(s) // 2 + 1]
        assert len(bounds) == 2
        if len(set(bounds)) == 1:
            return bounds[0]
        elif "interpolated" in bounds:
            return [b for b in bounds if b != "interpolated"][0]
        else:
            return "inconsistent"


median_titers_noqc = (
    per_rep_titers.sort_values("titer")  # for getting median bound
    .groupby(["serum", "virus", "titer_as"], as_index=False)
    .aggregate(
        titer=pd.NamedAgg("titer", "median"),
        n_replicates=pd.NamedAgg("replicate", "count"),
        titer_sem=pd.NamedAgg("titer", "sem"),
        titer_bound=pd.NamedAgg("titer_bound", get_median_bound),
    )
)

per_rep_titers_w_fc = (
    per_rep_titers.merge(
        median_titers_noqc[["serum", "virus", "titer"]].rename(
            columns={"titer": "median_titer"}
        ),
        validate="many_to_one",
        on=["serum", "virus"],
    )
    .assign(
        fc_from_median=lambda x: numpy.where(
            x["titer"] > x["median_titer"],
            x["titer"] / x["median_titer"],
            x["median_titer"] / x["titer"],
        ),
    )
    .drop(columns=["serum", "titer_as", "median_titer"])
)

median_titers_noqc = median_titers_noqc.merge(
    per_rep_titers_w_fc.groupby("virus", as_index=False).aggregate(
        max_fc_from_median=pd.NamedAgg("fc_from_median", "max")
    ),
    on="virus",
    validate="one_to_one",
).assign(
    fails_min_reps=lambda x: x["n_replicates"] < qc_thresholds["min_replicates"],
    fails_max_fc=lambda x: (
        x["max_fc_from_median"] >= qc_thresholds["max_fold_change_from_median"]
    ),
    fails_qc=lambda x: x["fails_min_reps"] | x["fails_max_fc"],
    fails_qc_reason=lambda x: (
        x.apply(
            lambda r: ", ".join(
                (["min_replicates"] if r["fails_min_reps"] else [])
                + (["max_fold_change_from_median"] if r["fails_max_fc"] else [])
            ),
            axis=1,
        )
    ),
)

# get viruses failing QC in order to plot
viruses_failing_qc = (
    median_titers_noqc.query("fails_qc").set_index("virus")["fails_qc_reason"].to_dict()
)
viruses_failing_qc = {
    v: viruses_failing_qc[v] for v in viruses if v in viruses_failing_qc
}

median_titers_noqc = median_titers_noqc.drop(
    columns=["fails_min_reps", "fails_max_fc", "fails_qc_reason"]
)

per_rep_titers_w_fc = per_rep_titers_w_fc.merge(
    median_titers_noqc[["virus", "fails_qc"]],
    on="virus",
    validate="many_to_one",
)

Now plot the per-replicate and median titers, indicating any viruses that failed QC.
Note that potentially some of these titers may still be retained if the viruses in question are specified in `viruses_ignore_qc` of `qc_thresholds`.

In [None]:
virus_selection = alt.selection_point(fields=["virus"], on="mouseover", empty=False)

per_rep_chart = (
    alt.Chart(per_rep_titers_w_fc)
    .encode(
        alt.X("titer", scale=alt.Scale(nice=False, padding=5, type="log")),
        alt.Y("virus", sort=viruses),
        alt.Fill(
            "fails_qc",
            title=f"fails {qc_thresholds['min_replicates']=}, {qc_thresholds['max_fold_change_from_median']=}",
            legend=alt.Legend(titleLimit=500),
        ),
        strokeWidth=alt.condition(virus_selection, alt.value(2), alt.value(0)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if per_rep_titers_w_fc[c].dtype == float else c
            for c in per_rep_titers_w_fc
        ],
    )
    .mark_point(
        shape="circle",
        size=40,
        filled=True,
        fillOpacity=0.5,
        strokeOpacity=1,
        stroke="black",
    )
)

median_chart = (
    alt.Chart(median_titers_noqc)
    .encode(
        alt.X("titer", scale=alt.Scale(nice=False, padding=5, type="log")),
        alt.Y("virus", sort=viruses),
        alt.Fill("fails_qc"),
        strokeWidth=alt.condition(virus_selection, alt.value(2), alt.value(0)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if median_titers_noqc[c].dtype == float else c
            for c in median_titers_noqc
        ],
    )
    .mark_point(
        shape="square",
        size=75,
        filled=True,
        fillOpacity=0.9,
        strokeOpacity=1,
        stroke="black",
    )
)

titer_chart = (
    (per_rep_chart + median_chart)
    .add_params(virus_selection)
    .properties(
        height=alt.Step(11),
        width=250,
        title=f"{serum} median (square) and per-replicate (small circle) titers",
    )
    .configure_axis(grid=False)
)

titer_chart

## Plot individual curves for any viruses failing QC
Plot individual curves for viruses failing QC.
Note that potentially some of these titers may still be retained if the viruses in question are specified in `viruses_ignore_qc` of `qc_thresholds`.

In [None]:
print(f"Neutralization curves for the {len(viruses_failing_qc)} viruses failing QC:")

if len(viruses_failing_qc):
    fig, _ = fits_noqc.plotReplicates(
        viruses=viruses_failing_qc,
        attempt_shared_legend=False,
        legendfontsize=8,
        ncol=4,
        heightscale=1.2,
        widthscale=1.2,
        subplot_titles="{virus}",
    )
    _ = fig.suptitle(
        f"neutralization curves for viruses failing QC for {serum}",
        y=1,
        fontsize=18,
        fontweight="bold",
    )
    fig.tight_layout()

## Get the viruses to drop for QC failures
Drop any viruses that fail QC and are not specified in `viruses_ignore_qc` of `qc_thresholds`:

In [None]:
viruses_to_drop = {
    v: reason
    for (v, reason) in viruses_failing_qc.items()
    if v not in qc_thresholds["viruses_ignore_qc"]
}

print(f"Dropping {len(viruses_to_drop)} viruses for failing QC:")
yaml.YAML(typ="rt").dump(viruses_to_drop, sys.stdout)
if nkept := (len(viruses_failing_qc) - len(viruses_to_drop)):
    print(
        f"\nRetaining {nkept} viruses that fail QC because they are in `viruses_ignore_qc`:"
    )
    print(
        {
            v: reason
            for (v, reason) in viruses_failing_qc.items()
            if v in qc_thresholds["viruses_ignore_qc"]
        }
    )

print(f"\nWriting QC drops to {qc_drops_file}")
with open(qc_drops_file, "w") as f:
    yaml.YAML(typ="rt").dump(viruses_to_drop, f)

## Get and plot the neutralization curves for all retained viruses
First, get the `CurveFits` for just those retained viruses (dropping ones that fail QC), and plot:

In [None]:
fits_qc = neutcurve.CurveFits.combineCurveFits(
    [fits_noqc],
    viruses=[v for v in viruses if v not in viruses_to_drop],
)
assert len(viruses) == len(fits_qc.viruses[serum]) + len(viruses_to_drop)

fig, _ = fits_qc.plotReplicates(
    attempt_shared_legend=False,
    legendfontsize=8,
    ncol=4,
    heightscale=1.2,
    widthscale=1.2,
    subplot_titles="{virus}",
    viruses=[v for v in viruses if v not in viruses_to_drop],
)
_ = fig.suptitle(
    f"neutralization curves for retained viruses for {serum}",
    y=1,
    fontsize=18,
    fontweight="bold",
)
fig.tight_layout()
display(fig)

print(f"Saving to plot of curves to {curves_pdf}")
fig.savefig(curves_pdf)
plt.close(fig)

Save the `CurveFits` to a pickle file:

In [None]:
with open(output_pickle, "wb") as f:
    pickle.dump(fits_qc, f)

Write the titers (excluding QC dropped viruses) to a CSV:

In [None]:
print(f"Writing titers to {titers_csv}")

(
    median_titers_noqc.query("virus not in @viruses_to_drop")[
        [
            "serum",
            "virus",
            "titer",
            "titer_bound",
            "titer_sem",
            "n_replicates",
            "titer_as",
        ]
    ].to_csv(titers_csv, index=False, float_format="%.4g")
)