# Explore GISAID data


In [2]:
import os
import math
import pickle
import numpy as np
import torch
from pprint import pprint
from collections import Counter, defaultdict
from pyrocov import mutrans, pangolin, geo
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["figure.facecolor"] = "white"

In [None]:
with open("results/usher.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("loaded {} rows".format(len(columns["day"])))
print(list(columns.keys()))

## Phylogenetic distribution of samples

In [6]:
%%time
clade_counts = {"full": Counter(columns["clade"])}
for max_num_clades in [2000, 5000, 10000]:
    with open(f"results/columns.{max_num_clades}.pkl", "rb") as f:
        counts = Counter(pickle.load(f)["clade"])
        clade_counts[max_num_clades] = counts
        print((max_num_clades, len(counts)))

(5000, 0)
CPU times: user 144 ms, sys: 59.1 ms, total: 203 ms
Wall time: 226 ms


In [7]:
strain_counts = Counter(columns["lineage"])
plt.figure(figsize=(6,4))
Y = [c for _, c in strain_counts.most_common()]
plt.plot(torch.arange(1,1 + len(Y)), Y, "k--", label="Pango lineages")
for k, v in clade_counts.items():
    Y = [c for _, c in v.most_common()]
    plt.plot(torch.arange(1,1 + len(Y)), Y, label=f"{k} clusters")
plt.yscale("log")
plt.xscale("log")
plt.xlim(1, None)
plt.ylim(0.95, None)
plt.xlabel("Cluster Rank")
plt.ylabel("Number of Samples")
plt.legend(loc="upper right")
plt.title("Distribution of samples among clusters")
plt.tight_layout()
plt.savefig("paper/clade_distribution.png")

NameError: name 'columns' is not defined

In [None]:
keys = sorted(clade_counts["full"])
X = [clade_counts["full"][k] for k in keys]
Y = [clade_counts[2000][k] for k in keys]
plt.scatter(X, Y, 10, alpha=0.3, lw=0)
plt.xscale("symlog")
plt.yscale("symlog")
plt.xlabel(f"Full tree with {len(clade_counts['full'])} nodes")
plt.ylabel("Pruned tree with 2000 nodes")
plt.title("Clade sizes in two different trees")
plt.tight_layout()

## Geographic distribution of samples

In [None]:
%%time
pprint(sorted(Counter(map(geo.gisaid_normalize, columns["location"])).items()))

In [None]:
counts = Counter()
for location in columns["location"]:
    parts = location.split("/")
    if len(parts) < 2:
        continue
    parts = tuple(p.strip() for p in parts[:3])
    counts[parts] += 1
total_counts = [c for _, c in counts.most_common()]

In [None]:
countries = Counter()
regions = defaultdict(Counter)
for location in columns["location"]:
    location = geo.gisaid_normalize(location)
    parts = location.split(" / ")
    location = " / ".join(parts[:3])
    country = " / ".join(parts[:2])
    countries[country] += 1
    regions[country][location] += 1
print("Total Regions")
print("------------------------------------------------------")
for c, count in countries.most_common():
    rs = " ".join(str(v) for k, v in regions[c].most_common())
    print(f"{c}\n{count}\t{rs}")

In [None]:
plt.figure(figsize=(8, 5))
Y = total_counts
X = torch.arange(1, 1 + len(Y))
plt.plot(X, Y, "k--", label="total")
for c, count in countries.most_common():
    Y = [v for k, v in regions[c].most_common()]
    X = torch.arange(1, 1 + len(Y))
    plt.plot(X, Y, alpha=0.5, label=c if len(Y) > 100 else None)
plt.axhline(50, color="k", linestyle=":", lw=1, label="threshold = 50")
plt.xscale("log")
plt.yscale("log")
plt.xlim(1, 1 + len(total_counts))
plt.ylim(0.95, None)
plt.legend(loc="upper right")
plt.title("Distribution of samples among regions")
plt.xlabel("Region Rank")
plt.ylabel("Number of samples")
plt.tight_layout()
plt.savefig("paper/region_distribution.png")

In [None]:
strain_counts = Counter(columns["lineage"])
for strain, count in strain_counts.most_common(20):
    short = pangolin.compress(strain)
    long = pangolin.decompress(strain)
    assert strain == short, (strain, long)
    if short == long:
        print(f"{count: >10d} {short}")
    else:
        print(f"{count: >10d} {short} (aka {long})")

In [None]:
fine_countries = set()
for parts, count in counts.items():
    if count >= 5000:
        fine_countries.add(parts[1])
fine_countries = list(sorted(fine_countries))
print("\n".join(fine_countries))

In [None]:
locations = Counter(columns["location"])
print("\n".join(f"{c}\t{p}" for p, c in locations.most_common() if "United Kingdom" in p))

In [None]:
print("Europe:", sum(1 for l in columns["location"] if "Europe" in l))
print("World - Europe:", sum(1 for l in columns["location"] if "Europe" not in l))

## When were lineages born?

In [None]:
days = defaultdict(list)
for lineage, day in zip(columns["lineage"], columns["day"]):
    days[lineage].append(day)

In [None]:
def plot_birth(lineage):
    plt.figure(figsize=(6,2))
    plt.hist(np.array(days[lineage]), bins=50)
    plt.yscale("log")
    plt.ylabel(lineage)
    plt.tight_layout()
    plt.xlim(0, None)
plot_birth("A")
plot_birth("B")
plot_birth("B.1")
plot_birth("B.1.1")
plot_birth("B.1.1.7")
plot_birth("B.1.617.2")
plot_birth("AY.4")
plot_birth("AY.4.2")

In [None]:
pdf = torch.zeros(max(days) + 1).scatter_add(0, torch.tensor(days), torch.ones(len(days)))
pdf /= pdf.sum()
cdf = pdf.cumsum(0)

In [None]:
plt.plot(cdf)
plt.yscale("log")
plt.title("CDF of B.1.1.7 observations")

## Explore count data

In [None]:
def load_data():
    filename = "results/mutrans.data.single.pt"
    if os.path.exists(filename):
        dataset = torch.load(filename)
    else:
        dataset = mutrans.load_gisaid_data()
        torch.save(dataset, filename)
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset

dataset = load_data()
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

In [None]:
daily_cases.shape

In [None]:
probs = weekly_clades + 1 / weekly_clades.size(-1)
probs /= probs.sum(-1, True)
logits = probs.log()
logits -= logits.median(-1, True).values
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

In [None]:
logits -= logits.mean(-1, True)
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

## How heterogeneous are lineages?

In [None]:
def plot_agreement(pairs):
    M = int(len(pairs) ** 0.5 + 0.5)
    N = int(math.ceil(len(pairs) / M))
    assert len(pairs) <= M * N
    fig, axes = plt.subplots(M, N, figsize=(2 * N, 2 * M + 0.5))
    fig.suptitle("Mutation correlation between parent-child lineage pairs", y=0.91)
    pairs = iter(pairs)
    for axe in axes:
        for ax in axe:
            ax.set_xticks(())
            ax.set_yticks(())
            ax.set_xlim(-0.05, 1.05)
            ax.set_ylim(-0.05, 1.05)
            try:
                x, y = next(pairs)
            except StopIteration:
                continue
            ax.set_xlabel(x)
            ax.set_ylabel(y)
            X = dataset["features"][dataset["lineage_id"][x]]
            Y = dataset["features"][dataset["lineage_id"][y]]
            ax.scatter(X.numpy(), Y.numpy(), alpha=0.3, lw=0)
            X = (X - X.mean()) / X.std()
            Y = (Y - Y.mean()) / Y.std()
            ax.text(0.5, 0.5, "{:0.3g}".format((X * Y).mean()),
                    va="center", ha="center")
plot_agreement([
    # Alpha
    ("A", "B"),
    ("B", "B.1"),
    ("B.1", "B.1.1"),
    ("B.1.1", "B.1.1.7"),
    # Beta
    ("B.1", "B.1.351"),
    ("B.1", "B.1.351.2"),
    ("B.1", "B.1.351.3"),
    # Gamma
    ("B.1.1", "P.1"),
    ("P.1", "P.1.1"),
    ("P.1", "P.1.2"),
    # Delta
    # ("B.1.617", "B.1.617.1"),
    # ("B.1.617", "B.1.617.2"),
    # ("B.1.617", "B.1.617.3"),
    ("B.1", "B.1.617.1"),
    ("B.1", "B.1.617.2"),
    ("B.1", "B.1.617.3"),
    ("B.1.617.2", "AY.1"),
    # Epsilon
    ("B.1", "B.1.427"),
])