In [1]:
# Fake input for debugging
# class FakeSnakemake:
#     input = {"counts_by_date": "../results/strain_counts_h3/h3-gisaid-ha1-exact_counts_by_date.csv"}
#     params = {"protset": "h3-gisaid-ha1-exact", "group": "h3"}

# snakemake = FakeSnakemake()

# # Parameters passed by Snakemake
# input_file = snakemake.input["counts_by_date"]
# protset = snakemake.params["protset"]
# group = snakemake.params["group"]

# Comment to run interactively


# Parameters (injected by Papermill or Snakemake shell)
protset = None
group = None
counts_by_date = None



In [2]:
# Parameters
protset = "h3-gisaid-ha1-within1"
group = "h3"
counts_by_date = "results/strain_counts_h3/h3-gisaid-ha1-within1_counts_by_date.csv"


# Plot strain frequencies

In [3]:
# Import
import datetime
import math
import re
import altair as alt
import matplotlib.pyplot as plt
import numpy
import pandas as pd

# Ignore Altair error message about large dataframes
_ = alt.data_transformers.disable_max_rows()

# Get variables from `snakemake`
# desc = f"{snakemake.wildcards.protset}_{snakemake.wildcards.mlrfit}"
# print(desc)
# counts_by_date_csv = snakemake.input.counts_by_date

# Get variables
# Uncomment to run interactively
color_scheme = [
    '#345995', #blue
    '#03cea4', #teal
    '#ca1551', #red
    '#eac435', #yellow
               ]

date_start = datetime.date.fromisoformat('2024-04-01')
date_end = datetime.date.fromisoformat('2025-04-08')

date_start = pd.Timestamp(date_start)
date_end = pd.Timestamp(date_end)

min_counts = 50

plot_window_frame_days = 10

input_file = counts_by_date
counts_by_date_csv = input_file


Read data. We filter for counts within the specified date range, and then filter for strains that have sufficient counts in that range. At the end of this, strains fall in one of two categories:

* Exact match or within 1 aa mutation of a named library strain (eg, A/Bhutan/0845/2023)
* Strain not in library: does not match a strain in library or any sequence within 1 

In [4]:
NOT_IN_LIBRARY = "strain not in library"

all_counts_by_date = (
    pd.read_csv(counts_by_date_csv, parse_dates=["date"])
    .sort_values(["date", "variant"])
)

if date_start < all_counts_by_date["date"].min():
    raise ValueError(f"{date_start=} before {all_counts_by_date['date'].min()=}")
if date_end > all_counts_by_date["date"].max():
    raise ValueError(f"{date_start=} after {all_counts_by_date['date'].max()=}")

print(f"Trimming counts by date to the range {date_start=} to {date_end=}")
all_counts_by_date = all_counts_by_date[
    (all_counts_by_date["date"] >= date_start)
    & (all_counts_by_date["date"] <= date_end)
]

assert "other" in set(all_counts_by_date["variant"])
assert NOT_IN_LIBRARY not in set(all_counts_by_date["variant"])
all_counts_by_date["variant"] = all_counts_by_date["variant"].replace(
    "other", NOT_IN_LIBRARY
)

Trimming counts by date to the range date_start=Timestamp('2024-04-01 00:00:00') to date_end=Timestamp('2025-04-08 00:00:00')


Get total counts for each variant:

In [5]:
total_counts = (
    all_counts_by_date
    .groupby("variant", as_index=False)
    .aggregate(total_sequences=pd.NamedAgg("sequences", "sum"))
    .assign(sufficient_counts=lambda x: x["total_sequences"] >= min_counts)
)

total_counts_chart = (
    alt.Chart(total_counts)
    .encode(
        alt.X(
            "variant",
            sort=alt.SortField("total_sequences", order="descending"),
            title=None,
        ),
        alt.Y(
            "total_sequences",
            scale=alt.Scale(type="symlog", constant=50),
            title="total sequences",
            axis=alt.Axis(grid=False),
        ),
        alt.Fill(
            "sufficient_counts",
            scale=alt.Scale(range=["gray", "white"], domain=[True, False]),
            title=f"counts >= {min_counts}?",
            legend=alt.Legend(orient="top-right", offset=3)
        ),
        tooltip=total_counts.columns.tolist(),
    )
    .mark_bar(stroke="black")
    .properties(
        height=150,
        width=alt.Step(11),
        title=f"total sequences per strain from {date_start.date()} to {date_end.date()}",
    )
)

