In [None]:
## Common functions etc
import re
from calendar import isleap
from datetime import datetime, date, timedelta
from collections import defaultdict
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.patheffects as path_effects
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Polygon, PathPatch, Patch, Rectangle, Circle
from matplotlib.path import Path
from matplotlib.colors import to_hex
from matplotlib.transforms import Bbox
import matplotlib.colors as colors
from matplotlib.collections import PatchCollection
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import to_hex
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import colorsys
import csv
import math
plt.rcParams['figure.dpi'] = 150

def numeric_date(dt_string):
    dt = datetime.strptime(dt_string, "%Y-%m-%d")
    days_in_year = 366 if isleap(dt.year) else 365
    return dt.year + (dt.timetuple().tm_yday-0.5) / days_in_year

def summarise_dates(df):
    """just for my own curiosity / understanding of pandas"""
    vals = list(df['date'])
    # print(f"n: {len(vals)}")
    for idx, v in enumerate(vals):
        if not idx: continue
        prev = vals[idx-1]
        should_be = (date.fromisoformat(prev) + timedelta(days=1)).isoformat()
        if should_be != v:
            print(f"Jumped from {prev} to {v} (should be {should_be})")


## To do list

* This requires hardcoded commit SHAs which are then fetched from GitHub. This isn't ideal (but it's easy). We should instead use something like `git log --oneline -- estimates/omicron-countries-split/omicron-countries-split_freq-combined-forecast-GARW.tsv` combined with `git show 4fec4d2:estimates/omicron-countries-split/omicron-countries-split_freq-combined-GARW.tsv` and then read that as a DataFrame. 

In [None]:
## Download the latest model. This is thought of as the source of truth.
model_latest = pd.read_table("https://raw.githubusercontent.com/blab/rt-from-frequency-dynamics/master/estimates/omicron-countries-split/omicron-countries-split_freq-combined-GARW.tsv", sep="\t")
model_latest

In [None]:
# Latest sequence counts (which will be converted to frequencies)
seq_counts = pd.read_table("https://raw.githubusercontent.com/blab/rt-from-frequency-dynamics/master/data/omicron-countries-split/omicron-countries-split_location-variant-sequence-counts.tsv", sep="\t")
seq_counts

In [None]:
# commits which modified the forecast TSV: https://github.com/blab/rt-from-frequency-dynamics/commits/master/estimates/omicron-countries-split/omicron-countries-split_freq-combined-forecast-GARW.tsv
# cd projects/blab/rt-from-frequency-dynamics/estimates/omicron-countries-split
# git log -- omicron-countries-split_freq-combined-forecast-GARW.tsv | grep commit | cut -d ' ' -f 2

commits = [
    'b7ff151ae7f54d79da2039304821e3c5a4aecdeb',
    '4fec4d2c039be75b96aab5f2f89cfbf7ab1482cd',
    'd4adf9740209ed0e8c1264dcc12514f27e09a890',
    'a9ebb8ad8a6e83b0e0384d268cf5e98a41d8298f',
    'd0c553297835d0741dba730c99a1a6de33d6f063',
    '295330b9eeeab77bdfd39c2eb12b1e296ccb6b9f',
    'd294df9bf4e2316015b39b33a15e8da576fdce81',
    'cb3e4212b87ba717269055e0fc700b90426fbf4d',
    'e65d2bda5335710c692d8a4db0333eebfc8e4e7a',
    '0c5326fa316cba453e535079fdae10c98166c9d0',
    '6b61079c537c2481e134364488ca34778b2b0aae',
    '0c5062e918920ac9b49d4171dd575f8a7747baf8',
    'bb7d8cd73bec104cebeb5f6b4fb26e790f9c5056',
    '5ef08b50ee1b7c10895da8353a4155c5f88d977d',
    '83a50c388e3ee92c17813bf9f4179a0a2d750f86',
    'da809352020660b4c58be4bdc459ba09e20bea52',
    '871188d411049ce238366c50b48e807a42690650',
    '22626a6ceea4ca82be01be085c05f171cabfad51',
    '759006497f536bfe957ba224e33fa9f8523c58fe',
    '8333654e333cc3c3ec7b88a38901be52da1bf2a9',
    '04a24bf70dafddee9ace4c8aec54f86a48dcbcac',
    '8eb83dd04c51e674ec4a3a64faacfa3882f37dc4',
]

models = {}
forecasts = {}
for idx,c in enumerate(commits):
    print(f"{idx+1}/{len(commits)}. {c[0:7]}")
    github_url = f"https://raw.githubusercontent.com/blab/rt-from-frequency-dynamics/{c}/estimates/omicron-countries-split/"
    forecasts[c[0:7]] = pd.read_table(github_url+"omicron-countries-split_freq-combined-forecast-GARW.tsv", sep="\t")
    models[c[0:7]] = pd.read_table(github_url+"omicron-countries-split_freq-combined-GARW.tsv", sep="\t")


