# Plots for backtesting mutation growth rate paper

This notebook generates plots for the [paper/](paper/) directory. This assumes you've alread run
```sh
make update                       # Downloads and preprocesses data.
python mutrans.py --backtesting-max-day 150,200,250,300,350,400,450,500,550 # Fits models with different data truncations
```
Note that `make update` takes a couple hours the first time it is run (mostly in sequence alignment), and `mutrans.py` takes about 15 minutes on a GPU (will take much longer if no GPU is available).

In [None]:
import datetime
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pyro.distributions as dist
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import pretty_print, pearson_correlation

In [None]:
# configure logging
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

In [None]:
logging.getLogger().setLevel(logging.DEBUG)

In [None]:
# set matplotlib params
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams['figure.figsize'] = [8, 8]
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

## Load data

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("Loaded data from {} samples".format(len(columns["lineage"])))

Sanity checking case count time series:

In [None]:
if False:
    plt.figure(figsize=(8, 3))
    plt.plot(weekly_cases, lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("confirmed cases");

    plt.figure(figsize=(8, 3))
    plt.plot(weekly_strains.sum(-1), lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("sequenced samples");

In [None]:
# locations = set(location_id)
# N_usa = sum(1 for k in locations if "/ USA /" in k)
# N_uk = sum(1 for k in locations if "/ United Kingdom /" in k)
# N_other = len(locations) - N_usa - N_uk
# print(N_usa, N_uk, N_other)

We'll account for epidemiological dynamics in the form of random drift on top of our logistic growth model. Since random drift is inversely proportional to the local number of infections, we'll need a new data source for the number of infections in each region. We'll use JHU's confirmed case counts time series as a proxy for the number of total infections in each region.

## Load trained models

In [None]:
fits = torch.load("results/mutrans.pt", map_location="cpu")
first_key = list(fits.keys())[0]
for key in fits:
    print(key)
fits[first_key].keys()

In [None]:
for fit in fits.values():
    print(fit["weekly_strains_shape"])

In [None]:
# This is really just the second fit
best_fit = list(fits.values())[1]
#pretty_print(best_fit, max_items=40)

Scale `coef` by 1/100 in all results.

In [None]:
ALREADY_SCALED = set()

def scale_tensors(x, names={"coef"}, scale=0.01, prefix=""):
    if id(x) in ALREADY_SCALED:
        return
    if isinstance(x, dict):
        for k, v in list(x.items()):
            if k in names:
                print(f"{prefix}.{k}")
                x[k] = v * scale
            elif k == "diagnostics":
                continue
            else:
                scale_tensors(v, names, scale, f"{prefix}.{k}")
    ALREADY_SCALED.add(id(x))
                
scale_tensors(fits)

In [None]:
plt.figure(figsize=(2,2))
plt.hist(best_fit["params"]["local_time"].reshape(-1).numpy(), bins=50, density=True)
plt.xlabel("local time shift");


## Assess model fitness

In [None]:
def plot_fits():
    for key, fit in fits.items():
        weekly_strains = fit['weekly_strains']
        num_nonzero = int(torch.count_nonzero(weekly_strains))
        median = fit.get("median", fit.get("mean", {}))
        plt.figure(figsize=(8, 7))
        time = np.arange(1, 1 + len(fit["losses"]))
        plt.plot(fit["losses"], "k--", label="loss")
        locs = []
        grads = []
        for name, series in fit["series"].items():
            rankby = -torch.tensor(series).log1p().mean().item()
            if name.startswith("Guide."):
                name = name[len("Guide."):].replace("$$$", ".")
                grads.append((name, series, rankby))
            elif name.endswith("_centered") or name == "local_time":
                grads.append((name, series, rankby))
            elif name != "loss":
                locs.append((name, series, rankby))
        locs.sort(key=lambda x: x[-1])
        grads.sort(key=lambda x: x[-1])
        for name, series, _ in locs:
            plt.plot(time, series, label=name)
        for name, series, _ in locs:
            plt.plot(time, series, color="white", lw=3, alpha=0.3, zorder=-1)
        for name, series, _ in grads:
            plt.plot(time, series, lw=1, alpha=0.3, label=name, zorder=-2)
        plt.yscale("log")
        plt.xscale("log")
        plt.xlim(1, len(fit["losses"]))
        plt.legend(loc="upper left", fontsize=8)
        plt.xlabel("SVI step (duration = {:0.1f} minutes)".format(fit["walltime"]/60))
        loss = np.median(fit["losses"][-201:]) / num_nonzero
        scalars = " ".join([f"L={loss:0.6g}"] + [
            "{}={:0.3g}".format(
                "".join(p[0] for p in k.split("_")).upper(), v
            )
            for k, v in median.items()
            if v.numel() == 1
        ])
        plt.title("{} ({})\n{}".format(key[0], scalars, key[-1]))
plot_fits()

## Forecasting

In [None]:
import importlib
import mutrans_helpers

In [None]:
# Reload the import library helpers
importlib.reload(mutrans_helpers)

In [None]:
importlib.reload(mutrans_helpers)
# Plot the forecasts of a specific model
k = list(fits.keys())
max_days = k[3]
fit = fits[k[3]]

fc1 = mutrans_helpers.generate_forecast(
    fit=fit
    queries=["England", "USA / California", "Brazil"])

mutrans_helpers.get_forecast_values(forecast=fc1)

In [None]:
# Plot the forecasts of a specific model
k = list(fits.keys())
max_days = k[3]
fit = fits[k[3]]

fc1 = mutrans_helpers.generate_forecast(
    fit=fit, 
    queries=["England", "USA / California", "Brazil"])

mutrans_helpers.plot_forecast(
    forecast=fc1, 
    filename= f"results/max_days_{max_days}.png", 
    plot_fit=True, 
    plot_fit_ci=True)

In [None]:
    max_days = k[3]
    fit = fits[k]
    fc1 = generate_forecast(fit=fit, queries=["England", "USA / California", "Brazil"])
    plot_forecast(forecast=fc1, filename= f"results/max_days_{max_days}.png")

In [None]:
for k in fits.keys():
    max_days = k[8]
    fit = fits[k]
    fc1 = generate_forecast(fit=fit, queries=["England", "USA / California", "Brazil"])
    plot_forecast(forecast=fc1, filename= f"results/max_days_{max_days}.png")

In [None]:
for i in len(fits)
best_fit = list(fits.values())[8]
forecast1 = generate_forecast(fit=best_fit, queries=["England", "USA / California", "Brazil"])
plot_forecast_2(forecast=forecast1)

In [None]:
selected_fit = list(fits.values())[5]
selected_fit["mean"]["probs"].shape
# T,P,S shape

plot_forecast(selected_fit,
              queries=["England", "USA / California", "Brazil"],
              num_strains=10,
              filenames=["paper/forecast.png"])