In [None]:
import numpy as np
import xarray as xr
import pickle
import pandas as pd
import time
import os

from collections import OrderedDict

import datetime
from sys import getsizeof,path

import matplotlib.pyplot as plt
import matplotlib

path.append("../src")
from Splines import CentripetalCatmullRomSpline_splitControls,Spline
import Cases
from Population import ImportPopulation
from ModelParams import ObservedData,ModelParam

import pymc3 as pm
import theano
import theano.tensor as tt


theano.config.gcc_cxxflags = "-Wno-c++11-narrowing"

import arviz as az

# Simple Example - recover control Params

In [None]:
old = False

coords = {"cpx":np.array([1,2,3,4,5,6,7],"float64"),"space":np.linspace(2.01,6.99,128,"float64")}
test_control = np.array([1,3,2,1,2,2,3],"float64")
test_spline = Spline(coords["cpx"],tt.cast(test_control,"float64"))

t1 = time.time()
test_values = test_spline.EvaluateAt(coords["space"],old=old).eval()
test_values += np.random.normal(0, .1, coords["space"].shape[0])
t2 = time.time()
print("Test values in %.3fs"%(t2-t1))

plt.plot(coords["space"],test_values,color="tab:red")
plt.plot(coords["cpx"],test_control,"rx")

with pm.Model(coords=coords) as model:
    
    sigma_control = pm.Lognormal("sigma_control",mu=tt.log(2),sigma=1.)
    
    est_control = pm.Normal("est_control",mu=0,sigma=sigma_control,dims=("cpx"))
    est_control = tt.exp(tt.cumsum(est_control))
    
    est_spline = Spline(coords["cpx"],est_control)
    est_values = est_spline.EvaluateAt(coords["space"],old=old)
    
    pm.Deterministic("est_values",est_values)

    sigma_obs = pm.HalfCauchy(name="sigma_obs",beta=2)
    pm.StudentT("est",nu=4,sigma=tt.abs_(est_values+1.)**.5 * sigma_obs,mu=est_values,observed=test_values)
    
    trace = pm.sample(init="advi",return_inferencedata=True,tune=200,draws=200,cores=4,chains=4,max_treedepth=12,target_accept=.95)

t3 = time.time()
print("Compiled and Sampled in %.3f"%(t3-t2))

q = trace.posterior["est_values"].quantile((0.5,.05,.95),["chain","draw"])
plt.plot(coords["space"],q[0],color="tab:blue")
plt.fill_between(coords["space"],q[1],q[2],alpha=.1,color="tab:blue")

In [None]:
az.plot_trace(trace.posterior["est_control"])

In [None]:
az.plot_trace(trace.posterior["sigma_control"])

In [None]:
az.plot_trace(trace.posterior["sigma_obs"])

# SEIR Model with R_eff from Spline

In [None]:
coords = {}
if False:
    coords["cpx"] = np.linspace(0,70,8,"float64")
    coords["space"] = np.linspace(8,51,44,"float64")
else:
    coords["cpx"] = pd.date_range(datetime.date(2020,1,30),datetime.date(2020,5,16),freq="SM")
    coords["space"] = pd.date_range(datetime.date(2020,2,20),datetime.date(2020,4,5),freq="D")

print(coords["cpx"],len(coords["cpx"]))
print(coords["space"],len(coords["space"]))

rtest_control = np.array([3,5,6,2,.9,.9,1.1,1.2],"float64")
rtest_spline = Spline(coords["cpx"],tt.cast(rtest_control,"float64"))


t1 = time.time()
rtest_tvalues = rtest_spline.EvaluateAt(coords["space"])
rtest_values = rtest_tvalues.eval()
rtest_values += np.random.normal(0, .1, coords["space"].shape[0])
t2 = time.time()
print("Test values in %.3fs"%(t2-t1))

print(rtest_values)

In [None]:
def tt_lognormal(x, mu, sigma):
# Limit to prevent NANs
    x = tt.clip(x,1e-9,1e12)
    sigma = tt.clip(sigma,1e-9,1e12)
    mu = tt.clip(mu,1e-9,1e12)
    
    distr = 1/x * tt.exp( -( (tt.log(x) - mu) ** 2) / (2 * sigma ** 2))
    return distr / (tt.sum(distr, axis=0) + 1e-12)

