# Exploring GISAID alignment-free clustering

This explores a low-dimensional embedding constructed via AMS sketches of k-mers. To run this notebook, first get GISAID data (sign agreement, set up feed, ...), then run
```sh
python preprocess_gisaid.py
```

In [None]:
import os
import torch
import umap
import matplotlib.pyplot as plt
from pyrophylo.cluster import ClockSketcher

## Sketches

In [None]:
filename = "results/gisaid.sketch.pt"
clustering = torch.load(filename)
sketch = clustering["sketch"]

In [None]:
plt.hist(sketch.count.float().numpy(), bins=100)
plt.title(f"mean k-mer count = {sketch.count.float().mean():0.1f}")
plt.yscale("log");

In [None]:
batch = sketch[torch.randperm(len(sketch))[:200]]
sketcher = ClockSketcher(20)
diffs, std = sketcher.estimate_set_difference(batch, batch)
d = diffs.reshape(-1)
l = (batch.count[:, None] - batch.count).float()

plt.scatter(l, d, lw=0, alpha=0.01)
plt.plot([0, l.max()], [0, l.max()], "k--", lw=1)
plt.xlabel("|x| - |y|")
plt.ylabel(r"|x \ y|");

## Clustering

In [None]:
clustering = torch.load("results/gisaid.cluster.200.200.10.pt")
clusters = clustering["clusters"]
weight = clustering["weight"]

In [None]:
plt.plot(weight.sort().values.flip(0))
plt.xlabel("cluster rank")
plt.ylabel("cluster size")
plt.yscale("log")
p = weight / weight.sum()
perplexity = p.log().neg().mul(p).sum().exp()
plt.title(f"weight = {weight.sum():0.1f}, perplexity = {perplexity:0.2f}");

## OBSOLETE analysis of soft hashes

In [None]:
probs = hc_distances.mul(-3.5)
probs = (probs - probs.logsumexp(-1, True)).exp()
perplexity = probs.log().mul(probs).neg().sum(-1).exp()
best = hc_distances.min(-1).indices

plt.figure(figsize=(6, 3))
plt.scatter(best, perplexity, lw=0, alpha=0.1)
plt.ylabel("perplexity")
plt.xlabel("cluster rank");

In [None]:
h = clustering["soft_hashes"]
h = h[torch.randperm(len(h))[:10000]]
mean = h.mean(0)
std = h.std(0)
bits = h.size(-1)

rows = 6
fig, axes = plt.subplots(rows, (bits // 2 + rows - 1) // rows, figsize=(12, 12), dpi=200)
axes = [a for a_ in axes for a in a_]
i = 0
for ax in axes:
    j = i + 1
    if j >= bits:
        break
    x, y = h[:, i], h[:, j]
    ax.scatter(x, y, lw=0, alpha=0.01)
    ax.set_xlim(mean[i] - 3 * std[i], mean[i] + 3 * std[i])
    ax.set_ylim(mean[j] - 3 * std[j], mean[j] + 3 * std[j])
    i += 2

In [None]:
%%time
h = clustering["soft_hashes"]
h = h[torch.randperm(len(h))[:50000]]
u = umap.UMAP().fit_transform(h)

In [None]:
plt.scatter(u[:, 0], u[:, 1], lw=0, alpha=0.1);

## OBSOLETE GMM model based on sketches

In [None]:
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDelta, init_to_sample
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
from pyro.ops.indexing import Vindex
import pyro.poutine as poutine

data = clustering["soft_hashes"].clone()
data -= data.mean(0)
data /= data.std(0)

@config_enumerate
def model(num_clusters, data):
    loc = pyro.sample("loc",
                      dist.Normal(0, 1).expand([num_clusters, data.size(-1)]).to_event(2))
    scale = pyro.sample("scale", dist.LogNormal(-1, 1))
    weights = pyro.sample("weights", dist.Dirichlet(torch.full((num_clusters,), 3.)))
    with pyro.plate("data", len(data), subsample_size=256) as ind:
        c = pyro.sample("component", dist.Categorical(weights))
        pyro.sample("locs", dist.Normal(loc[c], scale).to_event(1),
                    obs=data[ind])

num_clusters = 10
guide = AutoDelta(poutine.block(model, hide=["component"]),
                  init_loc_fn=init_to_sample)

pyro.clear_param_store()
pyro.set_rng_seed(20201223)
svi = SVI(model, guide, Adam({"lr": 0.1}), TraceEnum_ELBO(max_plate_nesting=1))
losses = []
for step in range(1001):
    loss = svi.step(num_clusters, data) / data.numel()
    losses.append(loss)
    if step % 100 == 0:
        print(f"step {step: >4d} loss = {loss:0.3g}")

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(losses);

In [None]:
with torch.no_grad():
    median = guide.median()
print(median["scale"])
print(median["loc"][:, 0].data.numpy())
print(median["weights"].data.sort(0).values.numpy())

In [None]:
%%time
clusters = median["loc"].data
u = umap.UMAP().fit_transform(clusters)
plt.scatter(u[:, 0], u[:, 1], lw=0, alpha=0.5);