# Backtesting Plots for mutation growth rate paper

This notebook generates plots for the [paper/backtesting](paper/backtesting) directory. This assumes you've alread run
```sh
make update                       # Downloads data (~1hour).
make preprocess-usher             # Preprocesses usher tree
make backtesting-complete                  # Fits backtesting models
```

# Initialization

In [None]:
#%load_ext autoreload
#%autoreload 2

In [None]:
import datetime
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pyro.distributions as dist
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import pretty_print, pearson_correlation
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import seaborn as sns
import colorcet as cc

In [None]:
matplotlib.rcParams["figure.dpi"] = 200

In [None]:
# configure logging
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
# This line can be used to modify logging as required later in the notebook
#logging.getLogger().setLevel(logging.INFO)

In [None]:
# set matplotlib params
#matplotlib.rcParams["figure.dpi"] = 200
#matplotlib.rcParams['figure.figsize'] = [8, 8]
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

## Load input data

In [None]:
# Load the entire constant dataset
max_num_clades = 3000
min_num_mutations = 1
min_region_size = 50
ambiguous = False
columns_filename=f"results/columns.{max_num_clades}.pkl"
features_filename=f"results/features.{max_num_clades}.{min_num_mutations}.pt"

In [None]:
input_dataset = mutrans.load_gisaid_data(
        device="cpu",
        columns_filename=columns_filename,
        features_filename=features_filename,
        min_region_size=min_region_size
)

## Load backtesting trained models

In [None]:
fits = torch.load("results/mutrans.backtesting.pt", map_location="cpu")

In [None]:
print(f'We have loaded {len(fits)} models')

In [None]:
# print info on available models and what the keys are
if True:
    for key in fits:
        print(f'{key} -- {fits[key]["weekly_clades_shape"]}')

Scale `coef` by 1/100 in all results.

In [None]:
ALREADY_SCALED = set()

def scale_tensors(x, names={"coef"}, scale=0.01, prefix="", verbose=True):
    if id(x) in ALREADY_SCALED:
        return
    if isinstance(x, dict):
        for k, v in list(x.items()):
            if k in names:
                if verbose:
                    print(f"{prefix}.{k}")
                x[k] = v * scale
            elif k == "diagnostics":
                continue
            else:
                scale_tensors(v, names, scale, f"{prefix}.{k}", verbose=verbose)
    ALREADY_SCALED.add(id(x))
                
scale_tensors(fits, verbose=False)

In [None]:
forecast_dir_prefix = "paper/backtesting/"

# Forecasting

In [None]:
def weekly_clades_to_lineages(weekly_clades, clade_id_to_lineage_id, n_model_lineages):
    weekly_lineages = weekly_clades.new_zeros(weekly_clades.shape[:-1] + (n_model_lineages,)).scatter_add_(
        -1, clade_id_to_lineage_id.expand_as(weekly_clades), weekly_clades)
    return weekly_lineages

In [None]:
def plusminus(mean, std):
    p95 = 1.96 * std
    return torch.stack([mean - p95, mean, mean + p95])

In [None]:
from pyrocov.util import (
    pretty_print, pearson_correlation, quotient_central_moments, generate_colors
)

In [None]:
def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))

In [None]:
def select_lineages_for_plot(
        weekly_lineages,
        num_lineages,
        lineage_id_inv,
        location_ids, # location ids
        nbins = 10,
        additional_lineages = [],
    ):
    """Return names of lineages for plot"""
    
    keep_per_bin = math.ceil(num_lineages / nbins)
    T = weekly_lineages.shape[0]
    time_intervals = list(split(np.arange(T), nbins))
    lineage_ids = []
    for interval in time_intervals:
        kept_lineage_ids = weekly_lineages[interval][:, location_ids].sum([0, 1]).sort(-1, descending=True).indices[:keep_per_bin]
        lineage_ids.append(kept_lineage_ids)
    lineage_ids = torch.cat(lineage_ids)
    additional_indexes = list(lineage_id_inv.index(x) for x in additional_lineages)
    lineage_ids = torch.cat((lineage_ids, torch.tensor(additional_indexes))).tolist()

    return sorted(set(lineage_id_inv[x] for x in lineage_ids)) 

In [None]:
input_dataset.keys()

In [None]:
fits[list(fits.keys())[0]].keys()

In [None]:
def generate_rainbow(fit, lineage_names, input_dataset):
    clade_id = input_dataset['clade_id']
    lineage_to_clade = input_dataset['lineage_to_clade']

    rate = fit["mean"]["rate"].mean(0) # Mean over places
    rates = torch.stack([
        rate[clade_id[lineage_to_clade[l]]] for l in lineage_names
    ])
    C = len(lineage_names)
    colors = [None] * C
    for c, l in enumerate(rates.sort(0).indices.tolist()):
        colors[l] = cm.rainbow(c / (C - 1))
    return {
        'colors': colors,
        'min': rates.min().item(),
        'max': rates.max().item(),
        'rates': rates,
    }