In [None]:
def prepare_data(location, variant):
    """
    Subset the dataframes and merge them together to represent the desired location + variant.
    Expects the objects `forecasts`, `seq_counts` and `model_latest` to be in the namespace.
    Returns a tuple of (first_day, df) where first_day is the date where the df has freq>0.1%
    (the df is not subsetted)
    """

    def curate_df(df, suffix, isModel):
        df = df[(df["location"]==location) & (df["variant"]==variant)]

        prefix = "model" if isModel else "forecast"
        
        # this can happen if (e.g.) the location wasn't in this model run,
        # or the model run was before we defined a variant etc
        if df.shape[0]==0:
            return False

        ## then pull out the columns we care about
        ## NOTE. At some point the keys changed to include _forecast_, e.g.
        ## old format: 'freq_upper_80' new format: 'freq_forecast_upper_80'
        substr = ""
        if any(['_forecast_' in name for name in df.columns]):
            substr = '_forecast'
            
        df = df[["date", f"median_freq{substr}",
                 f"freq{substr}_upper_50", f"freq{substr}_lower_50", 
                 f"freq{substr}_upper_80", f"freq{substr}_lower_80", 
                 f"freq{substr}_upper_95", f"freq{substr}_lower_95"]]
        # rename columns to avoid any merge messiness & be able to track things later on
        df.rename(columns={
            f'median_freq{substr}': f'{prefix}_median_{suffix}',
            f'freq{substr}_lower_50': f'{prefix}_lower50_{suffix}',
            f'freq{substr}_upper_50': f'{prefix}_upper50_{suffix}',
            f'freq{substr}_lower_80': f'{prefix}_lower80_{suffix}',
            f'freq{substr}_upper_80': f'{prefix}_upper80_{suffix}',
            f'freq{substr}_lower_95': f'{prefix}_lower95_{suffix}',
            f'freq{substr}_upper_95': f'{prefix}_upper95_{suffix}',
        }, inplace=True)
        
        return df
        
    
    # the latest model data fitted to actual (retrospective) data. (Update: not quite true. Todo.)
    ## this is going to provide us the `median_freq` column...
    subset_retrospective = model_latest[(model_latest["location"]==location) & (model_latest["variant"]==variant)]
    subset_retrospective = subset_retrospective[["date", "median_freq"]]

    # the raw frequencies - this is going to give us the column `raw_freq`
    ## modify seq_counts into something we can actually use...
    raw = seq_counts[seq_counts["location"]==location]
    ## NOTE that the dates here have missing values, but this is ok as we'll merge into the model data which doesn't
    raw = raw.pivot(index='date', columns='variant', values='sequences')
    raw = raw.assign(total=raw.sum(axis=1))
    raw.reset_index(inplace=True) ## to restore 'date' column
    raw['raw_freq'] = raw.apply(lambda row: row[variant]/row['total'], axis=1)
    raw = raw[["date", "raw_freq"]]

    # the models. This gives us a number of dataframes each with many columns, all starting with `model_`
    subset_models = {k:curate_df(v, k, True) for k,v in models.items()}
    subset_models = {k:v for k,v in subset_models.items() if v is not False}
    
    # the forecasts. This gives us a number of dataframes each with many columns, all starting with `forecast_`
    subset_forecasts = {k:curate_df(v, k, False) for k,v in forecasts.items()}
    subset_forecasts = {k:v for k,v in subset_forecasts.items() if v is not False}
    
    # merges all the above data frames
    data = pd.merge(subset_retrospective, raw, how='outer', on='date')
    for commit in subset_forecasts.keys():
        data = pd.merge(data, subset_models[commit], how='outer', on='date')
        data = pd.merge(data, subset_forecasts[commit], how='outer', on='date')

    data.sort_values(by=['date'], ascending=True, inplace=True)    
    summarise_dates(data) # useful check -- will print out warnings if dates aren't completely sequential

    ## first date where freqs get above 0.1%
    start_date = data[(data['median_freq']>0.001) | (data['raw_freq']>0.001)]['date'].iloc[0]

    data.reset_index(inplace=True, drop=True)
    
    # data = data[data.apply(lambda row: row['date']>=start_date, axis=1)]
    # summarise_dates(data)
    return (start_date, data)


## Print out an example output to help:
first_day, data = prepare_data("USA", "Omicron 22B")
print(first_day)
data


