# Data import and Fits of the SARS-CoV-2 D614G model

* Data from UK and NL
* Model fits
* Diagnostics

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json
import scipy.special

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker
from matplotlib.transforms import blended_transform_factory


from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

## SARS-CoV-2 data
Mutant frequencies and incidence.

In [None]:
#seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/seq_epi-08282020.csv"
#seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/new_mutations/seq501_epi-21122020"
seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/new_mutations/seq_epi-08282020.csv"
## original B117
#seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/new_mutations/seq-pl_epi-07012020.csv"
## longer B117
#seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/new_mutations/seq-pl_epi-26022021.csv"
## SA variant
#seq_epi_data_file = "/home/chris/Projects/COVID-19/G614D-mutant/new_mutations/seq-pl-sa_epi-07012020.csv"

## specify variants. Mutant and WT (WT is optional):

##### D614G options

var_keys = ["seq_G", "seq_D"]
aa_pos = 614
total_key = "seq_total"


##### B.1.1.7 options

#var_keys = ["N_mt"]
#aa_pos = 501 ## will be used for file names!!
#total_key = "N"


##### ZA options

#var_keys = ["N_mt"]
#aa_pos = "ZA" ## will be used for file names!!
#total_key = "N"


### IMPORT DATA

with open(seq_epi_data_file) as f:
    reader = csv.DictReader(f)
    data_dicts = [row for row in reader]


fields = var_keys + [total_key, "deaths", "cases", "recoveries"]

## start counting from Jan-1-2020
day0 = datetime.datetime.strptime("2020-01-01", "%Y-%m-%d")

print("weekday of day zero:", day0.weekday()) ## monday = 0, sunday = 6

for row in data_dicts:
    ## convert date to days from first obs
    date = datetime.datetime.strptime(row["date"], "%Y-%m-%d")
    row["date"] = date
    row["t"] = (date - day0).days
    ## convert counts to int, add censoring info
    for k in fields:
        sval = row[k]
        if sval != 'NA':
            row[k] = int(row[k])
            row[k + "_CC"] = defn.uncensored_code
        else:
            row[k] = 0
            row[k + "_CC"] = defn.missing_code


In [None]:
#region = "Netherlands"
region = "United Kingdom"
#region = "South Africa"
sel_data_dicts = [copy.copy(row) for row in data_dicts if row["country"] == region]
   

## check for negative counts...
for dd in sel_data_dicts:
    if dd["deaths"] < 0:
        print(dd["t"], dd["deaths"])
    
## apply corrections
corrections_NL = {
    208 : {"deaths" : 0},
    207 : {"deaths" : 1},
    223 : {"deaths" : 0},
    224 : {"deaths" : 4},
    191 : {"deaths" : 0},
    195 : {"deaths" : 0},
    199 : {"deaths" : 0},
}

if region == "Netherlands":
    for t, cord in corrections_NL.items():
        for dd in sel_data_dicts:
            if dd["t"] == t:
                for k, v in cord.items():
                    dd[k] = v

In [None]:
## take a look at the data

fig, axs = plt.subplots(2, 1, figsize=(14,7), sharex=True)

axs[0].plot([row["t"] for row in sel_data_dicts],
            [row["cases"] for row in sel_data_dicts], color="k")

bx = axs[0].twinx()

bx.plot([row["t"] for row in sel_data_dicts],
        [row["deaths"] for row in sel_data_dicts], color="red")

axs[0].set_ylabel("cases")
bx.set_ylabel("deaths", color='red')

Fms = [row[var_keys[0]] / seq_total if (seq_total := row[total_key]) > 0 
       else np.nan for row in sel_data_dicts]

## CIs
lFms = [sts.beta.ppf(0.025, row[var_keys[0]]+0.5, row[total_key]-row[var_keys[0]]+0.5) 
        for row in sel_data_dicts]
uFms = [sts.beta.ppf(0.975, row[var_keys[0]]+0.5, row[total_key]-row[var_keys[0]]+0.5) 
        for row in sel_data_dicts]

ts = [row["t"] for row in sel_data_dicts]
dates = [day0 + datetime.timedelta(days=t) for t in ts]
date_strs = [date.strftime("%m-%d-%Y") for date in dates]

axs[1].scatter(ts, Fms, color="k", s=5)

for t, l, u in zip(ts, lFms, uFms):
    axs[1].plot([t,t], [l,u], color='k', alpha=0.5)
    
axs[1].set_ylabel("fraction mutant ($F_m$)")

axs[0].set_title(region)

dt = 14
axs[1].set_xticks(ts[::dt])
axs[1].set_xticklabels(date_strs[::dt], rotation=45, ha='right')

fig.savefig(f"../data/out/figures/sars-data-{region}.pdf", bbox_inches="tight")

