# Backtesting

This notebook generates plots for for backtesting. It requires that the script `run_backtesting.sh` has been run prior to execution.

In [None]:
# imports
import datetime
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pyro.distributions as dist
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import pretty_print, pearson_correlation
import seaborn as sns

In [None]:
# configure logging
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

In [None]:
# This line can be used to modify logging as required
logging.getLogger().setLevel(logging.INFO)

In [None]:
# set matplotlib params
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams['figure.figsize'] = [8, 8]
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

## Load data

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("Loaded data from {} samples".format(len(columns["lineage"])))

In [None]:
type(columns)

In [None]:
columns.keys()

In [None]:
for k in columns.keys():
    print(k, columns[k][0:2])

## Load trained models

In [None]:
fits = torch.load("results/mutrans.pt", map_location="cpu")

In [None]:
# print info on available models and what the keys are
if False:
    for key in fits:
        print(key)
    first_key = list(fits.keys())[0]
    fits[first_key].keys()

In [None]:
# Print the shape of the weekly strains in each fit
if False:
    for fit in fits.values():
        print(fit["weekly_strains_shape"])

In [None]:
# This is really just the second fit
#selected_fit = list(fits.values())[1]
#pretty_print(selected_fit, max_items=40)

Scale `coef` by 1/100 in all results.

In [None]:
ALREADY_SCALED = set()

def scale_tensors(x, names={"coef"}, scale=0.01, prefix="", verbose=True):
    if id(x) in ALREADY_SCALED:
        return
    if isinstance(x, dict):
        for k, v in list(x.items()):
            if k in names:
                if verbose:
                    print(f"{prefix}.{k}")
                x[k] = v * scale
            elif k == "diagnostics":
                continue
            else:
                scale_tensors(v, names, scale, f"{prefix}.{k}", verbose=verbose)
    ALREADY_SCALED.add(id(x))
                
scale_tensors(fits, verbose=False)

# Forecasting

In [None]:
import importlib
from pyrocov import mutrans_helpers

In [None]:
# customize logging
if False:
    logging.getLogger().setLevel(logging.ERROR)

## Day 542: Last available

In [None]:
if False:
    logging.getLogger().setLevel(logging.INFO)

In [None]:
importlib.reload(mutrans_helpers)

i = len(fits) - 1
key, latest_fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=latest_fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    filename="paper/forecasts/forecast_day_542.png"
)

## Day 346:  B.1.1.7 in UK

In [None]:
importlib.reload(mutrans_helpers)

i = 14
key, fit_d346 = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit_d346, 
    strains_to_show=strains_from_manuscript_figure,queries=['England'],
    future_fit=latest_fit,
    num_strains=2000,
    filename="paper/forecasts/forecast_day_346.png"
)

In [None]:
q = {"England": ['B.1','B.1.1.7','B.1.177','B.1.177.4']}

mutrans_helpers.evaluate_fit_forecast(fit_d346, latest_fit, 
                                      queries = q)

# Day 234: B.1.617.2 in UK

In [None]:
importlib.reload(mutrans_helpers)

In [None]:
i = 26
key, fit_d514 = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit_d514, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
    filename='paper/forecasts/forecast_day_514.png',
    forecast_periods_plot=2,
)

In [None]:
q = {"England": ['B.1','B.1.1','B.1.1.7','B.1.177','B.1.177.4']}

mutrans_helpers.evaluate_fit_forecast(fit_d514, latest_fit, 
                                      queries = q)

## Evaluate future forecast MAE in different windows

In [None]:
def get_fits_forecast_mae(n_fits=7, queries=None, key_eval = "England MAE",
                         steps_for_mae_eval = 10, n_intervals=3):
    """Get forecast scores for different fits
    
    :param n_fits: indexes of fits in global fits to evaluate
    :param queries: a queries hashmap
    :param key_eval: which result of evaluate_fit_forecast to keep
    :param steps_for_mae_eval: number of steps to skips=
    """

    # Results
    mae = []
    step = []
    fit_days = []
    
    # queries has to be set
    assert queries
    
    # For the different fits
    for fit_index in range(n_fits):
        fit_key, fit_plot = mutrans_helpers.get_fit_by_index(fits, fit_index)

        # For the different steps
        for steps_index in range(steps_for_mae_eval):
            fit_eval = mutrans_helpers.evaluate_fit_forecast(
                fit_plot,
                latest_fit,
                queries=q,
                n_intervals=n_intervals,
                n_intervals_skip=steps_index)
            
            # Value that we want to keep
            val_save = fit_eval[key_eval].item()
        
            mae.append(val_save)
            step.append(steps_index)
            fit_days.append(fit_key[8])
            
    return pd.DataFrame({'fit_days': fit_days, 'step': step, 'mae': mae})

In [None]:
q = {"England": ['B.1','B.1.1','B.1.1.7','B.1.177','B.1.177.4']}

england_forecasts_eval = get_fits_forecast_mae(n_fits=10, queries=q)

In [None]:
sns_plot = sns.lineplot(x="step", y="mae",
             hue="fit_days", 
             data=england_forecasts_eval)

fig = sns_plot.get_figure()

fig.savefig('paper/forecasts/England_MAE_forecast.png')