# Titers for a serum
Analyze titers for a serum, aggregating replicates which may be across multiple plates.

In [None]:
import pickle

import altair as alt

import neutcurve

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [None]:
plate_fits = snakemake.input.plate_fits
pickle_fits = snakemake.input.pickles
viral_strain_plot_order = snakemake.params.viral_strain_plot_order
per_rep_titers_csv = snakemake.output.per_rep_titers
median_titers_csv = snakemake.output.median_titers
curves_pdf = snakemake.output.curves_pdf
output_pickle = snakemake.output.pickle
serum = snakemake.wildcards.serum
qc_thresholds = snakemake.params.qc_thresholds
qc_exclusions = snakemake.params.qc_exclusions
qc_failures_file = snakemake.output.qc_failures

Process the QC exclusions:

In [None]:
print(f"The QC thresholds are:\n{qc_thresholds}")

# get viruses for which we are ignoring QC
viruses_ignore_qc = [
    virus
    for virus, d in qc_exclusions.items()
    if "ignore_qc" in d and d["ignore_qc"] is True
]
if viruses_ignore_qc:
    print("\nIgnoring QC for these viruses:\n\t" + "\n\t".join(viruses_ignore_qc))

# get virus replicates to drop
virus_replicates_to_drop = [
    (virus, rep)
    for virus, d in qc_exclusions.items()
    for rep in (d["replicates_to_drop"] if "replicates_to_drop" in d else [])
]
if virus_replicates_to_drop:
    print(
        "\nDropping these virus-replicates:\n\t"
        + "\n\t".join(str(tup) for tup in virus_replicates_to_drop)
    )

Read all titers for this serum, dropping virus/replicates to drop:

In [None]:
print(f"Reading titers for {serum=}")

assert len(plate_fits)

serum_fits = pd.concat([pd.read_csv(f) for f in plate_fits]).query("serum == @serum")

# remove virus-replicates to drop
print(f"Read a total of {len(serum_fits)} titers")
assert "tup" not in set(serum_fits.columns)
serum_fits = (
    serum_fits.assign(
        tup=lambda x: list(
            x[["virus", "replicate"]].itertuples(index=False, name=None)
        ),
    )
    .query("tup not in @virus_replicates_to_drop")
    .drop(columns="tup")
)
print(f"Retained {len(serum_fits)} titers after dropping specified ones")

assert len(serum_fits), f"no titers for {serum=}"
assert len(serum_fits) == len(serum_fits.groupby(["replicate", "virus"]))
print(f"Read a total of {len(serum_fits)} titers for {serum=}")

# get viruses in the order to plot them
viruses = sorted(serum_fits["virus"].unique())
if viral_strain_plot_order is not None:
    if not set(viruses).issubset(viral_strain_plot_order):
        raise ValueError(
            "`viral_strain_plot_order` lacks some viruses with titers:\n"
            + str(set(viruses) - set(viral_strain_plot_order))
        )
    viruses = [v for v in viral_strain_plot_order if v in viruses]
print(f"{serum=} has titers for a total of {len(viruses)} viruses")

Get the NT50s per replicate.
Note that a lower bound in the IC50s represents and upper-bound on the NT50s:

In [None]:
per_rep_titers = serum_fits.assign(
    nt50=lambda x: 1 / x["ic50"],
    nt50_bound=lambda x: x["ic50_bound"].map(
        {"interpolated": "interpolated", "upper": "lower", "lower": "upper"}
    ),
)[["serum", "replicate", "virus", "nt50", "nt50_bound"]]

assert per_rep_titers.notnull().all().all()

print(f"Saving {len(per_rep_titers)} per-replicate titers to {per_rep_titers_csv}")
per_rep_titers.to_csv(per_rep_titers_csv, index=False, float_format="%.4g")

Get the median titers and the standard error of the mean. We defined the bounds (is the measurement interpolated or at a bound of the dilution series) for the median as the same as the bounds for the median measurement if there are an odd number of measurements, or if there are an even number of measurements interpolated only if all of the middle values are interpolated:

