# Plots for backtesting mutation growth rate paper

This notebook generates plots for the [paper/](paper/) directory. This assumes you've alread run
```sh
make update                       # Downloads and preprocesses data.
source run_backtesting.sh         # Runs backtesting experiments
```
Note that `make update` takes a couple hours the first time it is run (mostly in sequence alignment), and `mutrans.py` takes about 15 minutes on a GPU (will take much longer if no GPU is available).

In [None]:
import datetime
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pyro.distributions as dist
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import pretty_print, pearson_correlation

In [None]:
# configure logging
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

In [None]:
logging.getLogger().setLevel(logging.DEBUG)

In [None]:
# set matplotlib params
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams['figure.figsize'] = [8, 8]
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

## Load data

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("Loaded data from {} samples".format(len(columns["lineage"])))

In [None]:
type(columns)

In [None]:
columns.keys()

In [None]:
for k in columns.keys():
    print(k, columns[k][0:2])

Sanity checking case count time series:

In [None]:
if False:
    plt.figure(figsize=(8, 3))
    plt.plot(weekly_cases, lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("confirmed cases");

    plt.figure(figsize=(8, 3))
    plt.plot(weekly_strains.sum(-1), lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("sequenced samples");

In [None]:
# locations = set(location_id)
# N_usa = sum(1 for k in locations if "/ USA /" in k)
# N_uk = sum(1 for k in locations if "/ United Kingdom /" in k)
# N_other = len(locations) - N_usa - N_uk
# print(N_usa, N_uk, N_other)

We'll account for epidemiological dynamics in the form of random drift on top of our logistic growth model. Since random drift is inversely proportional to the local number of infections, we'll need a new data source for the number of infections in each region. We'll use JHU's confirmed case counts time series as a proxy for the number of total infections in each region.

## Load trained models

In [None]:
fits = torch.load("results/mutrans.pt", map_location="cpu")
first_key = list(fits.keys())[0]
for key in fits:
    print(key)
fits[first_key].keys()

In [None]:
# Print the shape of the weekly strains in each fit
for fit in fits.values():
    print(fit["weekly_strains_shape"])

In [None]:
# This is really just the second fit
selected_fit = list(fits.values())[1]
#pretty_print(selected_fit, max_items=40)

Scale `coef` by 1/100 in all results.

In [None]:
ALREADY_SCALED = set()

def scale_tensors(x, names={"coef"}, scale=0.01, prefix=""):
    if id(x) in ALREADY_SCALED:
        return
    if isinstance(x, dict):
        for k, v in list(x.items()):
            if k in names:
                print(f"{prefix}.{k}")
                x[k] = v * scale
            elif k == "diagnostics":
                continue
            else:
                scale_tensors(v, names, scale, f"{prefix}.{k}")
    ALREADY_SCALED.add(id(x))
                
scale_tensors(fits)

In [None]:
# plt.figure(figsize=(2,2))
# plt.hist(selected_fit["params"]["local_time"].reshape(-1).numpy(), bins=50, density=True)
# plt.xlabel("local time shift");

## Forecasting

In [None]:
import importlib
from pyrocov import mutrans_helpers

In [None]:
# Reload the import library helpers
importlib.reload(mutrans_helpers)

In [None]:
# customize logging
logging.getLogger().setLevel(logging.ERROR)

## Day 525 ( as in latest prediction for comparison )

In [None]:
importlib.reload(mutrans_helpers)

i = 15
key, latest_fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=latest_fit, 
    strains_to_show=strains_from_manuscript_figure,queries=['England']
)

## Day 225

In [None]:
importlib.reload(mutrans_helpers)

i = 3
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")
mutrans_helpers.plot_fit_forecasts(fit)

In [None]:
i = 16
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

In [None]:
strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

In [None]:
avail_strains = mutrans_helpers.get_available_strains(fit, num_strains=10)
avail_strains

In [None]:
mutrans_helpers.plot_fit_forecasts(fit, strains_to_show=strains_from_manuscript_figure)

# Forecasts at the point where B.1.1.7 arises in the UK

## Day 325

In [None]:
importlib.reload(mutrans_helpers)

#i = 9
i = 7
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,queries=['England'],num_strains=1000,
        future_fit=latest_fit,
)

## Day 350

In [None]:
importlib.reload(mutrans_helpers)

#i = 9
i = 8
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,queries=['England'],
        future_fit=latest_fit,
)

In [None]:
## Day 375

In [None]:
latest_fit.keys()

In [None]:
importlib.reload(mutrans_helpers)

#i = 9
i = 9
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
)

In [None]:
## Timepoint wher B.1.177 arises
importlib.reload(mutrans_helpers)

i = 5
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
)

In [None]:
importlib.reload(mutrans_helpers)

i = 6
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
)

In [None]:
importlib.reload(mutrans_helpers)

i = 7
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
)

In [None]:
importlib.reload(mutrans_helpers)

i = 8
key, fit = mutrans_helpers.get_fit_by_index(fits, i)
print(f"Max days of model: {key[8]}")

strains_from_manuscript_figure = ['B.1.1.7','B.1.617.2','B.1.177','B.1.429','P.1',
                                 'B.1.1','B.1','B.1.427','B.1.2','B.1.177.4']

matplotlib.rcParams['figure.figsize'] = [10, 4]
mutrans_helpers.plot_fit_forecasts(
    fit=fit, 
    strains_to_show=strains_from_manuscript_figure,
    queries=['England'],
    future_fit=latest_fit,
)

## How often do we identify correct future dominant strain?