def SEIR_model(N, imported_t,Reff_t, median_incubation,sigma_incubation,l=32):
    N = tt.cast(N,'float64')
    beta = tt_lognormal(tt.arange(l), tt.log(median_incubation), sigma_incubation)
    
    # Dirty hack to prevent nan - seems not needed if priors are better
 #   beta = tt.alloc(0,l)
  #  beta = tt.set_subtensor(beta[tt.clip(tt.cast(median_incubation,'int32'),1,l-2)],1)
     
    Reff_t = tt.as_tensor_variable(Reff_t)
    imported_t = tt.as_tensor_variable(imported_t)

    def new_day(Reff_at_t,imported_at_t,infected,E_t,beta,N):
        f = E_t / N
     #   f = 1
        new = imported_at_t + tt.dot(infected,beta) * Reff_at_t * f
        new = tt.clip(new,0,N)
     
        infected = tt.roll(infected,1,0)
        infected = tt.set_subtensor(infected[:1],new,inplace=False)
        E_t = tt.clip(E_t-new,0,E_t)
#        E_t = E_t-new
        return new,infected,E_t
    
    outputs_info = [None,np.zeros(l),N]
    infected_t,updates = theano.scan(fn=new_day,
                                     sequences=[Reff_t,imported_t],
                                     outputs_info=outputs_info,
                                     non_sequences=[beta,N],
                                     profile=False)
                                     
    return infected_t

In [None]:
initial = tt.zeros_like(rtest_tvalues)
initial = tt.set_subtensor(initial[:5],tt.cast([6,6,4,2,1],"float64"))

new,infected,E_t = SEIR_model(83e8,initial,rtest_values,6,.3)

new_infected = new.eval()
new_infected += np.random.normal(0, .3, coords["space"].shape[0])

print(rtest_values)
print(new_infected)

plt.plot(coords["space"],rtest_values)
plt.plot(coords["space"],new_infected)

In [None]:
old = False
t2 = time.time()

with pm.Model(coords=coords) as model:
    
    sigma_control = pm.Lognormal("sigma_control",mu=tt.log(2),sigma=1.)
    
    est_control = pm.Normal("est_control",mu=0,sigma=sigma_control,dims=("cpx"))
    est_control = tt.exp(tt.cumsum(est_control))
    
    est_spline = Spline(coords["cpx"],est_control)
    est_values = est_spline.EvaluateAt(coords["space"],old=old)
    
    est_new,est_infected,est_E_t = SEIR_model(83e8,initial,est_values,6,.3)
    
    pm.Deterministic("est_control_sum",est_control)
    pm.Deterministic("est_values",est_values)
    pm.Deterministic("est_new",est_new)

    sigma_obs = pm.HalfCauchy(name="sigma_obs",beta=2)
    pm.StudentT("est",nu=4,sigma=tt.abs_(est_new+1.)**.5 * sigma_obs,mu=est_new,observed=new_infected)
    
    trace = pm.sample(init="advi",return_inferencedata=True,tune=200,draws=200,cores=4,chains=4,max_treedepth=12,target_accept=.95)

t3 = time.time()
print("Compiled and Sampled in %.3f"%(t3-t2))

In [None]:
q = trace.posterior["est_control_sum"].quantile((0.5,.05,.95),["chain","draw"])

plt.plot(coords["cpx"],q[0])
plt.fill_between(coords["cpx"],q[1],q[2],alpha=.1)

print(q)
plt.plot(coords["cpx"],rtest_control,"rx")

In [None]:
az.plot_trace(trace.posterior["est_control"])

In [None]:
q = trace.posterior["est_new"].quantile((0.5,.05,.95),["chain","draw"])

plt.plot(coords["space"],q[0])
plt.fill_between(coords["space"],q[1],q[2],alpha=.1)

print(q)
plt.plot(coords["space"],new_infected,"rx")

In [None]:
az.plot_trace(trace.posterior["sigma_control"])

In [None]:
az.plot_trace(trace.posterior["sigma_obs"])

# Hierachical SurvStat
Run Age,location stratified SEIR and compare with weekly reporting from SurvStat

- seasonality (12 monthly values) --> global R_0
- per week: walk representing gradual change in measures
- per BL: monthly diff (measures might be fast)
- per AG: monthly diff

100 weeks * 16 BL * 5 AG = 8k weekly values, 56k entries in R_eff-matrix (daily)

In [None]:
# Unfortunately not feasable even with HalfCauchy instead of Lognormal for the hyperparameter.
4-10x more time spend outside of sampling, than inside if number of days surpasses ~100 days. ~1h pre+post 

In [None]:
t0 = time.time()
old = False