In [None]:
## compute incidence
def add_incidence(dds, cumul=True):
    for i in range(len(dds)):
        if not cumul:
            ## just copy cases and death
            dds[i]["cases_incidence"] = dds[i]["cases"]
            dds[i]["deaths_incidence"] = dds[i]["deaths"]
        else: 
            ## convert cumulative incidence to actual daily incidence
            if i == 0:
                dds[i]["cases_incidence"] = 0
                dds[i]["deaths_incidence"] = 0
            else:
                c0 = dds[i-1]["cases"]
                c1 = dds[i]["cases"]
                dds[i]["cases_incidence"] = c1 - c0
                d0 = dds[i-1]["deaths"]
                d1 = dds[i]["deaths"]
                dds[i]["deaths_incidence"] = d1 - d0
        
#add_incidence(sel_data_dicts) ## for old data
add_incidence(sel_data_dicts, cumul=False) ## for new data
        
## plot incidence

fig, ax = plt.subplots(1, 1, figsize=(14,5))
bx = ax.twinx()

ax.plot(ts, [row["cases_incidence"] for row in sel_data_dicts], color='k')
bx.plot(ts, [row["deaths_incidence"] for row in sel_data_dicts], color='red')

dt = 14
ax.set_xticks(ts[::dt])
ax.set_xticklabels(date_strs[::dt], rotation=45, ha='right')

ax.set_ylabel("cases")
bx.set_ylabel("deaths", color='red')

ax.set_title(region)

fig.savefig(f"../data/out/figures/sars-incidence-data-{region}.pdf", bbox_inches="tight")

### Incidence of cases and deaths per day
* Weekday-weekend pattern
* Delay in death
* Increase in cases in July, but not deaths

