In [None]:
import pickle
from collections import Counter

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import ClippedAdam

from pyrophylo.pangolin import find_edges, canonize

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print(columns.keys())
print(len(columns["day"]))

In [None]:
lineages = columns["lineage"]
print(f"Top 12 of {len(set(lineages))} lineages")
print("-" * 30)
for lineage, count in Counter(lineages).most_common(12):
    print(f"{count: >10d} {lineage}")

In [None]:
sparse_data = Counter()
location_id = {}
lineage_id = {}
for day, location, lineage in zip(columns["day"], columns["location"], columns["lineage"]):
    parts = location.split(" / ")
    if len(parts) < 2:
        continue
    location = " / ".join(p.strip() for p in parts)
    x = location_id.setdefault(location, len(location_id))
    s = lineage_id.setdefault(lineage, len(lineage_id))
    t = day // 7
    sparse_data[t, x, s] += 1
    
T = 1 + max(columns["day"]) // 7
P = len(location_id)
S = len(lineage_id)
dense_data = torch.zeros(T, P, S)
for (t, p, s), n in sparse_data.items():
    dense_data[t, p, s] = n
print(dense_data.shape)

In [None]:
edges = find_edges(list(lineage_id))
edges = torch.tensor([[lineage_id[u], lineage_id[v]] for u, v in edges], dtype=torch.long)

In [None]:
def model(dense_data, edges):
    T, P, S = dense_data.shape
    time_plate = pyro.plate("time", T, dim=-2)
    place_plate = pyro.plate("place", P, dim=-1)
    time = torch.arange(float(T)) * 7 / 365.25  # in years
    
    # Assume relative growth rate depends on strain but not time or place.
    log_rate = pyro.sample(
        "log_rate",
        dist.Normal(0, 1).expand([S]).to_event(1),
    )
    # Assume related strains have similar relative growth rate.
    # We model this as Cauchy whose heavy tails lead to a jump process.
    tree_scale = pyro.sample("tree_scale", dist.LogNormal(-5, 5))
    with pyro.plate("edges", len(edges), dim=-1):
        u, v = edges.unbind(-1)
        pyro.sample(
            "rate_change",
            dist.Cauchy(0, tree_scale),
            obs=log_rate[..., u] - log_rate[..., v],
        )

    # Assume places differ only in their initial infection count.
    with place_plate:
        log_init = pyro.sample(
            "log_init",
            dist.LogNormal(0, 10).expand([S]).to_event(1),
        )

    # Finally observe overdispersed counts.
    dispersion = pyro.sample("dispersion", dist.LogNormal(0, 1))
    with time_plate, place_plate:
        base_rate = (log_init + log_rate * time[:, None, None]).softmax(dim=-1)
        pyro.sample(
            "obs",
            dist.DirichletMultinomial(
                total_count=dense_data.sum(-1).max(),
                concentration=dispersion * base_rate,
                is_sparse=True,  # uses a faster algorithm
            ),
            obs=dense_data,
        )

In [None]:
guide = AutoNormal(model)
optim = ClippedAdam({"lr": 0.01})
svi = SVI(model, guide, optim, Trace_ELBO())
for step in range(501):
    loss = svi.step(dense_data, edges)
    if step % 10 == 0:
        print(f"step {step} loss = {loss:0.3g}")