# Make plots that illustrate principle of MLR to estimate growth rates
This notebook makes some interactive charts designed to illustrate the principle of multiple linear regression to estimate growth rates:

In [1]:
import gzip
import json
import urllib.request

import altair as alt

import pandas as pd

In [2]:
bedford_rates_url = "https://data.nextstrain.org/files/workflows/forecasts-ncov/gisaid/pango_lineages/global/mlr/2023-10-02_results.json"

In [3]:
location = "USA"

In [4]:
with urllib.request.urlopen(bedford_rates_url) as url:
    bedford_d = json.loads(gzip.decompress(url.read()))

bedford_usa_data = (
    pd.DataFrame(bedford_d["data"])
    .query("location == @location")
    .drop(columns="location")
    .assign(date=lambda x: pd.to_datetime(x["date"]))
)

In [5]:
freqs_fit = (
    bedford_usa_data
    .query("site == 'freq'")
    .pivot_table(
        index=["variant", "date"],
        values="value",
        columns="ps",
    )
    .rename(
        columns={
            "median": "fit",
            "HDI_95_lower": "fit_lower",
            "HDI_95_upper": "fit_upper",
        }
    )
    [["fit", "fit_lower", "fit_upper"]]
    .reset_index()
)

In [6]:
freqs_raw = (
    bedford_usa_data
    .query("site in ['weekly_raw_freq', 'daily_raw_freq']")
    .pivot_table(
        index=["variant", "date"],
        values="value",
        columns="site",
    )
    .reset_index()
)

In [7]:
max_weekly_freq = 0.036

freqs = (
    freqs_raw
    .merge(freqs_fit, on=["variant", "date"], how="outer", validate="one_to_one")
    .assign(max_weekly_freq=lambda x: x.groupby("variant")["weekly_raw_freq"].transform("max"))
    .query("max_weekly_freq > @max_weekly_freq")
)

variants = freqs["variant"].unique().tolist()

display(
    freqs
    [["variant", "max_weekly_freq"]]
    .sort_values("max_weekly_freq", ascending=False)
    .drop_duplicates()
    .reset_index(drop=True)
)

Unnamed: 0,variant,max_weekly_freq
0,XBB.1.5,0.49
1,FL.1.5.1,0.151
2,XBB.1.16,0.113
3,XBB.1.16.6,0.108
4,EG.5.1,0.094
5,EG.5.1.1,0.085
6,HV.1,0.083
7,EG.5.1.3,0.073
8,HK.3,0.05
9,XBB.1.9.1,0.049


In [8]:
pivot_variant = bedford_d["metadata"]["variants"][-1]

growth_advantages = (
    bedford_usa_data
    .query("site == 'ga'")
    .pivot_table(index="variant", values="value", columns="ps")
    .reset_index()
    .rename(
        columns={
            "median": "growth advantage",
            "HDI_95_lower": "lower",
            "HDI_95_upper": "upper",
        }
    )
    [["variant", "growth advantage", "lower", "upper"]]
    .query("variant in @variants")
)

if pivot_variant in variants:
    growth_advantages = pd.concat(
        [
            growth_advantages,
            pd.DataFrame({"variant": [pivot_variant], "growth advantage": [1], "upper": [1], "lower": [1]}),
        ]
    )

growth_advantages = growth_advantages.sort_values("growth advantage").reset_index(drop=True)

variants = growth_advantages.sort_values("growth advantage")["variant"].unique().tolist()

growth_advantages

Unnamed: 0,variant,growth advantage,lower,upper
0,XBB.1.5,1.0,1.0,1.0
1,XBB.1.9.1,1.055,1.048,1.062
2,XBB.1.16.1,1.11,1.103,1.117
3,XBB.1.16,1.133,1.128,1.138
4,XBB.2.3,1.135,1.129,1.142
5,GJ.1.2,1.233,1.223,1.243
6,EG.5.1,1.287,1.279,1.297
7,XBB.1.16.6,1.317,1.308,1.326
8,XBB.1.16.11,1.318,1.302,1.333
9,EG.5.1.3,1.327,1.311,1.343


In [10]:
variant_highlight = alt.selection_point(
    fields=["variant"],
    on="mouseover",
    empty=False,
)

variant_color = alt.Color(
    "variant",
    scale=alt.Scale(domain=variants, scheme="category20"),
)

growth_base = alt.Chart(growth_advantages).encode(
    alt.Y(
        "variant",
        scale=alt.Scale(domain=variants),
    ),
    variant_color,
)

growth_points = (
    growth_base
    .encode(
        alt.X("growth advantage", scale=alt.Scale(zero=False, nice=False, padding=10)),
        tooltip=["variant", alt.Tooltip("growth advantage", format=".2f")],
        strokeWidth=alt.condition(variant_highlight, alt.value(4), alt.value(0)),
        size=alt.condition(variant_highlight, alt.value(180), alt.value(100)),
    )
    .mark_circle(opacity=1, size=110, stroke="black")
)

growth_interval = (
    growth_base
    .encode(alt.X("lower", title="growth advantage"), alt.X2("upper"))
    .mark_rule(size=3)
)

growth_chart = (growth_interval + growth_points)

In [11]:
variant_selection = alt.selection_point(
    fields=["variant"],
    bind=alt.binding_select(
        options=[None] + variants,
        labels=["all"] + variants,
        name="variant(s) to show:",
    )
)

show_fit_line = alt.param(
    value=True,
    bind=alt.binding_checkbox(name="show fit line?")
)

base_chart = (
    alt.Chart(freqs)
    .transform_filter(variant_selection)
    .encode(
        alt.X(
            "date",
            axis=alt.Axis(format="%b-%d-%Y", labelAngle=-90, tickCount=10),
        ),
        variant_color,
        tooltip=["variant"],
    )
)

fits_chart = (
    base_chart
    .encode(
        alt.Y(
            "fit",
            title=f"frequency of variant in {location}",
            axis=alt.Axis(tickCount=5),
        ),
        strokeWidth=alt.condition(variant_highlight, alt.value(5), alt.value(2.5)),
    )
    .mark_line(opacity=1)
)

fits_hpd_chart = (
    base_chart
    .encode(alt.Y("fit_lower"), alt.Y2("fit_upper"))
    .mark_area(opacity=0.2)
)

raw_chart = (
    base_chart
    .encode(
        alt.Y("daily_raw_freq"),
        size=alt.condition(variant_highlight, alt.value(50), alt.value(30)),
        strokeWidth=alt.condition(variant_highlight, alt.value(1), alt.value(0)),
    )
    .mark_circle(opacity=0.65, stroke="black", strokeOpacity=1)
)

mlr_chart = (raw_chart + (fits_chart + fits_hpd_chart).transform_filter(show_fit_line)).properties(
    width=650, height=325
)

mlr_and_growth_chart = alt.hconcat(
    mlr_chart,
    growth_chart.properties(height=325, width=240),
    spacing=50
)

for chart, fname in [(mlr_chart, "mlr_chart.html"), (mlr_and_growth_chart, "mlr_and_growth_chart.html")]:
    chart = (
        chart
        .add_params(variant_selection, show_fit_line, variant_highlight)
        .configure_axis(grid=False, labelFontSize=16, titleFontSize=18)
        .configure_title(fontSize=20)
        .configure_legend(labelFontSize=16, titleFontSize=18, rowPadding=5, symbolOpacity=1, orient="bottom", columns=7)
    )

    display(chart)
    print(f"Saving to {fname}")
    chart.save(fname)

Saving to mlr_chart.html


Saving to mlr_and_growth_chart.html