### Some data describing the lockdown
[Government response tracker](https://www.bsg.ox.ac.uk/research/research-projects/coronavirus-government-response-tracker)

* This data is currently not used: any ideas?
* Other useful data: simple measure of testing efforts?

**TODO: use new long format file from 0xCGRT**

In [None]:
with open("../data/in/OxCGRT_latest.csv", 'r') as f:
    reader = csv.DictReader(f)
    ld_table = [row for row in reader]

country_code_dict = {
    "United Kingdom" : "GBR",
    "Netherlands" : "NLD",
    "South Africa" : "ZAF",
}
    
country_code = country_code_dict[region]

## filter only region-specific data from 0xCGRT
sel_ld_table = [rec for rec in ld_table if rec["CountryCode"] == country_code and rec["RegionName"] == '']

ld_dates = [datetime.datetime.strptime(rec["Date"], "%Y%m%d") for rec in sel_ld_table]
ld_ts = [(d - day0).days for d in ld_dates]

ld_scores = [rec["StringencyIndex"] for rec in sel_ld_table]
ld_scores = [float(x) if x != '' else None for x in ld_scores]

## extract testing policy data
ld_testpol_codes = [rec["H2_Testing policy"] for rec in sel_ld_table]
ld_testpol_codes = [float(x) if x != '' else None for x in ld_testpol_codes]

## find periods of testing policies

testpol_ts_dict = {
    cc : [t for t, c in zip(ld_ts, ld_testpol_codes) if c == cc]
    for cc in [0,1,2,3]
}

testpol_period_dict = {
    cc : (np.min(tt), np.max(tt)) for cc, tt in testpol_ts_dict.items()
    if len(tt) > 0
}

fig, ax = plt.subplots(1, 1, figsize=(10,5))

ax.plot(ts, [row["cases_incidence"] for row in sel_data_dicts], color='k')

bx = ax.twinx()
bx.plot(ld_ts, ld_scores, color='blue')

ax.set_ylabel("cases")
bx.set_ylabel("stringency index", color='blue')

ax.set_title(region)

## color test-policy periods

testpol_color_dict = {
    0 : 'tab:gray',
    1 : 'tab:blue',
    2 : 'tab:green',
    3 : 'tab:orange'
}

for c, period in testpol_period_dict.items():
    ax.axvspan(*period, color=testpol_color_dict[c], label=str(c), alpha=0.5)


ax.legend()

In [None]:
## export to a simple data file

region_nospace = region.replace(" ", "_")
with open(f"../data/in/sars2-{region_nospace}.tsv", 'w') as f:
    for row in sel_data_dicts:
        ID = region_nospace
        t = row["t"]
        ev = "[]"
        dI = row['cases_incidence']
        dIc = defn.uncensored_code
        dD = row['deaths_incidence']
        dDc = defn.uncensored_code
        Nm = row[var_keys[0]]
        Nmc = defn.uncensored_code
        N = row[total_key]
        Nc = defn.uncensored_code
        line = f"{ID}\t{t}\t{ev}\t{dI}\t{dIc}\t{dD}\t{dDc}\t{Nm}\t{Nmc}\t{N}\t{Nc}\n"
        f.write(line)

### Get some testing information from OWID database

https://github.com/owid/covid-19-data/tree/master/public/data

In [None]:
with open("../data/in/owid-covid-data.json", 'rb') as f:
    owid_data = json.load(f)

In [None]:
sel_owid_data = owid_data[country_code]["data"]
owid_dates = [datetime.datetime.strptime(dd["date"], "%Y-%m-%d") for dd in sel_owid_data]

owid_ts = [(d - day0).days for d in owid_dates]
#test_key = "new_tests"
test_key = "new_tests_smoothed"
owid_tests = [int(dd[test_key]) if test_key in dd.keys() else None for dd in sel_owid_data]
cases_key = "new_cases"
owid_cases = [int(dd[cases_key]) if cases_key in dd.keys() else None for dd in sel_owid_data]


fig, ax = plt.subplots(1, 1, figsize=(7,5))

ax.plot(owid_ts, owid_tests, color='tab:blue')

bx = ax.twinx()
bx.plot(owid_ts, owid_cases, color='tab:red')

ax.set_ylabel("tests", color='tab:blue')
bx.set_ylabel("positive tests", color='tab:red')


## add testing policy data
for c, period in testpol_period_dict.items():
    ax.axvspan(*period, color=testpol_color_dict[c], label=str(c), alpha=0.5)

ax.legend()
    

### Aggregate case count data to week level 
to account for periodic reporting pattern

In [None]:
## define weeks

sel_day0 = sel_data_dicts[0]["date"]
sel_monday0 = sel_day0 - datetime.timedelta(days=sel_day0.weekday())
sel_day_last = sel_data_dicts[-1]["date"]
sel_monday_last = sel_day_last - datetime.timedelta(days=(sel_day_last.weekday()+1)%7)
print(sel_monday_last)
num_days = (sel_monday_last - sel_monday0).days + 1

mondays = [sel_monday0 + datetime.timedelta(days=7*i) for i in range(0, num_days//7 + 1)]

## collect incidence between these mondays

sel_weekly_data_dicts = []

for m0, m1 in zip(mondays[:-1], mondays[1:]):
    wdds = [dd for dd in sel_data_dicts if dd["date"] >= m0 and dd["date"] < m1]
    C = np.sum([dd["cases_incidence"] for dd in wdds])
    D = np.sum([dd["deaths_incidence"] for dd in wdds])
    Nmut = np.sum([dd[var_keys[0]] for dd in wdds])
    Ntot = np.sum([dd[total_key] for dd in wdds])
    sel_weekly_data_dicts.append({
        "date" : m1,
        "t" : (m1 - day0).days,
        "cases_incidence" : C,
        "deaths_incidence" : D,
        var_keys[0] : Nmut,
        total_key : Ntot
    })
    
    
fig, (ax, cx) = plt.subplots(2, 1, figsize=(14,10), sharex=True)
bx = ax.twinx()

ws = [row["t"] for row in sel_weekly_data_dicts]

ax.plot(ws, [row["cases_incidence"] for row in sel_weekly_data_dicts], color='k', marker='o')
bx.plot(ws, [row["deaths_incidence"] for row in sel_weekly_data_dicts], color='red', marker='o')

## plot aggregated sequence data

Ntots = [row[total_key] for row in sel_weekly_data_dicts]
Nmuts = [row[var_keys[0]] for row in sel_weekly_data_dicts]
Fms = [x/n if n > 0 else 0.5 for x, n in zip(Nmuts, Ntots)]
cx.scatter(ws, Fms, color='k', marker='_')
for w, Nmut, Ntot in zip(ws, Nmuts, Ntots):
    l, u = sts.beta.interval(0.95, Nmut+0.5, Ntot-Nmut+0.5)
    cx.plot([w,w], [l,u], color='k', alpha=0.5)


dates = [dd["date"] for dd in sel_weekly_data_dicts]
date_strs = [date.strftime("%m-%d-%Y") for date in dates]

dt = 2
cx.set_xticks(ws[::dt])
cx.set_xticklabels(date_strs[::dt], rotation=45, ha='right')

ax.set_ylabel("cases")
bx.set_ylabel("deaths", color='red')

ax.set_title(region)


fig.savefig(f"../data/out/figures/sars-weekly-incidence-data-{region}.pdf", 
            bbox_inches="tight")

In [None]:
## find date of first mutant case

next(filter(lambda dd: dd[var_keys[0]] > 0, sel_weekly_data_dicts))

In [None]:
## select start and end dates

start_date = datetime.datetime.strptime("2020-02-25", "%Y-%m-%d") ## NL and UK D614G
#start_date = datetime.datetime.strptime("2020-09-01", "%Y-%m-%d") ## NL and UK N501Y
#start_date = datetime.datetime.strptime("2020-09-22", "%Y-%m-%d") ## ZA
## WARNING! make sure this starts on a Tuesday!

if aa_pos == 614:
    end_date = datetime.datetime.strptime("2020-08-24", "%Y-%m-%d") ## NL and UK D614G
elif aa_pos == 501:
    if region == "United Kingdom":
        end_date = datetime.datetime.strptime("2021-02-22", "%Y-%m-%d") ## UK N501Y
    elif region == "Netherlands":
        end_date = datetime.datetime.strptime("2021-02-08", "%Y-%m-%d") ## NL N501Y
    else:
        print("WARNING! no end date specified for region " + region)
        end_date = None
else:
    print("WARNING! no end date specified for variant " + str(aa_pos))
    end_date = None
    
end_date_epi = None

if start_date.weekday() != 1:
    print("WARNING! make sure data stream starts on a Tuesday!")


In [None]:
## make a simple file with the data for estavoir
date_dict = {dd["date"] : dd for dd in sel_weekly_data_dicts}
    
with open(f"../data/in/sars2-weekly-incidence-{region_nospace}-{aa_pos}.tsv", 'w') as f:
    for row in sel_data_dicts:
        d = row["date"]
        if d < start_date or (end_date is not None and d > end_date):
            ## skip dates before a chosen starting point or after chosen end data
            continue
        if d in date_dict.keys():      
            ev = "[RESET_CASES]"
            wdd = date_dict[d]
            if end_date_epi is None or d <= end_date_epi:
                #dI, dIc = wdd['cases_incidence'], defn.uncensored_code
                dI, dIc = 0, defn.missing_code ## TESTING: don't use cases
                dD, dDc = wdd['deaths_incidence'], defn.uncensored_code
            else:
                dI, dIc = 0, defn.missing_code
                dD, dDc = 0, defn.missing_code
        else:
            ev = "[]"
            dI, dIc = 0, defn.missing_code
            dD, dDc = 0, defn.missing_code
        Nm = row[var_keys[0]]
        Nmc = defn.uncensored_code
        N = row[total_key]
        Nc = defn.uncensored_code
        ID = region_nospace
        t = row["t"]
        line = f"{ID}\t{t}\t{ev}\t{dI}\t{dIc}\t{dD}\t{dDc}\t{Nm}\t{Nmc}\t{N}\t{Nc}\n"
        f.write(line)

In [None]:
## make a simple file with the data for estavoir

date_dict = {dd["date"] : dd for dd in sel_weekly_data_dicts}
    
with open(f"../data/in/sars2-seq-death-week-{region_nospace}-{aa_pos}.tsv", 'w') as f:
    for row in sel_weekly_data_dicts:
        d = row["date"]
        if d < start_date or (end_date is not None and d > end_date):
            ## skip dates before a chosen starting point or after chosen end data
            continue
        ev = "[RESET_CASES]"
        if end_date_epi is None or d <= end_date_epi:
            dD, dDc = row['deaths_incidence'], defn.uncensored_code
            dI, dIc = row['cases_incidence'], defn.uncensored_code
        else:
            dD, dDc = 0, defn.missing_code
            dI, dIc = 0, defn.missing_code
        Nm = row[var_keys[0]]
        Nmc = defn.uncensored_code
        N = row[total_key]
        Nc = defn.uncensored_code
        ID = region_nospace
        t = row["t"]
        line = f"{ID}\t{t}\t{ev}\t{dI}\t{dIc}\t{dD}\t{dDc}\t{Nm}\t{Nmc}\t{N}\t{Nc}\n"
        f.write(line)

In [None]:
## compute the day number of border closing

date_close = datetime.datetime.strptime("12-21-2020", "%m-%d-%Y")

print("day number border closing:", (date_close-day0).days)

### Plot the results from SMC

In [None]:
#pfout_file = "../data/out/ipf_result-sars_model.xml"

#pfout_file = "../data/out/wk-seq/NL/D614G/ipf_result-sars_model_NL-614-wk.xml"
#pfout_file = "../data/out/wk-seq/UK/D614G/ipf_result-sars_model_UK-614-wk.xml"

#pfout_file = "../data/out/wk-seq/UK/B117/ipf_result-sars_model_UK-501-wk.xml"
#pfout_file = "../data/out/wk-seq/NL/B117/ipf_result-sars_model_NL-501-wk.xml"

#pfout_file = "../data/out/wk-seq/NL/B117/ipf_result-sars_model_NL-501-wk_sigma=0.7.xml"

#pfout_file = "../data/out/wk-seq/UK/B117/ipf_result-sars_model_UK-501-long-wk.xml"
pfout_file = "../data/out/wk-seq/NL/B117/ipf_result-sars_model_NL-501-long-wk.xml"


#pfout_file = "../data/out/wk-seq/UK/D614G/ipf_result-sars_model_UK-614-wk-long.xml"
#pfout_file = "../data/out/wk-seq/NL/D614G/ipf_result-sars_model_NL-614-wk-long.xml"


### OLD FILES

#pfout_file = "../data/out/ipf_result-sars_model_UK-501_10K.xml"
#pfout_file = "../data/out/ipf_result-sars_model_NL-501_10K.xml"

#pfout_file = "../data/out/ipf_result-sars_model_NL-614_10K_short.xml"
#pfout_file = "../data/out/ipf_result-sars_model_UK-614_10K_short.xml"

#pfout_file = "../data/out/ipf_result-sars_model-NL-fit_sigma_od.xml"
#pfout_file = "../data/out/ipf_result-sars_model-UK-fit_sigma_od.xml"
#pfout_file = "../data/out/ipf_result-sars_model-fit_sigma_od_r-UK.xml"
#pfout_file = "../data/out/ipf_result-sars_model-fit_sigma_od_r-NL.xml"
#pfout_file = "../data/out/ipf_result-sars_model-NLnu.xml"
#pfout_file = "../data/out/ipf_result-sars_model_sigma=0.36.xml"
#pfout_file = "../data/out/ipf_result-sars_model-migration.xml"
#pfout_file = "../data/out/ipf_result-sars_model-UK_N501Y_10K.xml"
#pfout_file = "../data/out/ipf_result-sars_model-UK_D614G_10K.xml"
#pfout_file = "../data/out/ipf_result-sars_model-NL_614_10K.xml"

In [None]:
idx = -1 ## select one of the PF iterations

#pf_data = pftools.extract_pfilter_data(pfout_file, idx=idx, parnames=parnames)
pf_data = pftools.extract_pfilter_data(pfout_file)
parnames = pf_data["parnames"]

In [None]:
fit_dicts = {}  

for r, ID in enumerate(pf_data["pfIDs"]):
    ## make a dictionary
    fit_dicts[ID] = {
        "params" : pftools.makeParDict(r, parnames, pf_data["param_medians"], 
                                       pf_data["param_ranges"])
        ## TODO: other elements?
    }
    
## TODO: extract hyper parameters

In [None]:
stat_dicts = {}

for r, ID in enumerate(pf_data["pfIDs"]):
    ts = [float(pf.attrib["t"]) for pf in pf_data["particle_filters"][ID]]
    Jeffs = [int(pf.attrib["J_eff"]) for pf in pf_data["particle_filters"][ID]]
    Jcofs = [int(pf.attrib["J_coffin"]) for pf in pf_data["particle_filters"][ID]]
    Jsims = [float(pf.attrib["J_inv_simpson"]) for pf in pf_data["particle_filters"][ID]]
    cLLs = [float(pf.attrib["cond_LL_hat"]) for pf in pf_data["particle_filters"][ID]]
    stat_dicts[ID] = {
        "ts" : ts,
        "Jeffs" : Jeffs,
        "Jsims" : Jsims,
        "Jcofs" : Jcofs,
        "cLLs" : cLLs,
    }

In [None]:
## Figure for manuscript

## OPTIONS

INCL_INSET = False


fig, axs = plt.subplots(3, figsize=(10,12), sharex=True)

ID = pf_data["pfIDs"][0] ## select single ID

###### TRAJECTORIES OF LATENT VARIABLES ######

## plot trajectories on a linear scale

varnames = ["Iw", "Im", "H"]
pretty_varnames = ["$I_w$", "$I_m$", "$H$"]

trajcolors = ['tab:orange', 'tab:blue', 'tab:red']

alpha_traj = 0.7

ax = axs[0]
if INCL_INSET:
    axins = inset_axes(ax, width="12%", height="30%", loc=2)
## PLOT FILTERED PATHS
for j, path in enumerate(pf_data["paths"][ID]):
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    for color, X, lab in zip(trajcolors, varnames, pretty_varnames):
        Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
        ## plot
        kwargs = {"label" : lab} if j == 0 else {}
        ax.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
        if INCL_INSET:
            axins.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
            ## restricy limits of axins
            axins.set_xlim(ts[0], 75)
            axins.set_ylim(-50, 2e3)
            axins.yaxis.set_label_position("right")
            axins.yaxis.tick_right()
            axins.tick_params(axis='both', which='major', labelsize='x-small')
## labels
ax.set_ylabel("population size")
leg = ax.legend(ncol=1, loc=2, fontsize='small')
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    lh.set_linewidth(1)
ax.yaxis.set_major_formatter(ticker.FuncFormatter(evplot.y_fmt))


###### DATA AND PREDICTIONS ######

trajcolor = ["pink", "deepskyblue"]
varcolor = ['purple', 'tab:blue']
obsvarnames = ['D', 'Fm']
data_markers = ['o', '_']
#legend_locs = [1, 4]
#legend_locs = [1, 1]
legend_locs = [2, 2]
data_colors = ['w', 'lightgray']
obslabels = ["deaths", "mutant frequency"]
yscales = ['linear', 'linear']
alpha=0.7
dt = 1

Ob = len(obsvarnames)


print("sampled", len(pf_data["paths"][ID]), "trajectories for", ID)
## PLOT DATA
# deaths
ax = axs[1]
ws = [row["t"] for row in sel_weekly_data_dicts]
Ds = [row["deaths_incidence"] for row in sel_weekly_data_dicts]
ax.scatter(ws, Ds, color='k', edgecolor='k', zorder=4, label='data', s=20)    
# mutant freq
ax = axs[2]
ts = [row["t"] for row in sel_data_dicts if row[total_key] > 0]
Fms = [row[var_keys[0]] / n for row in sel_data_dicts if (n := row[total_key]) > 0]
## CIs
lFms = [sts.beta.ppf(0.025, row[var_keys[0]]+0.5, n - row[var_keys[0]]+0.5) 
        for row in sel_data_dicts if (n := row[total_key]) > 0]
uFms = [sts.beta.ppf(0.975, row[var_keys[0]]+0.5, n - row[var_keys[0]]+0.5) 
        for row in sel_data_dicts if (n := row[total_key]) > 0]
for t, l, u in zip(ts, lFms, uFms):
    ax.plot([t,t], [l,u], color='k', alpha=0.3)
ax.scatter(ts, Fms, color='k', edgecolor='k', zorder=4, label='data', s=20,  marker='|')
## PLOT FILTERED PATHS
for path in pf_data["paths"][ID]:
    for i, X in enumerate(obsvarnames):
        ## extract timeseries
        xs = path.findall("state")
        ts = [float(x.attrib["t"]) for x in xs]
        Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
        ## plot
        ax = axs[i+1]
        ax.plot(ts, Xs, color=trajcolor[i], alpha=alpha, linewidth=0.5, zorder=1)
## PLOT PREDICTION RANGES
ts = [float(x.attrib["t"]) for x in pf_data["pred_medians"][ID]]
for i, X in enumerate(obsvarnames):
    if X in ["C", "D", 'Fm']:
        ws = [row["t"] for row in sel_weekly_data_dicts]
        mask = [False if t in ws else True for t in ts]
    else:
        mask = None
    ax = axs[i+1]
    rans = pf_data["ranges"][ID]
    Xs_ran = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in ran]
              for ran in rans]
    Xs_pred = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["pred_medians"][ID]]
    Xs_filt = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["filter_medians"][ID]]
    evplot.pfilter_boxplot(ax, ts, Xs_ran, Xs_pred, Xs_filt, mask=mask,
                           color=varcolor[i], dt=dt)