start = datetime.datetime(2020,2,15)
end = datetime.datetime(2021,8,1)

coords = {}

coords["seasons_12month"] = range(1,13)
coords["days"] = pd.date_range(start,end,freq="D")
coords["weeks"] = pd.date_range(start,end,freq="W")
coords["months"] = pd.date_range(start,end,freq="M")
coords["BL"] = range(1,17)
coords["AG"] = np.array([0,20,40,60,80])

print(len(coords["days"]))

observed_cases = Cases.ObservedCasesFromSurvstat("210713")
observed_cases.RenameAxes({"age":"AG"})
pop = ObservedData("population",ImportPopulation().sum("sex")["31.12.2019"] * 1000.)
pop.RenameAxes({"age":"AG"})

In [None]:
with pm.Model(coords=coords) as model:

    # Seasonality
    season_sigma = pm.HalfCauchy("season_sigma",beta=1.)
    season_control = pm.Lognormal("season_control",mu=tt.log(1.),sigma=season_sigma,dims=("seasons_12month"))
    season_control24m = tt.concatenate([season_control,season_control])
    season_range24m = pd.date_range(start=datetime.datetime(2020,1,1),end=datetime.datetime(2022,1,1),freq="M")
    season_spline = Spline(season_range24m,season_control24m)
    
    season_daily = season_spline.EvaluateAt(coords["days"],old=old)
    
    # weekly walk
    measures_sigma = pm.HalfCauchy("measures_sigma",beta=1.)
    measures_value = pm.Normal("measures_value",mu=0.,sigma=measures_sigma,dims=("weeks"))
    measures_walk = tt.cumsum(measures_value,axis=0)
    measures_spline = Spline(coords["weeks"],measures_walk)
    
    measures_daily = measures_spline.EvaluateAt(coords["days"],old=old)
    pm.Deterministic("measures_daily",measures_daily)
    
    # Monthly Age-Diff
    if True:
        age_sigma = pm.HalfCauchy("age_sigma",beta=1.)
        age_value = pm.Normal("age_value",mu=0.,sigma=age_sigma,dims=("AG","months",))
        age_walk = tt.cumsum(age_value,axis=1)
        age_spline = Spline(coords["months"],age_walk)
    
        age_daily = age_spline.EvaluateAt(coords["days"],old=old).dimshuffle(1,0,'x')
    else:
        age_daily = tt.cast(0.,"float64").reshape((1,1,))
    pm.Deterministic("age_daily",age_daily)
    
    # Monthly BL-Diff 
    if True:
        BL_sigma = pm.HalfCauchy("BL_sigma",beta=1.)
        BL_value = pm.Normal("BL_value",mu=0.,sigma=BL_sigma,dims=("BL","months",))
        BL_walk = tt.cumsum(BL_value,axis=1)
        BL_spline = Spline(coords["months"],BL_walk)
        BL_daily = BL_spline.EvaluateAt(coords["days"],old=old).dimshuffle(1,'x',0)
    else:
        BL_daily = tt.cast(0,"float64").reshape((1,1,))
    pm.Deterministic("BL_daily",BL_daily)
    
    modsum = measures_daily.dimshuffle(0,'x','x')+age_daily+BL_daily
    R_eff = season_daily.reshape((len(coords["days"]),1,1,))*tt.exp(modsum)
    pm.Deterministic("R_eff",R_eff)
    
    # Initial Cases
    initial_length = 14
    initial_mag = pm.Lognormal("initial",mu=tt.log(10.),sigma=1.,dims=("AG","BL"))
    initial = tt.zeros_like(R_eff)
    imported = tt.set_subtensor(initial[:initial_length],tt.stack([initial_mag]*initial_length))
    
    # population
    imp_coord = OrderedDict()
    imp_coord["week"] = coords["days"]
    imp_coord["AG"] = coords["AG"]
    imp_coord["BL"] = coords["BL"]
    imported_mp = ModelParam("imported",imp_coord,imported)
    # Match population age-groups
    
    popAG,_,pop_coord = pop.Overlap(imported_mp,{"AG":"left"})
    
    
    pm.Deterministic("imported",imported)
    pm.Deterministic("population",popAG)
    
    # Dimensions = time x age x BL
  #  pm.Deterministic("sum",measures_daily+age_daily+BL_daily)
    
    t1 = time.time()
    print("parsing %.2f"%(t1-t0))
    #+adapt_diag
    trace = pm.sample(init="advi+adapt_diag",return_inferencedata=True,tune=400,draws=400,cores=4,chains=4,max_treedepth=12,target_accept=.95)
    t2 = time.time()
    print("toal %.1fs"%(t2-t1))

