# Compute the effective reproduction number

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json
import scipy.special

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker
from matplotlib.transforms import blended_transform_factory


from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

### Import SMC results

In [None]:
region = "United_Kingdom"
variant = "B.1.1.7"

pfout_file = f"../data/out/ipf_result-sars_model_{region}_{variant}.xml"

pf_data = pftools.extract_pfilter_data(pfout_file)

In [None]:
## extract parameter estimates

parnames = pf_data["parnames"]
fit_dicts = {}  

for r, ID in enumerate(pf_data["pfIDs"]):
    ## make a dictionary
    fit_dicts[ID] = {
        "params" : pftools.makeParDict(r, parnames, pf_data["param_medians"], 
                                       pf_data["param_ranges"])
        ## other elements?
    }
    
par_ests = {
    ID: {
        k : v["median"][-1]
        for k, v in fit_dicts[ID]["params"].items()
    } 
    for ID in pf_data["pfIDs"]
}

In [None]:
day0 = datetime.datetime.strptime("01-01-2020", "%m-%d-%Y")

## compute the day number of border closing

date_close = datetime.datetime.strptime("12-21-2020", "%m-%d-%Y")

days_close = (date_close-day0).days

print("day number border closing:", days_close)

In [None]:
## find prevalence at a particular day

fprevs = [] ## to be filled below...

target_time = (date_close-day0).days

for path in pf_data["paths"][region]:
    ## extract timeseries
    xs = path.findall("state")
    ts = np.array([float(x.attrib["t"]) for x in xs])
    idx = np.argmin(np.abs(ts - target_time))
    Im = float(xs[idx].find(f"var_vec[@name='Im']/var").attrib["val"])
    Em = float(xs[idx].find(f"var_vec[@name='Em']/var").attrib["val"])
    fprevs.append(Im + Em)

print(np.mean(fprevs))
np.percentile(fprevs, [0,100])

In [None]:
## get mutant frequency data

fig, ax = plt.subplots(1, 1, figsize=(7,5))

fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][region][0].findall("state")]

for path in pf_data["paths"][region]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    Xs = [float(x.find(f"var_vec[@name='Fm']/var").attrib["val"]) for x in xs]
    fpaths.append(Xs)
    ## plot trajectories
    ax.plot(ts, Xs, color='tab:blue', alpha=0.2)
    
mean_fpath = np.mean(fpaths, axis=0)
ax.plot(ftimes, mean_fpath, color='k')

## reduce number of points

print("number of time points:", len(ftimes))

stride = 5

ftimes_red_mfr = ftimes[::stride] + [ftimes[-1]]
mean_fpath_red_mfr = list(mean_fpath[::stride]) + [mean_fpath[-1]]

ax.plot(ftimes_red_mfr, mean_fpath_red_mfr, color='red')

print("number of break points:", len(ftimes_red_mfr))

In [None]:
## get relative prevalence data

fig, ax = plt.subplots(1, 1, figsize=(7,5))

fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][region][0].findall("state")]

for path in pf_data["paths"][region]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    varnames = ["S", "I", "E", "H", "R"]
    Xss = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
           for X in varnames]
    Ns = np.sum(Xss, axis=0)
    Is = Xss[1]
    relIs = [I/N for I, N in zip(Is, Ns)]
    fpaths.append(relIs)
    ## plot trajectories
    ax.plot(ts, relIs, color='tab:blue', alpha=0.2)
    
mean_fpath = np.mean(fpaths, axis=0)
ax.plot(ftimes, mean_fpath, color='k')

## reduce number of points

print("number of time points:", len(ftimes))

stride = 5

ftimes_red_prev = ftimes[::stride] + [ftimes[-1]]
mean_fpath_red_prev = list(mean_fpath[::stride]) + [mean_fpath[-1]]

#ax.plot(ftimes_red_prev, mean_fpath_red_prev, color='red')

print("number of break points:", len(ftimes_red_prev))

## implement a closing date

txs = [(t, x if t < days_close else 0.0) 
       for t, x in zip(ftimes_red_prev, mean_fpath_red_prev)]

ftimes_red_prev = [t for t, x in txs]
mean_fpath_red_prev = [x for t, x in txs]

ax.plot(ftimes_red_prev, mean_fpath_red_prev, color='tab:red', zorder=2)

ax.axvline(x=days_close, color='k', zorder=1)

In [None]:
## write data to file

with open(f"../external_foi_{region}_{variant}.txt", 'w') as f:
    ## mut freq
    f.write(' '.join(map(str, ftimes_red_mfr)) + '\n')
    f.write(' '.join(map(str, mean_fpath_red_mfr)) + '\n')
    ## rel prevalence
    f.write(' '.join(map(str, ftimes_red_prev)) + '\n')
    f.write(' '.join(map(str, mean_fpath_red_prev)) + '\n')


### Compute Re and make a figure

In [None]:
def average_Re(beta, gamma, nu, s, Fm, S, N):
    return beta/(gamma+nu) * S/N * ((1-Fm) + Fm*(1+s))