## lims and labels
for i, X in enumerate(obsvarnames):
    ax = axs[i+1]
    #ax.set_ylim(0, ylims[i])
    ax.set_yscale(yscales[i])
    ## labels
    ax.set_ylabel(obslabels[i])
    ## Legend
    legend_elements = [
        Line2D([0], [0], marker=data_markers[i], color=data_colors[i], label='data', 
               markerfacecolor='k', markeredgecolor='k', markersize=7),
        Line2D([0], [0], color=varcolor[i], label='model'),
    ]
    ax.legend(handles=legend_elements, ncol=1, fontsize='small', loc=legend_locs[i])
    
axs[2].set_ylim(-0.05, 1.05)
axs[0].set_title(region)
axs[-1].set_xlabel(f"days since {day0.strftime('%m-%d-%Y')}")

dx_num = -0.15

axs[0].text(dx_num, 1.0, "A", fontsize=24, transform=axs[0].transAxes)
axs[1].text(dx_num, 1.0, "B", fontsize=24, transform=axs[1].transAxes)
axs[2].text(dx_num, 1.0, "C", fontsize=24, transform=axs[2].transAxes)

#axs[1].set_ylim(0, 400)

#axs[0].set_yscale("log")

tmin = (start_date - day0).days-3
#tmax = 250
tmax = 430
#tmax = 360
axs[-1].set_xlim(tmin, tmax)
   
    
fig.align_ylabels(axs)
    
