# Quality-control counts and use them to compute fraction infectivity
This notebook is designed to be run using `snakemake`, and analyzes a plate of sequencing-based neutralization assays.

## Setup
Import Python modules:

In [None]:
import altair as alt

import neutcurve

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Get the variables passed by `snakemake`:

In [None]:
count_csvs = snakemake.input.count_csvs
fate_csvs = snakemake.input.fate_csvs
viral_library_csv = snakemake.input.viral_library_csv
neut_standard_set_csv = snakemake.input.neut_standard_set_csv
frac_infectivity_csv = snakemake.output.frac_infectivity_csv
qc_failures_file = snakemake.output.qc_failures
samples = snakemake.params.samples
plate_params = snakemake.params.plate_params
plate = snakemake.wildcards.plate

samples_df = plate_params["samples"]

assert len(samples) == len(count_csvs) == len(fate_csvs) == len(samples_df)

print(f"Processing {plate=}")

# define set of QC failures for this plate
qc_failures = set([])

# get and print QC thresholds
qc_thresholds = plate_params["process_counts_qc_thresholds"]
display(pd.Series(qc_thresholds))

# will drop the following samples from `wells_to_drop`
wells_to_drop = plate_params["wells_to_drop"]
if wells_to_drop:
    if not set(wells_to_drop).issubset(samples_df["well"]):
        raise ValueError(f"{wells_to_drop=} not all in `samples_df`")
    print("Dropping the following wells (samples):")
    display(samples_df.query("well in @wells_to_drop"))
    samples_df = samples_df.query("well not in @wells_to_drop")

## Statistics on barcode-parsing for each sample
Make interactive chart of the "fates" of the sequencing reads parsed for each sample on the plate.

If most sequencing reads are not "valid barcodes", this could potentially indicate some problem in the sequencing or barcode set you are parsing.

Potential fates are:
 - *valid barcode*: barcode that matches a known virus or neutralization standard, we hope most reads are this.
 - *invalid barcode*: a barcode with proper flanking sequences, but does not match a known virus or neutralization standard. If you  have a lot of reads of this type, it is probably a good idea to look at the invalid barcode CSVs (in the `./results/barcode_invalid/` subdirectory created by the pipeline) to see what these invalid barcodes are.
 - *unparseable barcode*: could not parse a barcode from this read as there was not a sequence of the correct length with the appropriate flanking sequence.
 - *low quality barcode*: low-quality or `N` nucleotides in barcode, could indicate problem with sequencing.
 - *failed chastity filter*: reads that failed the Illumina chastity filter, if these are reported in the FASTQ (they may not be).

Also, if the number of reads per sample is very uneven, that could indicate that you did not do a good job of balancing the different samples in the Illumina sequencing.

