# Explore GISAID data


In [11]:
import os
import pickle
import torch
from collections import Counter
from pyrocov import mutrans, pangolin
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200

## Explore columns

In [12]:
with open("results/gisaid.columns.pkl", "rb") as f:
    columns = pickle.load(f)
print("loaded {} rows".format(len(columns["day"])))
print(list(columns.keys()))

loaded 1851054 rows
['lineage', 'virus_name', 'accession_id', 'collection_date', 'location', 'add_location', 'day']


In [13]:
strain_counts = Counter(columns["lineage"])
for strain, count in strain_counts.most_common(20):
    short = pangolin.compress(strain)
    long = pangolin.decompress(strain)
    assert strain == long, (strain, long)
    if short == long:
        print(f"{count: >10d} {short}")
    else:
        print(f"{count: >10d} {short} (aka {long})")

    832417 B.1.1.7
     88799 B.1.2
     80377 B.1
     70042 B.1.177
     46762 B.1.1
     31619 B.1.617.2
     31437 B.1.429
     25715 P.1 (aka B.1.1.28.1)
     23204 B.1.160
     21723 B.1.351
     21602 B.1.526
     18691 B.1.1.519
     16986 B.1.1.214
     13143 B.1.427
     13025 B.1.221
     13004 B.1.177.21
     12857 B.1.258
     12332 D.2 (aka B.1.1.25.2)
     10623 B.1.243
      9975 B.1.526.2


In [3]:
counts = Counter()
for location in columns["location"]:
    parts = location.split("/")
    if len(parts) < 2:
        continue
    parts = tuple(p.strip() for p in parts[:3])
    counts[parts] += 1

In [4]:
fine_countries = set()
for parts, count in counts.items():
    if count >= 5000:
        fine_countries.add(parts[1])
fine_countries = list(sorted(fine_countries))
print("\n".join(fine_countries))

Australia
Brazil
Canada
Denmark
Finland
France
Germany
Iceland
Ireland
Italy
Japan
Luxembourg
Netherlands
Portugal
Spain
Sweden
Switzerland
USA
United Kingdom


In [7]:
locations = Counter(columns["location"])
print("\n".join(f"{c}\t{p}" for p, c in locations.most_common() if "United Kingdom" in p))

346691	Europe / United Kingdom / England
43553	Europe / United Kingdom / Scotland
34153	Europe / United Kingdom / Wales
6933	Europe / United Kingdom / Northern Ireland
171	Europe / United Kingdom / England / South Yorkshire
107	Europe / United Kingdom / England / London
11	Europe / United Kingdom / England / Derbyshire
2	Europe / United Kingdom / England / Yorkshire / Sheffield
2	Europe / United Kingdom
1	Europe / United Kingdom / England / Northamtonshire
1	Europe / United Kingdom / England / Nottinghamhisre
1	Europe / United Kingdom / Wales / Cardiff
1	Europe / United Kingdom / England / Warwickshire
1	Europe / United Kingdom / London
1	Europe / United Kingdom / Scotland / Fraserburg


## Explore count data

In [None]:
def load_data():
    filename = "results/mutrans.data.single.pt"
    if os.path.exists(filename):
        dataset = torch.load(filename)
    else:
        dataset = mutrans.load_gisaid_data()
        torch.save(dataset, filename)
    dataset.update(mutrans.load_jhu_data(dataset))
    return dataset

dataset = load_data()
locals().update(dataset)
for k, v in sorted(dataset.items()):
    if isinstance(v, torch.Tensor):
        print(f"{k} \t{type(v).__name__} of shape {tuple(v.shape)}")
    else:
        print(f"{k} \t{type(v).__name__} of size {len(v)}")

In [None]:
daily_cases.shape

In [None]:
probs = weekly_strains + 1 / weekly_strains.size(-1)
probs /= probs.sum(-1, True)
logits = probs.log()
logits -= logits.median(-1, True).values
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");

In [None]:
logits -= logits.mean(-1, True)
plt.hist(logits.reshape(-1).numpy(), bins=100)
plt.yscale("symlog");