# Compare strain growth versus titers

In [None]:
# get variables from snakemake

desc = f"{snakemake.wildcards.protset}_{snakemake.wildcards.mlrfit}_{snakemake.wildcards.sera}"
print(desc)

growth_csv = snakemake.input.growth
titers_csv = snakemake.input.titers
muts_from_mrca_csv = snakemake.input.muts_from_mrca

simple_plot_scale = snakemake.params.simple_plot_scale

chart_html = snakemake.output.chart
simple_corr_chart_html = snakemake.output.simple_corr_chart
simple_cutoff_chart_html = snakemake.output.simple_cutoff_chart
corr_csv = snakemake.output.corr_csv
scatter_csv = snakemake.output.scatter_csv

corr_titer_cutoff_range = snakemake.params.corr_titer_cutoff_range
corr_titer_cutoff_points = snakemake.params.corr_titer_cutoff_points
nrandom = snakemake.params.nrandom
sera_regex = snakemake.params.sera_regex
pool = snakemake.params.pool

In [None]:
import collections
import math
import re

import altair as alt

import numpy

import pandas as pd

import scipy

_ = alt.data_transformers.disable_max_rows()

Read growth, then get titers just for those strains with growth estimates:

In [None]:
growth = pd.read_csv(growth_csv)
growth = growth[
    ~growth["strain"].isin(
        {"library strains insufficient counts", "strain not in library"}
    )
]

titers = (
    pd.read_csv(titers_csv)
    .rename(columns={"virus": "strain"})
    [["strain", "serum", "titer"]]
)
assert set(growth["strain"]).issubset(titers["strain"]), set(growth["strain"]) - set(titers["strain"])
assert len(titers) == len(titers.groupby(["strain", "serum"]))

print(
    f"Read titers for {titers['strain'].nunique()} strains, of which "
    f"{growth['strain'].nunique()} have growth advantage estimates."
)

if pool:
    print(f"Parsing pool titers for {pool=}")
    pool_titers = titers.query("serum == @pool")
    if pool_titers["serum"].nunique() == 1:
        print(f"Parsed pool titers for {pool=}")
    else:
        raise ValueError(f"Failed to parse pool titers for {pool=}")
    assert len(pool_titers) == len(pool_titers.groupby(["strain"]))
    assert set(growth["strain"]).issubset(pool_titers["strain"])
else:
    print(f"Not getting any pool titers")

if sera_regex is not None:
    print(f"Parsing {titers['serum'].nunique()} sera for those matching {sera_regex=}")
    titers = titers[titers["serum"].str.contains(sera_regex, regex=True)]
    pass

nsera = titers["serum"].nunique()
print(f"Analyzing titers for {nsera} sera")

titers = titers[titers["strain"].isin(set(growth["strain"]))]

if muts_from_mrca_csv:
    print(f"Parsing mutations from MRCA from {muts_from_mrca_csv}")
    muts_from_mrca = (
        pd.read_csv(muts_from_mrca_csv)
        [["strain", "nucleotide_mutations", "protein_mutations", "HA1_protein_mutations"]]
    )
    missing_muts = set(growth["strain"]) - set(muts_from_mrca["strain"])
    assert not missing_muts, missing_muts
else:
    muts_from_mrca = None
    print("Not getting mutations from MRCA")

Get the correlation of growth with the **geometric** mean titer, median titer, and the fraction of sera below whatever cutoff gives the best correlation.
Do this for the actual data, and randomizations of the data (randomizing the growth rates among strains) to get P-values.
Some notes:
 - Correlations with mean and median titers are with the **log** of these titers, whereas with fraction sera below cutoff it is with the value
 - The P-values are one-sided, testing if mean and median titers are negatively correlated with growth, and fraction sera below cutoff is positive correlated.

