# Estimate strain growth advantages using MLR

Import modules including [evofr](https://blab.github.io/evofr/) and get variables from `snakemake`:

In [None]:
import altair as alt

import evofr

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

In [None]:
# get variables from `snakemake`
counts_by_date_csv = snakemake.input.counts_by_date

counts_chart_html = snakemake.output.counts_chart

wildcards_desc = ", ".join(f"{key}={val}" for (key, val) in snakemake.wildcards.items())
print(f"{wildcards_desc=}")

min_counts = snakemake.params.min_counts

plot_window_frame_days = snakemake.params.plot_window_frame_days

plot_window_frame_days = 14

date_start, date_end = ["2022-01-01", "2024-08-01"]

## Counts for each strain by date

Read data.
We filter for counts within the specified date range, and then filter for strains that have sufficient counts in that range.
At the end of this, strains fall in one of three categories:
 - a named library strain (eg, *A/Bhutan/0845/2023*)
 - *strain not in library*: does not match a strain in library
 - *library strains with insufficient counts*: strains in the library with insufficient counts

In [None]:
NOT_IN_LIBRARY = "strain not in library"
INSUFFICIENT_COUNTS = "library strains insufficient counts"

all_counts_by_date = (
    pd.read_csv(counts_by_date_csv, parse_dates=["date"])
    .sort_values(["date", "variant"])
)

datetime_start = pd.to_datetime(date_start)
datetime_end = pd.to_datetime(date_end)

if datetime_start < all_counts_by_date["date"].min():
    raise ValueError(f"{datetime_start=} before {all_counts_by_date['date'].min()=}")
if datetime_end > all_counts_by_date["date"].max():
    raise ValueError(f"{datetime_start=} after {all_counts_by_date['date'].max()=}")

print(f"Trimming counts by date to the range {date_start=} to {date_end=}")
all_counts_by_date = all_counts_by_date[
    (all_counts_by_date["date"] >= datetime_start)
    & (all_counts_by_date["date"] <= datetime_end)
]

assert "other" in set(all_counts_by_date["variant"])
assert NOT_IN_LIBRARY not in set(all_counts_by_date["variant"])
all_counts_by_date["variant"] = all_counts_by_date["variant"].replace(
    "other", NOT_IN_LIBRARY
)

Get total counts for each variant:

In [None]:
total_counts = (
    all_counts_by_date
    .groupby("variant", as_index=False)
    .aggregate(total_sequences=pd.NamedAgg("sequences", "sum"))
    .assign(sufficient_counts=lambda x: x["total_sequences"] >= min_counts)
)

total_counts_chart = (
    alt.Chart(total_counts)
    .encode(
        alt.X(
            "variant",
            sort=alt.SortField("total_sequences", order="descending"),
            title=None,
        ),
        alt.Y(
            "total_sequences",
            scale=alt.Scale(type="symlog", constant=50),
            title="total sequences",
            axis=alt.Axis(grid=False),
        ),
        alt.Fill(
            "sufficient_counts",
            scale=alt.Scale(range=["gray", "white"], domain=[True, False]),
            title="sufficient counts?",
            legend=alt.Legend(orient="top-right", offset=3)
        ),
        tooltip=total_counts.columns.tolist(),
    )
    .mark_bar(stroke="black")
    .properties(
        height=150,
        width=alt.Step(11),
        title=f"total sequences per strain from {date_start} to {date_end}",
    )
)

total_counts_chart

Now filter the counts by date to group all library strains with insufficient counts.
Also pad any missing dates in the range:

In [None]:
strains_w_insufficient_counts = set(
    total_counts
    .query("(not sufficient_counts) and variant != @NOT_IN_LIBRARY")
    ["variant"]
)

print(f"Grouping {len(strains_w_insufficient_counts)=} to '{INSUFFICIENT_COUNTS}'")

assert INSUFFICIENT_COUNTS not in set(all_counts_by_date["variant"])

# group strains w insufficient counts and pad zero counts on days w no counts
filtered_counts_by_date = (
    all_counts_by_date
    .assign(
        variant=lambda x: x["variant"].map(
            lambda v: INSUFFICIENT_COUNTS if v in strains_w_insufficient_counts else v
        ),
        day=lambda x: (x["date"] - datetime_start).dt.days,
    )
    .groupby(["variant", "day"], as_index=False)
    .aggregate({"sequences": "sum"})
    .sort_values(["day", "variant"])
)

days = filtered_counts_by_date["day"].unique()
assert all(days == days.astype(int)), "dates not all rounded to day"

print(f"Padding with zero counts any missing days between {date_start} and {date_end}")
filtered_counts_by_date = (
    filtered_counts_by_date
    .merge(
        pd.DataFrame(
            [(v, d) for v in filtered_counts_by_date["variant"].unique() for d in days],
            columns=["variant", "day"],
        ),
        how="outer",
        validate="one_to_one",
    )
    .assign(
        sequences=lambda x: x["sequences"].fillna(0),
        date=lambda x: x["day"].map(lambda d: datetime_start + pd.Timedelta(days=d)),
    )
    .drop(columns="day")
)

Plot number of strains in each group as a function of date:

In [None]:
# we create an integer days columns so we can impute missing days
grouped_counts_by_date = (
    filtered_counts_by_date
    .assign(
        set_of_strains=lambda x: x["variant"].map(
            lambda v: (
                "library strains"
                if v not in {INSUFFICIENT_COUNTS, NOT_IN_LIBRARY}
                else v
            ),
        )
    )
    .groupby(["set_of_strains", "date"], as_index=False)
    .aggregate({"sequences": "sum"})
)

grouped_counts_by_date_chart = (
    alt.Chart(grouped_counts_by_date)
    .transform_window(
        count="mean(sequences)",
        groupby=["set_of_strains"],
        frame=[-plot_window_frame_days, plot_window_frame_days],
    )
    .transform_joinaggregate(total_count="sum(count)", groupby=["date"])
    .transform_calculate(fraction=alt.datum.count / alt.datum.total_count)
    .transform_fold(
        fold=["count", "fraction"],
        as_=["statistic", "count_or_fraction"],
    )
    .encode(
        alt.X("date", title=None, axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90)),
        alt.Y(
            "count_or_fraction:Q",
            axis=alt.Axis(grid=False),
            title=None,
            scale=alt.Scale(nice=False)
        ),
        alt.Fill(
            "set_of_strains",
            title="set of strains",
            legend=alt.Legend(orient="top", labelLimit=500, titleOrient="left"),
        ),
        alt.Column(
            "statistic:N",
            title=None,
            header=alt.Header(orient="left", labelFontStyle="bold", labelFontSize=11)
        ),
        tooltip=[
            "set_of_strains",
            "date",
            "statistic:N",
            alt.Tooltip("count_or_fraction:Q", format=".2f"),
        ],
    )
    .mark_area()
    .properties(
        width=350,
        height=160,
        title=alt.TitleParams(
            (
                "count or fraction of sequences in each set of strains "
                f"(rolling mean +/- {plot_window_frame_days} days)"
            ),
            anchor="middle",
        )
    )
    .resolve_scale(y="independent")
)

