# Exploring GISAID alignment-free clustering

This explores a low-dimensional embedding constructed via AMS sketches of k-mers. To run this notebook, first get GISAID data (sign agreement, set up feed, ...), then run
```sh
python preprocess_gisaid.py
```

In [None]:
import os
import datetime
import torch
import umap
import matplotlib.pyplot as plt
from pyrophylo.cluster import ClockSketcher

## Sketches

In [None]:
result = torch.load("results/gisaid.sketch.pt")
sketch = result["sketch"]

In [None]:
plt.hist(sketch.count.float().numpy(), bins=100)
plt.title(f"mean k-mer count = {sketch.count.float().mean():0.1f}")
plt.yscale("log");

In [None]:
batch = sketch[torch.randperm(len(sketch))[:200]]
sketcher = ClockSketcher(20)
diffs, std = sketcher.estimate_set_difference(batch, batch)
d = diffs.reshape(-1)
l = (batch.count[:, None] - batch.count).float()

plt.scatter(l, d, lw=0, alpha=0.01)
plt.plot([0, l.max()], [0, l.max()], "k--", lw=1)
plt.plot([l.min(), 0], [0, 0], "k--", lw=1)
plt.xlabel("|x| - |y|")
plt.ylabel(r"|x \ y|");

## Clustering

In [None]:
clustering = torch.load("results/gisaid.cluster.pt")
full_clusters = clustering["full_clusters"]
full_weights = clustering["full_weights"]
clusters = clustering["clusters"]
weights = clustering["weights"]
class_probs = clustering["class_probs"]

In [None]:
plt.plot(full_weights)
plt.xlabel("cluster rank")
plt.ylabel("cluster size")
plt.yscale("log")
p = full_weights / full_weights.sum()
perplexity = p.log().neg().mul(p).sum().exp()
plt.title(f"weight = {full_weights.sum():0.1f}, perplexity = {perplexity:0.2f}");

In [None]:
sketch_result = torch.load("results/gisaid.sketch.pt")
day = torch.tensor(sketch_result["columns"]["day"], dtype=torch.long)

In [None]:
plt.hist(day.numpy(), bins=200)
plt.yscale("log");

In [None]:
strain = class_probs.max(-1).indices
# strain = class_probs.multinomial(1)  # Extremely slow.
week = day // 7
num_weeks = 1 + int(week.max())
num_strains = class_probs.size(-1)
counts = torch.zeros(num_strains, num_weeks)
i = strain * num_weeks + week
counts.reshape(-1).scatter_add_(0, i, torch.tensor(1.).expand_as(i))

plt.figure(figsize=(8, 4), dpi=300)
plt.plot(counts.T, lw=1)
plt.yscale("log")
plt.xlabel("week after 2019-12-01")
plt.ylabel("samples / week")
plt.title(f"{num_strains} clusters")
plt.xlim(0, num_weeks);

## OBSOLETE GMM model based on sketches

In [None]:
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDelta, init_to_sample
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
from pyro.ops.indexing import Vindex
import pyro.poutine as poutine

data = clustering["soft_hashes"].clone()
data -= data.mean(0)
data /= data.std(0)

@config_enumerate
def model(num_clusters, data):
    loc = pyro.sample("loc",
                      dist.Normal(0, 1).expand([num_clusters, data.size(-1)]).to_event(2))
    scale = pyro.sample("scale", dist.LogNormal(-1, 1))
    weights = pyro.sample("weights", dist.Dirichlet(torch.full((num_clusters,), 3.)))
    with pyro.plate("data", len(data), subsample_size=256) as ind:
        c = pyro.sample("component", dist.Categorical(weights))
        pyro.sample("locs", dist.Normal(loc[c], scale).to_event(1),
                    obs=data[ind])

num_clusters = 10
guide = AutoDelta(poutine.block(model, hide=["component"]),
                  init_loc_fn=init_to_sample)

pyro.clear_param_store()
pyro.set_rng_seed(20201223)
svi = SVI(model, guide, Adam({"lr": 0.1}), TraceEnum_ELBO(max_plate_nesting=1))
losses = []
for step in range(1001):
    loss = svi.step(num_clusters, data) / data.numel()
    losses.append(loss)
    if step % 100 == 0:
        print(f"step {step: >4d} loss = {loss:0.3g}")

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(losses);

In [None]:
with torch.no_grad():
    median = guide.median()
print(median["scale"])
print(median["loc"][:, 0].data.numpy())
print(median["weights"].data.sort(0).values.numpy())

In [None]:
%%time
clusters = median["loc"].data
u = umap.UMAP().fit_transform(clusters)
plt.scatter(u[:, 0], u[:, 1], lw=0, alpha=0.5);