## Scalable tree clustering with Pyro

The results in this notebook require first obtaining gisaid data, then running
```sh
python preprocess_gisaid.py --align
```

In [None]:
from collections import Counter
import logging
import torch
import pyro
import pyro.distributions as dist
from opt_einsum import contract as einsum
from pyro.infer.autoguide import AutoDelta, init_to_sample
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions.spanning_tree import make_complete_graph
from pyrophylo.cluster import sample_diverse_clusters
import umap
import matplotlib.pyplot as plt

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

In [None]:
result = torch.load("results/gisaid.align.pt")
print(result.keys())
features = result["features"]
data = features[:, :512].float()
print(data.shape)

First let's sample a diverse set of points, which we'll use for initialization below. We'll use reservoir sampling, weighting by diversity.

In [None]:
%%time
num_clusters = 100
radius = 10.
init_clusters, weights = sample_diverse_clusters(
    data, num_clusters=num_clusters, radius=radius, log_every=100000)

In [None]:
plt.figure(figsize=(8, 3), dpi=300)
plt.plot(weights)
plt.yscale("log");

This model aims to minimize the sum of squares of mutation distances along a coarse phylogenetic tree consisting of ``K ~ 100`` internal nodes arranged in a spanning tree plus ``N ~ 300000`` leaves connected to the internal nodes. We marginalize over the spanning tree and over leaf-internal connections, and train on minibatches of leaves.

In [None]:
%%time
LOG = {}

def make_tree_dist(centers, radius=1.):
    v1, v2 = make_complete_graph(len(centers)).unbind(0)
    distance = (centers[v1] - centers[v2]).abs().sum(-1)
    edge_logits = (distance / radius).square().neg()
    return dist.SpanningTree(edge_logits)

def model(data, num_clusters=100, radius=1., batch_size=1024):
    N, P = data.shape
    K = num_clusters
    centers = pyro.sample("centers", dist.Uniform(0, 1).expand([K, P]).to_event(2))
    tree_dist = make_tree_dist(centers, radius)
    pyro.factor("internal", tree_dist.log_partition_function)

    with pyro.plate("data", len(data), subsample_size=batch_size) as ind:
        distance = torch.cdist(data[ind], centers, p=1)
        logits = (distance / radius).square().neg()
        LOG["logits"] = logits.data
        pyro.factor("leaf", logits.logsumexp(-1))

def init_loc_fn(site):
    if site["name"] == "centers":
        return init_clusters * 0.99 + 0.005
    return init_to_sample(site)

guide = AutoDelta(model, init_loc_fn=init_loc_fn)

pyro.clear_param_store()
pyro.set_rng_seed(20201223)
optim = Adam({"lr": 0.01, "betas": (0.8, 0.99)})
svi = SVI(model, guide, optim, Trace_ELBO(max_plate_nesting=1))
losses = []
for step in range(501):
    loss = svi.step(data, num_clusters, radius) / data.numel()
    losses.append(loss)
    if step % 50 == 0:
        with torch.no_grad():
            p = (LOG["logits"] - LOG["logits"].logsumexp(-1, True)).exp().mean(0)
            perplexity = p.log().clamp(min=-100).neg().mul(p).sum().exp()
        print(f"step {step: >4d} loss = {loss:0.3g}\tperplexity = {perplexity:0.2f}")
plt.figure(figsize=(8, 2), dpi=300)
plt.plot(losses, lw=1);

In [None]:
with torch.no_grad():
    centers = guide.median()["centers"]
    print(f"distance from start = {(centers - init_clusters).abs().sum():0.1f}")
if True:
    U, S, V = torch.pca_lowrank(centers)
    x, y = U[:, 0], U[:, 1]
else:
    x, y = torch.from_numpy(umap.UMAP().fit_transform(centers)).T
edges = make_tree_dist(centers, radius=0.01).sample()

plt.figure(figsize=(8, 2), dpi=300)
d = (centers[edges[:, 0]] - centers[edges[:, 1]]).abs().sum(-1)
plt.plot(d.sort(descending=True).values, "bo", markersize=2)
plt.text(len(d) * 0.8, d.max() * 0.8, f"total length = {d.sum():0.1f}",
         horizontalalignment='center', verticalalignment='center')
plt.ylabel("edge length")

plt.figure(figsize=(8, 6), dpi=300)
xs, ys = [], []
for v1, v2 in edges:
    xs.extend((x[v1], x[v2], None))
    ys.extend((y[v1], y[v2], None))
    d = (centers[v1] - centers[v2]).abs().sum()
    plt.text((x[v1] + x[v2]) / 2, (y[v1] + y[v2]) / 2, "{:0.1f}".format(d),
             fontsize=6, horizontalalignment='center', verticalalignment='center')
plt.plot(xs, ys, lw=0.5)
plt.scatter(x, y)
plt.xticks(())
plt.yticks(())
plt.box(False);