In [None]:
def plot_forecast2(fit, input_dataset, queries, num_lineages=10, filenames=[], 
                   verbose=False, additional_lineages = ['BA.2'], nbins=5, 
                   legend_out=False, figsize_x = None, figsize_y = None, 
                   auto_select_lineages = True, colors_dict_export = None,
                  show_legend = True, show_case_counts = True, show_second_legend = True):
    # Convert queries to array if only only string
    if isinstance(queries, str):
        queries = [queries]
    
    # Get dimensions of the model fit (T,P,L) these are probabilities
    n_model_periods, n_model_places, n_model_lineages = fit['mean']['probs'].shape
    if (verbose):
        print('---')
        print(f'n_model_periods: {n_model_periods}')
        print(f'n_model_places: {n_model_places}')
        print(f'n_model_lineages: {n_model_lineages}')
    
    # Get dimensions of weekly_cases (T,P) these are JHU counts
    weekly_cases_fit = fit['weekly_cases']
    n_cases_periods, n_cases_places = weekly_cases_fit.shape
    if (verbose):
        print('---')
        print(f'n_cases_periods: {n_cases_periods}')
        print(f'n_cases_places: {n_cases_places}')
    
    # Some checks
    assert n_cases_places == n_model_places
    assert n_model_periods > n_cases_periods
    
    # Calculate how many periods are forecasted (i.e. are beyond the input to the model)
    n_forecast_steps = n_model_periods - n_cases_periods
    if (verbose):
        print(f'n_forecast_steps: {n_forecast_steps}')
        
    # Weekly case counts by time place and clade obtained from the fit
    weekly_clades_fit = fit['weekly_clades'] # T, P, C
    if verbose:
        print('---')
        print(f'weekly_clades_fit shape: {weekly_clades_fit.shape}')
    
    # Weekly case counts by time place and clade obtain from the input data
    # This has more time point and more regions than the one from the fit
    weekly_clades_data = input_dataset['weekly_clades']
    if verbose:
        print('---')
        print(f'weekly_clades_data shape: {weekly_clades_data.shape}')
    
    # Mapping from clades to lineages, a tensor of indexes
    # This is valid for both the fit and the input_data
    clade_id_to_lineage_id = input_dataset['clade_id_to_lineage_id']
    if verbose:
        print('---')
        print(f'clade_id_to_lineage_id length: {len(clade_id_to_lineage_id)}')
        
    # We don't have clade_id_to_lineage_id in the fit -- it should in principle be the same
    
    # Summarize the counts of the weekly_clades (from data or fit) to the number of lineages in the model
    weekly_lineages_data = weekly_clades_to_lineages(weekly_clades_data, clade_id_to_lineage_id, n_model_lineages)
    weekly_lineages_fit = weekly_clades_to_lineages(weekly_clades_fit, clade_id_to_lineage_id, n_model_lineages)
    
    # Add CI to the probs
    probs = plusminus(fit['mean']['probs'], fit['std']['probs']) # [3,T,P,L]
    
    # Expand weekly_cases_fit (JHU counts) from the model to cover the steps we are forecasting
    padding = 1 + weekly_cases_fit.mean(0, keepdim=True).expand(n_forecast_steps, -1)
    weekly_cases_fit_ = torch.cat([weekly_cases_fit, padding], 0)
    weekly_cases_fit_.add_(10)
    # Generate predictions
    # Note: For the evaluation maybe we are better off comparing probabilities not counts
    predicted = probs * weekly_cases_fit_[..., None]
    
    # This is an array of strings listing the locations for the data
    location_id_inv_data = input_dataset['location_id_inv']
    if (verbose):
        print('---')
        print(f'location_id_inv_data length: {len(location_id_inv_data)}')
    
    # This is an array of strings listing the locations for the fit
    location_id_inv_fit = fit['location_id_inv']
    if verbose:
        print('---')
        print(f'location_id_inv_fit length: {len(location_id_inv_fit)}')
    
    # Get the location indexes that we want to keep based on query for the data
    ids_fit = torch.tensor([i for i, name in enumerate(location_id_inv_fit) if any(q in name for q in queries)])
    
    # These are the lineage labels, we can get them from either the fit or the dataset. 
    # We assume that these are identical and we assert this below
    lineage_id_inv_fit = fit['lineage_id_inv']
    lineage_id_inv_data = input_dataset['lineage_id_inv']
    assert lineage_id_inv_fit == lineage_id_inv_data
    
    # Subset weekly_lineages_fit to those location sum over time and place and get the indices in descending order
    plot_lineages_ids_inv_fit = select_lineages_for_plot(
        weekly_lineages = weekly_lineages_fit,
        num_lineages = num_lineages,
        lineage_id_inv = lineage_id_inv_fit,
        location_ids = ids_fit, 
        nbins = nbins,
        additional_lineages = additional_lineages,
    )
    
    # tbw
    plot_lineages_ids_inv_pred = select_lineages_for_plot(
        weekly_lineages = fit['mean']['probs'],
        num_lineages = num_lineages,
        lineage_id_inv = lineage_id_inv_fit,
        location_ids = ids_fit, 
        nbins = nbins,
        additional_lineages = additional_lineages,
    )

    # Same thing for the data
    ids_data = torch.tensor([ i for i, name in enumerate(location_id_inv_data) if any(q in name for q in queries)])
    
    # Subset weekly_lineages_fit to those location sum over time and place and get the indices in descending order
    plot_lineages_ids_inv_data = select_lineages_for_plot(
        weekly_lineages = weekly_lineages_data,
        num_lineages = num_lineages,
        lineage_id_inv = lineage_id_inv_data,
        location_ids = ids_data, 
        nbins = nbins,
        additional_lineages = additional_lineages,
    )
    
    # merge the lineage name from datset and fit to get a single list
    lineage_name_to_index_map_data = { l:i for i, l in enumerate(lineage_id_inv_data)}
    lineage_name_to_index_map_fit = { l:i for i, l in enumerate(lineage_id_inv_fit)}
    
    if auto_select_lineages:
        plot_lineages_ids_inv_joint = sorted(
            set(plot_lineages_ids_inv_fit)
                .union(plot_lineages_ids_inv_data)
                .union(plot_lineages_ids_inv_pred))
        colors_dict = generate_rainbow(fit, plot_lineages_ids_inv_joint, input_dataset)
        colors = colors_dict['colors']
        rates = colors_dict['rates']
        order_perm = rates.argsort().numpy()[::-1]
        
        # we may have a few plotted lineages now...
        num_lineages = len(plot_lineages_ids_inv_joint)

        lineage_ids_fit = list(map(lineage_name_to_index_map_fit.get, plot_lineages_ids_inv_joint))
        lineage_ids_data = list(map(lineage_name_to_index_map_data.get, plot_lineages_ids_inv_joint))
        assert lineage_ids_fit == lineage_ids_data

        # reorder lineage_ids_data, lineage_ids_fit, colors
        plot_lineages_ids_inv_joint = np.asarray(plot_lineages_ids_inv_joint)[order_perm]
        lineage_ids_data = np.asarray(lineage_ids_data)[order_perm]
        lineage_ids_fit = np.asarray(lineage_ids_fit)[order_perm]
        colors = np.asarray(colors)[order_perm]
        colors_dict_export = {l:c for l, c in zip(plot_lineages_ids_inv_joint, colors)}
        
        rates = np.asarray(rates)[order_perm]
    
    else:
        rates = None
        assert colors_dict_export is not None
        assert additional_lineages is not None
        
        plot_lineages_ids_inv_joint = additional_lineages
        num_lineages = len(plot_lineages_ids_inv_joint)
        
        lineage_ids_fit = list(map(lineage_name_to_index_map_fit.get, plot_lineages_ids_inv_joint))
        lineage_ids_data = list(map(lineage_name_to_index_map_data.get, plot_lineages_ids_inv_joint))
        assert lineage_ids_fit == lineage_ids_data
        
        # grab colors from the provided dictionary
        colors = list(map(colors_dict_export.get, plot_lineages_ids_inv_joint))


    
    assert len(colors) >= num_lineages
    light = '#bbbbbb'
    dark = '#444444'
    
    # Generate Figure
    if figsize_x is None:
        figsize_x = 8
        
    if figsize_y is None:
        figsize_y = 0.5 + 2.5 * len(queries)
    
    fig, axes = plt.subplots(len(queries), figsize=(figsize_x, figsize_y), sharex=True)
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]
    
    # Get x axis dates for plotting
    dates = matplotlib.dates.date2num(mutrans.date_range(len(fit["mean"]["probs"])))

    # Query (region) plotting loop
    for row, (query, ax) in enumerate(zip(queries, axes)):
        # location ids for this query (some queries are made of multiple regions)
        ids_fit = torch.tensor([i for i, name in enumerate(location_id_inv_fit) if query in name])
        if verbose:
            print('---')
            print(f"{query} matched {len(ids_fit)} regions in the fit")
        
        # location ids for this query in the data
        ids_data = torch.tensor([i for i, name in enumerate(location_id_inv_data) if query in name])
        
        if len(axes) > 1 and show_case_counts:
            # Plot weekly cases total
            counts = weekly_cases_fit[:, ids_fit].sum(1)
            if verbose:
                print(f"{query}: max {counts.max():g}, total {counts.sum():g}")
            counts /= counts.max()
            ax.plot(dates[:len(counts)], counts, "k-", color=light, lw=0.8, zorder=-20)
            
            # Plot weekly lineages total we are getting the data from the fit not the dataset
            counts = weekly_lineages_fit[:, ids_fit].sum([1, 2])
            counts /= counts.max()
            ax.plot(dates[:len(counts)], counts, "k--", color=light, lw=1, zorder=-20)
            
        # Get the predictions for the relevant regions, normalize
        pred = predicted.index_select(-2, ids_fit).sum(-2)
        pred /= pred[1].sum(-1, True).clamp_(min=1e-20)
        
        # Get the observations for the relevant regions
        obs = weekly_lineages_fit[:, ids_fit].sum(1)
        obs /= obs.sum(-1, True).clamp_(min=1e-9)
        
        # Observations from the data -- this extends further in the time dimension
        obs_data = weekly_lineages_data[:, ids_data].sum(1)
        obs_data /= obs_data.sum(-1, True).clamp(min=1e-9)
        
        # lineage plotting loop
        for s, color in zip(lineage_ids_fit, colors):
            lb, mean, ub = pred[..., s]
            ax.fill_between(dates, lb, ub, color=color, alpha=0.2, zorder=-10)
            ax.plot(dates, mean, color=color, lw=1, zorder=-9)
            # Get the lineage label
            lineage = lineage_id_inv_fit[s]
            ax.plot(dates[:len(obs)], obs[:, s], color=color, lw=0, marker='o', markersize=3,
                    label=lineage if row == 0 else None)
        
        # Plot observations from the dataset for all the forecast points
        # TODO: Fix colors to match (we probably want to fix "sort(-1, descending=True)" to be a matching permutation instead)
        for s, color in zip(lineage_ids_data, colors):
            lineage = lineage_id_inv_data[s]
            max_time_step = min((len(obs)+n_forecast_steps), obs_data.shape[0]-1)
            
            ax.plot(dates[len(obs):max_time_step], obs_data[len(obs):max_time_step, s], label='_nolegend_',
                    color=color, lw=0, marker='x', markersize=2)
            
        # Add shading for the forecast region
        ax.axvline(dates[len(obs)], linestyle='--', lw=1, color=(0.5, 0.5, 0.5))
        ax.axvspan(dates[len(obs)],dates[len(obs)+n_forecast_steps-1], color=(0.5, 0.5, 0.5), alpha=0.2)
        
        # Set axis ticks
        ax.set_ylim(0, 1)
        ax.set_yticks(())
        ax.set_ylabel(query.replace(" / ", "\n"))
        ax.set_xlim(dates.min(), dates.max())
        
        # Print legend
        if show_legend:
            if legend_out:

                if row == 0:
                    ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1.04), fontsize=10)
                elif row == 1:
                    if show_case_counts and show_second_legend:
                        ax.plot([], "k--", color=light, lw=1, label="relative #samples")
                        ax.plot([], "k-", color=light, lw=0.8, label="relative #cases")
                        ax.plot([], lw=0, marker='o', markersize=3, color='gray',
                                label="observed portion")
                        ax.fill_between([], [], [], color='gray', label="predicted portion")
                        ax.legend(loc="upper left")
            else:

                if row == 0:
                    ax.legend(loc="upper left", fontsize=8 * (10 / num_lineages) ** 0.8)
                elif row == 1:
                    if show_case_counts and show_second_legend:
                        ax.plot([], "k--", color=light, lw=1, label="relative #samples")
                        ax.plot([], "k-", color=light, lw=0.8, label="relative #cases")
                        ax.plot([], lw=0, marker='o', markersize=3, color='gray',
                                label="observed portion")
                        ax.fill_between([], [], [], color='gray', label="predicted portion")
                        ax.legend(loc="upper left",)
          
    # Setup the date axis correctly
    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))
    plt.xticks(rotation=90)
    plt.subplots_adjust(hspace=0)
    
    for filename in filenames:
        plt.savefig(filename, bbox_inches='tight')
        
    return {
        'lineages_plotted': plot_lineages_ids_inv_joint,
        'colors_dict': colors_dict_export,
        'rates': rates,
        'ax': ax,
        'fig': fig,
    }

