# Exploring GISAID alignment-free clustering

This explores a low-dimensional embedding constructed via AMS sketches of k-mers. To run this notebook, first get GISAID data (sign agreement, set up feed, ...), then run
```sh
python preprocess_gisaid.py
```

In [None]:
import os
import torch
import umap
import matplotlib.pyplot as plt

In [None]:
filename = "results/gisaid.cluster.pt"
clustering = torch.load(filename)
clusters = clustering["clusters"]
cc_distances = clustering["cc_distances"]
hc_distances = clustering["hc_distances"]

In [None]:
probs = hc_distances.mul(-3.5)
probs = (probs - probs.logsumexp(-1, True)).exp()
perplexity = probs.log().mul(probs).neg().sum(-1).exp()
best = hc_distances.min(-1).indices

plt.figure(figsize=(6, 3))
plt.scatter(best, perplexity, lw=0, alpha=0.1)
plt.ylabel("perplexity")
plt.xlabel("cluster rank");

In [None]:
h = clustering["soft_hashes"]
h = h[torch.randperm(len(h))[:10000]]
mean = h.mean(0)
std = h.std(0)
bits = h.size(-1)

fig, axes = plt.subplots(3, (bits // 2 + 2) // 3, figsize=(12, 12), dpi=200)
axes = [a for a_ in axes for a in a_]
i = 0
for ax in axes:
    j = i + 1
    if j >= bits:
        break
    x, y = h[:, i], h[:, j]
    ax.scatter(x, y, lw=0, alpha=0.01)
    ax.set_xlim(mean[i] - 3 * std[i], mean[i] + 3 * std[i])
    ax.set_ylim(mean[j] - 3 * std[j], mean[j] + 3 * std[j])
    i += 2

In [None]:
%%time
u = umap.UMAP().fit_transform(h)

In [None]:
plt.scatter(u[:, 0], u[:, 1], lw=0, alpha=0.1);