# Estimate strain growth advantages using MLR

Import modules including [evofr](https://blab.github.io/evofr/) and get variables from `snakemake`:

In [None]:
import datetime
import math
import re

import altair as alt

import evofr
import evofr.plotting

import matplotlib.pyplot as plt

import numpy

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

In [None]:
# get variables from `snakemake`

desc = f"{snakemake.wildcards.protset}_{snakemake.wildcards.mlrfit}"
print(desc)

counts_by_date_csv = snakemake.input.counts_by_date

chart_html = snakemake.output.chart
counts_to_fit_csv = snakemake.output.counts_to_fit
growth_advantages_csv = snakemake.output.growth_advantages

date_start = snakemake.params.date_start
date_end = snakemake.params.date_end
assert all(isinstance(d, datetime.date) for d in [date_start, date_end])
date_start = pd.Timestamp(date_start)
date_end = pd.Timestamp(date_end)
assert date_start < date_end

min_counts = snakemake.params.min_counts

keep_not_in_library = snakemake.params.keep_not_in_library
keep_insufficient_counts = snakemake.params.keep_insufficient_counts

plot_window_frame_days = snakemake.params.plot_window_frame_days
pivot_strain = snakemake.params.pivot_strain
mlr_tau = snakemake.params.mlr_tau
num_warmup = snakemake.params.num_warmup
num_samples = snakemake.params.num_samples
hpd_interval = snakemake.params.hpd_interval

## Counts for each strain by date

Read data.
We filter for counts within the specified date range, and then filter for strains that have sufficient counts in that range.
At the end of this, strains fall in one of three categories:
 - a named library strain (eg, *A/Bhutan/0845/2023*)
 - *strain not in library*: does not match a strain in library
 - *library strains with insufficient counts*: strains in the library with insufficient counts

In [None]:
NOT_IN_LIBRARY = "strain not in library"
INSUFFICIENT_COUNTS = "library strains insufficient counts"

all_counts_by_date = (
    pd.read_csv(counts_by_date_csv, parse_dates=["date"])
    .sort_values(["date", "variant"])
)

if date_start < all_counts_by_date["date"].min():
    raise ValueError(f"{date_start=} before {all_counts_by_date['date'].min()=}")
if date_end > all_counts_by_date["date"].max():
    raise ValueError(f"{date_start=} after {all_counts_by_date['date'].max()=}")

print(f"Trimming counts by date to the range {date_start=} to {date_end=}")
all_counts_by_date = all_counts_by_date[
    (all_counts_by_date["date"] >= date_start)
    & (all_counts_by_date["date"] <= date_end)
]

assert "other" in set(all_counts_by_date["variant"])
assert NOT_IN_LIBRARY not in set(all_counts_by_date["variant"])
all_counts_by_date["variant"] = all_counts_by_date["variant"].replace(
    "other", NOT_IN_LIBRARY
)

Get total counts for each variant:

In [None]:
total_counts = (
    all_counts_by_date
    .groupby("variant", as_index=False)
    .aggregate(total_sequences=pd.NamedAgg("sequences", "sum"))
    .assign(sufficient_counts=lambda x: x["total_sequences"] >= min_counts)
)

total_counts_chart = (
    alt.Chart(total_counts)
    .encode(
        alt.X(
            "variant",
            sort=alt.SortField("total_sequences", order="descending"),
            title=None,
        ),
        alt.Y(
            "total_sequences",
            scale=alt.Scale(type="symlog", constant=50),
            title="total sequences",
            axis=alt.Axis(grid=False),
        ),
        alt.Fill(
            "sufficient_counts",
            scale=alt.Scale(range=["gray", "white"], domain=[True, False]),
            title="sufficient counts?",
            legend=alt.Legend(orient="top-right", offset=3)
        ),
        tooltip=total_counts.columns.tolist(),
    )
    .mark_bar(stroke="black")
    .properties(
        height=150,
        width=alt.Step(11),
        title=f"total sequences per strain from {date_start.date()} to {date_end.date()}",
    )
)

total_counts_chart

Now filter the counts by date to group all library strains with insufficient counts.
Also pad any missing dates in the range:

In [None]:
strains_w_insufficient_counts = set(
    total_counts
    .query("(not sufficient_counts) and variant != @NOT_IN_LIBRARY")
    ["variant"]
)

print(f"Grouping {len(strains_w_insufficient_counts)=} to '{INSUFFICIENT_COUNTS}'")

assert INSUFFICIENT_COUNTS not in set(all_counts_by_date["variant"])

# group strains w insufficient counts and pad zero counts on days w no counts
filtered_counts_by_date = (
    all_counts_by_date
    .assign(
        variant=lambda x: x["variant"].map(
            lambda v: INSUFFICIENT_COUNTS if v in strains_w_insufficient_counts else v
        ),
        day=lambda x: (x["date"] - date_start).dt.days,
    )
    .groupby(["variant", "day"], as_index=False)
    .aggregate({"sequences": "sum"})
    .sort_values(["day", "variant"])
)

days = filtered_counts_by_date["day"].unique()
assert all(days == days.astype(int)), "dates not all rounded to day"

print(f"Padding with zero counts any missing days between {date_start} and {date_end}")
filtered_counts_by_date = (
    filtered_counts_by_date
    .merge(
        pd.DataFrame(
            [
                (v, d)
                for v in filtered_counts_by_date["variant"].unique()
                for d in range(days.min(), days.max() + 1)
            ],
            columns=["variant", "day"],
        ),
        how="outer",
        validate="one_to_one",
    )
    .assign(
        sequences=lambda x: x["sequences"].fillna(0),
        date=lambda x: x["day"].map(lambda d: date_start + pd.Timedelta(days=d)),
    )
    .drop(columns="day")
)

Plot number of strains in each group as a function of date:

In [None]:
# we create an integer days columns so we can impute missing days
grouped_counts_by_date = (
    filtered_counts_by_date
    .assign(
        set_of_strains=lambda x: x["variant"].map(
            lambda v: (
                "library strains"
                if v not in {INSUFFICIENT_COUNTS, NOT_IN_LIBRARY}
                else v
            ),
        )
    )
    .groupby(["set_of_strains", "date"], as_index=False)
    .aggregate({"sequences": "sum"})
)

grouped_counts_by_date_chart = (
    alt.Chart(grouped_counts_by_date)
    .transform_window(
        count="mean(sequences)",
        groupby=["set_of_strains"],
        frame=[-plot_window_frame_days, plot_window_frame_days],
    )
    .transform_joinaggregate(total_count="sum(count)", groupby=["date"])
    .transform_calculate(fraction=alt.datum.count / alt.datum.total_count)
    .transform_fold(
        fold=["count", "fraction"],
        as_=["statistic", "count_or_fraction"],
    )
    .encode(
        alt.X("date", title=None, axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90)),
        alt.Y(
            "count_or_fraction:Q",
            axis=alt.Axis(grid=False),
            title=None,
            scale=alt.Scale(nice=False)
        ),
        alt.Fill(
            "set_of_strains",
            title="set of strains",
            legend=alt.Legend(orient="top", labelLimit=500, titleOrient="left"),
        ),
        alt.Column(
            "statistic:N",
            title=None,
            header=alt.Header(orient="left", labelFontStyle="bold", labelFontSize=11)
        ),
        tooltip=[
            "set_of_strains",
            "date",
            "statistic:N",
            alt.Tooltip("count_or_fraction:Q", format=".2f"),
        ],
    )
    .mark_area()
    .properties(
        width=350,
        height=160,
        title=alt.TitleParams(
            (
                "count or fraction of sequences in each set of strains "
                f"(rolling mean +/- {plot_window_frame_days} days)"
            ),
            anchor="middle",
        )
    )
    .resolve_scale(y="independent")
)

