# Aggregate titers across all sera
Aggregate the titers across all sera, failing if there are QC failures for any individual sera.

In [None]:
import pickle

import altair as alt

import neutcurve

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [None]:
qc_failures_file = snakemake.input.qc_serum_titer_failures
input_pickles = snakemake.input.pickles
input_titers = snakemake.input.titers
output_pickle = snakemake.output.pickle
output_titers = snakemake.output.titers
titers_chart_html = snakemake.output.titers_chart
viral_strain_plot_order = snakemake.params.viral_strain_plot_order

Check for quality control failures for any individual sera:

In [None]:
with open(qc_failures_file) as f:
    qc_failures = f.readlines()
if not all(line.strip().endswith("serum passed all QC") for line in qc_failures):
    raise ValueError(
        f"QC failures for some serum titers. See {qc_failures_file}:\n\n"
        + "".join(qc_failures)
    )
else:
    print("All serum titers pass the QC filters.")

Get the merged titers and merged `CurveFits` object:

In [None]:
assert len(input_titers) == len(input_pickles)

titers = pd.concat([pd.read_csv(f) for f in input_titers], ignore_index=True)
assert len(titers) == len(titers.groupby(["serum", "virus"]))
print(f"Writing aggregated titers to {output_titers}")
titers.to_csv(output_titers, index=False, float_format="%.4g")

fits_list = []
for fname in input_pickles:
    with open(fname, "rb") as f:
        fits_list.append(pickle.load(f))
curvefits = neutcurve.CurveFits.combineCurveFits(fits_list)
print(f"Pickling aggregated `CurveFits` to {output_pickle}")
with open(output_pickle, "wb") as f:
    pickle.dump(curvefits, f)

Plot all the titers:

In [None]:
viruses = [v for v in viral_strain_plot_order if v in curvefits.allviruses]

sera = curvefits.sera

virus_selection = alt.selection_point(fields=["virus"], on="mouseover", empty=False)

serum_selection = alt.selection_point(
    fields=["serum"],
    bind="legend",
    toggle="true",
)

ncols = 8

titers_chart = (
    alt.Chart(titers)
    .add_params(virus_selection, serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "nt50",
            title="neutralization titer",
            scale=alt.Scale(nice=False, padding=4, type="log"),
            axis=alt.Axis(labelOverlap=True),
        ),
        alt.Y("virus", sort=viruses),
        alt.Facet(
            "serum",
            header=alt.Header(
                title=None, labelFontSize=11, labelFontStyle="bold", labelPadding=0
            ),
            spacing=3,
            columns=ncols,
        ),
        alt.StrokeWidth(
            "serum:N",
            scale=alt.Scale(domain=sera, range=[1] * len(sera)),
            legend=alt.Legend(
                orient="bottom",
                columns=ncols,
                symbolLimit=0,
                symbolFillColor="black",
                title="serum (click to select)",
            ),
        ),
        color=alt.condition(virus_selection, alt.value("red"), alt.value("black")),
        tooltip=[
            "serum",
            "virus",
            alt.Tooltip("nt50", title="NT50", format=".3g"),
            "n_replicates",
        ],
    )
    .mark_line(point=True)
    .configure_axis(grid=False)
    .configure_point(size=45)
    .properties(
        height=alt.Step(11),
        width=100,
        title=alt.TitleParams(
            "Interactive chart of serum neutralization titers",
            subtitle="Mouseover points for details, click serum legend at bottom to select sera to show",
            fontSize=15,
            dx=100,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
)

print(f"Saving chart to {titers_chart_html}")
titers_chart.save(titers_chart_html)

titers_chart