# Exploring alignment tools

See [Bio.SeqIO docs](https://biopython.org/wiki/SeqIO) and [mappy docs](https://pypi.org/project/mappy/).

In [None]:
import re
import json
from collections import Counter, defaultdict
from Bio import SeqIO
import mappy
import matplotlib.pyplot as plt
import numpy as np
import torch

## How long are the sequences?

In [None]:
for reference in SeqIO.parse("data/ncbi-reference-sequence.fasta", "fasta"):
    break
print(len(reference.seq))
print(reference.seq[:1000] + "...")

In [None]:
aligner = mappy.Aligner(seq=str(reference.seq), preset="asm10")

In [None]:
examples = []
with open("results/gisaid.000-of-012.json") as f:
    for i, line in enumerate(f):
        if i == 200:
            break
        datum = json.loads(line)
        seq = datum["sequence"].replace("\n", "")
        examples.append(seq)
        
plt.figure(figsize=(8, 2), dpi=300)
plt.plot(sorted(map(len, examples), reverse=True))
plt.yscale("log")
plt.xlabel("rank")
plt.ylabel("sequence length");

In [None]:
print([len(set(seq)) for seq in examples])
print(Counter(examples[7]))

## How frequent are aligned features?

In [None]:
result = torch.load("results/gisaid.align.pt")
stats = result["stats"]
features = result["features"]
print(result.keys())

In [None]:
plt.figure(figsize=(8, 6), dpi=300)
x = defaultdict(list)
y = defaultdict(list)
s = defaultdict(list)
for (pos, code, size), count in stats.items():
    x[code].append(pos)
    y[code].append(len(size) if isinstance(size, str) else size)
    s[code].append(count)
plt.scatter(x["X"], torch.zeros(len(x["X"])), s=torch.tensor(s["X"]) / 20.,
            color="darkblue", alpha=0.2, lw=0)
plt.scatter(x["I"], torch.tensor(y["I"]), s=s["I"],
            color="darkgreen", alpha=0.5, lw=0)
plt.scatter(x["D"], -torch.tensor(y["D"]), s=s["D"],
            color="darkred", alpha=0.5, lw=0)
plt.yscale("symlog")
plt.xlabel("position")
plt.ylabel("⟵ deletions    SNPs    insertions ⟶")
plt.axvline(2000, color="black", linestyle="--", alpha=0.5, lw=1, zorder=0)
plt.axvline(27000, color="black", linestyle="--", alpha=0.5, lw=1, zorder=0)
plt.title("frequency of mutations");

In [None]:
feature_counts = features.sum(0)
plt.figure(figsize=(8,4), dpi=300)
plt.plot(torch.arange(1, 1 + len(feature_counts)),ss
         feature_counts, "bo", lw=0, markersize=2)
plt.yscale("log")
plt.xscale("log")
plt.title(f"Frequency of {features.size(1)} mutations among {features.size(0)} samples")
plt.xlabel("rank")
plt.ylabel("# occurrences");

In [None]:
subsample = features[torch.randperm(len(features))[:10000]].float()

In [None]:
%%time
from sklearn.cluster import DBSCAN
db = DBSCAN(eps=3, min_samples=10, metric="l1").fit(subsample[:, :20])
print(f"Found {1 + db.labels_.max()} clusters")

In [None]:
labels = torch.from_numpy(db.labels_)
ok = labels != -1
labels = labels[ok]
num_clusters = 1 + labels.max().item()
counts = torch.ones(num_clusters)
counts.scatter_add_(0, labels, torch.ones(()).expand_as(labels))
plt.figure(figsize=(8,4), dpi=300)
plt.plot(torch.arange(1, 1 + len(counts)),
         counts.sort(descending=True).values, "bo", lw=0, markersize=2)
plt.yscale("log")
plt.xscale("log")
plt.title(f"Sizes of {num_clusters} clusters "
          f"(dropped {(~ok).sum():d}/{len(subsample)} noisy samples)")
plt.xlabel("rank")
plt.ylabel("cluster size");