# Plots for mutation growth rate paper

This notebook generates plots for the [paper/](paper/) directory. This assumes you've alread run
```sh
make update                       # Downloads data (~1hour).
make preprocess                   # Preprocesses data (~3days on a big machine).
python mutrans.py --vary-holdout  # Fits and crossvalidates model (~1hour GPU).
```

In [None]:
import datetime
import math
import os
import pickle
import re
import logging
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import pyro.distributions as dist
from pyro.ops.tensor_utils import convolve
from pyrocov import mutrans, pangolin, stats
from pyrocov.stats import normal_log10bf
from pyrocov.util import (
    pretty_print, pearson_correlation, quotient_central_moments, generate_colors
)
from pyrocov.sarscov2 import GENE_TO_POSITION, GENE_STRUCTURE, aa_mutation_to_position

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["figure.facecolor"] = "white"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams["savefig.pad_inches"] = 0.01
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']
matplotlib.rcParams.update({
    # 'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsfonts}',
})

## Load data

In [None]:
max_num_clades = 3000
min_num_mutations = 1

In [None]:
%%time
def load_data():
    filename = f"results/mutrans.data.single.{max_num_clades}.{min_num_mutations}.50.None.pt"
    dataset = torch.load(filename, map_location="cpu")
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset
dataset = load_data()
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

Create a dense mapping between fine clades and Pango lineages.

In [None]:
print("{} x {} x {} = {}".format(*weekly_clades.shape, weekly_clades.shape.numel()))
print(int(weekly_clades.sum()))
print(weekly_clades.ne(0).float().mean().item())
print(weekly_clades.ne(0).any(0).float().mean().item())

In [None]:
with open(f"results/columns.{max_num_clades}.pkl", "rb") as f:
    columns = pickle.load(f)
print("Loaded data from {} samples".format(len(columns["lineage"])))

In [None]:
try:
    with open("results/nextclade.counts.pkl", "rb") as f:
        all_mutations = pickle.load(f)
except Exception:
    with open("results/stats.pkl", "rb") as f:
        all_mutations = pickle.load(f)["aaSubstitutions"]
print(f"Loaded {len(all_mutations)} mutations")

Sanity checking case count time series:

In [None]:
if False:
    plt.figure(figsize=(8, 3))
    plt.plot(weekly_cases, lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("confirmed cases");

    plt.figure(figsize=(8, 3))
    plt.plot(weekly_clades.sum(-1), lw=1, alpha=0.5)
    plt.yscale("symlog", linthresh=10)
    plt.ylim(0, None)
    plt.xlim(0, len(weekly_cases) - 1)
    plt.xlabel("week after 2019-12-01")
    plt.ylabel("sequenced samples");

In [None]:
locations = set(location_id)
N_usa = sum(1 for k in locations if "/ USA /" in k)
N_uk = sum(1 for k in locations if "/ United Kingdom /" in k)
N_other = len(locations) - N_usa - N_uk
print(N_usa, N_uk, N_other)

We'll account for epidemiological dynamics in the form of random drift on top of our logistic growth model. Since random drift is inversely proportional to the local number of infections, we'll need a new data source for the number of infections in each region. We'll use JHU's confirmed case counts time series as a proxy for the number of total infections in each region.

## Load trained models

In [None]:
fits = torch.load("results/mutrans.pt", map_location="cpu")

In [None]:
for key in fits:
    print(key)

In [None]:
best_fit = list(fits.values())[0]
pretty_print(best_fit, max_items=40)

Scale `coef` by 1/100 in all results.

In [None]:
ALREADY_SCALED = set()

def scale_tensors(x, names={"coef"}, scale=0.01, prefix=""):
    if id(x) in ALREADY_SCALED:
        return
    if isinstance(x, dict):
        for k, v in list(x.items()):
            if k in names:
                print(f"{prefix}.{k}")
                x[k] = v * scale
            elif k == "diagnostics":
                continue
            else:
                scale_tensors(v, names, scale, f"{prefix}.{k}")
    ALREADY_SCALED.add(id(x))
                
scale_tensors(fits)

In [None]:
print(list(best_fit["params"]))

In [None]:
if False:
    plt.plot(
        best_fit["mean"]["init_loc"] + 0 * best_fit["median"]["init"],
        best_fit["median"]["init"],
        "k.",
    )


## Assess model fitness

## Interpreting results

In [None]:
def plusminus(mean, std):
    p95 = 1.96 * std
    return torch.stack([mean - p95, mean, mean + p95])

def plot_forecast(fit, queries=None, num_strains=10, filenames=[]):
    if queries is None:
        queries = list(location_id)
    elif isinstance(queries, str):
        queries = [queries]
    fig, axes = plt.subplots(len(queries), figsize=(8, 0.5 + 2.5 * len(queries)), sharex=True)
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]
    dates = matplotlib.dates.date2num(mutrans.date_range(len(fit["mean"]["probs"])))
    forecast_steps = len(fit["mean"]["probs"]) - len(weekly_cases)
    assert forecast_steps >= 0
    probs = plusminus(fit["mean"]["probs"], fit["std"]["probs"])  # [3, T, P, L]
    padding = 1 + weekly_cases.mean(0, True).expand(forecast_steps, -1)
    weekly_cases_ = torch.cat([weekly_cases, padding], 0)
    weekly_cases_.add_(1)  # avoid divide by zero
    predicted = probs * weekly_cases_[..., None]
    L = probs.shape[-1]
    weekly_lineages = weekly_clades.new_zeros(weekly_clades.shape[:-1] + (L,)).scatter_add_(
        -1, clade_id_to_lineage_id.expand_as(weekly_clades), weekly_clades
    )
    ids = torch.tensor([i for i, name in enumerate(location_id_inv)
                        if any(q in name for q in queries)])
    
    T = weekly_lineages.shape[0]
    early_strain_ids = weekly_lineages[:(T-8), ids].sum([0, 1]).sort(-1, descending=True).indices
    late_strain_ids = weekly_lineages[(T-8):, ids].sum([0, 1]).sort(-1, descending=True).indices
    strain_ids = torch.cat((early_strain_ids[:(num_strains//2)], late_strain_ids[:(num_strains - num_strains//2)]))
    print(type(strain_ids))
    
    #strain_ids = weekly_lineages[:, ids].sum([0, 1]).sort(-1, descending=True).indices
    #strain_ids = strain_ids[:num_strains]
    colors = generate_colors()
    assert len(colors) >= num_strains
    light = "#bbbbbb"
    for row, (query, ax) in enumerate(zip(queries, axes)):
        ids = torch.tensor([i for i, name in enumerate(location_id_inv) if query in name])
        print(f"{query} matched {len(ids)} regions")
        if len(axes) > 1:
            counts = weekly_cases[:, ids].sum(1)
            print(f"{query}: max {counts.max():g}, total {counts.sum():g}")
            counts /= counts.max()
            ax.plot(dates[:len(counts)], counts, linestyle="-", color=light, lw=0.8, zorder=-20)
            counts = weekly_lineages[:, ids].sum([1, 2])
            counts /= counts.max()
            ax.plot(dates[:len(counts)], counts, linestyle="--", color=light, lw=1, zorder=-20)
        pred = predicted.index_select(-2, ids).sum(-2)
        pred /= pred[1].sum(-1, True).clamp_(min=1e-20)
        obs = weekly_lineages[:, ids].sum(1)
        obs /= obs.sum(-1, True).clamp_(min=1e-9)
        for s, color in zip(strain_ids, colors):
            lb, mean, ub = pred[..., s]
            ax.fill_between(dates, lb, ub, color=color, alpha=0.2, zorder=-10)
            ax.plot(dates, mean, color=color, lw=1, zorder=-9)
            lineage = lineage_id_inv[s]
            ax.plot(dates[:len(obs)], obs[:, s], color=color, lw=0, marker='o', markersize=3,
                    label=lineage if row == 0 else None)
        ax.set_ylim(0, 1)
        ax.set_yticks(())
        ax.set_ylabel(query.replace(" / ", "\n"))
        ax.set_xlim(dates.min(), dates.max())
        if row == 0:
            ax.legend(loc="upper left", fontsize=8 * (10 / num_strains) ** 0.8)
        elif row == 1:
            ax.plot([], linestyle="--", color=light, lw=1, label="relative #samples")
            ax.plot([], linestyle="-", color=light, lw=0.8, label="relative #cases")
            ax.plot([], lw=0, marker='o', markersize=3, color='gray',
                    label="observed portion")
            ax.fill_between([], [], [], color='gray', label="predicted portion")
            ax.legend(loc="upper left",)
    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))
    plt.xticks(rotation=90)
    plt.subplots_adjust(hspace=0)
    for filename in filenames:
        plt.savefig(filename)

In [None]:
for loc in location_id.keys():
    name = loc.replace(' / ','_')
    print(name)
    plot_forecast(best_fit,
        queries=[loc],
        num_strains=15,
        filenames=[f"paper/per_region_forecasts/{name}.png"])