In [None]:
def get_corrs_by_cutoff(
    growth_df,
    titers_df,
    cutoff_range=corr_titer_cutoff_range,
    cutoff_points=corr_titer_cutoff_points,
):
    """For each titer cutoff, get Pearson correlation of growth w frac titers below."""
    assert cutoff_range[0] < cutoff_range[1]
    corrs = []
    for cutoff in numpy.logspace(
        numpy.log10(cutoff_range[0]), numpy.log10(cutoff_range[1]), cutoff_points
    ):
        frac_below = (
            titers_df
            .groupby("strain", as_index=False)
            .aggregate(frac=pd.NamedAgg("titer", lambda s: (s <= cutoff).sum() / len(s)))
            .merge(growth_df[["strain", "growth_advantage_median"]])
        )
        if frac_below["frac"].nunique() == 1:
            continue  # cannot compute correlation if values all identical
        r = scipy.stats.pearsonr(frac_below["growth_advantage_median"], frac_below["frac"])[0]
        corrs.append((cutoff, r))
    return pd.DataFrame(corrs, columns=["titer_cutoff", "correlation"])

# get actual correlations
corr_cutoff_actual = get_corrs_by_cutoff(growth, titers)
best_cutoff = corr_cutoff_actual.sort_values("correlation", ascending=False)["titer_cutoff"].tolist()[0]
best_cutoff_corr = corr_cutoff_actual.sort_values("correlation", ascending=False)["correlation"].tolist()[0]
corr_cutoffs = [corr_cutoff_actual.assign(dataset="actual")]
growth_actual_scatter = growth.merge(
    titers.groupby("strain", as_index=False).aggregate(
        mean_titer=pd.NamedAgg("titer", lambda s: scipy.stats.gmean(s)),
        median_titer=pd.NamedAgg("titer", "median"),
        frac_below_titer=pd.NamedAgg("titer", lambda s: (s <= best_cutoff).sum() / len(s)),
    ),
    on="strain",
    validate="one_to_one",
)
if pool is not None:
    growth_actual_scatter = growth_actual_scatter.merge(
        pool_titers.rename(columns={"titer": "pool_titer"})[["strain", "pool_titer"]],
        on="strain",
        validate="one_to_one",
    )
if muts_from_mrca is not None:
    growth_actual_scatter = growth_actual_scatter.merge(
        muts_from_mrca, on="strain", validate="one_to_one",
    )

print(f"Writing data for scatter plot to {scatter_csv}")
growth_actual_scatter.to_csv(scatter_csv, index=False, float_format="%.4g")

# get correlations for randomizations
titer_cols = [
    c for c in growth_actual_scatter.columns if c.endswith("_titer") or c.endswith("_mutations")
]
if "frac_below_titer" in titer_cols:
    titer_cols = ["frac_below_titer"] + [c for c in titer_cols if c != "frac_below_titer"]
randomized_corrs = {c: [] for c in titer_cols}
for irandom in range(nrandom):
    numpy.random.seed(irandom)
    growth_random = growth.assign(
        growth_advantage_median=lambda x: numpy.random.permutation(
            x["growth_advantage_median"].values
        )
    )
    corr_cutoffs.append(
        get_corrs_by_cutoff(growth_random, titers).assign(dataset=f"random_{irandom}")
    )
    assert (growth_actual_scatter["strain"].values == growth_random["strain"].values).all()
    for col in [c for c in titer_cols if c != "frac_below_titer"]:
        randomized_corrs[col].append(
            scipy.stats.pearsonr(
                growth_random["growth_advantage_median"].values,
                numpy.log10(growth_actual_scatter[col].values),
            )[0]
        )
    randomized_corrs["frac_below_titer"].append(corr_cutoffs[-1]["correlation"].max())
    
corr_cutoffs = pd.concat(corr_cutoffs, ignore_index=True)

corrs_and_pvals = []
for var in titer_cols:
    assert len(randomized_corrs[var]) == nrandom
    if var == "frac_below_titer" or var.endswith("_mutations"):
        corr = scipy.stats.pearsonr(
            growth_actual_scatter["growth_advantage_median"],
            growth_actual_scatter[var],
        )[0]
        assert (var != "frac_below_titer") or numpy.allclose(corr, best_cutoff_corr), f"{corr=}, {best_cutoff_corr=}"
        p = sum(r >= corr for r in randomized_corrs[var]) / nrandom
    else:
        corr = scipy.stats.pearsonr(
            growth_actual_scatter["growth_advantage_median"],
            numpy.log10(growth_actual_scatter[var]),
        )[0]
        p = sum(r <= corr for r in randomized_corrs[var]) / nrandom
    corrs_and_pvals.append((var, corr, p))