def Hv(t, u):
    return scipy.special.expit(t/u)

def betat(t, bvec, tvec, uvec):
    ls = [1] + [Hv(t-ti, uvec[i]) for i, ti in enumerate(tvec)]
    rs = [(1-Hv(t-ti, uvec[i])) for i, ti in enumerate(tvec)] + [1]
    return np.sum([b*l*r for b, l, r in zip(bvec, ls, rs)])


## define parameters
ID = list(par_ests.keys())[0]

num_breakpoints = 3

bvec = [par_ests[ID][f"beta{i}"] for i in range(num_breakpoints+1)]
tvec = [par_ests[ID][f"t{i}"] for i in range(1,num_breakpoints+1)]
s = par_ests[ID]["sigma"]
## FIXED parameters
uvec = [1.189, 1.189, 1.189]
gamma = 0.25
nu = 0.005

fig, (bx, ax) = plt.subplots(2, 1, figsize=(7,5), sharex=True)

fpaths_av = [] ## to be filled below...
fpaths = [] ## to be filled below...
ftimes = [float(x.attrib["t"]) for x in pf_data["paths"][ID][0].findall("state")]

for path in pf_data["paths"][ID]:
    ## extract timeseries
    xs = path.findall("state")
    ts = [float(x.attrib["t"]) for x in xs]
    varnames = ["S", "Ew", "Em", "Iw", "Im", "H", "R"]
    Xss = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
           for X in varnames]
    Ns = np.sum(Xss, axis=0)
    Iws, Ims = np.array(Xss[3]), np.array(Xss[4]) ## TODO: add Em and Ew to the equation?
    Fms = Ims / (Iws + Ims)
    Ss = Xss[0]
    bts = [betat(t, bvec, tvec, uvec) for t in ts]
    avRe = [average_Re(bt, gamma, nu, s, Fm, S, N) for (S, Fm, N, bt) in zip(Ss, Fms, Ns, bts)]
    Re = [average_Re(bt, gamma, nu, 0, Fm, S, N) for (S, Fm, N, bt) in zip(Ss, Fms, Ns, bts)]
    fpaths.append(Re)
    fpaths_av.append(avRe)
    ## plot trajectories
    
mean_fpath = np.mean(fpaths, axis=0)
ci_fpath = np.percentile(fpaths, axis=0, q=[2.5,97.5])
ax.plot(ftimes, mean_fpath, color='k')
ax.fill_between(ftimes, *ci_fpath, linewidth=0, color='k', alpha=0.3)

mean_fpath_av = np.mean(fpaths_av, axis=0)
ci_fpath_av = np.percentile(fpaths_av, axis=0, q=[2.5,97.5])
ax.plot(ftimes, mean_fpath_av, color='r')
ax.fill_between(ftimes, *ci_fpath_av, linewidth=0, color='r', alpha=0.3)

ax.axhline(y=1, color='k', linestyle='--')

## add dates to x-axis

xmin, xmax = int(np.min(ftimes)), int(np.max(ftimes))

xmin += 7 ## correction...
dt = 14

xticks = [xmin + dt * i for i in range(0,(xmax-xmin)//dt+2)]

xdates = [day0 + datetime.timedelta(days=d) for d in xticks]
xticklabels = [d.strftime("%b %d") for d in xdates]

ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')

ax.set_ylabel("$R_e$")
ax.set_xlabel("date")

## plot the time-varying infection rate beta(t)

betats = [betat(t, bvec, tvec, uvec) for t in ftimes]
bx.plot(ftimes, betats, color='k', linewidth=2, zorder=2)

def hybrid_trans_x(ax):
    return blended_transform_factory(ax.transData, ax.transAxes)

def hybrid_trans_y(ax):
    return blended_transform_factory(ax.transAxes, ax.transData)

W = 7

for i, t in enumerate(tvec):
    bx.axvspan(t-W/2, t+W/2, zorder=1, color='k', alpha=0.3, linewidth=0)
    bx.text(t, 1, f"$t_{i+1}$", transform=hybrid_trans_x(bx), ha='center', va='bottom')

for i, b in enumerate(bvec):
    bx.text(1.01, b, f"$\\beta_{i}$", transform=hybrid_trans_y(bx), va='center', ha='left')
    if i < len(tvec):
        xmin = tvec[i]-W if i < len(tvec) else max(ftimes)
        xmax = max(ftimes)
        bx.plot([xmin, xmax], [b,b], linestyle='--', linewidth=0.5, color='k', zorder=2)
    
bx.set_ylabel("$\\beta$")
bx.set_ylim(-0.025, 0.425)

dx_num = -0.15

ax.text(dx_num, 1.0, "B", fontsize=24, transform=ax.transAxes)
bx.text(dx_num, 1.0, "A", fontsize=24, transform=bx.transAxes)

#fig.savefig("../data/out/figures/B117_UK_Re.pdf", bbox_inches='tight')