## Scalable tree clustering with Pyro

In [None]:
import torch
import pyro
import pyro.distributions as dist
from opt_einsum import contract as einsum
from pyro.infer.autoguide import AutoDelta, init_to_sample
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions.spanning_tree import make_complete_graph
import matplotlib.pyplot as plt

In [None]:
result = torch.load("results/gisaid.align.pt")
features = result["features"]
data = features[:, :64].float()
print(data.shape)

In [None]:
LOG = {}

def make_tree_dist(probs, temperature=1.):
    v1, v2 = make_complete_graph(len(probs)).unbind(0)
    edge_distance = (probs[v1] - probs[v2]).abs().sum(-1)
    edge_logits = edge_distance.mul(-1/temperature)
    return dist.SpanningTree(edge_logits)

def model(data, num_clusters=100, batch_size=1024, temperature=0.1):
    N, P = data.shape
    K = num_clusters
    probs = pyro.sample("probs", dist.Uniform(0, 1).expand([K, P]).to_event(2))
    tree_dist = make_tree_dist(probs, temperature)
    pyro.factor("internal", tree_dist.log_partition_function)

    with pyro.plate("data", len(data), subsample_size=batch_size) as ind:
        distance = (probs - data[ind, None]).abs().sum(-1)
        logits = distance.mul(-1/temperature)
        LOG["logits"] = logits.data
        pyro.factor("leaf", logits.logsumexp(-1))

num_clusters = 100
guide = AutoDelta(model, init_loc_fn=init_to_sample)

pyro.clear_param_store()
pyro.set_rng_seed(20201223)
svi = SVI(model, guide, Adam({"lr": 0.05}), Trace_ELBO(max_plate_nesting=1))
losses = []
for step in range(501):
    loss = svi.step(data, num_clusters) / data.numel()
    losses.append(loss)
    if step % 50 == 0:
        with torch.no_grad():
            p = (LOG["logits"] - LOG["logits"].logsumexp(-1, True)).exp().mean(0)
            perplexity = p.log().clamp(min=-100).neg().mul(p).sum().exp()
        print(f"step {step: >4d} loss = {loss:0.3g}\tperplexity = {perplexity:0.2f}")
plt.figure(figsize=(8, 2), dpi=300)
plt.plot(losses, lw=1);

In [None]:
with torch.no_grad():
    probs = guide.median()["probs"]
U, S, V = torch.pca_lowrank(probs)
x, y = U[:, 0], U[:, 1]
edges = make_tree_dist(probs, temperature=0.01).sample()

plt.figure(figsize=(8, 2), dpi=300)
d = (probs[edges[:, 0]] - probs[edges[:, 1]]).abs().sum(-1)
plt.plot(d.sort(descending=True).values, "bo", markersize=2)
plt.ylabel("edge length")

plt.figure(figsize=(8, 6), dpi=300)
xs, ys = [], []
for v1, v2 in edges:
    xs.extend((x[v1], x[v2], None))
    ys.extend((y[v1], y[v2], None))
    d = (probs[v1] - probs[v2]).abs().sum()
    plt.text((x[v1] + x[v2]) / 2, (y[v1] + y[v2]) / 2, "{:0.1f}".format(d),
             fontsize=6, horizontalalignment='center', verticalalignment='center')
plt.plot(xs, ys, lw=0.5)
plt.scatter(x, y)
plt.xticks(())
plt.yticks(())
plt.box(False);