In [None]:
fates = (
    pd.concat([pd.read_csv(f).assign(sample=s) for f, s in zip(fate_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="sample")
    .assign(fate_counts=lambda x: x.groupby("fate")["count"].transform("sum"))
    .query("fate_counts > 0")  # only keep fates with at least one count
    [["fate", "count", "well", "serum", "sample_noplate", "dilution_factor"]]
)

assert len(fates) == len(fates.drop_duplicates())

serum_selection = alt.selection_point(
    fields=["serum"],
    bind=alt.binding_select(
        options=[None] + sorted(fates["serum"].unique().tolist()),
        labels=["all"] + [str(s) for s in sorted(fates["serum"].unique())],
        name="serum",
    )
)   

fates_chart = (
    alt.Chart(fates)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X("count", scale=alt.Scale(nice=False, padding=3)),
        alt.Y(
            "sample_noplate",
            title=None,
            sort=list(
                fates.sort_values(["serum", "dilution_factor"])["sample_noplate"]
            ),
        ),
        alt.Color("fate", sort=sorted(fates["fate"].unique(), reverse=True)),
        alt.Order("fate", sort="descending"),
        tooltip=fates.columns.tolist(),
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=200,
        title=f"Barcode parsing for {plate}",
    )
    .configure_axis(grid=False)
)

fates_chart

## Counts per barcode

First get the counts per barcode and classification information on these barcodes:

In [None]:
# get barcode counts
counts = (
    pd.concat([pd.read_csv(c).assign(sample=s) for c, s in zip(count_csvs, samples)])
    .merge(samples_df, validate="many_to_one", on="sample")
    .drop(columns=["replicate", "plate", "fastq"])
)

# get classification of barcodes as viral or neut standard
barcode_class = pd.concat(
    [
        pd.read_csv(viral_library_csv)[["barcode", "strain"]].assign(
            neut_standard=False,
        ),
        pd.read_csv(neut_standard_set_csv)[["barcode"]].assign(
            neut_standard=True, strain=pd.NA,
        ),
    ],
    ignore_index=True
)

# merge counts and classification of barcodes
assert set(counts["barcode"]) == set(barcode_class["barcode"])
counts = counts.merge(barcode_class, on="barcode", validate="many_to_one")

Drop any barcodes that are specified to drop:

In [None]:
barcodes_to_drop = plate_params["barcodes_to_drop"]

if len(barcodes_to_drop):
    print(
        "The following barcodes are specified to drop:\n\t"
        + "\n\t".join(barcodes_to_drop)
    )
    invalid_barcodes = set(barcodes_to_drop) - set(counts["barcode"])
    if invalid_barcodes:
        raise ValueError(f"Barcodes to drop do not exist: {invalid_barcodes}")
    counts = counts.query("barcode not in @barcodes_to_drop")

else:
    print(f"No barcodes specified to drop.")

Plot average counts per barcode, and make sure that these pass the QC threshold.
If a sample has inadequate barcode counts, it may not have good enough statistics for accurate analysis:

In [None]:
avg_barcode_counts = (
    counts
    .groupby(
        ["well", "serum", "dilution_factor", "sample_noplate"],
        dropna=False,
        as_index=False,
    )
    .aggregate(avg_count=pd.NamedAgg("count", "mean"))
    .assign(passes_qc=lambda x: x["avg_count"] >= qc_thresholds["avg_barcode_counts"])
)

avg_barcode_counts_chart = (
    alt.Chart(avg_barcode_counts)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "avg_count",
            title="average counts per barcode",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y(
            "sample_noplate",
            title=None,
            sort=list(
                avg_barcode_counts.sort_values(["serum", "dilution_factor"])["sample_noplate"]
            ),
        ),
        alt.Color(
            "passes_qc",
            title=f"passes QC threshold {qc_thresholds['avg_barcode_counts']}",
            scale=alt.Scale(domain=[True, False]),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if avg_barcode_counts[c].dtype == float
            else c
            for c in avg_barcode_counts.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Average barcode counts for {plate}",
    )
    .configure_axis(grid=False)
)

display(avg_barcode_counts_chart)

if not avg_barcode_counts["passes_qc"].all():
    qc_failures.add("avg_barcode_counts")
    print(f"\nThe following samples failed {qc_thresholds['avg_barcode_counts']=}")
    display(avg_barcode_counts.query("not passes_qc").reset_index(drop=True))
else:
    print(f"\nAll samples passed {qc_thresholds['avg_barcode_counts']=}")

## Fraction of counts from neutralization standard
Determine the fraction of counts from the neutralization standard in each sample, and make sure this fraction passess the QC threshold.

In [None]:
neut_standard_fracs = (
    counts
    .assign(neut_standard_count=lambda x: x["count"] * x["neut_standard"].astype(int))
    .groupby(
        ["well", "serum", "dilution_factor", "sample_noplate"],
        dropna=False,
        as_index=False,
    )
    .aggregate(
        total_count=pd.NamedAgg("count", "sum"),
        neut_standard_count=pd.NamedAgg("neut_standard_count", "sum"),
    )
    .assign(
        neut_standard_frac=lambda x: x["neut_standard_count"] / x["total_count"],
        passes_qc=lambda x: (
            (x["neut_standard_frac"] >= qc_thresholds["min_neut_standard_frac"])
            & (
                (x["serum"] != "none")
                | (x["neut_standard_frac"] <= qc_thresholds["max_neut_standard_frac_no_serum"])
            )
        ),
    )
)

neut_standard_qc_desc = (
    f"neut standard frac >= {qc_thresholds['min_neut_standard_frac']}, "
    + f"<= {qc_thresholds['max_neut_standard_frac_no_serum']} for no-serum samples"
)

neut_standard_fracs_chart = (
    alt.Chart(neut_standard_fracs)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "neut_standard_frac",
            title="fraction of counts from neutralization standard",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y(
            "sample_noplate",
            title=None,
            sort=list(
                neut_standard_fracs.sort_values(["serum", "dilution_factor"])["sample_noplate"]
            ),
        ),
        alt.Color(
            "passes_qc",
            title=neut_standard_qc_desc,
            scale=alt.Scale(domain=[True, False]),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if neut_standard_fracs[c].dtype == float
            else c
            for c in neut_standard_fracs.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Neutralization-standard fractions for {plate}",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(neut_standard_fracs_chart)

if not neut_standard_fracs["passes_qc"].all():
    qc_failures.add("min_neut_standard_frac or max_neut_standard_frac_no_serum")
    print(f"\nThe following samples failed {neut_standard_qc_desc}")
    display(neut_standard_fracs.query("not passes_qc").reset_index(drop=True))
else:
    print(f"\nAll samples passed {neut_standard_qc_desc}")

## Consistency and minimum fractions for barcodes
We examine the fraction of counts attributable to each barcode. We do this splitting the data two ways:

 1. Looking at all viral (but not neut-standard) barcodes only for the no-serum samples.

 2. Looking at just the neut-standard barcodes for all samples.

The reasons is that if the experiment is set up perfectly, these fractions should be the same across all samples for each barcode.
(We do not expect viral barcodes to have consistent fractions across no-serum samples as they will be neutralized differently depending on strain).

We plot these fractions in interactive plots (you can mouseover points and zoom) so you can identify barcodes that fail the expected consistency QC thresholds.

We also make sure the barcodes meet specified QC minimum thresholds for all samples, and flag any that do not.

In [None]:
barcode_selection = alt.selection_point(fields=["barcode"], on="mouseover", empty=False)

# look at all samples for neut standard barcodes, or no-serum samples for all barcodes
for is_neut_standard, df in counts.groupby("neut_standard"):

    # process data frame
    if not is_neut_standard:
        df = df.query("serum == 'none'")
    df = (
        df.assign(
            sample_counts=lambda x: x.groupby("sample")["count"].transform("sum"),
            count_frac=lambda x: x["count"] / x["sample_counts"],
            median_count_frac=lambda x: x.groupby("barcode")["count_frac"].transform(
                "median"
            ),
            fold_change_from_median=lambda x: x["count_frac"] / x["median_count_frac"],
        )
        .drop(
            columns=(
                ["sample", "serum_replicate", "sample_counts", "neut_standard"]
                + (["strain"] if is_neut_standard else ["dilution_factor"])
            ),
        )
    )

    # make chart
    evenness_chart = (
        alt.Chart(df)
        .add_params(barcode_selection)
        .encode(
            alt.X(
                "count_frac",
                title=(
                    "barcode's fraction of neut standard counts"
                    if is_neut_standard
                    else "barcode's fraction of non-neut standard counts"
                ),
                scale=alt.Scale(nice=False, padding=5),
            ),
            alt.Y(
                "sample_noplate",
                title=None,
                sort=list(
                    neut_standard_fracs.sort_values(["serum", "dilution_factor"])["sample_noplate"]
                ),
            ),
            alt.Fill("barcode", legend=None),
            strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
            size=alt.condition(barcode_selection, alt.value(60), alt.value(35)),
            tooltip=[
                alt.Tooltip(c, format=".3g") if df[c].dtype == float
                else c
                for c in df.columns
            ],
        )
        .mark_circle(fillOpacity=0.6, stroke="black", strokeOpacity=1)
        .properties(
            height=alt.Step(10),
            width=300,
            title=(
                f"{plate} all samples, neut-standard barcodes"
                if is_neut_standard
                else f"{plate} no-serum samples, all barcodes"
            ),
        )
        .configure_axis(grid=False)
        .configure_legend(titleLimit=1000)
        .interactive()
    )

    if is_neut_standard:
        evenness_chart = evenness_chart.add_params(serum_selection).transform_filter(
            serum_selection
        )
        print(f"\n\n{'=' * 89}\nAnalyzing neut-standard barcodes from all samples\n")
    else:
        print(f"\n\n{'=' * 89}\nAnalyzing all barcodes from no-serum samples\n")

    display(evenness_chart)

    # make sure barcode fractions are reasonably consistent when they should be
    excess_fold_change = df[
        (df["fold_change_from_median"] <= 1 / qc_thresholds["barcode_frac_consistency"])
        | (df["fold_change_from_median"] >= qc_thresholds["barcode_frac_consistency"])
    ]
    if len(excess_fold_change):
        print(f"\nFollowing barcodes failed {qc_thresholds['barcode_frac_consistency']=}")
        display(excess_fold_change)
        qc_failures.add("barcode_frac_consistency")
    else:
        print(f"\nPassed {qc_thresholds['barcode_frac_consistency']=}")

    # make sure barcodes have sufficient fraction
    if is_neut_standard:
        insufficient_neut_standard_barcode_frac = df[
            df["count_frac"] < qc_thresholds["min_neut_standard_barcode_frac"]
        ]
        if len(insufficient_neut_standard_barcode_frac):
            print(
                "\nFollowing barcodes fail "
                + f"{qc_thresholds['min_neut_standard_barcode_frac']=}"
            )
            display(insufficient_neut_standard_barcode_frac)
            qc_failures.add("min_neut_standard_barcode_frac")
        else:
            print(f"\nPassed {qc_thresholds['min_neut_standard_barcode_frac']=}")
    else:
        insufficient_viral_barcode_frac = df[
            df["count_frac"] < qc_thresholds["min_viral_barcode_frac"]
        ]
        if len(insufficient_viral_barcode_frac):
            print(
                f"\nFollowing barcodes fail {qc_thresholds['min_viral_barcode_frac']=}"
            )
            display(insufficient_viral_barcode_frac)
            qc_failures.add("min_viral_barcode_frac")
        else:
            print(f"\nPassed {qc_thresholds['min_viral_barcode_frac']=}")

## Compute fraction infectivity

The fraction infectivity for viral barcode $v_b$ in sample $s$ is computed as:
$$
F_{v_b,s} = \frac{c_{v_b,s} / \left(\sum_{n_b} c_{n_b,s}\right)}{{\rm median}_{s_0}\left[ c_{v_b,s_0} / \left(\sum_{n_b} c_{n_b,s_0}\right)\right]}
$$
where
 - $c_{v_b,s}$ is the counts of viral barcode $v_b$ in sample $s$.
 - $\sum_{n_b} c_{n_b,s}$ is the sum of the counts for all neutralization standard barcodes $n_b$ for sample $s$.
 - $c_{v_b,s_0}$ is the counts of viral barcode $v_b$ in no-serum sample $s_0$.
 - $\sum_{n_b} c_{n_b,s_0}$ is the sum of the counts for all neutralization standard barcodes $n_b$ for no-serum sample $s_0$.
 - ${\rm median}_{s_0}\left[ c_{v_b,s_0} / \left(\sum_{n_b} c_{n_b,s_0}\right)\right]$ is the median taken across all no-serum samples of the counts of viral barcode $v_b$ versus the total counts for all neutralization standard barcodes.

First, compute the total neutralization-standard counts for each sample.
Plot these, and make sure they meet the QC threshold.

In [None]:
neut_standard_counts = (
    counts
    .query("neut_standard")
    .groupby(
        ["well", "serum", "sample_noplate", "dilution_factor"], dropna=False, as_index=False,
    )
    .aggregate(neut_standard_count=pd.NamedAgg("count", "sum"))
    .assign(
        passes_qc=lambda x: (
            x["neut_standard_count"] >= qc_thresholds["min_neut_standard_count"]
        ),
    )
)

neut_standard_counts_chart = (
    alt.Chart(neut_standard_counts)
    .add_params(serum_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "neut_standard_count",
            title="counts from neutralization standard",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y(
            "sample_noplate",
            title=None,
            sort=list(
                neut_standard_counts.sort_values(["serum", "dilution_factor"])["sample_noplate"]
            ),
        ),
        alt.Color(
            "passes_qc",
            title=f"at least {qc_thresholds['min_neut_standard_count']} counts",
            scale=alt.Scale(domain=[True, False]),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if neut_standard_counts[c].dtype == float
            else c
            for c in neut_standard_counts.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Neutralization-standard counts for {plate}",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(neut_standard_counts_chart)

if (neut_standard_counts["passes_qc"]).all():
    print(f"\nAll samples pass {qc_thresholds['min_neut_standard_count']=}")
else:
    print(f"\nSamples failing {qc_thresholds['min_neut_standard_count']=}")
    display(neut_standard_counts.query("not passes_qc"))
    qc_failures.add("min_neut_standard_count")

Compute and plot the no-serum sample viral barcode counts and check if they pass the QC filters.

In [None]:
no_serum_counts = (
    counts
    .query("serum == 'none'")
    .query("not neut_standard")
    .merge(neut_standard_counts, validate="many_to_one")
    [["barcode", "strain", "well", "sample_noplate", "count", "neut_standard_count"]]
    .assign(
        passes_qc=lambda x: (
            x["count"] >= qc_thresholds["min_no_serum_viral_barcode_count"]
        ),
    )
)

# make chart
no_serum_counts_chart = (
    alt.Chart(no_serum_counts)
    .add_params(barcode_selection)
    .encode(
        alt.X("count", title="viral barcode count", scale=alt.Scale(nice=False, padding=5)),
        alt.Y("sample_noplate", title=None),
        alt.Fill("barcode", legend=None),
        strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
        size=alt.condition(barcode_selection, alt.value(60), alt.value(35)),
        tooltip=no_serum_counts.columns.tolist(),
    )
    .mark_circle(fillOpacity=0.6, stroke="black", strokeOpacity=1)
    .properties(
        height=alt.Step(10),
        width=300,
        title=f"{plate} viral barcode counts in no-serum samples",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
    .interactive()
)

display(no_serum_counts_chart)

# QC check
if (no_serum_counts["passes_qc"]).all():
    print(f"\nAll samples pass {qc_thresholds['min_no_serum_viral_barcode_count']=}")
else:
    print(f"\nSamples failing {qc_thresholds['min_no_serum_viral_barcode_count']=}")
    display(no_serum_counts.query("not passes_qc"))
    qc_failures.add("min_no_serum_viral_barcode_count")

Compute and plot the median ratio of viral barcode count to neut standard counts across no-serum samples:

In [None]:
median_no_serum_ratio = (
    no_serum_counts
    .assign(ratio=lambda x: x["count"] / x["neut_standard_count"])
    .groupby(["barcode", "strain"], as_index=False)
    .aggregate(median_no_serum_ratio=pd.NamedAgg("ratio", "median"))
)

strain_selection = alt.selection_point(fields=["strain"], on="mouseover", empty=False)

median_no_serum_ratio_chart = (
    alt.Chart(median_no_serum_ratio)
    .add_params(strain_selection)
    .encode(
        alt.X(
            "median_no_serum_ratio",
            title="median ratio of counts",
            scale=alt.Scale(nice=False, padding=5),
        ),
        alt.Y(
            "barcode",
            sort=alt.SortField("median_no_serum_ratio", order="descending"),
            axis=alt.Axis(labelFontSize=5),
        ),
        color=alt.condition(strain_selection, alt.value("orange"), alt.value("gray")),
        tooltip=[
            alt.Tooltip(c, format=".3g") if median_no_serum_ratio[c].dtype == float
            else c
            for c in median_no_serum_ratio.columns
        ],
    )
    .mark_bar(height={"band": 0.85})
    .properties(
        height=alt.Step(5),
        width=250,
        title=f"{plate} no-serum median ratio viral barcode to neut-standard barcode",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(median_no_serum_ratio_chart)

Compute the actual fraction infectivities, QC check if any are null (from zero counts), and also plot and check if any exceed the `max_frac_infectivity`:

In [None]:
frac_infectivity = (
    counts
    .query("not neut_standard")
    .query("serum != 'none'")
    .merge(median_no_serum_ratio, validate="many_to_one")
    .merge(
        neut_standard_counts.drop(columns="passes_qc"),
        validate="many_to_one",
    )
    .assign(
        frac_infectivity=lambda x: (
            (x["count"] / x["neut_standard_count"]) / x["median_no_serum_ratio"]
        ),
        passes_qc=lambda x: x["frac_infectivity"] <= qc_thresholds["max_frac_infectivity"],
    )
    [
        [
            "barcode",
            "strain",
            "serum",
            "serum_replicate",
            "plate_replicate",
            "dilution_factor",
            "frac_infectivity",
            "sample_noplate",
            "well",
            "passes_qc",
        ]
    ]
)

assert (
    len(frac_infectivity.groupby(["barcode", "serum", "plate_replicate", "dilution_factor"]))
    == len(frac_infectivity)
)
assert frac_infectivity["dilution_factor"].notnull().all()

frac_infectivity_chart = (
    alt.Chart(frac_infectivity)
    .add_params(serum_selection, barcode_selection)
    .transform_filter(serum_selection)
    .encode(
        alt.X(
            "frac_infectivity",
            title="fraction infectivity",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y(
            "sample_noplate",
            title=None,
            sort=list(
                neut_standard_counts.sort_values(["serum", "dilution_factor"])["sample_noplate"]
            ),
        ),
        strokeWidth=alt.condition(barcode_selection, alt.value(2), alt.value(0)),
        size=alt.condition(barcode_selection, alt.value(60), alt.value(35)),
        color=alt.Color(
            "passes_qc",
            title=f"frac_infectivity <= {qc_thresholds['max_frac_infectivity']}",
            scale=alt.Scale(domain=[True, False]),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if frac_infectivity[c].dtype == float
            else c
            for c in frac_infectivity.columns
        ],
    )
    .mark_circle(stroke="black", strokeOpacity=1)
    .properties(
        height=alt.Step(10),
        width=250,
        title=f"Fraction infectivities for {plate}",
    )
    .configure_axis(grid=False)
    .configure_legend(titleLimit=1000)
)

display(frac_infectivity_chart)

if not frac_infectivity["passes_qc"].all():
    print(f"\nSome barcode-samples fail {qc_thresholds['max_frac_infectivity']=}")
    display(frac_infectivity.query("not passes_qc"))
    qc_failures.add("max_frac_infectivity")
else:
    print("\nAll barcode-samples pass {qc_thresholds['max_frac_infectivity']=}")

if frac_infectivity["frac_infectivity"].isnull().any():
    print("\nSome barcodes have undefined fraction infectivity due to zero counts:")
    display(frac_infectivity.query("frac_infectivity.isnull()"))
    qc_failures.add("null_frac_infectivity")
else:
    print("\nNo undefined fraction infectivities")

frac_infectivity

Write fraction infectivities to file:

In [None]:
print(f"\nWriting fraction infectivities to {frac_infectivity_csv}")
(
    frac_infectivity
    [["barcode", "strain", "serum", "plate_replicate", "dilution_factor", "frac_infectivity"]]
    .sort_values(["serum", "plate_replicate", "dilution_factor", "barcode"])
    .to_csv(frac_infectivity_csv, index=False, float_format="%.5g")
)

Make sure we have enough dilutions with non-null fraction infectivities for each serum-replicate:

In [None]:
n_dilutions = (
    frac_infectivity
    .query("frac_infectivity.notnull()")
    .groupby("serum_replicate")
    .aggregate(n_dilutions=pd.NamedAgg("dilution_factor", "nunique"))
    .assign(
        fails_qc=lambda x: (
            x["n_dilutions"] <= qc_thresholds["min_dilutions_per_serum_replicate"]
        ),
    )
)

if n_dilutions["fails_qc"].any():
    print(f"Failing {qc_thresholds['min_dilutions_per_serum_replicate']=}:")
    display(n_dilutions.query("fails_qc"))
    qc_failures.add("min_dilutions_per_serum_replicate")
else:
    print(f"Passed {qc_thresholds['min_dilutions_per_serum_replicate']=}:")    

Summarize all QC failures and write to file:

In [None]:
qc_failures = "\n".join(sorted(qc_failures))

if qc_failures:
    print(f"Encountered the following QC failures:\n{qc_failures}")
else:
    print("No QC failures")

print(f"\nLogging QC failures to {qc_failures_file}")
with open(qc_failures_file, "w") as f:
    f.write(qc_failures)