# Parse the pango dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import datetime
import scipy.stats as sts
import os

import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, aux, defn]:
    importlib.reload(mod)

In [None]:
def parse_date(s):
    if "-" in s:
        return datetime.datetime.strptime(s, "%Y-%m-%d")
    elif "/" in s:
        return datetime.datetime.strptime(s, "%m/%d/%y")
    else:
        raise Exception("unknown date format!")

date0 = parse_date("2020-01-01")

def cumul_to_daily(xs):
    ys = np.array([0] + xs)
    return list(ys[1:] - ys[:-1])

def extract_variant_timeseries(country, variants, data_dicts):
    if type(variants) is not list:
        variants = [variants]
    sel_keys = ["Day", "Total"] + variants

    ## filter region
    sel_data_dicts = [
        {k : dd[k] for k in sel_keys} 
        for dd in data_dicts if dd["Country"] == country
    ]

    dates = [parse_date(dd["Day"]) for dd in sel_data_dicts]
    days = [(date - date0).days for date in dates]
    timeseries_dict = {
        "days" : days,
        "dates" : dates,
        "Total" : [int(dd["Total"]) for dd in sel_data_dicts]
    }
    timeseries_dict.update({
        variant : [int(dd[variant]) for dd in sel_data_dicts]
        for variant in variants
    })
    return timeseries_dict

In [None]:
filename = "../data/in/pango_2021-05-20.csv"

with open(filename) as f:
    reader = csv.DictReader(f)
    data_dicts = [row for row in reader]

In [None]:
#variants = ["B.1.617.2"] ## delta
#variants = ["B.1.1.7"] ## alpha
#variants = ["B.1.351"]
variants = ["R.1"]

#country = "United Kingdom"
#country = "Netherlands"
country = "Japan"

timeseries_dict = extract_variant_timeseries(country, variants, data_dicts)

## import JHU data

In [None]:
filename = os.path.expanduser("~/Repositories/clones/COVID-19/" + \
    "csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv")

with open(filename) as f:
    reader = csv.reader(f)
    epi_data = [row for row in reader]
epi_header = epi_data[0]
epi_data = epi_data[1:]
epi_dates = [parse_date(d) for d in epi_header[4:]]
epi_days = [(date - date0).days for date in epi_dates]

sel_epi_data = [row for row in epi_data if row[1] == country and row[0] == ""][0]
death_incidence_cumul = [int(x) for x in sel_epi_data[4:]]
death_incidence = cumul_to_daily(death_incidence_cumul)

## remove outliers

def get_rav(xs, w=3):
    ys = [(np.sum(xs[i-w : i+w+1]) - xs[i])//(2*w) for i in range(len(xs))]
    return ys

ra_inc = get_rav(death_incidence)

for i in range(len(death_incidence)):
    if 3 * ra_inc[i] < death_incidence[i]:
        death_incidence[i] = ra_inc[i]

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(7,7))

days = timeseries_dict["days"]
Ftot = timeseries_dict["Total"]
Fvars = [timeseries_dict[var] for var in variants]

axs[0].plot(days, Ftot)
for Fvar in Fvars:
    axs[0].plot(days, Fvar)

fvars = [[n/N for n, N in zip(Fvar, Ftot)] for Fvar in Fvars]

for fvar in fvars:
    axs[1].plot(days, fvar)

    
axs[2].plot(epi_days, death_incidence)

axs[0].set_ylabel("num seq.")
axs[1].set_ylabel("freq mutant seq.")
axs[2].set_ylabel("death incidence")

fig.savefig("../seq-abs-rel_death.pdf")

## Week-level data

In [None]:
## week starts on a monday. 

mondays = [date for date in epi_dates if date.weekday() == 0]
days_week = [(date - date0).days for date in mondays]

dates_per_week = {
    mon : [date for date in epi_dates if date >= mon and date < mon + datetime.timedelta(days=7)]
    for mon in mondays
}

In [None]:
## week-level variant data

def aggregate_counts(ds, xs, dates, date_dict):
    dx_dict = dict(zip(ds, xs))
    return [np.sum([dx_dict[d] if d in dx_dict else 0 
                    for d in date_dict[d0]]) for d0 in dates]


timeseries_dict_week = {
    "dates" : mondays,
    "days" : days_week,
}

for var in variants + ["Total"]:
    timeseries_dict_week[var] = \
        aggregate_counts(timeseries_dict["dates"], 
                         timeseries_dict[var], mondays, dates_per_week)

In [None]:
## week-level epi data

death_incidence_week = aggregate_counts(epi_dates, death_incidence, 
                                        mondays, dates_per_week)

In [None]:
fig, axs = plt.subplots(3,1, figsize=(25,7))

days = timeseries_dict_week["days"]
Ftot = timeseries_dict_week["Total"]
Fvars = [timeseries_dict_week[var] for var in variants]

axs[0].plot(days, Ftot, color='k')
for Fvar in Fvars:
    axs[0].plot(days, Fvar)

fvars = [[n/N for n, N in zip(Fvar, Ftot)] for Fvar in Fvars]
civars = [[sts.beta.interval(0.95, n+0.5, N-n+0.5) 
           for n, N in zip(Fvar, Ftot)] for Fvar in Fvars]


for fvar in fvars:
    axs[1].plot(days, fvar)
    
for civar in civars:
    for t, CI in zip(days, civar):
        axs[1].plot([t,t], CI, color='k')


axs[2].plot(days_week, death_incidence_week, marker='o', markersize=3)

## B.1.1.7
#idx0 = 31 
#num_weeks = 29

## B.1.351
#idx0 = 44
#num_weeks = 20

## R.1
idx0 = 43
num_weeks = 18


t0 = days_week[idx0]
tend = days_week[idx0 + num_weeks]
print("t0 =", t0, "tend = ", tend)

for ax in axs:
    ax.axvline(x=days_week[idx0])
    ax.axvline(x=days_week[idx0 + num_weeks])
    
axs[-1].set_xticks(days_week)
axs[-1].set_xticklabels(days_week, rotation=90)

pass

In [None]:
## make dataset for SMC

var = variants[0]

country_fn = country.replace(" ", "_")

ev = "[RESET_CASES]"
with open(f"../data/in/sars2-seq-death-week-{country_fn}-{var}.tsv", 'w') as f:
    for i in range(len(days_week)):
        t = days_week[i]
        D = death_incidence_week[i]
        N = timeseries_dict_week["Total"][i]
        n = timeseries_dict_week[var][i]
        c = defn.uncensored_code
        if t <= t0 or t > tend:
            continue
        ## else...
        line = f"{country_fn}\t{t}\t{ev}\t{D}\t{c}\t{n}\t{c}\t{N}\t{c}"
        f.write(line + '\n')