total_counts_chart

Now filter the counts by date to group all library strains with insufficient counts. Also pad any missing dates in the range:

In [6]:
# Group strains w insufficient counts and pad zero counts on days w no counts
filtered_counts_by_date = (
    all_counts_by_date
    .assign(
        variant=lambda x: x["variant"],
        day=lambda x: (x["date"] - date_start).dt.days,
    )
    .groupby(["variant", "day"], as_index=False)
    .aggregate({"sequences": "sum"})
    .sort_values(["day", "variant"])
)

days = filtered_counts_by_date["day"].unique()
assert all(days == days.astype(int)), "dates not all rounded to day"

print(f"Padding with zero counts any missing days between {date_start} and {date_end}")
filtered_counts_by_date = (
    filtered_counts_by_date
    .merge(
        pd.DataFrame(
            [
                (v, d)
                for v in filtered_counts_by_date["variant"].unique()
                for d in range(days.min(), days.max() + 1)
            ],
            columns=["variant", "day"],
        ),
        how="outer",
        validate="one_to_one",
    )
    .assign(
        sequences=lambda x: x["sequences"].fillna(0),
        date=lambda x: x["day"].map(lambda d: date_start + pd.Timedelta(days=d)),
    )
    .drop(columns="day")
)

Padding with zero counts any missing days between 2024-04-01 00:00:00 and 2025-04-08 00:00:00


Plot number of strains in each group as a function of date:



In [7]:
# We create an integer days columns so we can impute missing days
grouped_counts_by_date = (
    filtered_counts_by_date
    .assign(
        set_of_strains=lambda x: x["variant"].map(
            lambda v: (
                "library strains"
                if v not in {
                    NOT_IN_LIBRARY}
                else v
            ),
        )
    )
    .groupby(["set_of_strains", "date"], as_index=False)
    .aggregate({"sequences": "sum"})
)


# Configure chart
titleFontSize = 18
labelFontSize = 18

# Make chart
grouped_counts_by_date_chart = (
    alt.Chart(grouped_counts_by_date)
    .transform_window(
        count="mean(sequences)",
        groupby=["set_of_strains"],
        frame=[-plot_window_frame_days, plot_window_frame_days],
    )
    .transform_joinaggregate(total_count="sum(count)", groupby=["date"])
    .transform_calculate(fraction=alt.datum.count / alt.datum.total_count)
    .transform_fold(
        fold=["count", "fraction"],
        as_=["statistic", "count_or_fraction"],
    )
    .encode(
        alt.X("date", title=None, axis=alt.Axis(grid=False, format="%b-%Y", labelAngle=-90,
                                                tickCount=100, tickSize=0,
                                                titleFontSize=titleFontSize,
                                                labelFontSize=labelFontSize)),
        alt.Y(
            "count_or_fraction:Q",
            axis=alt.Axis(grid=False, 
                          titleFontSize=titleFontSize,
                          labelFontSize=labelFontSize),
            title=None,
            scale=alt.Scale(nice=False)
        ),
        alt.Fill(
            "set_of_strains",
            title="set of strains",
            legend=alt.Legend(orient="right",
                              labelLimit=500, titleOrient="top"),
            scale=alt.Scale(range=color_scheme)
        ),
        alt.Row(
            "statistic:N",
            title=None,
            header=alt.Header(orient="left", labelFontStyle="bold",
                              titleFontSize=titleFontSize,
                              labelFontSize=labelFontSize)
        ),
        tooltip=[
            "set_of_strains",
            "date",
            "statistic:N",
            alt.Tooltip("count_or_fraction:Q", format=".2f"),
        ],
    )
    .mark_area()
    .configure_legend(titleFontSize=titleFontSize, labelFontSize=labelFontSize)
    .properties(
        width=450,
        height=160,
        title=alt.TitleParams(
            ("sequences in each set of strains " +
             f"(rolling mean +/- {plot_window_frame_days} days)",
             f'from {counts_by_date_csv.strip('../results/strain_counts/')}'
            ),
            anchor="middle",
            dx=-50,
            fontSize = titleFontSize
        ),
    )
    .resolve_scale(y="independent")
)

grouped_counts_by_date_chart