grouped_counts_by_date_chart

Now make per-strain plots:

In [None]:
statistic_selection = alt.selection_point(
    fields=["statistic"],
    bind=alt.binding_radio(
        options=["count", "fraction"],
        name="show count or fraction on y-axis?",
    ),
    value="fraction",
)

include_not_in_library = alt.param(
    bind=alt.binding_radio(
        options=[True, False],
        name=f"include {NOT_IN_LIBRARY}?",
    ),
    value=True,
)

include_insufficient_counts = alt.param(
    bind=alt.binding_radio(
        options=[True, False],
        name=f"include {INSUFFICIENT_COUNTS}?",
    ),
    value=True,
) 

counts_by_date_chart = (
    alt.Chart(filtered_counts_by_date)
    .add_params(statistic_selection, include_not_in_library, include_insufficient_counts)
    .transform_filter((alt.datum["variant"] != NOT_IN_LIBRARY) | include_not_in_library)
    .transform_filter((alt.datum["variant"] != INSUFFICIENT_COUNTS) | include_insufficient_counts)
    .transform_window(
        count="mean(sequences)",
        groupby=["variant"],
        frame=[-plot_window_frame_days, plot_window_frame_days],
    )
    .transform_joinaggregate(total_count="sum(count)", groupby=["date"])
    .transform_calculate(fraction=alt.datum.count / alt.datum.total_count)
    .transform_fold(
        fold=["count", "fraction"],
        as_=["statistic", "count_or_fraction"],
    )
    .transform_filter(statistic_selection)
    .encode(
        alt.X("date", title=None, axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90)),
        alt.Y(
            "count_or_fraction:Q",
            axis=alt.Axis(grid=False),
            title="sequences",
            scale=alt.Scale(nice=False)
        ),
        alt.Facet(
            "variant",
            title=None,
            header=alt.Header(labelFontSize=9, labelPadding=0),
            columns=5,
            spacing=5,
        ),
        tooltip=[
            "variant",
            "date",
            "statistic:N",
            alt.Tooltip("count_or_fraction:Q", format=".2f"),
        ],
    )
    .mark_area(stroke="black", fill="gray")
    .properties(
        width=160,
        height=70,
        title=alt.TitleParams(
            (
                "count or fraction of sequences for each strain "
                f"(rolling mean +/- {plot_window_frame_days} days)"
            ),
            anchor="middle",
        )
    )
)

counts_by_date_chart

Make a merged plot:

In [None]:
counts_chart = (
    alt.vconcat(
        total_counts_chart,
        grouped_counts_by_date_chart,
        counts_by_date_chart,
        spacing=35,
    )
    .resolve_scale(fill="independent")
    .properties(
        title=alt.TitleParams(
            f"sequence counts chart for {wildcards_desc}",
            anchor="middle",
            fontSize=15,
            dy=-20,
        )
    )
)

print(f"Saving merged chart to {counts_chart_html}")
counts_chart.save(counts_chart_html)

counts_chart