### Initial run of last model to find lineages

In [None]:
k = list(fits.keys())[-2]
print(k[9])
fit_n = fits[k]
plot_forecast_results = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    additional_lineages = ['BA.1'],
)

### Barplot for rates

In [None]:
xs = list(range(plot_forecast_results['rates'].shape[0]))
ys = np.exp(list(plot_forecast_results['rates']))

fig, ax = plt.subplots()
ax.bar(x = xs, height = ys, color = list(map(plot_forecast_results['colors_dict'].get, plot_forecast_results['lineages_plotted'])))
ax.set_xticks(xs)
ax.set_xticklabels(plot_forecast_results['lineages_plotted'].tolist(), rotation=90)
ax.set_ylabel('$R_{lineage}/R_A$')
plt.savefig('paper/backtesting/barplot_rates_inset.pdf')

## Prediction for Fig S6 (AY.4 and BA.1)

In [None]:
prediction_figsize_y = 2
prediction_figsize_x = 12

In [None]:
k = list(fits.keys())[15]
print(k[9])
fit_n = fits[k]
plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=14,
    verbose=False,
    filenames =  [f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.png',
                 f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.pdf'],
    figsize_x = 360 / 752 * prediction_figsize_x,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = False,
    legend_out = True,
);

In [None]:
k = list(fits.keys())[27]
print(k[9])
fit_n = fits[k]
plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    filenames =  [f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.png',
                 f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.pdf'],
    figsize_x = 528 / 752 * prediction_figsize_x,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = False,
    legend_out = True,
);

In [None]:
k = list(fits.keys())[-3]
print(k[9])
fit_n = fits[k]
plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    filenames =  [f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.png',
                 f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.pdf'],
    figsize_x = 738 / 752 * prediction_figsize_x,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = False,
    legend_out = True,
);

In [None]:
k = list(fits.keys())[-2]
print(k[9])
fit_n = fits[k]
plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    filenames =  [f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.png',
                 f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england.pdf'],
    figsize_x = prediction_figsize_x,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = False,
    legend_out = True,
);

# Forecast Evaluation