corrs_and_pvals = pd.DataFrame(corrs_and_pvals, columns=["stat", "correlation", "P"]).assign(
    p_str=lambda x: x["P"].map(
        lambda p: f"P = {p:.2g}" if p else f"P < {1 / nrandom:.2g}"
    ),
    label=lambda x: x.apply(
        lambda r: [f"R = {r['correlation']:.2f}", r["p_str"]],
        axis=1,
    ),
)

display(corrs_and_pvals)

print(f"Writing correlations and P-values to {corr_csv}")
corrs_and_pvals.drop(columns="label").to_csv(corr_csv, index=False, float_format="%.4g")

Plot correlation of growth with fraction of titers above a cutoff for the actual data and the randomizations:

In [None]:
plotscales = [1, simple_plot_scale]

show_random = alt.param(
    value=True,
    name="show_random",
    bind=alt.binding_radio(
        options=[True, False],
        name="show randomizations for titer cutoff",
    ),
)

corr_cutoffs_charts = {}
for plotscale in plotscales:
    corr_cutoffs_charts[plotscale] = (
        alt.Chart(corr_cutoffs)
        .add_params(show_random)
        .transform_calculate(actual=alt.datum["dataset"] == "actual")
        .transform_filter(f"{show_random.name} | datum.actual")
        .encode(
            alt.X(
                "titer_cutoff",
                scale=alt.Scale(zero=False, nice=False, type="log", padding=5 * plotscale),
                title="neutralization titer cutoff",
                axis=alt.Axis(grid=False, labelOverlap=True, titleFontSize=12 * plotscale, labelFontSize=10 * plotscale),
            ),
            alt.Y("correlation", axis=alt.Axis(grid=False, titleFontSize=12 * plotscale, labelFontSize=10 * plotscale, tickCount=5)),
            alt.Color(
                "actual:N",
                scale=alt.Scale(domain=[True, False], range=["black", "slateblue"]),
                legend=None,
            ),
            alt.Opacity(
                "actual:N",
                scale=alt.Scale(domain=[True, False], range=[1, 0.15]),
            ),
            alt.StrokeWidth(
                "actual:N",
                scale=alt.Scale(domain=[True, False], range=[2 * plotscale, 1 * plotscale]),
            ),
            alt.Detail("dataset"),
            tooltip=[
                "dataset",
                alt.Tooltip("titer_cutoff", format=".1f"),
                alt.Tooltip("correlation", format=".2f"),
            ],
        )
        .mark_line()
        .properties(
            width=250 * plotscale,
            height=170 * plotscale,
            title=alt.TitleParams(
                [
                    "correlation of strain growth with",
                    "fraction sera with titers below cutoff",
                ],
                subtitle=[
                    "thick black line is actual data;",
                    f"thin blue lines are {nrandom} randomizations",
                ],
                fontSize=13 * plotscale,
                subtitleFontSize=12 * plotscale,
            ),
        )
    )

corr_cutoffs_charts[1]

Plot the correlation scatter plots:

In [None]:
corr_base = alt.Chart(
    growth_actual_scatter.assign(
        ga_label=lambda x: x.apply(
            lambda r: f"{r['growth_advantage_median']:.2f} [{r['growth_advantage_hpd_min']:.2f} - {r['growth_advantage_hpd_max']:.2f}]",
            axis=1,
        )
    )
)

strain_sel = alt.selection_point(fields=["strain"], on="mouseover", empty=False)