# advi 3.97, 103 total, season+measures
# advi+adapt_diag 1.62, 107.8 total, seaoson+measures
# advi 3.75, 51s sampling 166s total, season+measures+age
# advi 1.92, 51s sampling 166s total, season+measures+age
# BL as single Catmull 66.1

In [None]:
az.plot_trace(trace.posterior["season_sigma"])

In [None]:
q = trace.posterior["measures_daily"].quantile((0.5,.05,.95),["chain","draw"])

plt.plot(coords["days"],q[0])
plt.fill_between(coords["days"],q[1],q[2],alpha=.1)

In [None]:
q = trace.posterior["age_daily"][...,1,0].quantile((0.5,.05,.95),["chain","draw"])

print(q.shape)

for i in range(10):
    plt.plot(coords["days"],trace.posterior["age_daily"][0,i,:,0,0])
    plt.plot(coords["days"],trace.posterior["age_daily"][1,i,:,0,0],"--")

plt.plot(coords["days"],q[0])
plt.fill_between(coords["days"],q[1],q[2],alpha=.1)

In [None]:
q = trace.posterior["R_eff"].quantile((0.5,.05,.95),["chain","draw"])[...,1,1]
print(q.shape)

for i in range(10):
    plt.semilogy(coords["days"],trace.posterior["R_eff"][0,i,:,0,0])
    plt.semilogy(coords["days"],trace.posterior["R_eff"][1,i,:,0,0],"--")

ax = plt.plot(coords["days"],q[0])
#plt.fill_between(coords["days"],q[1],q[2],alpha=.1)

plt.ylim([.01,300])

In [None]:
q = trace.posterior["BL_daily"][...,0,0].quantile((0.5,.05,.95),["chain","draw"])

print(q.shape)

for i in range(10):
    plt.plot(coords["days"],trace.posterior["BL_daily"][0,i,:,0,0])
    plt.plot(coords["days"],trace.posterior["BL_daily"][1,i,:,0,0],"--")

plt.plot(coords["days"],q[0])
plt.fill_between(coords["days"],q[1],q[2],alpha=.1)

In [None]:
for k in ["measures_daily","age_daily","BL_daily","R_eff","imported","population"]:
    print(k,trace.posterior[k].shape)

In [None]:
i = trace.posterior["imported"][0,0]
r = trace.posterior["R_eff"][0,0]

print(i.shape)
print(r.shape)

In [None]:
def tt_lognormal(x, mu, sigma):
# Limit to prevent NANs
    x = tt.clip(x,1e-9,1e12)
    sigma = tt.clip(sigma,1e-9,1e12)
    mu = tt.clip(mu,1e-9,1e12)
    
    distr = 1/x * tt.exp( -( (tt.log(x) - mu) ** 2) / (2 * sigma ** 2))
    return distr / (tt.sum(distr, axis=0) + 1e-12)

def SEIR_model(N, imported_t,Reff_t, median_incubation,sigma_incubation,l=32):
    N = tt.cast(N,'float64')
    beta = tt_lognormal(tt.arange(l), tt.log(median_incubation), sigma_incubation)
    
    # Dirty hack to prevent nan - seems not needed if priors are better
 #   beta = tt.alloc(0,l)
  #  beta = tt.set_subtensor(beta[tt.clip(tt.cast(median_incubation,'int32'),1,l-2)],1)
     
    Reff_t = tt.as_tensor_variable(Reff_t)
    imported_t = tt.as_tensor_variable(imported_t)

    def new_day(Reff_at_t,imported_at_t,infected,E_t,beta,N):
        f = E_t / N
     #   f = 1
        new = imported_at_t + tt.dot(infected,beta) * Reff_at_t * f
        new = tt.clip(new,0,N)
     
        infected = tt.roll(infected,1,0)
        infected = tt.set_subtensor(infected[:1],new,inplace=False)
        E_t = tt.clip(E_t-new,0,E_t)
#        E_t = E_t-new
        return new,infected,E_t
    
    outputs_info = [None,np.zeros(l),N]
    infected_t,updates = theano.scan(fn=new_day,
                                     sequences=[Reff_t,imported_t],
                                     outputs_info=outputs_info,
                                     non_sequences=[beta,N],
                                     profile=False)
                                     
    return infected_t