In [None]:
def evaluate_forecast2(fit, input_dataset, queries, num_lineages=10, filenames=[], 
                       verbose=False, data_region = None):
    # Convert queries to array if only only string
    if isinstance(queries, str):
        queries = [queries]
    
    # Get dimensions of the model fit (T,P,L) these are probabilities
    n_model_periods, n_model_places, n_model_lineages = fit['mean']['probs'].shape
    if (verbose):
        print('---')
        print(f'n_model_periods: {n_model_periods}')
        print(f'n_model_places: {n_model_places}')
        print(f'n_model_lineages: {n_model_lineages}')
    
    # Get dimensions of weekly_cases (T,P) these are JHU counts
    weekly_cases_fit = fit['weekly_cases']
    n_cases_periods, n_cases_places = weekly_cases_fit.shape
    if (verbose):
        print('---')
        print(f'n_cases_periods: {n_cases_periods}')
        print(f'n_cases_places: {n_cases_places}')
    
    # Some checks
    assert n_cases_places == n_model_places
    assert n_model_periods > n_cases_periods
    
    # Calculate how many periods are forecasted (i.e. are beyond the input to the model)
    n_forecast_steps = n_model_periods - n_cases_periods
    if (verbose):
        print(f'n_forecast_steps: {n_forecast_steps}')
        
    # Weekly case counts by time place and clade obtained from the fit
    weekly_clades_fit = fit['weekly_clades'] # T, P, C
    if verbose:
        print('---')
        print(f'weekly_clades_fit shape: {weekly_clades_fit.shape}')
    
    # Weekly case counts by time place and clade obtain from the input data
    # This has more time point and more regions than the one from the fit
    weekly_clades_data = input_dataset['weekly_clades']
    if verbose:
        print('---')
        print(f'weekly_clades_data shape: {weekly_clades_data.shape}')
    
    # Mapping from clades to lineages, a tensor of indexes
    # This is valid for both the fit and the input_data
    clade_id_to_lineage_id = input_dataset['clade_id_to_lineage_id']
    if verbose:
        print('---')
        print(f'clade_id_to_lineage_id length: {len(clade_id_to_lineage_id)}')
        
    # We don't have clade_id_to_lineage_id in the fit -- it should in principle be the same
    
    # Summarize the counts of the weekly_clades (from data or fit) to the number of lineages in the model
    weekly_lineages_data = weekly_clades_to_lineages(weekly_clades_data, clade_id_to_lineage_id, n_model_lineages)
    weekly_lineages_fit = weekly_clades_to_lineages(weekly_clades_fit, clade_id_to_lineage_id, n_model_lineages)
    
    # Get the probs
    probs = fit['mean']['probs']
    #probs = plusminus(fit['mean']['probs'], fit['std']['probs']) # [3,T,P,L]
    
    # Expand weekly_cases_fit (JHU counts) from the model to cover the steps we are forecasting
    #padding = 1 + weekly_cases_fit.mean(0, keepdim=True).expand(n_forecast_steps, -1)
    #weekly_cases_fit_ = torch.cat([weekly_cases_fit, padding], 0)
    # Generate predictions
    # Note: For the evaluation maybe we are better off comparing probabilities not counts
    #predicted = probs * weekly_cases_fit_[..., None]
    
    # This is an array of strings listing the locations for the data
    location_id_inv_data = input_dataset['location_id_inv']
    if (verbose):
        print('---')
        print(f'location_id_inv_data length: {len(location_id_inv_data)}')
    
    # This is an array of strings listing the locations for the fit
    location_id_inv_fit = fit['location_id_inv']
    if verbose:
        print('---')
        print(f'location_id_inv_fit length: {len(location_id_inv_fit)}')
    
    # Get the location indexes that we want to keep based on query for the data
    ids_fit = torch.tensor([i for i, name in enumerate(location_id_inv_fit) if any(q in name for q in queries)])
    
    # Subset weekly_lineages_fit to those location sum over time and place and get the indices in descending order
    lineage_ids_fit = weekly_lineages_fit[:, ids_fit].sum([0, 1]).sort(-1, descending=True).indices
    if verbose:
        print('---')
        print(f'lineage_ids_fit shape: {lineage_ids_fit.shape}')
    # Keep only the top n number of lineages we want to plot
    lineage_ids_fit = lineage_ids_fit[:num_lineages]

    # This is problematic without fixing the above permutation
    # TODO: Add assert that they are the same set / eliminate code
    # Check if order of 
    lineage_ids_data = lineage_ids_fit[:num_lineages]
    
    # These are the lineage labels, we can get them from either the fit or the dataset. 
    # We assume that these are identical and we assert this below
    lineage_id_inv_fit = fit['lineage_id_inv']
    lineage_id_inv_data = input_dataset['lineage_id_inv']
    assert lineage_id_inv_fit == lineage_id_inv_data
    
    # Get shared locations between full dataset and fit dataset
    common_regions = list(set(location_id_inv_fit).intersection(set(location_id_inv_data)))
    
    if data_region is not None:
        common_regions = list(set(common_regions).intersection(set(data_region)))
    
    # Get indexes of these common regions for each set
    common_regions_fit_inv_map = []
    common_regions_data_inv_map = []
    for r in common_regions:
        common_regions_fit_inv_map.append(location_id_inv_fit.index(r))
        common_regions_data_inv_map.append(location_id_inv_data.index(r))
        
    # We want to compare empirical and predicted probabilities for the forecast interval
    probs = probs[n_cases_periods:,common_regions_fit_inv_map,:]
    
    # Subset observed to relevant periods and regions
    obs_data = weekly_lineages_data[n_cases_periods:n_cases_periods+n_forecast_steps,common_regions_data_inv_map,:]
    empirical_probs = obs_data / obs_data.sum(-1,True).clamp_(min=1e-9)
    
    # Truncate to availanle data
    probs = probs[:empirical_probs.shape[0],]
    
    # Calculate errors
    l1_error = (probs - empirical_probs).abs().sum([-1,-2]) / probs.shape[-2]
    l2_error = (probs - empirical_probs).pow(2).sum([-1,-2]).sqrt() / probs.shape[-2]

    # consider spearman error
    # correlations on the probabilities (average over time)
    # precision at k
    return {
        'L1_error': l1_error,
        'L2_error': l2_error,
    }

In [None]:
def generate_forecast_eval(fits, input_dataset, data_region = None, queries = None):
    model_keys = list(fits.keys())
    
    if not queries:
        queries = input_dataset['location_id_inv']

    forecast_start_days = []
    period_forecast_ahead = []
    l1_error = []
    l2_error = []
    
    period_length = 14

    for key in model_keys:
        forecast_start_day = key[9]
        fit_n = fits[key]
        # Get forecast error for all independent location ids 
        forecast_error = evaluate_forecast2(
            fit_n, 
            input_dataset, 
            queries = queries,
            num_lineages=100,
            data_region = data_region,
        verbose=False)
        n_periods_forecast = len(forecast_error['L1_error'].tolist())
        forecast_start_days.extend([forecast_start_day] * n_periods_forecast)
        period_forecast_ahead.extend(list(range(1,n_periods_forecast+1)))
        l1_error.extend(forecast_error['L1_error'].tolist())
        l2_error.extend(forecast_error['L2_error'].tolist())
        
    df1 = pd.DataFrame({
    'forecast_start_days': forecast_start_days,
    'period_forecast_ahead': period_forecast_ahead, 
    'l1_error': l1_error, 
    'l2_error': l2_error})
    
    df1['day_of_forecast'] = df1['forecast_start_days'] + df1['period_forecast_ahead'] * 14
    
    return df1

## England forecasting

In [None]:
regions = ['Europe / United Kingdom / England']

In [None]:
top_region_forecast = generate_forecast_eval(fits, input_dataset, data_region = regions)

In [None]:
sns.reset_orig()

In [None]:
plt.rcParams['figure.figsize'] = (4,3)
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=top_region_forecast, color='gray')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_England.pdf',bbox_inches='tight')

## Brazil forecasting

In [None]:
regions = list(x for x in input_dataset['location_id_inv'] if 'Brazil' in x)