grouped_counts_by_date_chart

Now make per-strain plots:

In [None]:
statistic_selection = alt.selection_point(
    fields=["statistic"],
    bind=alt.binding_radio(
        options=["count", "fraction"],
        name="show count or fraction on y-axis?",
    ),
    value="fraction",
)

include_not_in_library = alt.param(
    bind=alt.binding_radio(
        options=[True, False],
        name=f"include {NOT_IN_LIBRARY}?",
    ),
    value=keep_not_in_library,
)

include_insufficient_counts = alt.param(
    bind=alt.binding_radio(
        options=[True, False],
        name=f"include {INSUFFICIENT_COUNTS}?",
    ),
    value=keep_insufficient_counts,
) 

counts_by_date_chart = (
    alt.Chart(filtered_counts_by_date)
    .add_params(statistic_selection, include_not_in_library, include_insufficient_counts)
    .transform_filter((alt.datum["variant"] != NOT_IN_LIBRARY) | include_not_in_library)
    .transform_filter((alt.datum["variant"] != INSUFFICIENT_COUNTS) | include_insufficient_counts)
    .transform_window(
        count="mean(sequences)",
        groupby=["variant"],
        frame=[-plot_window_frame_days, plot_window_frame_days],
    )
    .transform_joinaggregate(total_count="sum(count)", groupby=["date"])
    .transform_calculate(fraction=alt.datum.count / alt.datum.total_count)
    .transform_fold(
        fold=["count", "fraction"],
        as_=["statistic", "count_or_fraction"],
    )
    .transform_filter(statistic_selection)
    .encode(
        alt.X("date", title=None, axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90)),
        alt.Y(
            "count_or_fraction:Q",
            axis=alt.Axis(grid=False),
            title="sequences",
            scale=alt.Scale(nice=False)
        ),
        alt.Facet(
            "variant",
            title=None,
            header=alt.Header(labelFontSize=9, labelPadding=0),
            columns=5,
            spacing=5,
        ),
        tooltip=[
            "variant",
            "date",
            "statistic:N",
            alt.Tooltip("count_or_fraction:Q", format=".2f"),
        ],
    )
    .mark_area(stroke="black", fill="gray")
    .properties(
        width=160,
        height=70,
        title=alt.TitleParams(
            (
                "count or fraction of sequences for each strain "
                f"(rolling mean +/- {plot_window_frame_days} days)"
            ),
            anchor="middle",
        )
    )
)

