# Process plate counts to get fraction infectivities and fit curves
This notebook is designed to be run using `snakemake`, and analyzes a plate of sequencing-based neutralization assays.

The plots generated by this notebook are interactive, so you can mouseover points for details, use the mouse-scroll to zoom and pan, and use interactive dropdowns at the bottom of the plots.

## Setup
Import Python modules:

In [None]:
import pickle
import sys

import altair as alt

import matplotlib.pyplot as plt

import neutcurve

import numpy

import pandas as pd

import ruamel.yaml as yaml

_ = alt.data_transformers.disable_max_rows()

Get the variables passed by `snakemake`:

In [None]:
count_csvs = snakemake.input.count_csvs
fate_csvs = snakemake.input.fate_csvs
viral_library_csv = snakemake.input.viral_library_csv
neut_standard_set_csv = snakemake.input.neut_standard_set_csv
qc_drops_yaml = snakemake.output.qc_drops
frac_infectivity_csv = snakemake.output.frac_infectivity_csv
fits_csv = snakemake.output.fits_csv
fits_pickle = snakemake.output.fits_pickle
samples = snakemake.params.samples
plate = snakemake.wildcards.plate
plate_params = snakemake.params.plate_params

# get thresholds turning lists into tuples as needed
manual_drops = {
    filter_type: [tuple(w) if isinstance(w, list) else w for w in filter_drops]
    for (filter_type, filter_drops) in plate_params["manual_drops"].items()
}
qc_thresholds = plate_params["qc_thresholds"]
curvefit_params = plate_params["curvefit_params"]
curvefit_qc = plate_params["curvefit_qc"]
curvefit_qc["barcode_serum_replicates_ignore_curvefit_qc"] = [
    tuple(w) for w in curvefit_qc["barcode_serum_replicates_ignore_curvefit_qc"]
]

print(f"Processing {plate=}")

samples_df = pd.DataFrame(plate_params["samples"])
print(f"\nPlate has {len(samples)} samples (wells)")
assert all(
    (len(samples_df) == samples_df[c].nunique())
    for c in ["well", "sample", "sample_noplate"]
)
assert len(samples_df) == len(
    samples_df.groupby(["serum_replicate", "dilution_factor"])
)
assert len(samples) == len(count_csvs) == len(fate_csvs) == len(samples_df)

for d, key, title in [
    (manual_drops, "manual_drops", "Data manually specified to drop:"),
    (qc_thresholds, "qc_thresholds", "QC thresholds applied to data:"),
    (curvefit_params, "curvefit_params", "Curve-fitting parameters:"),
    (curvefit_qc, "curvefit_qc", "Curve-fitting QC:"),
]:
    print(f"\n{title}")
    yaml.YAML(typ="rt").dump({key: d}, stream=sys.stdout)

Set up dictionary to keep track of wells, barcodes, well-barcodes, and serum-replicates that are dropped:

In [None]:
qc_drops = {
    "wells": {},
    "barcodes": {},
    "barcode_wells": {},
    "barcode_serum_replicates": {},
    "serum_replicates": {},
}

assert set(manual_drops).issubset(
    qc_drops
), f"{manual_drops.keys()=}, {qc_drops.keys()}"

## Statistics on barcode-parsing for each sample
Make interactive chart of the "fates" of the sequencing reads parsed for each sample on the plate.

If most sequencing reads are not "valid barcodes", this could potentially indicate some problem in the sequencing or barcode set you are parsing.

Potential fates are:
 - *valid barcode*: barcode that matches a known virus or neutralization standard, we hope most reads are this.
 - *invalid barcode*: a barcode with proper flanking sequences, but does not match a known virus or neutralization standard. If you  have a lot of reads of this type, it is probably a good idea to look at the invalid barcode CSVs (in the `./results/barcode_invalid/` subdirectory created by the pipeline) to see what these invalid barcodes are.
 - *unparseable barcode*: could not parse a barcode from this read as there was not a sequence of the correct length with the appropriate flanking sequence.
 - *low quality barcode*: low-quality or `N` nucleotides in barcode, could indicate problem with sequencing.
 - *failed chastity filter*: reads that failed the Illumina chastity filter, if these are reported in the FASTQ (they may not be).

Also, if the number of reads per sample is very uneven, that could indicate that you did not do a good job of balancing the different samples in the Illumina sequencing.