In [None]:
top_region_forecast = generate_forecast_eval(fits, input_dataset, data_region = regions)

In [None]:
plt.rcParams['figure.figsize'] = (4,3)
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=top_region_forecast, color='gray')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_Brazil.pdf',bbox_inches='tight')

### Mass forecasting

In [None]:
regions = list(x for x in input_dataset['location_id_inv'] if 'Massachusetts' in x)

In [None]:
top_region_forecast = generate_forecast_eval(fits, input_dataset, data_region = regions)

In [None]:
plt.rcParams['figure.figsize'] = (4,3)
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=top_region_forecast, color='gray')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_Massachusetts.pdf',bbox_inches='tight')

## All region forecasting

In [None]:
len(fits)

In [None]:
all_region_forecast = generate_forecast_eval(fits, input_dataset)

In [None]:
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=all_region_forecast,color='gray')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_all.png')

## Top 100 region forecasting

In [None]:
# Get top covered regions
top_region_idx = input_dataset['weekly_clades'].sum([0,2]).sort(-1, descending=True).indices[:100].tolist()
regions = list(input_dataset['location_id_inv'][x] for x in top_region_idx)

In [None]:
top_region_forecast = generate_forecast_eval(fits, input_dataset, data_region = regions)

In [None]:
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=top_region_forecast, color='gray')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_top100.png')

## Top 100-200 region forecasting

In [None]:
# Get top covered regions
top_region_idx = input_dataset['weekly_clades'].sum([0,2]).sort(-1, descending=True).indices[100:1000].tolist()
regions = list(input_dataset['location_id_inv'][x] for x in top_region_idx)

In [None]:
top_region_forecast = generate_forecast_eval(fits, input_dataset, data_region = regions)

In [None]:
ax = sns.boxplot(x="period_forecast_ahead", y="l1_error", data=top_region_forecast, color='grey')
ax.set(xlabel = '2-week period forecast ahead', ylabel="L1 Error")
ax.set_ylim([0.0,2.0])
plt.savefig('paper/backtesting/L1_error_barplot_top100-1000.png')

## Evaluation of forecasting accuracy

 - What are we trying to do? For a given region and for all models get a % of how often we predict the correct strain n intervals ahead


In [None]:
def evaluate_forecast3(fit, input_dataset, queries, num_lineages=10, verbose=False, data_region = None):    
    # Convert queries to array if only only string
    if isinstance(queries, str):
        queries = [queries]
    
    # Get dimensions of the model fit (T,P,L) these are probabilities
    n_model_periods, n_model_places, n_model_lineages = fit['mean']['probs'].shape
    if (verbose):
        print('---')
        print(f'n_model_periods: {n_model_periods}')
        print(f'n_model_places: {n_model_places}')
        print(f'n_model_lineages: {n_model_lineages}')
    
    # Get dimensions of weekly_cases (T,P) these are JHU counts
    weekly_cases_fit = fit['weekly_cases']
    n_cases_periods, n_cases_places = weekly_cases_fit.shape
    if (verbose):
        print('---')
        print(f'n_cases_periods: {n_cases_periods}')
        print(f'n_cases_places: {n_cases_places}')
    
    # Some checks
    assert n_cases_places == n_model_places
    assert n_model_periods > n_cases_periods
    
    # Calculate how many periods are forecasted (i.e. are beyond the input to the model)
    n_forecast_steps = n_model_periods - n_cases_periods
    if (verbose):
        print(f'n_forecast_steps: {n_forecast_steps}')
        
    # Weekly case counts by time place and clade obtained from the fit
    weekly_clades_fit = fit['weekly_clades'] # T, P, C
    if verbose:
        print('---')
        print(f'weekly_clades_fit shape: {weekly_clades_fit.shape}')
    
    # Weekly case counts by time place and clade obtain from the input data
    # This has more time point and more regions than the one from the fit
    weekly_clades_data = input_dataset['weekly_clades']
    if verbose:
        print('---')
        print(f'weekly_clades_data shape: {weekly_clades_data.shape}')
    
    # Mapping from clades to lineages, a tensor of indexes
    # This is valid for both the fit and the input_data
    clade_id_to_lineage_id = input_dataset['clade_id_to_lineage_id']
    if verbose:
        print('---')
        print(f'clade_id_to_lineage_id length: {len(clade_id_to_lineage_id)}')
        
    # We don't have clade_id_to_lineage_id in the fit -- it should in principle be the same
    
    # Summarize the counts of the weekly_clades (from data or fit) to the number of lineages in the model
    weekly_lineages_data = weekly_clades_to_lineages(weekly_clades_data, clade_id_to_lineage_id, n_model_lineages)
    weekly_lineages_fit = weekly_clades_to_lineages(weekly_clades_fit, clade_id_to_lineage_id, n_model_lineages)
    
    # Get the probs
    probs = fit['mean']['probs']
    #probs = plusminus(fit['mean']['probs'], fit['std']['probs']) # [3,T,P,L]
    
    # Expand weekly_cases_fit (JHU counts) from the model to cover the steps we are forecasting
    #padding = 1 + weekly_cases_fit.mean(0, keepdim=True).expand(n_forecast_steps, -1)
    #weekly_cases_fit_ = torch.cat([weekly_cases_fit, padding], 0)
    # Generate predictions
    # Note: For the evaluation maybe we are better off comparing probabilities not counts
    #predicted = probs * weekly_cases_fit_[..., None]
    
    # This is an array of strings listing the locations for the data
    location_id_inv_data = input_dataset['location_id_inv']
    if (verbose):
        print('---')
        print(f'location_id_inv_data length: {len(location_id_inv_data)}')
    
    # This is an array of strings listing the locations for the fit
    location_id_inv_fit = fit['location_id_inv']
    if verbose:
        print('---')
        print(f'location_id_inv_fit length: {len(location_id_inv_fit)}')
    
    # Get the location indexes that we want to keep based on query for the data
    ids_fit = torch.tensor([i for i, name in enumerate(location_id_inv_fit) if any(q in name for q in queries)])
    
    # Subset weekly_lineages_fit to those location sum over time and place and get the indices in descending order
    lineage_ids_fit = weekly_lineages_fit[:, ids_fit].sum([0, 1]).sort(-1, descending=True).indices
    if verbose:
        print('---')
        print(f'lineage_ids_fit shape: {lineage_ids_fit.shape}')
    # Keep only the top n number of lineages we want to plot
    lineage_ids_fit = lineage_ids_fit[:num_lineages]

    # This is problematic without fixing the above permutation
    # TODO: Add assert that they are the same set / eliminate code
    # Check if order of 
    lineage_ids_data = lineage_ids_fit[:num_lineages]
    
    # These are the lineage labels, we can get them from either the fit or the dataset. 
    # We assume that these are identical and we assert this below
    lineage_id_inv_fit = fit['lineage_id_inv']
    lineage_id_inv_data = input_dataset['lineage_id_inv']
    assert lineage_id_inv_fit == lineage_id_inv_data
    
    # Get shared locations between full dataset and fit dataset
    common_regions = list(set(location_id_inv_fit).intersection(set(location_id_inv_data)))
    
    if data_region is not None:
        common_regions = list(set(common_regions).intersection(set(data_region)))
    
    # Get indexes of these common regions for each set
    common_regions_fit_inv_map = []
    common_regions_data_inv_map = []
    for r in common_regions:
        common_regions_fit_inv_map.append(location_id_inv_fit.index(r))
        common_regions_data_inv_map.append(location_id_inv_data.index(r))
        
    # We want to compare empirical and predicted probabilities for the forecast interval
    probs = probs[n_cases_periods:,common_regions_fit_inv_map,:]
    
    # Subset observed to relevant periods and regions
    obs_data = weekly_lineages_data[n_cases_periods:n_cases_periods+n_forecast_steps,common_regions_data_inv_map,:]
    empirical_probs = obs_data / obs_data.sum(-1,True).clamp_(min=1e-9)
    
    # Truncate to availanle data
    probs = probs[:empirical_probs.shape[0]-1,]
    
    return {
        'probs': probs,
        'empirical_probs': empirical_probs,
    }