In [None]:
def colours_via_commit():
    """Associate commits to colours. Helps keep all plots in-sync"""
    colours = {}
    cmap = plt.get_cmap('viridis')(np.linspace(0.9, 0.25, len(commits)))
    for idx, commit_long in enumerate(commits):
        if idx==0:
            colours[commit_long[0:7]] = [.66, .66, .66, 1] # final model (first commit) is grey
        else:
            colours[commit_long[0:7]] = cmap[idx]
    return colours
colours = colours_via_commit()

def pangoise(clade):
    if clade=="Omicron 21L": return "BA.2"
    if clade=="Omicron 22B": return "BA.5"
    return "Unknown" # Todo. Incomplete.


In [None]:
def plot_forecasts(ax, data, title):
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ## plot the "truth", as it's currently known -- i.e. based on the most recent 
    ## model run of all available data
    ax.plot(data["date"], data["median_freq"], color='k', zorder=2)
    ax.scatter(data["date"], data["raw_freq"], s=15, color='k', zorder=2)

    ### which column names should we use for forecasts? (could be known a lot earlier)
    forecast_names = [colname for colname in data.columns if colname.startswith('forecast_median_')]
    model_names = [name.replace("forecast_", "model_") for name in forecast_names]

    for idx, colname in enumerate(forecast_names):
        commit_hex = colname.split('_')[-1] ## short len=7 hex
        c = colours[commit_hex]
        if commit_hex == commits[0][0:7]: # first commit is the latest model (most recent model)
            c = [.66, .66, .66, 1] # final model is grey

        # c = colours[idx] if commits[0][0:7] not in colname else [.66, .66, .66, 1] # final model is grey
        ax.plot(data["date"], data[colname], color=c, zorder=4, lw=2)
        colour_hls = colorsys.rgb_to_hls(*c[0:3])
        lightened = colorsys.hls_to_rgb(colour_hls[0], 1 - 0.5 * (1 - colour_hls[1]), colour_hls[2])
        # 95% CI first. ALl CIs below the real data (black lines / dots)
        ax.fill_between(data["date"], data[colname.replace('median', 'upper95')], data[colname.replace('median', 'lower95')],
                        color=lightened, alpha = 0.2, lw=0, zorder=1)
        # 80% CI 
        ax.fill_between(data["date"], data[colname.replace('median', 'upper80')], data[colname.replace('median', 'lower80')],
                        color=lightened, alpha = 0.2, lw=0, zorder=1)
        # then 50% CI
        ax.fill_between(data["date"], data[colname.replace('median', 'upper50')], data[colname.replace('median', 'lower50')],
                        color=lightened, alpha = 0.4, lw=0, zorder=1)

    ### plot the model data (retrospective) which was available at the point in time the model was run.
    ### this is the "truth" according to the data available at the time
    for idx, colname in enumerate(model_names):
        if commits[0][0:7] in colname:
            continue ## don't plot the latest model run, it's already plotted in black...
        commit_hex = colname.split('_')[-1] ## short len=7 hex
        ax.plot(data["date"], data[colname], color=colours[commit_hex], zorder=1, lw=2, dashes=[2, 2])
        
    #### x-labels
    def nice_date(tick_idx):
        d = date.fromisoformat(data['date'].iloc[tick_idx])
        if tick_idx==0:
            return d.strftime("%d %b %Y")
        if d.day==1:
            return d.strftime("%d %b")
        return ''
    
    ax.set_xticks(ax.get_xticks())
    ax.set_xticklabels([nice_date(x) for x in ax.get_xticks()])
    ax.set_ylabel("Frequency", size='large')
    ax.set_title(title)


## plot one just to see...

fig, ax = plt.subplots(figsize=(8, 5))
first_day, data = prepare_data("USA", "Omicron 22B")
data = data[data.apply(lambda row: row['date']>=first_day, axis=1)]
plot_forecasts(ax, data, "BA.5 (22B) frequency in USA")
plt.show()

del first_day, data, fig, ax

In [None]:
def prepare_multiple(locations, variant):
    """
    Helper fn to return a list of prepared datasets, one per location.
    Dates before the first with any non-trivial frequencies are removed
    """
    datasets = [prepare_data(location, variant) for location in locations]
    first_day = sorted([d[0] for d in datasets])[0] ## different for each column, but consistent within a column
    print(f"Earliest day for {variant} across {len(locations)} locations: {first_day}")
    ## remove data before the first day when any location had non-trivial frequencies
    datasets = [d[1][d[1].apply(lambda row: row['date']>=first_day, axis=1)] for d in datasets]
    return datasets


In [None]:
### Plot with rows: various countries, and columns variants

locations = ["USA", "United Kingdom", "New Zealand"]
variants = ["Omicron 21L", "Omicron 22B"]
fig, axes = plt.subplots(ncols=len(variants), nrows=len(locations), figsize=(12*len(variants), 6*len(locations)))
fig.patch.set_facecolor('white')

