# Explore GISAID data


In [None]:
import os
import pickle
import torch
from collections import Counter
from pyrocov import mutrans
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200

## Explore columns

In [None]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("loaded {} rows".format(len(columns["day"])))
print(list(columns.keys()))

In [None]:
counts = Counter()
for location in columns["location"]:
    parts = location.split("/")
    if len(parts) < 2:
        continue
    parts = tuple(p.strip() for p in parts[:3])
    counts[parts] += 1

In [None]:
fine_countries = set()
for parts, count in counts.items():
    if count >= 5000:
        fine_countries.add(parts[1])
fine_countries = list(sorted(fine_countries))
print("\n".join(fine_countries))

In [None]:
from pyrocov.mutrans import get_fine_countries
print("\n".join(sorted(get_fine_countries(columns))))

In [None]:
locations = Counter(columns["location"])
print("\n".join(f"{c}\t{p}" for p, c in locations.most_common() if "United Kingdom" in p))

## Explore count data

In [None]:
def load_data():
    filename = "results/mutrans.data.single.pt"
    if os.path.exists(filename):
        dataset = torch.load(filename)
    else:
        dataset = mutrans.load_gisaid_data()
        torch.save(dataset, filename)
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset

dataset = load_data()
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

In [None]:
daily_cases.shape

In [None]:
probs = weekly_strains + 1 / weekly_strains.size(-1)
probs /= probs.sum(-1, True)
logits = probs.log()
logits -= logits.median(-1, True).values
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

In [None]:
logits -= logits.mean(-1, True)
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");