In [None]:
def generate_forecast_eval_percent(fits, input_dataset, data_region = None, queries = None):
    model_keys = list(fits.keys())
    

    match_4wk = []
    match_8wk = []
    
    if queries is None:
        queries = input_dataset['location_id_inv']

    for key in model_keys:
        forecast_start_day = key[9]
        fit_n = fits[key]
        # Get forecast error for all independent location ids 
        probs_dict = evaluate_forecast3(
            fit_n, 
            input_dataset, 
            queries = queries,
            num_lineages=101,
            data_region = data_region,
        verbose=False)
        
        try:
            period_index_4wk = 1
            predicted_4wk = probs_dict['probs'][period_index_4wk,:].sum(-2).argmax(0).item()
            actual_4wk = probs_dict['empirical_probs'][period_index_4wk,:].sum(-2).argmax(0).item()      

            period_index_8wk = 3
            predicted_8wk = probs_dict['probs'][period_index_8wk,:].sum(-2).argmax(0).item()
            actual_8wk = probs_dict['empirical_probs'][period_index_8wk,].sum(-2).argmax(0).item()

            match_4wk.append(predicted_4wk == actual_4wk)
            match_8wk.append(predicted_8wk == actual_8wk)
        except:
            pass

    return {
        'match_4wk': match_4wk,
        'match_8wk': match_8wk,
    }


### USA

In [None]:
query = 'USA'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

### France

In [None]:
query = 'France'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)

In [None]:
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

### England

In [None]:
query = 'England'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)

In [None]:
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

### Brazil

In [None]:
query = 'Brazil'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

### Australia

In [None]:
query = 'Australia'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

### Russia

In [None]:
query = 'Russia'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
selected_region_forecast = generate_forecast_eval_percent(fits, input_dataset, data_region = regions)

In [None]:
torch.tensor(selected_region_forecast['match_4wk']).sum() / len(selected_region_forecast['match_4wk']) * 100

In [None]:
torch.tensor(selected_region_forecast['match_8wk']).sum() / len(selected_region_forecast['match_8wk']) * 100

# Heatmap of forecast evaluation

In [None]:
import matplotlib.dates as mdates
from matplotlib.pyplot import cm
import glob

In [None]:
def generate_forecast_eval_2(fits, input_dataset, data_region = None, queries = None):
    model_keys = list(fits.keys())
    
    if not queries:
        queries = input_dataset['location_id_inv']

    forecast_start_interval = []
    period_forecast_ahead = []
    l1_error = []
    l2_error = []
    
    period_length = 14

    for key in model_keys:
        
        forecast_start_day = key[9]
        fit_n = fits[key]
        # Get forecast error for all independent location ids 
        
        forecast_error = evaluate_forecast2(
            fit_n, 
            input_dataset, 
            queries = queries,
            num_lineages=100,
            data_region = data_region,
        verbose=False)
        
        forecast_start_interval_t = forecast_start_day // period_length
        
        n_periods_forecast = len(forecast_error['L1_error'].tolist())
        forecast_start_interval.extend([forecast_start_interval_t] * n_periods_forecast)
        period_forecast_ahead.extend(list(range(1,n_periods_forecast+1)))
        
        l1_error.extend(forecast_error['L1_error'].tolist())
        l2_error.extend(forecast_error['L2_error'].tolist())
        
    df1 = pd.DataFrame({
        'forecast_start_interval': forecast_start_interval,
        'period_forecast_ahead': period_forecast_ahead, 
        'l1_error': l1_error, 
        'l2_error': l2_error
    })
    
    df1['period_of_forecast'] = df1['forecast_start_interval'] + df1['period_forecast_ahead']
    
    return df1

## Heatmap for England

In [None]:
query = 'England'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
region_forecast_info = generate_forecast_eval_2(fits, input_dataset, data_region = regions)

In [None]:
# A date mapping
dates = (matplotlib.dates.date2num(mutrans.date_range(region_forecast_info['forecast_start_interval'].max()+12)))

In [None]:
region_forecast_info['forecast_start_interval_time'] = list( map(lambda x: dates[x-1], region_forecast_info['forecast_start_interval']) )
region_forecast_info['period_of_forecast_time'] = list( map(lambda x: dates[x-1], region_forecast_info['period_of_forecast']) )

In [None]:
region_forecast_info_pivot = region_forecast_info.pivot("forecast_start_interval_time","period_of_forecast_time","l1_error")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

_data = region_forecast_info_pivot.to_numpy()
_cns = np.array(region_forecast_info_pivot.columns)
_rns = np.array(list(region_forecast_info_pivot.index))

data = np.hstack((np.zeros((_data.shape[0], 1)), _data))
cns = np.asarray([_rns[0]] + _cns.tolist())
rns = _rns
for j in range(min(data.shape)):
    data[j, 0] = np.nan
    data[j, j] = np.nan


fig, ax = plt.subplots(figsize=(7,7))
im = ax.imshow(data, cmap = cm.plasma, aspect=1, extent = (rns[0], rns[-1], cns[0], cns[-1]), origin='lower', vmin= 0, vmax=2)
#ax.invert_yaxis()

# ax.set_yticks(ticks = list(range(len(rns))))
# ax.set_yticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in rns.tolist()), Fontsize = 2)

# ax.set_xticks(ticks = list(range(len(cns))))
# ax.set_xticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in cns.tolist()), Fontsize = 6)

ax.yaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.yaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

plt.colorbar(im, orientation = 'horizontal')

plt.xticks(rotation=90)
fig.subplots_adjust(bottom = -0.5)

#plt.show()

plt.savefig('paper/backtesting/heatmap_England.pdf',bbox_inches = 'tight')

## Heatmap for USA /Mass

In [None]:
query = 'USA / Massachusetts'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
regions

In [None]:
region_forecast_info = generate_forecast_eval_2(fits, input_dataset, data_region = regions)

In [None]:
# A date mapping
dates = (matplotlib.dates.date2num(mutrans.date_range(region_forecast_info['forecast_start_interval'].max()+12)))

In [None]:
region_forecast_info['forecast_start_interval_time'] = list( map(lambda x: dates[x-1], region_forecast_info['forecast_start_interval']) )
region_forecast_info['period_of_forecast_time'] = list( map(lambda x: dates[x-1], region_forecast_info['period_of_forecast']) )