fig.savefig(f"../data/out/figures/pf-traj-fit-sars_model-{region_nospace}.pdf", 
            bbox_inches="tight")

In [None]:
## plot PF statistics

R = len(pf_data["pfIDs"])

fig, axs = plt.subplots(R, 1, figsize=(10,3*R))

if R == 1:
    axs = np.array([axs])

bxs = [ax.twinx() for ax in axs]

for i, ID in enumerate(pf_data["pfIDs"]):
    stat_dict = stat_dicts[ID]
    ax = axs[i]
    ax.plot(stat_dict["ts"], stat_dict["cLLs"], color='k')
    ## plot PF statistics
    bx = bxs[i]
    bx.plot(stat_dict["ts"], stat_dict["Jeffs"], color='tab:red', alpha=0.7)
    #bx.plot(stat_dict["ts"], stat_dict["Jsims"], color='tab:purple')
    ax.set_ylabel("conditional\nlog-likelihood")
    bx.set_ylabel("effective\nswarm size", color='red')
    
## share axes after the fact
bxs[0].get_shared_y_axes().join(*bxs)
bxs[0].autoscale(axis='y')
    

fig.savefig("../data/out/figures/pf-stats-panel-sars_model.pdf",
            bbox_inches="tight", dpi=300)

### SMC Diagnostics
Conditional likelihood and effective number of particles