In [None]:
def get_median_nt50_bound(s):
    """Get the bound for the median NT50."""
    s = list(s)
    if len(s) % 2:
        return s[len(s) // 2]
    else:
        bounds = s[len(s) // 2 - 1 : len(s) // 2 + 1]
        assert len(bounds) == 2
        if len(set(bounds)) == 1:
            return bounds[0]
        elif "interpolated" in bounds:
            return [b for b in bounds if b != "interpolated"][0]
        else:
            return "inconsistent"


median_titers = (
    per_rep_titers.sort_values("nt50")  # for getting median nt50 bound
    .groupby(["serum", "virus"], as_index=False)
    .aggregate(
        nt50=pd.NamedAgg("nt50", "median"),
        n_replicates=pd.NamedAgg("replicate", "count"),
        nt50_stderr=pd.NamedAgg("nt50", "sem"),
        nt50_bound=pd.NamedAgg("nt50_bound", get_median_nt50_bound),
    )
)

print(f"Saving {len(median_titers)} median titers to {median_titers_csv}")
median_titers.to_csv(median_titers_csv, index=False, float_format="%.4g")

Plot the per-replicate and median titers:

In [None]:
virus_selection = alt.selection_point(fields=["virus"], on="mouseover", empty=False)

per_rep_chart = (
    alt.Chart(per_rep_titers)
    .encode(
        alt.X(
            "nt50",
            title="neutralization titer",
            scale=alt.Scale(nice=False, padding=5, type="log"),
        ),
        alt.Y("virus", sort=viruses),
        alt.Shape("nt50_bound", title="is titer bound?"),
        strokeWidth=alt.condition(virus_selection, alt.value(2), alt.value(0)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if per_rep_titers[c].dtype == float else c
            for c in per_rep_titers
            if c != "serum"
        ],
    )
    .mark_point(
        size=45,
        filled=True,
        fillOpacity=0.5,
        strokeOpacity=1,
        stroke="black",
        color="#56B4E9",
    )
)

median_chart = (
    alt.Chart(median_titers)
    .encode(
        alt.X(
            "nt50",
            title="neutralization titer",
            scale=alt.Scale(nice=False, padding=5, type="log"),
        ),
        alt.Y("virus", sort=viruses),
        alt.Shape("nt50_bound", title="is titer bound?"),
        strokeWidth=alt.condition(virus_selection, alt.value(2), alt.value(0)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if median_titers[c].dtype == float else c
            for c in median_titers
            if c != "serum"
        ],
    )
    .mark_point(
        color="#E69F00",
        size=85,
        filled=True,
        fillOpacity=0.9,
        strokeOpacity=1,
        stroke="black",
    )
)

titer_chart = (
    (per_rep_chart + median_chart)
    .add_params(virus_selection)
    .properties(
        height=alt.Step(14),
        width=250,
        title=f"{serum} median (orange) and per-replicate (blue) titers",
    )
    .configure_axis(grid=False)
)

titer_chart

Combine all the `CurveFits` for this serum and retained viruses:

In [None]:
fits_to_combine = []
for fname in pickle_fits:
    with open(fname, "rb") as f:
        fits_to_combine.append(pickle.load(f))
serum_curvefits = neutcurve.CurveFits.combineCurveFits(
    fits_to_combine,
    sera=[serum],
    serum_virus_replicates_to_drop=[
        (serum, virus, replicate) for (virus, replicate) in virus_replicates_to_drop
    ],
)

Get any QC failures (ignoring any viruses for which we specified to ignore QC), and write to file:

In [None]:
qc_failures = []

insufficient_replicates = (
    median_titers.sort_values("virus", key=lambda s: s.map(lambda v: viruses.index(v)))
    .reset_index(drop=True)[
        median_titers["n_replicates"] < qc_thresholds["min_replicates"]
    ]
    .query("virus not in @viruses_ignore_qc")
    .drop(columns="serum")
)
if len(insufficient_replicates):
    print(f"The following viruses fail {qc_thresholds['min_replicates']=}")
    display(insufficient_replicates)
    qc_failures.append("min_replicates")
else:
    print(f"All viruses pass {qc_thresholds['min_replicates']=}")

max_fold_change_from_median = qc_thresholds["max_fold_change_from_median"]
assert max_fold_change_from_median > 1, max_fold_change_from_median
excess_fold_change_from_median = (
    median_titers.rename(columns={"nt50": "median_nt50"})[["virus", "median_nt50"]]
    .merge(per_rep_titers, validate="one_to_many")
    .assign(
        fold_change_from_median=lambda x: x["nt50"] / x["median_nt50"],
        excess_fold_change=lambda x: (
            (x["fold_change_from_median"] > max_fold_change_from_median)
            | (x["fold_change_from_median"] < 1 / max_fold_change_from_median)
        ),
    )
    .sort_values("virus", key=lambda s: s.map(lambda v: viruses.index(v)))
    .reset_index(drop=True)
    .query("excess_fold_change and (virus not in @viruses_ignore_qc)")[
        ["virus", "replicate", "median_nt50", "nt50", "fold_change_from_median"]
    ]
)
if len(excess_fold_change_from_median):
    print(f"\nThese replicates fail {qc_thresholds['max_fold_change_from_median']=}")
    display(excess_fold_change_from_median)
    qc_failures.append("max_fold_change_from_median")
else:
    print(f"\nAll viruses pass {qc_thresholds['max_fold_change_from_median']=}")

qc_failures = "\n".join(qc_failures)
if qc_failures:
    print(f"\nEncountered the following QC failures:\n{qc_failures}")
else:
    print("\nNo QC failures")
print(f"\nLogging QC failures to {qc_failures_file}")
with open(qc_failures_file, "w") as f:
    f.write(qc_failures)

Plot the individual neutralization curves:

In [None]:
fig, _ = serum_curvefits.plotReplicates(
    attempt_shared_legend=False,
    legendfontsize=8,
    ncol=4,
    heightscale=1.25,
    widthscale=1.25,
    viruses=viruses,
    subplot_titles="{virus}",
)

_ = fig.suptitle(
    f"neutralization curves for {serum}", y=1, fontsize=18, fontweight="bold"
)

fig.tight_layout()

Save the curves to a file:

In [None]:
print(f"Saving to {curves_pdf}")
fig.savefig(curves_pdf)

Save the `CurveFits` for the serum to a pickle file:

In [None]:
with open(output_pickle, "wb") as f:
    pickle.dump(serum_curvefits, f)