In [None]:
region_forecast_info_pivot = region_forecast_info.pivot("forecast_start_interval_time","period_of_forecast_time","l1_error")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

_data = region_forecast_info_pivot.to_numpy()
_cns = np.array(region_forecast_info_pivot.columns)
_rns = np.array(list(region_forecast_info_pivot.index))

data = np.hstack((np.zeros((_data.shape[0], 1)), _data))
cns = np.asarray([_rns[0]] + _cns.tolist())
rns = _rns
for j in range(min(data.shape)):
    data[j, 0] = np.nan
    data[j, j] = np.nan


fig, ax = plt.subplots(figsize=(7,7))
im = ax.imshow(data, cmap = cm.plasma, aspect=1, extent = (rns[0], rns[-1], cns[0], cns[-1]), origin='lower', vmin= 0, vmax=2)
#ax.invert_yaxis()

# ax.set_yticks(ticks = list(range(len(rns))))
# ax.set_yticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in rns.tolist()), Fontsize = 2)

# ax.set_xticks(ticks = list(range(len(cns))))
# ax.set_xticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in cns.tolist()), Fontsize = 6)

ax.yaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.yaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

plt.colorbar(im, orientation = 'horizontal')

plt.xticks(rotation=90)
fig.subplots_adjust(bottom = -0.5)

#plt.show()

plt.savefig('paper/backtesting/heatmap_USA_Mass.pdf',bbox_inches = 'tight')

## Heatmap for Brazil

In [None]:
query = 'Brazil'
regions = list(x for x in input_dataset['location_id'].keys() if query in x)
region_forecast_info = generate_forecast_eval_2(fits, input_dataset, data_region = regions)

In [None]:
# A date mapping
dates = (matplotlib.dates.date2num(mutrans.date_range(region_forecast_info['forecast_start_interval'].max()+12)))

In [None]:
region_forecast_info['forecast_start_interval_time'] = list( map(lambda x: dates[x-1], region_forecast_info['forecast_start_interval']) )
region_forecast_info['period_of_forecast_time'] = list( map(lambda x: dates[x-1], region_forecast_info['period_of_forecast']) )

In [None]:
region_forecast_info_pivot = region_forecast_info.pivot("forecast_start_interval_time","period_of_forecast_time","l1_error")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

_data = region_forecast_info_pivot.to_numpy()
_cns = np.array(region_forecast_info_pivot.columns)
_rns = np.array(list(region_forecast_info_pivot.index))

data = np.hstack((np.zeros((_data.shape[0], 1)), _data))
cns = np.asarray([_rns[0]] + _cns.tolist())
rns = _rns
for j in range(min(data.shape)):
    data[j, 0] = np.nan
    data[j, j] = np.nan


fig, ax = plt.subplots(figsize=(7,7))
im = ax.imshow(data, cmap = cm.plasma, aspect=1, extent = (rns[0], rns[-1], cns[0], cns[-1]), origin='lower', vmin= 0, vmax=2)
#ax.invert_yaxis()

# ax.set_yticks(ticks = list(range(len(rns))))
# ax.set_yticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in rns.tolist()), Fontsize = 2)

# ax.set_xticks(ticks = list(range(len(cns))))
# ax.set_xticklabels(list(mdates.num2date(x).strftime("%d %b %Y") for x in cns.tolist()), Fontsize = 6)

ax.yaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.yaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))

plt.colorbar(im, orientation = 'horizontal')

plt.xticks(rotation=90)
fig.subplots_adjust(bottom = -0.5)

#plt.show()

plt.savefig('paper/backtesting/heatmap_Brazil.pdf',bbox_inches = 'tight')

## Corresponding Region Plots

### England

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england_for_heatmap.pdf',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england_for_heatmap.png',bbox_inches = 'tight')

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["England"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True,
    show_legend = False,
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england_for_heatmap_nolegend.png',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_england_for_heatmap_nolegend.pdf',bbox_inches = 'tight')

### Brazil 

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["Brazil"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_brazil_for_heatmap.pdf',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_brazil_for_heatmap.png',bbox_inches = 'tight')

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["Brazil"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True,
    show_legend = False,
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_brazil_for_heatmap_nolegend.png',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_brazil_for_heatmap_nolegend.pdf',bbox_inches = 'tight')

### USA / Mass

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["Massachusetts"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_Massachusetts_for_heatmap.pdf',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_Massachusetts_for_heatmap.png',bbox_inches = 'tight')

In [None]:
k = list(fits.keys())[-1]
print(k[9])
fit_n = fits[k]
england_for_heatmap = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["Massachusetts"],
    num_lineages=13,
    verbose=False,
    figsize_x = 7,
    figsize_y = prediction_figsize_y,
    additional_lineages = plot_forecast_results['lineages_plotted'],
    #colors_dict_export = plot_forecast_results['colors_dict'],
    auto_select_lineages = True,
    legend_out = True,
    show_legend = False,
);
start_date = matplotlib.dates.date2num(np.datetime64('2020-05-01'))
end_date = matplotlib.dates.date2num(np.datetime64('2021-12-01'))
england_for_heatmap['ax'].set_xlim((start_date, end_date))
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_Massachusetts_for_heatmap_nolegend.png',bbox_inches = 'tight')
plt.savefig(f'{forecast_dir_prefix}/backtesting_day_{k[9]}_early_prediction_Massachusetts_for_heatmap_nolegend.pdf',bbox_inches = 'tight')

# Supplementary Figure 3

In [None]:
main_analysis_fits = torch.load("results/mutrans.pt", map_location="cpu")

In [None]:
main_analysis_fits.keys()

In [None]:
k = list(main_analysis_fits.keys())[0]
print(k[9])
fit_n = main_analysis_fits[k]
plot_forecast_results = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["USA"],
    num_lineages=1,
    verbose=False,
    additional_lineages = ['BA.1.1'],
    legend_out=True,
    show_case_counts = False,
    filenames = ["paper/Figure_S3.pdf"]
)

In [None]:
k = list(main_analysis_fits.keys())[0]
print(k[9])
fit_n = main_analysis_fits[k]
plot_forecast_results = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["USA","France","England","Brazil","Australia","Russia"],
    num_lineages=13,
    verbose=False,
    additional_lineages = ['BA.2'],
    legend_out=True,
    show_case_counts = False,
    filenames = ["paper/Figure_S3.pdf"]
)

In [None]:
k = list(main_analysis_fits.keys())[0]
print(k[9])
fit_n = main_analysis_fits[k]
plot_forecast_results = plot_forecast2(
    fit_n, 
    input_dataset, 
    queries=["USA","France","England","Brazil","Australia","Russia"],
    num_lineages=13,
    verbose=False,
    additional_lineages = ['BA.1','BA.2'],
    legend_out=True,
    show_case_counts = True,
    filenames = ["paper/Figure_S3_withcases.pdf"],
    show_second_legend = False,
)

In [None]:
plot_forecast_results.keys()

In [None]:
xs = list(range(plot_forecast_results['rates'].shape[0]))
ys = np.exp(list(plot_forecast_results['rates']))