In [None]:
fates = (
    pd.concat([pd.read_csv(f).assign(sample=s) for f, s in zip(fate_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="sample")
    .assign(
        fate_counts=lambda x: x.groupby("fate")["count"].transform("sum"),
        sample_well=lambda x: x["sample_noplate"] + " (" + x["well"] + ")",
    )
    .query("fate_counts > 0")[  # only keep fates with at least one count
        ["fate", "count", "well", "serum_replicate", "sample_well", "dilution_factor"]
    ]
)

assert len(fates) == len(fates.drop_duplicates())

serum_replicates = sorted(fates["serum_replicate"].unique())
sample_wells = list(
    fates.sort_values(["serum_replicate", "dilution_factor"])["sample_well"]
)


serum_selection = alt.selection_point(
    fields=["serum_replicate"],
    bind=alt.binding_select(
        options=[None] + serum_replicates,
        labels=["all"] + serum_replicates,
        name="serum",
    ),
)

fates_chart = (
    alt.Chart(fates)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X("count", scale=alt.Scale(nice=False, padding=3)),
        alt.Y(
            "sample_well",
            title=None,
            sort=sample_wells,
        ),
        alt.Color("fate", sort=sorted(fates["fate"].unique(), reverse=True)),
        alt.Order("fate", sort="descending"),
        tooltip=fates.columns.tolist(),
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=200,
        title=f"Barcode parsing for {plate}",
    )
    .configure_axis(grid=False)
)

fates_chart

## Read barcode counts and apply manually specified drops
Read the counts per barcode:

In [None]:
# get barcode counts
counts = (
    pd.concat([pd.read_csv(c).assign(sample=s) for c, s in zip(count_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="sample")
    .drop(columns=["replicate", "plate", "fastq"])
    .assign(sample_well=lambda x: x["sample_noplate"] + " (" + x["well"] + ")")
)

# classify barcodes as viral or neut standard
barcode_class = pd.concat(
    [
        pd.read_csv(viral_library_csv)[["barcode", "strain"]].assign(
            neut_standard=False,
        ),
        pd.read_csv(neut_standard_set_csv)[["barcode"]].assign(
            neut_standard=True,
            strain=pd.NA,
        ),
    ],
    ignore_index=True,
)

# merge counts and classification of barcodes
assert set(counts["barcode"]) == set(barcode_class["barcode"])
counts = counts.merge(barcode_class, on="barcode", validate="many_to_one")
assert set(sample_wells) == set(counts["sample_well"])
assert set(serum_replicates) == set(counts["serum_replicate"])

Apply any manually specified data drops:

In [None]:
for filter_type, filter_drops in manual_drops.items():
    print(f"\nDropping {len(filter_drops)} {filter_type} specified in manual_drops")
    assert filter_type in qc_drops
    qc_drops[filter_type].update(
        {w: "manual_drop" for w in filter_drops if not isinstance(w, list)}
    )
    if filter_type == "barcode_wells":
        counts = counts[
            ~counts.assign(
                barcode_well=lambda x: x.apply(
                    lambda r: (r["barcode"], r["well"]), axis=1
                )
            )["barcode_well"].isin(qc_drops[filter_type])
        ]
    elif filter_type == "barcode_serum_replicates":
        counts = counts[
            ~counts.assign(
                barcode_serum_replicate=lambda x: x.apply(
                    lambda r: (r["barcode"], r["serum_replicate"]), axis=1
                )
            )["barcode_serum_replicate"].isin(qc_drops[filter_type])
        ]
    else:
        assert filter_type in set(counts.columns)
        counts = counts[~counts[filter_type].isin(qc_drops[filter_type])]

## Average counts per barcode in each well

Plot average counts per barcode.
If a sample has inadequate barcode counts, it may not have good enough statistics for accurate analysis, and a QC-threshold is applied:

In [None]:
avg_barcode_counts = (
    counts.groupby(
        ["well", "serum_replicate", "sample_well"],
        dropna=False,
        as_index=False,
    )
    .aggregate(avg_count=pd.NamedAgg("count", "mean"))
    .assign(
        fails_qc=lambda x: (
            x["avg_count"] < qc_thresholds["avg_barcode_counts_per_well"]
        ),
    )
)

avg_barcode_counts_chart = (
    alt.Chart(avg_barcode_counts)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "avg_count",
            title="average barcode counts per well",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Color(
            "fails_qc",
            title=f"fails {qc_thresholds['avg_barcode_counts_per_well']=}",
            legend=alt.Legend(titleLimit=500),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if avg_barcode_counts[c].dtype == float else c
            for c in avg_barcode_counts.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Average barcode counts per well for {plate}",
    )
    .configure_axis(grid=False)
)

display(avg_barcode_counts_chart)

# drop wells failing QC
avg_barcode_counts_per_well_drops = list(avg_barcode_counts.query("fails_qc")["well"])
print(
    f"\nDropping {len(avg_barcode_counts_per_well_drops)} wells for failing "
    f"{qc_thresholds['avg_barcode_counts_per_well']=}: "
    + str(avg_barcode_counts_per_well_drops)
)
qc_drops["wells"].update(
    {w: "avg_barcode_counts_per_well" for w in avg_barcode_counts_per_well_drops}
)
counts = counts[~counts["well"].isin(qc_drops["wells"])]

## Fraction of counts from neutralization standard
Determine the fraction of counts from the neutralization standard in each sample, and make sure this fraction passess the QC threshold.

In [None]:
neut_standard_fracs = (
    counts.assign(
        neut_standard_count=lambda x: x["count"] * x["neut_standard"].astype(int)
    )
    .groupby(
        ["well", "serum_replicate", "sample_well"],
        dropna=False,
        as_index=False,
    )
    .aggregate(
        total_count=pd.NamedAgg("count", "sum"),
        neut_standard_count=pd.NamedAgg("neut_standard_count", "sum"),
    )
    .assign(
        neut_standard_frac=lambda x: x["neut_standard_count"] / x["total_count"],
        fails_qc=lambda x: (
            x["neut_standard_frac"] < qc_thresholds["min_neut_standard_frac_per_well"]
        ),
    )
)

neut_standard_fracs_chart = (
    alt.Chart(neut_standard_fracs)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "neut_standard_frac",
            title="frac counts from neutralization standard per well",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Color(
            "fails_qc",
            title=f"fails {qc_thresholds['min_neut_standard_frac_per_well']=}",
            legend=alt.Legend(titleLimit=500),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if neut_standard_fracs[c].dtype == float else c
            for c in neut_standard_fracs.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Neutralization-standard fracs per well for {plate}",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(neut_standard_fracs_chart)

# drop wells failing QC
min_neut_standard_frac_per_well_drops = list(
    neut_standard_fracs.query("fails_qc")["well"]
)
print(
    f"\nDropping {len(min_neut_standard_frac_per_well_drops)} wells for failing "
    f"{qc_thresholds['min_neut_standard_frac_per_well']=}: "
    + str(min_neut_standard_frac_per_well_drops)
)
qc_drops["wells"].update(
    {
        w: "min_neut_standard_frac_per_well"
        for w in min_neut_standard_frac_per_well_drops
    }
)
counts = counts[~counts["well"].isin(qc_drops["wells"])]

## Consistency and minimum fractions for barcodes
We examine the fraction of counts attributable to each barcode. We do this splitting the data two ways:

 1. Looking at all viral (but not neut-standard) barcodes only for the no-serum samples (wells).

 2. Looking at just the neut-standard barcodes for all samples (wells).

The reasons is that if the experiment is set up perfectly, these fractions should be the same across all samples for each barcode.
(We do not expect viral barcodes to have consistent fractions across no-serum samples as they will be neutralized differently depending on strain).

We plot these fractions in interactive plots (you can mouseover points and zoom) so you can identify barcodes that fail the expected consistency QC thresholds.

We also make sure the barcodes meet specified QC minimum thresholds for all samples, and flag any that do not.

In [None]:
barcode_selection = alt.selection_point(fields=["barcode"], on="mouseover", empty=False)

# look at all samples for neut standard barcodes, or no-serum samples for all barcodes
for is_neut_standard, df in counts.groupby("neut_standard"):
    if is_neut_standard:
        print(
            f"\n\n{'=' * 89}\nAnalyzing neut-standard barcodes from all samples (wells)"
        )
        qc_name = "per_neut_standard_barcode_filters"
    else:
        print(f"\n\n{'=' * 89}\nAnalyzing all barcodes from no-serum samples (wells)")
        qc_name = "no_serum_per_viral_barcode_filters"
        df = df.query("serum == 'none'")

    df = df.assign(
        sample_counts=lambda x: x.groupby("sample")["count"].transform("sum"),
        count_frac=lambda x: x["count"] / x["sample_counts"],
        median_count_frac=lambda x: x.groupby("barcode")["count_frac"].transform(
            "median"
        ),
        fold_change_from_median=lambda x: numpy.where(
            x["count_frac"] > x["median_count_frac"],
            x["count_frac"] / x["median_count_frac"],
            x["median_count_frac"] / x["count_frac"],
        ),
    )[
        [
            "barcode",
            "count",
            "well",
            "sample_well",
            "count_frac",
            "median_count_frac",
            "fold_change_from_median",
        ]
        + ([] if is_neut_standard else ["strain"])
    ]

    # barcode fails QC if fails in sufficient wells
    qc = qc_thresholds[qc_name]
    print(f"Apply QC {qc_name}: {qc}\n")
    fails_qc = (
        df.assign(
            fails_qc=lambda x: ~(
                (x["count_frac"] >= qc["min_frac"])
                & (x["fold_change_from_median"] <= qc["max_fold_change"])
            ),
        )
        .groupby("barcode", as_index=False)
        .aggregate(n_wells_fail_qc=pd.NamedAgg("fails_qc", "sum"))
        .assign(fails_qc=lambda x: x["n_wells_fail_qc"] >= qc["max_wells"])[
            ["barcode", "fails_qc"]
        ]
    )
    df = df.merge(fails_qc, on="barcode", validate="many_to_one")

    # make chart
    evenness_chart = (
        alt.Chart(df)
        .add_params(barcode_selection)
        .encode(
            alt.X(
                "count_frac",
                title=(
                    "barcode's fraction of neut standard counts"
                    if is_neut_standard
                    else "barcode's fraction of non-neut standard counts"
                ),
                scale=alt.Scale(nice=False, padding=5),
            ),
            alt.Y("sample_well", sort=sample_wells),
            alt.Fill(
                "fails_qc",
                title=f"fails {qc_name}",
                legend=alt.Legend(titleLimit=500),
            ),
            strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
            size=alt.condition(barcode_selection, alt.value(60), alt.value(35)),
            tooltip=[
                alt.Tooltip(c, format=".2g") if df[c].dtype == float else c
                for c in df.columns
            ],
        )
        .mark_circle(fillOpacity=0.45, stroke="black", strokeOpacity=1)
        .properties(
            height=alt.Step(10),
            width=300,
            title=alt.TitleParams(
                (
                    f"{plate} all samples, neut-standard barcodes"
                    if is_neut_standard
                    else f"{plate} no-serum samples, all barcodes"
                ),
                subtitle="x-axis is zoomable (use mouse scroll/pan)",
            ),
        )
        .configure_axis(grid=False)
        .configure_legend(titleLimit=1000)
        .interactive()
    )

    display(evenness_chart)

    # drop barcodes failing QC
    barcode_drops = list(fails_qc.query("fails_qc")["barcode"])
    print(
        f"\nDropping {len(barcode_drops)} barcodes for failing {qc=}: {barcode_drops}"
    )
    qc_drops["barcodes"].update(
        {bc: "min_neut_standard_frac_per_well" for bc in barcode_drops}
    )
    counts = counts[~counts["barcode"].isin(qc_drops["barcodes"])]

## Compute fraction infectivity

The fraction infectivity for viral barcode $v_b$ in sample $s$ is computed as:
$$
F_{v_b,s} = \frac{c_{v_b,s} / \left(\sum_{n_b} c_{n_b,s}\right)}{{\rm median}_{s_0}\left[ c_{v_b,s_0} / \left(\sum_{n_b} c_{n_b,s_0}\right)\right]}
$$
where
 - $c_{v_b,s}$ is the counts of viral barcode $v_b$ in sample $s$.
 - $\sum_{n_b} c_{n_b,s}$ is the sum of the counts for all neutralization standard barcodes $n_b$ for sample $s$.
 - $c_{v_b,s_0}$ is the counts of viral barcode $v_b$ in no-serum sample $s_0$.
 - $\sum_{n_b} c_{n_b,s_0}$ is the sum of the counts for all neutralization standard barcodes $n_b$ for no-serum sample $s_0$.
 - ${\rm median}_{s_0}\left[ c_{v_b,s_0} / \left(\sum_{n_b} c_{n_b,s_0}\right)\right]$ is the median taken across all no-serum samples of the counts of viral barcode $v_b$ versus the total counts for all neutralization standard barcodes.

First, compute the total neutralization-standard counts for each sample (well).
Plot these, and drop any wells that do not meet the QC threshold.

In [None]:
neut_standard_counts = (
    counts.query("neut_standard")
    .groupby(
        ["well", "serum_replicate", "sample_well", "dilution_factor"],
        dropna=False,
        as_index=False,
    )
    .aggregate(neut_standard_count=pd.NamedAgg("count", "sum"))
    .assign(
        fails_qc=lambda x: (
            x["neut_standard_count"] < qc_thresholds["min_neut_standard_count_per_well"]
        ),
    )
)

neut_standard_counts_chart = (
    alt.Chart(neut_standard_counts)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "neut_standard_count",
            title="counts from neutralization standard",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Color(
            "fails_qc",
            title=f"fails {qc_thresholds['min_neut_standard_count_per_well']=}",
            legend=alt.Legend(titleLimit=500),
        ),
        tooltip=[
            (
                alt.Tooltip(c, format=".3g")
                if neut_standard_counts[c].dtype == float
                else c
            )
            for c in neut_standard_counts.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Neutralization-standard counts for {plate}",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(neut_standard_counts_chart)

# drop wells failing QC
min_neut_standard_count_per_well_drops = list(
    neut_standard_counts.query("fails_qc")["well"]
)
print(
    f"\nDropping {len(min_neut_standard_count_per_well_drops)} wells for failing "
    f"{qc_thresholds['min_neut_standard_count_per_well']=}: "
    + str(min_neut_standard_count_per_well_drops)
)
qc_drops["wells"].update(
    {
        w: "min_neut_standard_count_per_well"
        for w in min_neut_standard_count_per_well_drops
    }
)
neut_standard_counts = neut_standard_counts[
    ~neut_standard_counts["well"].isin(qc_drops["wells"])
]
counts = counts[~counts["well"].isin(qc_drops["wells"])]

Compute and plot the no-serum sample viral barcode counts and check if they pass the QC filters.

In [None]:
no_serum_counts = (
    counts.query("serum == 'none'")
    .query("not neut_standard")
    .merge(neut_standard_counts, validate="many_to_one")[
        ["barcode", "strain", "well", "sample_well", "count", "neut_standard_count"]
    ]
    .assign(
        fails_qc=lambda x: (
            x["count"] <= qc_thresholds["min_no_serum_count_per_viral_barcode_well"]
        ),
    )
)

strains = sorted(no_serum_counts["strain"].unique())
strain_selection_dropdown = alt.selection_point(
    fields=["strain"],
    bind=alt.binding_select(
        options=[None] + strains,
        labels=["all"] + strains,
        name="virus strain",
    ),
)

# make chart
no_serum_counts_chart = (
    alt.Chart(no_serum_counts)
    .add_params(barcode_selection, strain_selection_dropdown)
    .transform_filter(strain_selection_dropdown)
    .encode(
        alt.X(
            "count", title="viral barcode count", scale=alt.Scale(nice=False, padding=5)
        ),
        alt.Y("sample_well", sort=sample_wells),
        alt.Fill(
            "fails_qc",
            title=f"fails {qc_thresholds['min_no_serum_count_per_viral_barcode_well']=}",
            legend=alt.Legend(titleLimit=500),
        ),
        strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
        size=alt.condition(barcode_selection, alt.value(60), alt.value(35)),
        tooltip=no_serum_counts.columns.tolist(),
    )
    .mark_circle(fillOpacity=0.6, stroke="black", strokeOpacity=1)
    .properties(
        height=alt.Step(10),
        width=400,
        title=f"{plate} viral barcode counts in no-serum samples",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
    .interactive()
)

display(no_serum_counts_chart)

# drop barcode / wells failing QC
min_no_serum_count_per_viral_barcode_well_drops = list(
    no_serum_counts.query("fails_qc")[["barcode", "well"]].itertuples(
        index=False, name=None
    )
)
print(
    f"\nDropping {len(min_no_serum_count_per_viral_barcode_well_drops)} barcode-wells for failing "
    f"{qc_thresholds['min_no_serum_count_per_viral_barcode_well']=}: "
    + str(min_no_serum_count_per_viral_barcode_well_drops)
)
qc_drops["barcode_wells"].update(
    {
        w: "min_no_serum_count_per_viral_barcode_well"
        for w in min_no_serum_count_per_viral_barcode_well_drops
    }
)
no_serum_counts = no_serum_counts[
    ~no_serum_counts.assign(
        barcode_well=lambda x: x.apply(lambda r: (r["barcode"], r["well"]), axis=1)
    )["barcode_well"].isin(qc_drops["barcode_wells"])
]
counts = counts[
    ~counts.assign(
        barcode_well=lambda x: x.apply(lambda r: (r["barcode"], r["well"]), axis=1)
    )["barcode_well"].isin(qc_drops["barcode_wells"])
]

Compute and plot the median ratio of viral barcode count to neut standard counts across no-serum samples.
If library composition is equal, all of these values should be similar:

In [None]:
median_no_serum_ratio = (
    no_serum_counts.assign(ratio=lambda x: x["count"] / x["neut_standard_count"])
    .groupby(["barcode", "strain"], as_index=False)
    .aggregate(median_no_serum_ratio=pd.NamedAgg("ratio", "median"))
)

strain_selection = alt.selection_point(fields=["strain"], on="mouseover", empty=False)

median_no_serum_ratio_chart = (
    alt.Chart(median_no_serum_ratio)
    .add_params(strain_selection)
    .encode(
        alt.X(
            "median_no_serum_ratio",
            title="median ratio of counts",
            scale=alt.Scale(nice=False, padding=5),
        ),
        alt.Y(
            "barcode",
            sort=alt.SortField("median_no_serum_ratio", order="descending"),
            axis=alt.Axis(labelFontSize=5),
        ),
        color=alt.condition(strain_selection, alt.value("orange"), alt.value("gray")),
        tooltip=[
            (
                alt.Tooltip(c, format=".3g")
                if median_no_serum_ratio[c].dtype == float
                else c
            )
            for c in median_no_serum_ratio.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(5),
        width=250,
        title=f"{plate} no-serum median ratio viral barcode to neut-standard barcode",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(median_no_serum_ratio_chart)

Compute the actual fraction infectivities.
We compute both the raw fraction infectivities and the ones with the ceiling applied:

In [None]:
frac_infectivity = (
    counts.query("not neut_standard")
    .query("serum != 'none'")
    .merge(median_no_serum_ratio, validate="many_to_one")
    .merge(neut_standard_counts, validate="many_to_one")
    .assign(
        frac_infectivity_raw=lambda x: (
            (x["count"] / x["neut_standard_count"]) / x["median_no_serum_ratio"]
        ),
        frac_infectivity_ceiling=lambda x: x["frac_infectivity_raw"].clip(
            upper=curvefit_params["frac_infectivity_ceiling"]
        ),
        concentration=lambda x: 1 / x["dilution_factor"],
        plate_barcode=lambda x: x["plate_replicate"] + "-" + x["barcode"],
    )[
        [
            "barcode",
            "plate_barcode",
            "well",
            "strain",
            "serum",
            "serum_replicate",
            "dilution_factor",
            "concentration",
            "frac_infectivity_raw",
            "frac_infectivity_ceiling",
        ]
    ]
)

assert len(
    frac_infectivity.groupby(["serum", "plate_barcode", "dilution_factor"])
) == len(frac_infectivity)
assert frac_infectivity["dilution_factor"].notnull().all()
assert frac_infectivity["frac_infectivity_raw"].notnull().all()
assert frac_infectivity["frac_infectivity_ceiling"].notnull().all()

Plot the fraction infectivities, both the raw values and with the ceiling applied:

In [None]:
frac_infectivity_chart_df = (
    frac_infectivity.assign(
        fails_qc=lambda x: (
            x["frac_infectivity_raw"]
            > qc_thresholds["max_frac_infectivity_per_viral_barcode_well"]
        ),
    )
    .melt(
        id_vars=[
            "barcode",
            "strain",
            "well",
            "serum_replicate",
            "dilution_factor",
            "fails_qc",
        ],
        value_vars=["frac_infectivity_raw", "frac_infectivity_ceiling"],
        var_name="ceiling_applied",
        value_name="frac_infectivity",
    )
    .assign(
        ceiling_applied=lambda x: x["ceiling_applied"].map(
            {
                "frac_infectivity_raw": "raw fraction infectivity",
                "frac_infectivity_ceiling": f"fraction infectivity with ceiling at {curvefit_params['frac_infectivity_ceiling']}",
            }
        )
    )
)

frac_infectivity_chart = (
    alt.Chart(frac_infectivity_chart_df)
    .add_params(strain_selection_dropdown, barcode_selection)
    .transform_filter(strain_selection_dropdown)
    .encode(
        alt.X(
            "dilution_factor",
            title="dilution factor",
            scale=alt.Scale(nice=False, padding=5, type="log"),
        ),
        alt.Y(
            "frac_infectivity",
            title="fraction infectivity",
            scale=alt.Scale(nice=False, padding=5),
        ),
        alt.Column(
            "ceiling_applied",
            sort="descending",
            title=None,
            header=alt.Header(labelFontSize=13, labelFontStyle="bold", labelPadding=2),
        ),
        alt.Row(
            "serum_replicate",
            title=None,
            spacing=3,
            header=alt.Header(labelFontSize=13, labelFontStyle="bold"),
        ),
        alt.Detail("barcode"),
        alt.Shape(
            "fails_qc",
            title=f"fails {qc_thresholds['max_frac_infectivity_per_viral_barcode_well']=}",
            legend=alt.Legend(titleLimit=500, orient="bottom"),
        ),
        color=alt.condition(
            barcode_selection, alt.value("black"), alt.value("MediumBlue")
        ),
        strokeWidth=alt.condition(barcode_selection, alt.value(3), alt.value(1)),
        opacity=alt.condition(barcode_selection, alt.value(1), alt.value(0.25)),
        tooltip=[
            (
                alt.Tooltip(c, format=".3g")
                if frac_infectivity_chart_df[c].dtype == float
                else c
            )
            for c in frac_infectivity_chart_df.columns
        ],
    )
    .mark_line(point=True)
    .properties(
        height=150,
        width=250,
        title=f"Fraction infectivities for {plate}",
    )
    .interactive(bind_x=False)
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
    .configure_point(size=50)
    .resolve_scale(x="independent", y="independent")
)

display(frac_infectivity_chart)

# drop barcode / wells failing QC
max_frac_infectivity_per_viral_barcode_well_drops = list(
    frac_infectivity_chart_df.query("fails_qc")[["barcode", "well"]]
    .drop_duplicates()
    .itertuples(index=False, name=None)
)
print(
    f"\nDropping {len(max_frac_infectivity_per_viral_barcode_well_drops)} barcode-wells for failing "
    f"{qc_thresholds['max_frac_infectivity_per_viral_barcode_well']=}: "
    + str(max_frac_infectivity_per_viral_barcode_well_drops)
)
qc_drops["barcode_wells"].update(
    {
        w: "max_frac_infectivity_per_viral_barcode_well"
        for w in max_frac_infectivity_per_viral_barcode_well_drops
    }
)
frac_infectivity = frac_infectivity[
    ~frac_infectivity.assign(
        barcode_well=lambda x: x.apply(lambda r: (r["barcode"], r["well"]), axis=1)
    )["barcode_well"].isin(qc_drops["barcode_wells"])
]

Check how many dilutions we have per barcode / serum-replicate:

In [None]:
n_dilutions = (
    frac_infectivity.groupby(["serum_replicate", "strain", "barcode"], as_index=False)
    .aggregate(**{"number of dilutions": pd.NamedAgg("dilution_factor", "nunique")})
    .assign(
        fails_qc=lambda x: (
            x["number of dilutions"]
            < qc_thresholds["min_dilutions_per_barcode_serum_replicate"]
        ),
    )
)

n_dilutions_chart = (
    alt.Chart(n_dilutions)
    .add_params(barcode_selection)
    .encode(
        alt.X("number of dilutions", scale=alt.Scale(nice=False, padding=4)),
        alt.Y("strain", title=None),
        alt.Column(
            "serum_replicate",
            title=None,
            header=alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=0),
        ),
        alt.Fill(
            "fails_qc",
            title=f"fails {qc_thresholds['min_dilutions_per_barcode_serum_replicate']=}",
            legend=alt.Legend(titleLimit=500, orient="bottom"),
        ),
        strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
        size=alt.condition(barcode_selection, alt.value(55), alt.value(35)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if n_dilutions[c].dtype == float else c
            for c in n_dilutions.columns
        ],
    )
    .mark_circle(stroke="black", strokeOpacity=1, fillOpacity=0.45)
    .properties(
        height=alt.Step(10),
        width=120,
        title=alt.TitleParams(
            "number of dilutions for each barcode for each serum-replicate", dy=-2
        ),
    )
)

display(n_dilutions_chart)

# drop barcode / serum-replicates failing QC
min_dilutions_per_barcode_serum_replicate_drops = list(
    n_dilutions.query("fails_qc")[["barcode", "serum_replicate"]].itertuples(
        index=False, name=None
    )
)
print(
    f"\nDropping {len(min_dilutions_per_barcode_serum_replicate_drops)} barcode/serum-replicates for failing "
    f"{qc_thresholds['min_dilutions_per_barcode_serum_replicate']=}: "
    + str(min_dilutions_per_barcode_serum_replicate_drops)
)
qc_drops["barcode_serum_replicates"].update(
    {
        w: "min_dilutions_per_barcode_serum_replicate"
        for w in min_dilutions_per_barcode_serum_replicate_drops
    }
)
frac_infectivity = frac_infectivity[
    ~frac_infectivity.assign(
        barcode_serum_replicate=lambda x: x.apply(
            lambda r: (r["barcode"], r["serum_replicate"]), axis=1
        )
    )["barcode_serum_replicate"].isin(qc_drops["barcode_serum_replicates"])
]

## Fit neutralization curves without applying QC to curves
First fit curves to all serum replicates, then we will apply QC on the curve fits.
Note that the fitting is done to the fraction infectivities **with** the ceiling:

In [None]:
fits_noqc = neutcurve.CurveFits(
    frac_infectivity.rename(
        columns={
            "frac_infectivity_ceiling": "fraction infectivity",
            "concentration": "serum concentration",
        }
    ),
    conc_col="serum concentration",
    fracinf_col="fraction infectivity",
    virus_col="strain",
    serum_col="serum_replicate",
    replicate_col="barcode",
    fixtop=curvefit_params["fixtop"],
    fixbottom=curvefit_params["fixbottom"],
)

Determine which fits fail the curve fitting QC, and plot them.
Note the plot indicates as failing QC any barcode / serum-replicate that fails, even if we are also specified to ignore the QC for that one (so it will not be removed later):

In [None]:
fit_params_noqc = (
    frac_infectivity.groupby(["serum_replicate", "barcode"], as_index=False)
    .aggregate(max_frac_infectivity=pd.NamedAgg("frac_infectivity_ceiling", "max"))
    .merge(
        fits_noqc.fitParams(average_only=False, no_average=True)[
            ["serum", "virus", "replicate", "r2"]
        ].rename(columns={"serum": "serum_replicate", "replicate": "barcode"}),
        validate="one_to_one",
    )
    .assign(
        fails_max_frac_infectivity_at_least=lambda x: (
            x["max_frac_infectivity"] < curvefit_qc["max_frac_infectivity_at_least"]
        ),
        fails_min_R2=lambda x: x["r2"] < curvefit_qc["min_R2"],
        fails_qc=lambda x: x["fails_max_frac_infectivity_at_least"] | x["fails_min_R2"],
        ignore_qc=lambda x: x.apply(
            lambda r: (
                (
                    r["serum_replicate"]
                    in curvefit_qc["serum_replicates_ignore_curvefit_qc"]
                )
                or (
                    (r["barcode"], r["serum_replicate"])
                    in curvefit_qc["barcode_serum_replicates_ignore_curvefit_qc"]
                )
            ),
            axis=1,
        ),
    )
)

print(f"Plotting barcode / serum-replicates that fail {curvefit_qc=}\n")

for prop, col in [
    ("max frac infectivity", "max_frac_infectivity"),
    ("curve fit R2", "r2"),
]:
    fit_params_noqc_chart = (
        alt.Chart(fit_params_noqc)
        .add_params(barcode_selection)
        .encode(
            alt.X(col, title=prop, scale=alt.Scale(nice=False, padding=4)),
            alt.Y("virus", title=None),
            alt.Fill("fails_qc"),
            alt.Column(
                "serum_replicate",
                title=None,
                header=alt.Header(
                    labelFontSize=12, labelFontStyle="bold", labelPadding=0
                ),
            ),
            strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
            size=alt.condition(barcode_selection, alt.value(55), alt.value(35)),
            tooltip=[
                alt.Tooltip(c, format=".3g") if fit_params_noqc[c].dtype == float else c
                for c in fit_params_noqc.columns
            ],
        )
        .mark_circle(stroke="black", strokeOpacity=1, fillOpacity=0.55)
        .properties(
            height=alt.Step(10),
            width=120,
            title=alt.TitleParams(f"{prop} for each barcode serum-replicate", dy=-2),
        )
    )
    display(fit_params_noqc_chart)

Now get all barcode / serum-replicate pairs that fail any of the QC.
Plot curves for just these virus / serum-replicates (we plot all barcodes for a virus even if just one fails QC), and then exclude any that are not specified to ignore the QC:

In [None]:
barcode_serum_replicates_fail_qc = fit_params_noqc.query("fails_qc").reset_index(
    drop=True
)
print(f"Here are barcode / serum-replicates that fail {curvefit_qc=}")
display(barcode_serum_replicates_fail_qc)

if len(barcode_serum_replicates_fail_qc):
    print("\nCurves for viruses and serum-replicates with at least one failed barcode:")
    fig, _ = fits_noqc.plotReplicates(
        sera=sorted(barcode_serum_replicates_fail_qc["serum_replicate"].unique()),
        viruses=sorted(barcode_serum_replicates_fail_qc["virus"].unique()),
        attempt_shared_legend=False,
        legendfontsize=8,
        titlesize=10,
        ticksize=10,
        ncol=6,
    )
    display(fig)
    plt.close(fig)

# drop barcode / serum-replicates failing QC
for qc_filter in ["max_frac_infectivity_at_least", "min_R2"]:
    fits_qc_drops = list(
        fit_params_noqc.query(f"fails_{qc_filter} and (not ignore_qc)")[
            ["barcode", "serum_replicate"]
        ].itertuples(index=False, name=None)
    )
    print(
        f"\nDropping {len(fits_qc_drops)} barcode/serum-replicates for failing "
        f"{qc_filter}={curvefit_qc[qc_filter]}: " + str(fits_qc_drops)
    )
    qc_drops["barcode_serum_replicates"].update({w: qc_filter for w in fits_qc_drops})
    frac_infectivity = frac_infectivity[
        ~frac_infectivity.assign(
            barcode_serum_replicate=lambda x: x.apply(
                lambda r: (r["barcode"], r["serum_replicate"]), axis=1
            )
        )["barcode_serum_replicate"].isin(qc_drops["barcode_serum_replicates"])
    ]
    fit_params_noqc = fit_params_noqc[
        ~fit_params_noqc.assign(
            barcode_serum_replicate=lambda x: x.apply(
                lambda r: (r["barcode"], r["serum_replicate"]), axis=1
            )
        )["barcode_serum_replicate"].isin(qc_drops["barcode_serum_replicates"])
    ]

## Fit neutralization curves after applying QC
No we re-fit curves after applying all the QC:

In [None]:
fits_qc = neutcurve.CurveFits(
    frac_infectivity.rename(
        columns={
            "frac_infectivity_ceiling": "fraction infectivity",
            "concentration": "serum concentration",
        }
    ),
    conc_col="serum concentration",
    fracinf_col="fraction infectivity",
    virus_col="strain",
    serum_col="serum",
    replicate_col="plate_barcode",
    fixtop=curvefit_params["fixtop"],
    fixbottom=curvefit_params["fixbottom"],
)

fit_params_qc = fits_qc.fitParams(average_only=False, no_average=True)
assert len(fit_params_qc) <= len(
    fits_noqc.fitParams(average_only=False, no_average=True)
)

Plot all the curves that passed QC:

In [None]:
_ = fits_qc.plotReplicates(
    attempt_shared_legend=False,
    legendfontsize=8,
    titlesize=10,
    ticksize=10,
    ncol=6,
)

## Save results to files

In [None]:
print(f"Writing fraction infectivities to {frac_infectivity_csv}")
(
    frac_infectivity[
        [
            "serum",
            "strain",
            "plate_barcode",
            "dilution_factor",
            "frac_infectivity_raw",
            "frac_infectivity_ceiling",
        ]
    ]
    .sort_values(["serum", "plate_barcode", "dilution_factor"])
    .to_csv(frac_infectivity_csv, index=False, float_format="%.4g")
)

print(f"\nWriting fit parameters to {fits_csv}")
(
    fit_params_qc.drop(columns=["nreplicates", "ic50_str"]).to_csv(
        fits_csv, index=False, float_format="%.4g"
    )
)

print(f"\nPickling neutcurve.CurveFits object for these data to {fits_pickle}")
with open(fits_pickle, "wb") as f:
    pickle.dump(fits_qc, f)

print(f"\nWriting QC drops to {qc_drops_yaml}")


def tup_to_str(x):
    return " ".join(x) if isinstance(x, tuple) else x


qc_drops_for_yaml = {
    key: {tup_to_str(key2): val2 for key2, val2 in val.items()}
    for key, val in qc_drops.items()
}
with open(qc_drops_yaml, "w") as f:
    yaml.YAML(typ="rt").dump(qc_drops_for_yaml, f)
print("\nHere are the QC drops:\n***************************")
yaml.YAML(typ="rt").dump(qc_drops_for_yaml, sys.stdout)