In [None]:
sel_parnames = parnames

fig, axs = plt.subplots(len(sel_parnames)+1, 1, 
                        figsize=(14,2*len(sel_parnames)), 
                        sharex=True)

## FIXME: plot a single run for the evolution of a parameter

for ID in pf_data["pfIDs"]:
    paramss = fit_dicts[ID]["params"]
    for i, pn in enumerate(sel_parnames):
        ax = axs[i]
        meds = paramss[pn]["median"]
        rans = paramss[pn]["range"]
        ms = range(len(meds))
        evplot.range_plots(ax, ms, *aux.unzip(rans), dt=0.1, zorder=1)
        ax.scatter(ms, meds, color='r', marker='_', zorder=2) 
        ax.set_ylabel(pn)
    
    
## plot log-likelihood

ll_dicts = [xs.find("log_lik").attrib for xs in pf_data["iterf_steps"]]
ll_vals = [float(d["val"]) for d in ll_dicts]
ll_colors = ['k' if d["finite"] == 'true' else 'red' for d in ll_dicts]
ms = range(len(ll_dicts))
axs[-1].scatter(ms, ll_vals, color=ll_colors)
axs[-1].set_ylabel("LL")
axs[-1].set_xlabel("iteration")
#axs[-1].set_ylim(-500, -300)

