# Mini version of mutrans.py model

In [None]:
import os
import logging
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import pyro.distributions as dist
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
from pyrocov import mutrans, pangolin, stats

logging.basicConfig(format="%(message)s", level=logging.INFO)
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

In [None]:
def load_data_subset(*args, **kwargs):
    filename = "results/mutrans.data.single.pt"
    if os.path.exists(filename):
        dataset = torch.load(filename)
    else:
        dataset = mutrans.load_gisaid_data()
        torch.save(dataset, filename)
    dataset = mutrans.subset_gisaid_data(dataset, *args, **kwargs)
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset

dataset = load_data_subset(
    ["United Kingdom"],
    max_strains=20,
)
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

In [None]:
fit = mutrans.fit_svi(
    dataset,
    model_type="overdispersed",
    guide_type="full",
    init_data="",
    learning_rate=0.05,
    learning_rate_decay=1,
    num_steps=2001,
    num_particles=1,
    clip_norm=10.0,
    log_every=100,
    seed=20210319,
)

In [None]:
def plot_fit():
    num_nonzero = int(torch.count_nonzero(weekly_strains))
    median = fit["median"]
    plt.figure(figsize=(8, 7))
    time = np.arange(1, 1 + len(fit["losses"]))
    plt.plot(fit["losses"], "k--", label="loss")
    locs = []
    grads = []
    for name, series in fit["series"].items():
        rankby = -torch.tensor(series).log1p().mean().item()
        if "Guide." in name:
            name = name.split("Guide.")[-1].replace("$$$", ".")
            grads.append((name, series, rankby))
        elif name != "loss":
            locs.append((name, series, rankby))
    locs.sort(key=lambda x: x[-1])
    grads.sort(key=lambda x: x[-1])
    for name, series, _ in locs:
        plt.plot(time, series, label=name)
    for name, series, _ in locs:
        plt.plot(time, series, color="white", lw=3, alpha=0.3, zorder=-1)
    for name, series, _ in grads:
        plt.plot(time, series, lw=1, alpha=0.5, label=name, zorder=-2)
    plt.yscale("log")
    plt.xscale("log")
    plt.xlim(1, len(fit["losses"]))
    plt.legend(loc="upper left", fontsize=8)
    plt.xlabel("SVI step (duration = {:0.1f} minutes)".format(fit["walltime"]/60))
    plt.title("L={:0.4g} C={:0.3g} M={:0.3g} F={:0.3g} P={:0.3g}"
    .format(
        np.median(fit["losses"][-201:]) / num_nonzero,
        float(median.get("concentration", math.inf)),
        float(median.get("mislabel", 0)),
        float(median["feature_scale"]),
        float(median.get("place_scale", 0)),
    ))
plot_fit()

In [None]:
def plot_forecast(queries=None, num_strains=10):
    if queries is None:
        queries = list(location_id)
    elif isinstance(queries, str):
        queries = [queries]
    fig, axes = plt.subplots(len(queries), figsize=(8, 1 + 1.5 * len(queries)), sharex=True)
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]
    rate = fit["median"]["rate"]
    init = fit["median"]["init"]
    local_time = dataset["local_time"]
    probs = (init + rate * local_time[:, :, None]).softmax(-1)  # [T, P, S]
    predicted = probs * weekly_cases[:, :, None]  # [T, P, S]
    ids = torch.tensor([i for name, i in location_id.items()
                        if any(q in name for q in queries)])
    strain_ids = weekly_strains[:, ids].sum([0, 1]).sort(-1, descending=True).indices
    strain_ids = strain_ids[:num_strains]
    colors = [f"C{i}" for i in range(10)] + ["black"] * 90
    assert len(colors) >= num_strains
    light = "#bbbbbb"
    for row, (query, ax) in enumerate(zip(queries, axes)):
        ids = torch.tensor([i for name, i in location_id.items() if query in name])
        print(f"{query} matched {len(ids)} regions")
        counts = weekly_cases[:, ids].sum(1)
        counts /= counts.max()
        ax.plot(counts, "k-", color=light, lw=0.8)
        counts = weekly_strains[:, ids].sum([1, 2])
        counts /= counts.max()
        ax.plot(counts, "k--", color=light, lw=1)
        pred = predicted[:, ids].sum(1).clamp_(min=1e-8)
        pred /= pred.sum(-1, True)
        obs = weekly_strains[:, ids].sum(1)
        error = -dist.DirichletMultinomial(
            concentration=40*pred, validate_args=False).log_prob(obs).mean()
        obs.clamp_(min=1e-9)
        obs /= obs.sum(-1, True)
        for s, color in zip(strain_ids, colors):
            ax.plot(pred[:, s], color=color)
            strain = lineage_id_inv[s]
            ax.plot(obs[:, s], color=color, lw=0, marker='o', markersize=3,
                    label={"Q": "B.1.1.7"}.get(strain, strain) if row == 0 else None)
        ax.set_ylim(0, 1)
        ax.set_yticks(())
        ax.set_ylabel("{}\nerror = {:0.1f}".format(query.replace(" / ", "\n"), error))
        ax.set_xlim(0, len(weekly_strains))
        if row == len(axes) - 1:
            ax.set_xlabel("Time (week after 2019-12-01)")
        if row == 0:
            ax.legend(loc="upper left", fontsize=6)
        elif row == 1:
            ax.plot([], lw=0, marker='o', markersize=3, color='gray',
                    label="observed portion")
            ax.plot([], color='gray', label="predicted portion")
            ax.plot([], "k-", color=light, lw=0.8, label="relative #cases")
            ax.plot([], "k--", color=light, lw=1, label="relative #samples")
            ax.legend(loc="upper left", fontsize=8)
    plt.subplots_adjust(hspace=0)

plot_forecast()