for cidx, variant in enumerate(variants):
    datasets = prepare_multiple(locations, variant)
    for ridx, location in enumerate(locations):
        plot_forecasts(axes[ridx][cidx], datasets[ridx], f"{variant} ({pangoise(variant)}) frequency in {location}") 
    
plt.savefig('forecast_evaluation.png', format="png", bbox_inches='tight', transparent=False, pad_inches=0)
plt.show()

del datasets, axes, fig, location, locations, variant, variants, cidx, ridx

## Measure the nowcast error from (later known) truth

For each model estimate of past & current frequencies (i.e. not the forecast model), plot the difference between the estimate and the "true" frequency, where "true" is our current estimate with all available data.


In [None]:
def collect_nowcast_error(data):

    def extract_model_error(colname):
        prefix = colname.split('_')[-1]
        d = data[['date', 'median_freq', colname]]
        d = d[d[colname].notnull()]
        d[f'error_{prefix}'] = d[colname]-d['median_freq']
        d['t'] = d.index - d.index[-1]
        d = d[['t', f'error_{prefix}']]
        return d
    
    model_names = [n for n in data.columns if n.startswith('model_median_')]
    model_errors = [extract_model_error(colname) for colname in model_names]
    df_merged = pd.merge(model_errors[0], model_errors[1], how='outer', on='t')
    for df in model_errors[2:]:
        df_merged = pd.merge(df_merged, df, how='outer', on='t')
    # d1 = extract_model_error('model_median_b7ff151')
    # d2 = extract_model_error('model_median_4fec4d2')
    # df_merged = pd.merge(d1, d2, how='outer', on='t')
    df_merged.sort_values(by=['t'], ascending=True, inplace=True)
    
    return df_merged

## Show an example df:
collect_nowcast_error(prepare_data("USA", "Omicron 22B")[1])


In [None]:
def plot_nowcast_error(ax, data, title, ymin=False, ymax=False):
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    for colname in data.columns:
        if colname == 't':
            continue
        commit_short = colname.split('_')[-1]
        c = colours[commit_short]
        ax.plot(data["t"], data[colname], color=c, zorder=2)

    ## construct x-ticks to be weeks:
    xtickvals = [0]
    while xtickvals[-1]>data['t'].iloc[0]:
        xtickvals.append(xtickvals[-1]-7)
    ax.set_xticks(xtickvals)
    ax.set_xticklabels([int(x/7) for x in xtickvals])
    ax.set_xlabel('Weeks prior to final day in model output', size='large')
    
    # y-ticks:
    if ymax is False:
        ymax = max(data.max().drop('t'))
    if ymin is False:
        ymin = min(data.min().drop('t'))
    yticks = np.linspace(-1, 1, 21) # -1, -0.9, ..., 0.9, 1
    yticks = [y for y in yticks if y<=ymax and y>=ymin]
    ax.set_yticks(yticks)
    ax.set_yticklabels([f'{round(y*100)}%' for y in yticks])
    ax.set_ylabel("∆frequency (from eventual truth)", size='large')
    
    for y in yticks:
        if round(y,1)==0:
            ax.axhline(y=0, c='k', zorder=1)
        else:
            ax.axhline(y=y, c='k', alpha=0.1, zorder=1, dashes=[2,6])

    ax.set_title(title, size='x-large')


## plot one just to see...

fig, ax = plt.subplots(figsize=(8, 5))
df = collect_nowcast_error(prepare_data("USA", "Omicron 21L")[1]) ## subset prepare data based on 1st day?
plot_nowcast_error(ax, df, "Nowcast error (USA, 21L)")
plt.show()

del fig, ax, df



In [None]:
locations = ["USA", "United Kingdom", "New Zealand"]
variant = "Omicron 21L"
datasets = prepare_multiple(locations, variant)
errors = [collect_nowcast_error(df) for df in datasets]
ymin = min([min(df.min().drop('t')) for df in errors])
ymax = max([max(df.max().drop('t')) for df in errors])

fig, axes = plt.subplots(nrows=len(locations), figsize=(10, 5*len(locations)))
fig.patch.set_facecolor('white')

for idx, ax in enumerate(axes):
    location = locations[idx]
    plot_nowcast_error(ax, errors[idx], f"Nowcast error ({location}, {variant})", ymin=ymin, ymax=ymax)
                     
plt.savefig(f'nowcast_error.{variant.split()[-1]}.png', format="png", bbox_inches='tight', transparent=False, pad_inches=0)

plt.show()

del datasets, idx, ax, fig, axes, location, locations, variant, errors, ymin, ymax