from scipy.interpolate import UnivariateSpline

if len(ms) > 2:
    lb = -np.infty
    ## filter out some mistakes
    fms = [m for ll, m in zip(ll_vals, ms) if ll > lb]
    flls = [ll for ll, m in zip(ll_vals, ms) if ll > lb]
    cs = UnivariateSpline(fms, flls, s=1e4)
    xs = np.linspace(ms[0], ms[-1], 1000)
    axs[-1].plot(xs, cs(xs), label='spline', color='red', linewidth=2)

print("final LL:", ll_vals[-1])

fig.savefig("../data/out/figures/traces-panel-sars_model.pdf", 
            bbox_inches="tight", dpi=300)

In [None]:
## Print some parameter estimates...

par_ests = {
    ID: {
        k : v["median"][-1]
        for k, v in fit_dicts[ID]["params"].items()
    } 
    for ID in pf_data["pfIDs"]
}

for ID in par_ests.keys():
    for k, v in par_ests[ID].items():
        print(f"{ID} -> {k}: {v:0.3g}")
    print('-------')

### Get estimate of absolute prevalence of mutant at a particular time

In [None]:
fprevs = [] ## to be filled below...

target_time = (date_close-day0).days

for path in pf_data["paths"][ID]:
    ## extract timeseries
    xs = path.findall("state")
    ts = np.array([float(x.attrib["t"]) for x in xs])
    idx = np.argmin(np.abs(ts - target_time))
    Im = float(xs[idx].find(f"var_vec[@name='Im']/var").attrib["val"])
    Em = float(xs[idx].find(f"var_vec[@name='Em']/var").attrib["val"])
    fprevs.append(Im + Em)

print(np.mean(fprevs))
np.percentile(fprevs, [0,100])

### approximate cases and mutant frequency with piecewise linear function

In [None]:
## get mutant frequency data

fig, ax = plt.subplots(1, 1, figsize=(7,5))

fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][ID][0].findall("state")]

for path in pf_data["paths"][ID]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    Xs = [float(x.find(f"var_vec[@name='Fm']/var").attrib["val"]) for x in xs]
    fpaths.append(Xs)
    ## plot trajectories
    ax.plot(ts, Xs, color='tab:blue', alpha=0.2)
    
mean_fpath = np.mean(fpaths, axis=0)
ax.plot(ftimes, mean_fpath, color='k')

## reduce number of points

print("number of time points:", len(ftimes))

stride = 15

ftimes_red = ftimes[::stride] + [ftimes[-1]]
mean_fpath_red = list(mean_fpath[::stride]) + [mean_fpath[-1]]

ax.plot(ftimes_red, mean_fpath_red, color='red')

print("number of break points:", len(ftimes_red))

## write values to file

with open("../Fm_lin_fun.txt", 'w') as f:
    f.write(' '.join(map(str, ftimes_red)) + '\n')
    f.write(' '.join(map(str, mean_fpath_red)) + '\n')

In [None]:
## get relative prevalence data

fig, ax = plt.subplots(1, 1, figsize=(7,5))

fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][ID][0].findall("state")]

for path in pf_data["paths"][ID]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    varnames = ["S", "I", "E", "H", "R"]
    Xss = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
           for X in varnames]
    Ns = np.sum(Xss, axis=0)
    Is = Xss[1]
    relIs = [I/N for I, N in zip(Is, Ns)]
    fpaths.append(relIs)
    ## plot trajectories
    ax.plot(ts, relIs, color='tab:blue', alpha=0.2)
    
mean_fpath = np.mean(fpaths, axis=0)
ax.plot(ftimes, mean_fpath, color='k')

## reduce number of points

print("number of time points:", len(ftimes))

stride = 15

ftimes_red = ftimes[::stride] + [ftimes[-1]]
mean_fpath_red = list(mean_fpath[::stride]) + [mean_fpath[-1]]

ax.plot(ftimes_red, mean_fpath_red, color='red')

print("number of break points:", len(ftimes_red))

## write values to file

with open("../I_lin_fun.txt", 'w') as f:
    f.write(' '.join(map(str, ftimes_red)) + '\n')
    f.write(' '.join(map(str, mean_fpath_red)) + '\n')
    
    