fig, ax = plt.subplots()
ax.bar(x = xs, height = ys, color = list(map(plot_forecast_results['colors_dict'].get, plot_forecast_results['lineages_plotted'])))
ax.set_xticks(xs)
ax.set_xticklabels(plot_forecast_results['lineages_plotted'].tolist(), rotation=90)
ax.set_ylabel('$R_{lineage}/R_A$')
plt.savefig('paper/barplot_rates_inset_for_S3.pdf')

## Plot for Asia for the reviewer response

In [None]:
main_analysis_fits = torch.load("results/mutrans.pt", map_location="cpu")

In [None]:
location_keys = main_analysis_fits[list(main_analysis_fits.keys())[0]]['location_id'].keys()
fit_n = main_analysis_fits[list(main_analysis_fits.keys())[0]]

In [None]:
len(list(x for x in location_keys if 'Asia' in x))

In [None]:
plot_forecast_results = plot_forecast2(
    fit_n,
    input_dataset,
    num_lineages = 10,
    legend_out=True,show_second_legend=False,
    queries=['Asia'],
    filenames = ['paper/forecast_Asia.pdf','paper/forecast_Asia.png'])

## Generate selected per-region forecasts showing BA.2 rising

In [None]:
main_analysis_fits = torch.load("results/mutrans.pt", map_location="cpu")

In [None]:
location_keys = main_analysis_fits[list(main_analysis_fits.keys())[0]]['location_id'].keys()
fit_n = main_analysis_fits[list(main_analysis_fits.keys())[0]]

In [None]:
plot_forecast_results = plot_forecast2(
    fit_n,
    input_dataset,
    num_lineages = 10,
    legend_out=True,show_second_legend=False,
    queries=['Denmark','South Africa','India'],
    filenames = ['paper/forecast_Denmark_SouthAfrica_India.pdf'])

In [None]:
xs = list(range(plot_forecast_results['rates'].shape[0]))
ys = np.exp(list(plot_forecast_results['rates']))

fig, ax = plt.subplots()
ax.bar(x = xs, height = ys, color = list(map(plot_forecast_results['colors_dict'].get, plot_forecast_results['lineages_plotted'])))
ax.set_xticks(xs)
ax.set_xticklabels(plot_forecast_results['lineages_plotted'].tolist(), rotation=90)
ax.set_ylabel('$R_{lineage}/R_A$')
plt.savefig('paper/barplot_rates_inset_for_forecast_Denmark_SouthAfrica_India.pdf')

## Prediction accuracy vs Region Coverage

In [None]:
import tqdm

In [None]:
def evaluate_forecast3(fit, input_dataset, filenames=[]):

    n_model_periods, n_model_places, n_model_lineages = fit['mean']['probs'].shape

    weekly_cases_fit = fit['weekly_cases']
    n_cases_periods, n_cases_places = weekly_cases_fit.shape
    n_forecast_steps = n_model_periods - n_cases_periods
    
    weekly_clades_fit = fit['weekly_clades'] # T, P, C
    weekly_clades_data = input_dataset['weekly_clades']
    clade_id_to_lineage_id = input_dataset['clade_id_to_lineage_id']
    
    # Summarize the counts of the weekly_clades (from data or fit) to the number of lineages in the model
    weekly_lineages_data = weekly_clades_to_lineages(weekly_clades_data, clade_id_to_lineage_id, n_model_lineages)
    weekly_lineages_fit = weekly_clades_to_lineages(weekly_clades_fit, clade_id_to_lineage_id, n_model_lineages)
    
    # This is an array of strings listing the locations for the data
    location_id_inv_data = input_dataset['location_id_inv']

    # This is an array of strings listing the locations for the fit
    location_id_inv_fit = fit['location_id_inv']
    lineage_id_inv_fit = fit['lineage_id_inv']
    lineage_id_inv_data = input_dataset['lineage_id_inv']
        
    location_id_fit = fit['location_id']
    location_id_data = input_dataset['location_id']

    common_regions = list(set(location_id_inv_fit).intersection(set(location_id_inv_data)))
    common_regions_fit_inv_map = list(map(location_id_fit.get, common_regions))
    common_regions_data_inv_map = list(map(location_id_data.get, common_regions))
        
    # We want to compare empirical and predicted probabilities for the forecast interval
    probs = fit['mean']['probs']
    probs = probs[n_cases_periods:, common_regions_fit_inv_map, :]
    
    # Subset observed to relevant periods and regions
    obs_data = weekly_lineages_data[n_cases_periods:n_cases_periods+n_forecast_steps, common_regions_data_inv_map,:]
    empirical_probs = obs_data / obs_data.sum(-1,True).clamp_(min=1e-9)
    
    # Calculate errors
    l1_error = (probs[:empirical_probs.shape[0],] - empirical_probs).abs().sum(-1)
    l2_error = (probs[:empirical_probs.shape[0],] - empirical_probs).abs().pow(2).sum(-1).sqrt()

    return {
        'L1_error': l1_error, # T x P
        'L2_error': l2_error, # T x P
        'common_regions': common_regions,
    }

In [None]:
regions_col = []
fit_days = []
L1_error = []

for fit_key in tqdm.tqdm(fits.keys()):
    fit_n = fits[fit_key]
    days = fit_key[9]

    forecast_error = evaluate_forecast3(
        fit_n, 
        input_dataset)
        
    fit_days.extend(list(days for i in range(len(forecast_error['common_regions']))))
    regions_col.extend(forecast_error['common_regions'])
    L1_error.extend(forecast_error['L1_error'][0].tolist()) # The 2 week forecast

In [None]:
error_df = pd.DataFrame({
    'fit_days': fit_days,
    'region': regions_col,
    'L1_error': L1_error
})
total_region_coverage = pd.DataFrame(
    {
        'region': last_fit['location_id_inv'],
        'total_counts': last_fit['weekly_clades'].sum([0,2]).tolist()
    }
)
merged = error_df.merge(total_region_coverage, on = 'region')
merged['total_counts_per_day'] = merged['total_counts'] / merged['fit_days']

In [None]:
merged["total_counts_per_day_p1"] = merged["total_counts_per_day"] + 1e-2
merged["L1_error_p1"] = merged["L1_error"] + 1e-6

In [None]:
sns.histplot(
    merged,
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='All Data')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('Asia')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='Asia')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('USA')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='USA')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('United Kingdom')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='United Kingdom')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('Europe')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='Europe')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('Denmark')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='Denmark')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('India')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='India')

In [None]:
sns.histplot(
    merged[merged['region'].str.contains('South Africa')],
    x='total_counts_per_day_p1', 
    y='L1_error_p1', 
    log_scale=(True, False), 
    bins=10, 
    thresh=None).set(title='South Africa')

## Looking at Asia sub-regions

In [None]:
plot_forecast_results = plot_forecast2(
    fit_n,
    input_dataset,
    num_lineages = 10,
    legend_out=True,show_second_legend=False,
    queries=['Asia / Japan', 'Asia / India', 'Asia / Myanmar', 'Asia / Pakistan / Multan'])

## Model Prediction Accuracy Vs Sample Count