counts_by_date_chart

## Fit MLR models

Get the counts to fit and write to a file:

In [None]:
counts_to_fit = filtered_counts_by_date

for (keep, strainset) in [
    (keep_not_in_library, NOT_IN_LIBRARY),
    (keep_insufficient_counts, INSUFFICIENT_COUNTS),
]:
    if not keep:
        print(f"Dropping {strainset}")
        counts_to_fit = counts_to_fit[counts_to_fit["variant"] != strainset]

strains_to_fit = sorted(counts_to_fit["variant"].unique())
dates_to_fit = counts_to_fit["date"].unique()

print(f"Fitting counts for {len(strains_to_fit)=} to {len(dates_to_fit)=}")
assert len(counts_to_fit) == len(strains_to_fit) * len(dates_to_fit)

print(f"Writing the counts to fit to {counts_to_fit_csv}")
counts_to_fit.to_csv(counts_to_fit_csv, index=False, float_format="%.2f")

if pivot_strain not in strains_to_fit:
    raise ValueError(f"{pivot_strain=} not in {strains_to_fit=}")

Now set up and fit MLR model following [here](https://blab.github.io/evofr/notebooks/example_mlr.html):

In [None]:
variant_frequencies = evofr.VariantFrequencies(
    counts_to_fit, pivot=pivot_strain, var_names=strains_to_fit
)

mlr = evofr.MultinomialLogisticRegression(tau=mlr_tau)

inference_method = evofr.InferNUTS(num_samples=num_samples, num_warmup=num_warmup)

posterior = inference_method.fit(mlr, variant_frequencies)

Get the frequencies and growth advantages and their intervals (highest posterior density):

In [None]:
freqs = posterior.samples["freq"]

assert freqs.shape[0] == num_samples
freqs.shape[1] == counts_to_fit["date"].nunique()
freqs.shape[2] == len(strains_to_fit)


def hpd(data, hpd_level):
    """Calculate the HPD of a 1D array for a given level."""
    data = numpy.sort(data, axis=0)  # Sort the data
    n = len(data)
    interval_index = int(numpy.floor(hpd_level * n))
    
    # Find the range of values that gives the narrowest HPD interval
    intervals = data[interval_index:] - data[:n - interval_index]
    min_idx = numpy.argmin(intervals)
    hpd_min = data[min_idx]
    hpd_max = data[min_idx + interval_index]
    return hpd_min, hpd_max


def hpd_over_axis(data, hpd_level, axis):
    """Calculate the HPD over a specific axis for a NumPy array."""
    # Move the target axis to the last dimension for easy iteration
    data_swapped = numpy.moveaxis(data, axis, -1)
    
    # Apply the HPD calculation over the last axis
    hpd_min = numpy.apply_along_axis(lambda x: hpd(x, hpd_level=hpd_level)[0], -1, data_swapped)
    hpd_max = numpy.apply_along_axis(lambda x: hpd(x, hpd_level=hpd_level)[1], -1, data_swapped)
    
    return hpd_min, hpd_max

freqs_median = numpy.median(freqs, axis=0)
freqs_hpd_min, freqs_hpd_max = hpd_over_axis(freqs, hpd_interval / 100, axis=0)
dates = posterior.data.dates

assert freqs_median.shape == freqs_hpd_min.shape == freqs_hpd_max.shape
assert freqs_median.shape == (len(dates), len(strains_to_fit))

freqs_df = (
    pd.DataFrame(freqs_median, columns=strains_to_fit, index=dates)
    .reset_index(names="date")
    .melt(id_vars="date", var_name="strain", value_name="freq_median")
    .merge(
        (
            pd.DataFrame(freqs_hpd_min, columns=strains_to_fit, index=dates)
            .reset_index(names="date")
            .melt(id_vars="date", var_name="strain", value_name="freq_hpd_min")
        ),
        on=["strain", "date"],
        validate="one_to_one",
    )
    .merge(
        (
            pd.DataFrame(freqs_hpd_max, columns=strains_to_fit, index=dates)
            .reset_index(names="date")
            .melt(id_vars="date", var_name="strain", value_name="freq_hpd_max")
        ),
        on=["strain", "date"],
        validate="one_to_one",
    )
)

assert len(freqs_df) == len(dates) * len(strains_to_fit)

In [None]:
gas = posterior.samples["ga"]
assert gas.shape == (num_samples, len(strains_to_fit) - 1)

gas_median = numpy.median(gas, axis=0)
gas_hpd_min, gas_hpd_max = hpd_over_axis(gas, hpd_interval / 100, axis=0)

assert gas_median.shape == gas_hpd_min.shape == gas_hpd_max.shape
assert gas_median.shape == (len(strains_to_fit) - 1,)

nonpivot_strains = [s for s in strains_to_fit if s != pivot_strain]
assert len(nonpivot_strains) == len(strains_to_fit) - 1

gas_df = pd.DataFrame(
    {
        "growth_advantage_median": list(gas_median) + [1],
        "growth_advantage_hpd_min": list(gas_hpd_min) + [1],
        "growth_advantage_hpd_max": list(gas_hpd_max) + [1],
    },
    index=nonpivot_strains + [pivot_strain],
).reset_index(names="strain")

assert len(gas_df) == len(strains_to_fit)

print(f"Writing the growth advantages to {growth_advantages_csv}")
(
    gas_df
    .rename(
        columns={
            "growth_advantage_mean": "growth_advantage",
            "growth_advantange_hpd_min": f"growth_advantage_hpd{hpd_interval}_min",
            "growth_advantange_hpd_max": f"growth_advantage_hpd{hpd_interval}_max",
        }
    )
    .to_csv(growth_advantages_csv, index=False, float_format="%.2f")
)

Plot the frequencies:

In [None]:
# first get a data frame to plot, including adding the rolling average of the actual
# sequence fractions

windowed_counts_to_fit = (
    counts_to_fit
    .rename(columns={"variant": "strain"})
    .sort_values("date")
    .assign(
        frac_sequences=lambda x: x["sequences"] / x.groupby("date")["sequences"].transform("sum")
    )
    [["strain", "date", "frac_sequences"]]
)

windowed_counts_to_fit["frac_sequences"] = (
    windowed_counts_to_fit
    .groupby("strain")
    ["frac_sequences"]
    .transform(
        lambda s: s.rolling(
            window=plot_window_frame_days * 2 + 1,
            min_periods=1,
            center=True,
        ).mean()
    )
)

assert len(freqs_df) == len(windowed_counts_to_fit)

plot_freqs_df = (
    freqs_df
    .merge(
        (
            gas_df
            .assign(
                growth_advantage=lambda x: x.apply(
                    lambda r: f"{r['growth_advantage_median']:.2f} [{r['growth_advantage_hpd_min']:.2f} - {r['growth_advantage_hpd_max']:.2f}]",
                    axis=1,
                )
            )
            [["strain", "growth_advantage"]]
        ),
        on="strain",
        validate="many_to_one",
    )
    .merge(windowed_counts_to_fit, on=["strain", "date"], validate="one_to_one")
)

In [None]:
max_freq_sorted_strains = (
    plot_freqs_df.sort_values("freq_median", ascending=False)["strain"].unique()
)

strain_selection = alt.selection_point(
    fields=["strain"],
    bind=alt.binding_select(
        options=[None] + strains_to_fit,
        labels=["all"] + strains_to_fit,
        name="strain(s) to show for MLR fit:",
    )
)

freqs_chart_base = (
    alt.Chart(plot_freqs_df)
    .transform_filter(strain_selection)
    .encode(
        alt.X(
            "date",
            title=None,
            axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90),
        ),
        alt.Color(
            "strain",
            scale=alt.Scale(domain=max_freq_sorted_strains, scheme="category20"),
            legend=alt.Legend(
                labelLimit=500,
                columns=int(math.ceil(plot_freqs_df["strain"].nunique() / 12)),
            ),
        ),
        tooltip=[
            "strain",
            "growth_advantage",
            "date",
            alt.Tooltip("freq_median", format=".2f", title="predicted frequency"),
            alt.Tooltip(
                "frac_sequences",
                format=".2f",
                title=f"actual frequency (+/-{plot_window_frame_days} days)",
            ),
        ],
    )
    .properties(
        width=450,
        height=200,
        title=alt.TitleParams(
            "MLR fits and actual frequencies",
            subtitle=f"actual frequencies plotted as rolling mean +/- {plot_window_frame_days} days"
        ),
    )
)