corr_charts = collections.defaultdict(list)
ncols = 4
width = height = 150
for plotscale in plotscales:
    for istat, stat in enumerate(titer_cols):
        stat_name = {
            "mean_titer": "mean titer",
            "median_titer": "median titer",
            "frac_below_titer": f"fraction sera with titer < {best_cutoff:.0f}",
            "pool_titer": "titer for sera pool",
            "nucleotide_mutations": "HA nucleotide mutations",
            "protein_mutations": "HA protein mutations",
            "HA1_protein_mutations": "HA1 protein mutations",
        }[stat]
    
        corr_scatter = (
            corr_base
            .encode(
                alt.X(
                    stat,
                    title=stat_name,
                    scale=alt.Scale(
                        zero=False,
                        nice=False,
                        padding=8 * plotscale,
                        type="linear" if (stat == "frac_below_titer" or stat.endswith("mutations")) else "log",
                    ),
                    axis=alt.Axis(grid=False, labelOverlap=True, titleFontSize=11 * plotscale, labelFontSize=10 * plotscale, tickCount=5),
                ),
                alt.Y(
                    "growth_advantage_median",
                    title="growth advantage" if istat % ncols == 0 else None,
                    scale=alt.Scale(zero=False, nice=False, padding=8 * plotscale),
                    axis=alt.Axis(grid=False, labels=istat % ncols == 0, titleFontSize=11 * plotscale, labelFontSize=10 * plotscale, tickCount=4),
                ),
                stroke=alt.condition(strain_sel, alt.value("red"), alt.value("black")),
                strokeWidth=alt.condition(strain_sel, alt.value(2 * plotscale), alt.value(1 * plotscale)),
                tooltip=[
                    "strain",
                    alt.Tooltip("ga_label", title="growth advantage"),
                    alt.Tooltip(stat, title=stat_name, format=".2f"),
                ],
            )
            .mark_circle(size=45 * plotscale**2, color="black", strokeOpacity=1, fillOpacity=0.5)
            .properties(width=width * plotscale, height=height * plotscale)
        )
        corr_errorbar = (
            corr_base
            .encode(
                alt.X(stat),
                alt.Y("growth_advantage_hpd_min"),
                alt.Y2("growth_advantage_hpd_max"),
            )
            .mark_rule(color="black")
        )
        corr_text = (
            alt.Chart(corrs_and_pvals[corrs_and_pvals["stat"] == stat])
            .encode(
                x=(
                    alt.value(2 * plotscale)
                    if (stat == "frac_below_titer" or stat.endswith("mutations"))
                    else alt.value((width - 2) * plotscale)
                ),
                y=alt.value(2 * plotscale),
                text="label",
            )
            .mark_text(
                baseline="top",
                align="left" if (stat == "frac_below_titer" or stat.endswith("mutations")) else "right",
                color="blue",
                size=11 * plotscale,
            )
        )
    
        corr_charts[plotscale].append(corr_scatter + corr_errorbar + corr_text)

nrows = int(math.ceil(len(corr_charts[1]) / ncols))
chart_rows = []
for irow in range(nrows):
    chart_rows.append(alt.hconcat(*corr_charts[1][irow * ncols: irow * ncols + ncols], spacing=7))
corr_chart = alt.vconcat(*chart_rows, spacing=16).add_params(strain_sel).properties(
    title=alt.TitleParams(
        f"strain growth versus neutralization by {nsera} sera",
        anchor="middle",
    ),
)

corr_chart

Make and save the overall chart:

In [None]:
chart = alt.vconcat(corr_chart, corr_cutoffs_charts[1], spacing=25).properties(
    title=alt.TitleParams(
        f"strain growth vs neutralization for {desc}",
        anchor="middle",
        fontSize=15,
        dy=-20,
    )
)

print(f"Saving to {chart_html}")
chart.save(chart_html)

chart

Make bigger charts for presentation size:

In [None]:
biggerscale = plotscales[-1]

print(f"Saving to {simple_cutoff_chart_html}")
corr_cutoffs_charts[biggerscale].save(simple_cutoff_chart_html)

corr_cutoffs_charts[biggerscale]

In [None]:
frac_below_chart = corr_charts[biggerscale][titer_cols.index("frac_below_titer")]

pool_chart = corr_charts[biggerscale][titer_cols.index("pool_titer")]

frac_below_chart.add_params(strain_sel)

bigger_corr_chart = (
    alt.hconcat(frac_below_chart, pool_chart, spacing=8 * biggerscale)
    .add_params(strain_sel)
)

print(f"Saving to {simple_corr_chart_html}")
bigger_corr_chart.save(simple_corr_chart_html)

bigger_corr_chart

In [None]:
chart_html