## TODO: write both timeseries to one file!

# $R_e$ the effective reproduction number

In [None]:
def average_Re(beta, gamma, nu, s, Fm, S, N):
    return beta/(gamma+nu) * S/N * ((1-Fm) + Fm*(1+s))

def Hv(t, u):
    return scipy.special.expit(t/u)

def betat(t, bvec, tvec, uvec):
    ls = [1] + [Hv(t-ti, uvec[i]) for i, ti in enumerate(tvec)]
    rs = [(1-Hv(t-ti, uvec[i])) for i, ti in enumerate(tvec)] + [1]
    return np.sum([b*l*r for b, l, r in zip(bvec, ls, rs)])


## define parameters
ID = list(par_ests.keys())[0]

num_breakpoints = 3

bvec = [par_ests[ID][f"beta{i}"] for i in range(num_breakpoints+1)]
tvec = [par_ests[ID][f"t{i}"] for i in range(1,num_breakpoints+1)]
s = par_ests[ID]["sigma"]
## FIXED parameters
uvec = [1.189, 1.189, 1.189]
gamma = 0.25
nu = 0.005

fig, (bx, ax) = plt.subplots(2, 1, figsize=(7,5), sharex=True)

fpaths_av = [] ## to be filled below...
fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][ID][0].findall("state")]

for path in pf_data["paths"][ID]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    varnames = ["S", "Ew", "Em", "Iw", "Im", "H", "R"]
    Xss = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
           for X in varnames]
    Ns = np.sum(Xss, axis=0)
    Iws, Ims = np.array(Xss[3]), np.array(Xss[4]) ## TODO: add Em and Ew to the equation?
    Fms = Ims / (Iws + Ims)
    Ss = Xss[0]
    bts = [betat(t, bvec, tvec, uvec) for t in ts]
    avRe = [average_Re(bt, gamma, nu, s, Fm, S, N) for (S, Fm, N, bt) in zip(Ss, Fms, Ns, bts)]
    Re = [average_Re(bt, gamma, nu, 0, Fm, S, N) for (S, Fm, N, bt) in zip(Ss, Fms, Ns, bts)]
    fpaths.append(Re)
    fpaths_av.append(avRe)
    ## plot trajectories
    
mean_fpath = np.mean(fpaths, axis=0)
ci_fpath = np.percentile(fpaths, axis=0, q=[2.5,97.5])
ax.plot(ftimes, mean_fpath, color='k')
ax.fill_between(ftimes, *ci_fpath, linewidth=0, color='k', alpha=0.3)

mean_fpath_av = np.mean(fpaths_av, axis=0)
ci_fpath_av = np.percentile(fpaths_av, axis=0, q=[2.5,97.5])
ax.plot(ftimes, mean_fpath_av, color='r')
ax.fill_between(ftimes, *ci_fpath_av, linewidth=0, color='r', alpha=0.3)

ax.axhline(y=1, color='k', linestyle='--')

## add dates to x-axis

xmin, xmax = int(np.min(ftimes)), int(np.max(ftimes))

xmin += 7 ## correction...
dt = 14

xticks = [xmin + dt * i for i in range(0,(xmax-xmin)//dt+2)]

xdates = [day0 + datetime.timedelta(days=d) for d in xticks]
xticklabels = [d.strftime("%b %d") for d in xdates]

ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')

ax.set_ylabel("$R_e$")
ax.set_xlabel("date")

## plot the time-varying infection rate beta(t)

betats = [betat(t, bvec, tvec, uvec) for t in ftimes]
bx.plot(ftimes, betats, color='k', linewidth=2, zorder=2)

def hybrid_trans_x(ax):
    return blended_transform_factory(ax.transData, ax.transAxes)

def hybrid_trans_y(ax):
    return blended_transform_factory(ax.transAxes, ax.transData)

W = 7

for i, t in enumerate(tvec):
    bx.axvspan(t-W/2, t+W/2, zorder=1, color='k', alpha=0.3, linewidth=0)
    bx.text(t, 1, f"$t_{i+1}$", transform=hybrid_trans_x(bx), ha='center', va='bottom')

for i, b in enumerate(bvec):
    bx.text(1.01, b, f"$\\beta_{i}$", transform=hybrid_trans_y(bx), va='center', ha='left')
    if i < len(tvec):
        xmin = tvec[i]-W if i < len(tvec) else max(ftimes)
        xmax = max(ftimes)
        bx.plot([xmin, xmax], [b,b], linestyle='--', linewidth=0.5, color='k', zorder=2)
    
bx.set_ylabel("$\\beta$")
bx.set_ylim(-0.025, 0.425)

dx_num = -0.15

ax.text(dx_num, 1.0, "B", fontsize=24, transform=ax.transAxes)
bx.text(dx_num, 1.0, "A", fontsize=24, transform=bx.transAxes)

fig.savefig("../data/out/figures/B117_UK_Re.pdf", bbox_inches='tight')

In [None]:
plt.axhline?