freqs_chart_median = (
    freqs_chart_base
    .add_params(strain_selection)
    .encode(alt.Y("freq_median", title="strain frequency", axis=alt.Axis(grid=False)))
    .mark_line(opacity=1, strokeWidth=2)
)

freqs_chart_hpd = (
    freqs_chart_base
    .encode(
        alt.Y("freq_hpd_min"),
        alt.Y2("freq_hpd_max"),
    )
    .mark_area(opacity=0.2)
)

freqs_chart_points = (
    freqs_chart_base
    .encode(alt.Y("frac_sequences"))
    .mark_circle(size=8, opacity=0.7)
)

freqs_chart = freqs_chart_median + freqs_chart_hpd + freqs_chart_points

freqs_chart

Plot the growth advantages:

In [None]:
ga_ytitle = ["growth advantage relative to", pivot_strain]

ga_points = (
    alt.Chart(gas_df)
    .encode(
        alt.X("strain", title=None),
        alt.Y(
            "growth_advantage_median",
            title=ga_ytitle,
            scale=alt.Scale(zero=False, nice=False, padding=10),
            axis=alt.Axis(grid=False)
        ),
        tooltip=[
            "strain",
            alt.Tooltip("growth_advantage_median", format=".2f", title="growth advantage"),
            alt.Tooltip("growth_advantage_hpd_min", format=".2f", title=f"lower HPD{hpd_interval}%"),
            alt.Tooltip("growth_advantage_hpd_max", format=".2f", title=f"upper HPD{hpd_interval}%"),
        ],
    )
    .mark_circle(color="black", size=75, opacity=1)
    .properties(
        height=175,
        width=alt.Step(18),
        title=f"estimated strain growth advantages and HPD{hpd_interval}%",
    )
)

ga_hpd = (
    alt.Chart(gas_df)
    .encode(
        alt.X("strain"),
        alt.Y("growth_advantage_hpd_min", title=ga_ytitle),
        alt.Y2("growth_advantage_hpd_max"),
    )
    .mark_errorbar(thickness=2)
)

ga_chart = ga_hpd + ga_points

ga_chart

## Save merged chart

Make a merged chart:

In [None]:
chart = (
    alt.vconcat(
        total_counts_chart,
        grouped_counts_by_date_chart,
        counts_by_date_chart,
        freqs_chart,
        ga_chart,
        spacing=35,
    )
    .resolve_scale(fill="independent", color="independent")
    .properties(
        title=alt.TitleParams(
            f"sequence counts and MLR fits for {desc}",
            anchor="middle",
            fontSize=15,
            dy=-20,
        )
    )
)

print(f"Saving merged chart to {chart_html}")
chart.save(chart_html)

chart