# Aggregate titers across all sera
Aggregate the titers across all sera.

In [None]:
import pickle

import altair as alt

import neutcurve

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [None]:
input_pickles = snakemake.input.pickles
input_titers = snakemake.input.titers
output_pickles = snakemake.output.pickles
output_titers = snakemake.output.titers
titers_chart_html = snakemake.output.titers_chart
groups_sera = snakemake.params.groups_sera
groups = snakemake.params.groups
viral_strain_plot_order = snakemake.params.viral_strain_plot_order

Get the groups ordered by number of sera in each:

In [None]:
ordered_groups = [
    g
    for (_, g) in sorted(
        [(sum(g == group for (g, _) in groups_sera), group) for group in groups],
        reverse=True,
    )
]

ordered_groups_sera = [
    f"{group} {serum}"
    for group in ordered_groups
    for (g, serum) in sorted(groups_sera)
    if g == group
]

Get the merged titers and merged `CurveFits` object:

In [None]:
assert len(input_titers) == len(input_pickles) == len(groups_sera)
assert len(groups) == len(output_titers) == len(output_pickles)

titers = pd.concat([pd.read_csv(f) for f in input_titers], ignore_index=True)
assert len(titers) == len(titers.groupby(["group", "serum", "virus"]))
for group, f in zip(groups, output_titers):
    print(f"Writing aggregated titers for {group=} to {f}")
    titers.query("group == @group").to_csv(f, index=False, float_format="%.4g")

for group, f in zip(groups, output_pickles):
    fits_list = []
    for (g, serum), pickle_f in zip(groups_sera, input_pickles):
        if g == group:
            with open(pickle_f, "rb") as fin:
                fits_list.append(pickle.load(fin))
    curvefits = neutcurve.CurveFits.combineCurveFits(fits_list)
    print(f"Pickling aggregated `CurveFits` for {group=} to {f}")
    with open(f, "wb") as fout:
        pickle.dump(curvefits, fout)

Plot all the titers:

In [None]:
viruses = [v for v in viral_strain_plot_order if v in curvefits.allviruses]

assert set(viruses) == set(curvefits.allviruses)

sera = curvefits.sera

virus_selection = alt.selection_point(fields=["virus"], on="mouseover", empty=False)

serum_selection = alt.selection_point(
    fields=["group_serum"],
    bind="legend",
    toggle="true",
)

group_selection = alt.selection_point(
    fields=["group"],
    value=ordered_groups[0],
    bind=alt.binding_select(
        options=[None] + ordered_groups,
        labels=["all"] + ordered_groups,
        name="group",
    ),
)

ncols = min(8, titers["serum"].nunique())

titers_chart = (
    alt.Chart(titers)
    .transform_calculate(group_serum=alt.datum["group"] + " " + alt.datum["serum"])
    .add_params(virus_selection, serum_selection, group_selection)
    .transform_filter(group_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "titer",
            title="neutralization titer",
            scale=alt.Scale(nice=False, padding=4, type="log"),
            axis=alt.Axis(labelOverlap=True),
        ),
        alt.Y("virus", sort=viruses),
        alt.Facet(
            "group_serum:N",
            header=alt.Header(
                title=None, labelFontSize=11, labelFontStyle="bold", labelPadding=0
            ),
            spacing=3,
            columns=ncols,
            sort=ordered_groups_sera,
        ),
        alt.StrokeWidth(
            "group_serum:N",
            scale=alt.Scale(
                domain=ordered_groups_sera, range=[1] * len(ordered_groups_sera)
            ),
            legend=alt.Legend(
                orient="bottom",
                columns=ncols,
                symbolLimit=0,
                symbolFillColor="black",
                title="serum (click to select)",
            ),
            sort=ordered_groups_sera,
        ),
        alt.Shape(
            "titer_bound",
            title="titer interpolated or at dilution bounds",
            legend=alt.Legend(orient="top", titleLimit=200, titleOrient="left"),
        ),
        color=alt.condition(virus_selection, alt.value("red"), alt.value("black")),
        tooltip=[
            alt.Tooltip(c, format=".3g") if titers[c].dtype == float else c
            for c in titers.columns
        ],
    )
    .mark_line(point=True)
    .configure_axis(grid=False)
    .configure_legend(padding=10, labelOffset=2, columnPadding=8, labelLimit=400)
    .configure_point(size=45)
    .properties(
        height=alt.Step(11),
        width=100,
        title=alt.TitleParams(
            "Interactive chart of serum neutralization titers",
            subtitle=[
                "Mouseover points for details.",
                "Click serum legend at bottom to select sera to show.",
                "Use dropdown at bottom to select serum groups to show.",
            ],
            fontSize=15,
            dx=100,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
)

print(f"Saving chart to {titers_chart_html}")
titers_chart.save(titers_chart_html)

titers_chart