## Scalable soft minimax clustering

The results in this notebook require first obtaining gisaid data, then running
```sh
python preprocess_gisaid.py --align
```

In [None]:
from collections import Counter
import logging
import torch
import pyro
import pyro.distributions as dist
from opt_einsum import contract as einsum
from pyro.infer.autoguide import AutoDelta, init_to_sample
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions.spanning_tree import make_complete_graph
from pyrophylo.cluster import SoftminimaxClustering
import umap
import matplotlib.pyplot as plt

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

In [None]:
result = torch.load("results/gisaid.align.pt")
print(result.keys())
features = result["features"]
data = features[:, :64].float()
print(data.shape)

First let's sample a diverse set of points, which we'll use for initialization below. We'll use reservoir sampling, weighting by diversity.

In [None]:
%%time
clustering = SoftminimaxClustering(num_clusters=100, p_edge=8)
clustering.init(data, log_every=10000)

In [None]:
def plot_clustering(clustering, data):
    # Plot cluster sizes.
    obs = clustering.classify(data)
    cluster_sizes = torch.zeros(obs.max() + 1)
    cluster_sizes.scatter_add_(0, obs, torch.ones(()).expand_as(obs))
    p = cluster_sizes / cluster_sizes.sum()
    perplexity = p.log().neg().mul(p).sum().exp()
    plt.figure(figsize=(8, 3), dpi=300)
    plt.plot(cluster_sizes)
    plt.yscale("log")
    plt.xlabel("rank")
    plt.ylabel("cluster size")
    plt.title(f"perplexity = {perplexity:0.2f}")

    # Compute PCA coords.
    centers = clustering.mean
    if True:
        U, S, V = torch.pca_lowrank(centers)
        x, y = U[:, 0], U[:, 1]
    else:
        x, y = torch.from_numpy(umap.UMAP().fit_transform(centers)).T

    # Find a spanning tree.
    v1, v2 = make_complete_graph(len(centers)).unbind(0)
    distance = (centers[v1] - centers[v2]).abs().sum(-1)
    edges = dist.SpanningTree(-distance).sample()

    # Plot distribution of edge lengths.
    plt.figure(figsize=(8, 2), dpi=300)
    d = (centers[edges[:, 0]] - centers[edges[:, 1]]).abs().sum(-1)
    plt.plot(d.sort(descending=True).values, "bo", markersize=2)
    plt.text(len(d) * 0.8, d.max() * 0.8, f"total length = {d.sum():0.1f}",
             horizontalalignment='center', verticalalignment='center')
    plt.ylabel("edge length");
    plt.figure(figsize=(8, 6), dpi=300)
    
    # Plot a spanning tree.
    xs, ys = [], []
    for v1, v2 in edges:
        xs.extend((x[v1], x[v2], None))
        ys.extend((y[v1], y[v2], None))
        d = (centers[v1] - centers[v2]).abs().sum()
        plt.text((x[v1] + x[v2]) / 2, (y[v1] + y[v2]) / 2, "{:0.1f}".format(d),
                 fontsize=6, horizontalalignment='center', verticalalignment='center')
    plt.plot(xs, ys, lw=0.5)
    plt.scatter(x, y)
    plt.xticks(())
    plt.yticks(())
    plt.box(False)

In [None]:
plot_clustering(clustering, data)

In [None]:
%%time
losses = clustering.fine_tune(data)
plt.figure(figsize=(8, 2), dpi=300)
plt.plot(losses, lw=1);

In [None]:
plot_clustering(clustering, data)