# Explore GISAID data


In [None]:
import os
import math
import pickle
import numpy as np
import torch
from collections import Counter
from pyrocov import mutrans, pangolin
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"

## Explore columns

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("loaded {} rows".format(len(columns["day"])))
print(list(columns.keys()))

In [None]:
strain_counts = Counter(columns["lineage"])
for strain, count in strain_counts.most_common(20):
    short = pangolin.compress(strain)
    long = pangolin.decompress(strain)
    assert strain == long, (strain, long)
    if short == long:
        print(f"{count: >10d} {short}")
    else:
        print(f"{count: >10d} {short} (aka {long})")

In [None]:
plt.figure(figsize=(6,3))
plt.plot([c for _, c in strain_counts.most_common()])
plt.yscale("log")
plt.title("Distribution of samples among lineages")

In [None]:
counts = Counter()
for location in columns["location"]:
    parts = location.split("/")
    if len(parts) < 2:
        continue
    parts = tuple(p.strip() for p in parts[:3])
    counts[parts] += 1

In [None]:
plt.figure(figsize=(6,3))
plt.plot([c for _, c in counts.most_common()])
plt.yscale("log")
plt.title("Distribution of samples among regions")

In [None]:
fine_countries = set()
for parts, count in counts.items():
    if count >= 5000:
        fine_countries.add(parts[1])
fine_countries = list(sorted(fine_countries))
print("\n".join(fine_countries))

In [None]:
locations = Counter(columns["location"])
print("\n".join(f"{c}\t{p}" for p, c in locations.most_common() if "United Kingdom" in p))

In [None]:
print("Europe:", sum(1 for l in columns["location"] if "Europe" in l))
print("World - Europe:", sum(1 for l in columns["location"] if "Europe" not in l))

## When were lineages born?

In [None]:
days = []
for lineage, day in zip(columns["lineage"], columns["day"]):
    if lineage == "B.1.1.7":
        days.append(day)

In [None]:
plt.hist(np.array(days), bins=50)
plt.yscale("log")
plt.title("Distribution of dates of B.1.1.7");

In [None]:
pdf = torch.zeros(max(days) + 1).scatter_add(0, torch.tensor(days), torch.ones(len(days)))
pdf /= pdf.sum()
cdf = pdf.cumsum(0)

In [None]:
plt.plot(cdf)
plt.yscale("log")
plt.title("CDF of B.1.1.7 observations")

## Explore count data

In [None]:
def load_data():
    filename = "results/mutrans.data.single.pt"
    if os.path.exists(filename):
        dataset = torch.load(filename)
    else:
        dataset = mutrans.load_gisaid_data()
        torch.save(dataset, filename)
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset

dataset = load_data()
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

In [None]:
daily_cases.shape

In [None]:
probs = weekly_strains + 1 / weekly_strains.size(-1)
probs /= probs.sum(-1, True)
logits = probs.log()
logits -= logits.median(-1, True).values
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

In [None]:
logits -= logits.mean(-1, True)
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

## How heterogeneous are lineages?

In [None]:
def plot_agreement(pairs):
    M = int(len(pairs) ** 0.5 + 0.5)
    N = int(math.ceil(len(pairs) / M))
    assert len(pairs) <= M * N
    fig, axes = plt.subplots(M, N, figsize=(2 * N, 2 * M + 0.5))
    fig.suptitle("Mutation correlation between parent-child lineage pairs", y=0.91)
    pairs = iter(pairs)
    for axe in axes:
        for ax in axe:
            ax.set_xticks(())
            ax.set_yticks(())
            ax.set_xlim(-0.05, 1.05)
            ax.set_ylim(-0.05, 1.05)
            try:
                x, y = next(pairs)
            except StopIteration:
                continue
            ax.set_xlabel(x)
            ax.set_ylabel(y)
            X = dataset["features"][dataset["lineage_id"][x]]
            Y = dataset["features"][dataset["lineage_id"][y]]
            ax.scatter(X.numpy(), Y.numpy(), alpha=0.3, lw=0)
            X = (X - X.mean()) / X.std()
            Y = (Y - Y.mean()) / Y.std()
            ax.text(0.5, 0.5, "{:0.3g}".format((X * Y).mean()),
                    va="center", ha="center")
plot_agreement([
    # Alpha
    ("A", "B"),
    ("B", "B.1"),
    ("B.1", "B.1.1"),
    ("B.1.1", "B.1.1.7"),
    # Beta
    ("B.1", "B.1.351"),
    ("B.1", "B.1.351.2"),
    ("B.1", "B.1.351.3"),
    # Gamma
    ("B.1.1", "P.1"),
    ("P.1", "P.1.1"),
    ("P.1", "P.1.2"),
    # Delta
    # ("B.1.617", "B.1.617.1"),
    # ("B.1.617", "B.1.617.2"),
    # ("B.1.617", "B.1.617.3"),
    ("B.1", "B.1.617.1"),
    ("B.1", "B.1.617.2"),
    ("B.1", "B.1.617.3"),
    ("B.1.617.2", "AY.1"),
    # Epsilon
    ("B.1", "B.1.427"),
])