In [1]:
import numpy as np
import theano
import theano.tensor as tt
import pymc3 as pm
import datetime
import pandas as pd

import time
import pickle

import matplotlib.pyplot as plt
import matplotlib
theano.config.optimizer='fast_run'
theano.config.floatX = "float64"

# Legend-formating for matplotlib
def _format_k(prec):
    """
        format yaxis 10_000 as 10 k.
        _format_k(0)(1200, 1000.0) gives "1 k"
        _format_k(1)(1200, 1000.0) gives "1.2 k"
    """

    def inner(xval, tickpos):
        return f"${xval/1_000:.{prec}f}\,$k"

    return inner

In [2]:
def plain_SIR_model(beta,gamma,N,S0,I0,R0,l=150):
    """Simple classical SIR-Model with time-constant beta,gamma"""
    def next_day(t,St,It,Rt,oS,oI,oR,eff,beta,gamma,N):
        eff = beta*St/N
        bSIoN = eff * It # -beta * (S*I)/N
        gI = gamma*It
        # Calculate differences
        dS = -bSIoN
        dI = bSIoN - gI
        dR = gI
        # Update values
        nxt_St = St + dS
        nxt_It = It + dI
        nxt_Rt = Rt + dR
        
        nxt_St = tt.clip(nxt_St,0,1e16) # Improves sampling stability
        
        # Pass on to next step
        return nxt_St,nxt_It,nxt_Rt, dS,dI,dR, eff
        
    # Initial internal state
    zero = tt.cast(0.,dtype='float64')
    initial_state = [S0,I0,R0,zero,zero,zero,zero]
    sequences = [tt.arange(l)]
    # Exceute the 'Loop'
    state,_ = theano.scan(fn=next_day,
                            sequences=sequences,# time dependent sequences can be passed as first params
                            outputs_info=initial_state,
                            non_sequences=[beta,gamma,N])
    return state


def Gompertz(a,b,c,l=150,o=0):
    t = tt.arange(l)+tt.cast(o,dtype="float64")
    return a*tt.exp(-b*tt.exp(-c*t))

b,g,n,i0 = tt.scalar(dtype="float64"),tt.scalar(dtype="float64"),tt.scalar(dtype="float64"),tt.scalar(dtype="float64")
i = tt.scalar(dtype="int64")
z = tt.constant(0.,dtype='float64')
S_t,I_t,R_t,dS,dI,dR,eff = plain_SIR_model(b,g,n,n,i0,z,i)
sir_f = theano.function(inputs=[b,g,n,i0,i],outputs=[S_t,I_t,R_t,dS])

a,b,c,offset = tt.scalar(dtype="float64"),tt.scalar(dtype="float64"),tt.scalar(dtype="float64"),tt.scalar(dtype="float64")
i = tt.scalar(dtype="int64")
est = Gompertz(a,b,c,l=i,o=offset)
gomp_f = theano.function(inputs=[a,b,c,i,offset],outputs=[est])


In [3]:
with open("estimates/sir_symptomatic_1598942335","rb") as f:
    sir_est = pickle.load(f)
    
with open("estimates/gomp_symptomatic_1598946460","rb") as f:
    gomp_est = pickle.load(f)

In [4]:
# Transpose the Dataset for the SIR-Fit

t0 = time.time()
estimation_series = {}
for k,v in sir_est.items():
    print(k)
    for p in ['beta','gamma','N','I0']:
        if p not in estimation_series.keys():
            estimation_series[p] = {"mu":[],2.5:[],16:[],50:[],84:[],97.5:[]} 
        es = estimation_series[p]
        es["mu"].append(np.mean(v[p]))
        q = np.percentile(v[p],q=(2.5,16,50,84,97.5))
        es[2.5].append(q[0])
        es[16].append(q[1])
        es[50].append(q[2])
        es[84].append(q[3])
        es[97.5].append(q[4])
    
    sat,Rsq,Repro= [],[],[]
    for i,beta,gamma,N,I0 in zip(range(len(v["beta"])),v["beta"],v["gamma"],v["N"],v["I0"]):
        S_t,I_t,R_t,dS = sir_f(beta,gamma,N,I0,300)
        est = -np.cumsum(dS,axis=0)
        sat.append( est[-1] )
        Repro.append(beta/gamma)
        
        Rsq.append(Rsquare(data[:k],est[:k]))
#    estimation_series["SdS"].append(np.array(sest))
    
    sir_est
    
    for k2,v2 in zip(["saturation","Rsq","Repro0"],[sat,Rsq,Repro]):
        if k2 not in estimation_series.keys():
            estimation_series[k2] = {"mu":[],2.5:[],16:[],50:[],84:[],97.5:[]} 
        es = estimation_series[k2]
        sata = np.array(v2)
        q = np.percentile(sata,q=(2.5,16,50,84,97.5))
        es["mu"].append(np.mean(sata))
        es[2.5].append(q[0])
        es[16].append(q[1])
        es[50].append(q[2])
        es[84].append(q[3])
        es[97.5].append(q[4])
        
    max_rsq,max_rsq_i,max_set = 0,0,{}
    for i,rsq in enumerate(Rsq):
        if rsq > max_rsq:
            max_rsq = rsq
            max_rsq_i = i
    if "best_index" not in estimation_series.keys():
        estimation_series["best_index"] = []
    estimation_series["best_index"].append(max_rsq_i)
    
#    print(k,v["ssd"],np.mean(v["N"]),np.mean(sata))
t1 = time.time()
print("Recalculate SIR for saturation in %.2fs"%(t1-t0))
sir_est_series = estimation_series

60


NameError: name 'Rsquare' is not defined

In [None]:
print(sir_est_series["beta"][50])
l = len(sir_est_series["beta"][50])
r = range(l)
gamma = sir_est_series["gamma"]
plt.plot(r,gamma[50],color="tab:red")
plt.fill_between(r,gamma[2.5],gamma[97.5],alpha=.1,color="tab:red")

beta = sir_est_series["beta"]
plt.plot(r,beta[50],color="tab:blue")
plt.fill_between(r,beta[2.5],beta[97.5],alpha=.1,color="tab:blue")

In [None]:
# Transpose the Dataset for the Gompertz-Fit

estimation_series = {}
for k,v in gomp_est.items():
    print(k)
    for p in ['a','b','c']:
        if p not in estimation_series.keys():
            estimation_series[p] = {"mu":[],2.5:[],16:[],50:[],84:[],97.5:[]} 
        es = estimation_series[p]
        es["mu"].append(np.mean(v[p]))
        q = np.percentile(v[p],q=(2.5,16,50,84,97.5))
        es[2.5].append(q[0])
        es[16].append(q[1])
        es[50].append(q[2])
        es[84].append(q[3])
        es[97.5].append(q[4])
        
    Rsq= []
    for i,a,b,c in zip(range(len(v["a"])),v["a"],v["b"],v["c"]):
        est = gomp_f(a,b,c,300,0)[0]
        
        Rsq.append(Rsquare(data[:k],est[:k]))
#    estimation_series["SdS"].append(np.array(sest))
    
    max_rsq,max_rsq_i = 0,0
    for i,rsq in enumerate(Rsq):
        if rsq > max_rsq:
            max_rsq = rsq
            max_rsq_i = i
    if "best_index" not in estimation_series.keys():
        estimation_series["best_index"] = []
    estimation_series["best_index"].append(max_rsq_i)
    
    for k2,v2 in zip(["Rsq"],[Rsq]):
        if k2 not in estimation_series.keys():
            estimation_series[k2] = {"mu":[],2.5:[],16:[],50:[],84:[],97.5:[]} 
        es = estimation_series[k2]
        sata = np.array(v2)
        q = np.percentile(sata,q=(2.5,16,50,84,97.5))
        es["mu"].append(np.mean(sata))
        es[2.5].append(q[0])
        es[16].append(q[1])
        es[50].append(q[2])
        es[84].append(q[3])
        es[97.5].append(q[4])        
        
#    print(k,v["ssd"],np.mean(v["a"]))

gomp_est_series = estimation_series

In [None]:
print(gomp_est_series["a"][50])
l = len(gomp_est_series["a"][50])
r = range(l)
a = gomp_est_series["a"]
plt.plot(r,a[50],color="tab:red")
plt.fill_between(r,a[2.5],a[97.5],alpha=.1,color="tab:red")

#beta = gomp_est_series["beta"]
#plt.plot(r,beta[50],color="tab:blue")
#plt.fill_between(r,beta[2.5],beta[97.5],alpha=.1,color="tab:blue")

In [None]:
#print(sir_est)
#sir_est_series["best_index"]
print(sir_est.keys())

def FindBestSIR(estimates,series):
    best = {}
    t = sorted(estimates.keys())
    p = ["beta","gamma","N","I0"]
    for k in p:
        best[k] = np.array([estimates[j][k][i] for j,i in zip(t,series["best_index"])])
    
    
    sat,Rsq,Repro= [],[],[]
    for beta,gamma,N,I0,k in zip(best["beta"],best["gamma"],best["N"],best["I0"],t):
        S_t,I_t,R_t,dS = sir_f(beta,gamma,N,I0,300)
        est = -np.cumsum(dS,axis=0)
        sat.append( est[-1] )
        Repro.append(beta/gamma)
        Rsq.append(Rsquare(data[:k],est[:k]))
        
    best["saturation"] = np.array(sat)
    best["Rsq"] = np.array(Rsq)
    best["Repro0"] = np.array(Repro)
    return best

def FindBestGompertz(estimates,series):
    best = {}
    t = sorted(estimates.keys())
    p = ["a","b","c"]
    for k in p:
        best[k] = np.array([estimates[j][k][i] for j,i in zip(t,series["best_index"])])
    
    Rsq = []
    for a,b,c,k in zip(best["a"],best["b"],best["c"],t):
        est = gomp_f(a,b,c,300,0)[0]
        Rsq.append(Rsquare(data[:k],est[:k]))
    best["Rsq"] = np.array(Rsq)
    return best

best_sir = FindBestSIR(sir_est,sir_est_series)
best_gomp = FindBestGompertz(gomp_est,gomp_est_series)

print(best_sir)
print(best_gomp)

In [None]:
fig = plt.figure(figsize=(7,9))
fig.subplots_adjust(wspace=0.33,hspace=0.27,left=0.125,right=0.96,top=0.95,bottom=0.04)

ts = pd.date_range(datetime.date(2020,2,1),periods=120,freq='D')
pts = pd.date_range(datetime.date(2020,2,1)+datetime.timedelta(days=30),periods=31,freq='D')

gs = fig.add_gridspec(5, 3)
ax = fig.add_subplot(gs[0:2,0:3])

b_alpha = .5

for e,b,pn,k,c in zip([sir_est_series,gomp_est_series],[best_sir,best_gomp],["saturation","a"],["SIR","Gompertz"],["tab:blue","tab:red"]):
    ax.plot(pts,e[pn][50],color=c,label="%s '%s'"%(k,pn))
    ax.fill_between(pts,e[pn][16],e[pn][84],alpha=.1,color=c)
    ax.fill_between(pts,e[pn][2.5],e[pn][97.5],alpha=.1,color=c)

    ax.plot(pts,b[pn],"+",color=c)

ax.set_ylabel("estimated final size",fontsize=13)
ax.legend()

rx = fig.add_subplot(gs[2:3,0:3])
for d,b,k,c in zip([sir_est_series,gomp_est_series],[best_sir,best_gomp],["SIR","Gompertz"],["blue","red"]):
    e = d["Rsq"]
    rx.semilogy(pts,1-np.array(e[50]),color=c,label="%s"%(k))
    rx.fill_between(pts,1-np.array(e[84]),1-np.array(e[16]),color=c,alpha=.1)
    rx.fill_between(pts,1-np.array(e[2.5]),1-np.array(e[97.5]),color=c,alpha=.1)
    rx.semilogy(pts,1-b["Rsq"],"+",color=c,alpha=b_alpha)
    
rx.legend()
rx.text(0.4,.72,"lower ist better",transform=rx.transAxes)
rx.set_ylabel(r"$1-R^2$")

bx = fig.add_subplot(gs[3,1])
for k,c in zip(["gamma","beta"],["tab:green","purple"]):
    bx.plot(pts,sir_est_series[k][50],label=r"$\%s$"%k,color=c)
    bx.fill_between(pts,sir_est_series[k][16],sir_est_series[k][84],alpha=.1,color=c)
    bx.fill_between(pts,sir_est_series[k][2.5],sir_est_series[k][97.5],alpha=.1,color=c)
    bx.plot(pts,best_sir[k],"+",color=c,alpha=b_alpha)
bx.legend()

gx = fig.add_subplot(gs[3,2])
for k,c in zip(["Repro0"],["orange"]):
    gx.plot(pts,sir_est_series[k][50],label=k,color=c)
    gx.fill_between(pts,sir_est_series[k][16],sir_est_series[k][84],alpha=.1,color=c)
    gx.fill_between(pts,sir_est_series[k][2.5],sir_est_series[k][97.5],alpha=.1,color=c)
    gx.plot(pts,best_sir["Repro0"],"+",color=c,alpha=b_alpha)

gx.legend()
gx.set_ylim([0,10])
gx.text(.4,.5,r"$\frac{\beta}{\gamma}$",transform=gx.transAxes,fontsize=16)

cx = fig.add_subplot(gs[4,2])
for k,c in zip(["N"],["blue"]):
    cx.plot(pts,sir_est_series[k][50],label=k,color=c)
    cx.fill_between(pts,sir_est_series[k][16],sir_est_series[k][84],alpha=.1,color=c)
    cx.fill_between(pts,sir_est_series[k][2.5],sir_est_series[k][97.5],alpha=.1,color=c)
    cx.plot(pts,best_sir["N"],"+",color=c,alpha=b_alpha)

cx.legend()

dx = fig.add_subplot(gs[4,1])
for k,c in zip(["I0"],["red"]):
    dx.plot(pts,sir_est_series[k][50],label=r"$I_0$",color=c)
    dx.fill_between(pts,sir_est_series[k][16],sir_est_series[k][84],alpha=.1,color=c)
    dx.fill_between(pts,sir_est_series[k][2.5],sir_est_series[k][97.5],alpha=.1,color=c)
    dx.plot(pts,best_sir["I0"],"+",color=c,alpha=b_alpha)

dx.legend(loc="upper left")
dx.set_ylim([0,30])


ex = fig.add_subplot(gs[3,0])
fx = fig.add_subplot(gs[4,0])

for tx,k,c in zip([ex,fx],["b","c"],["green","red"]):
    tx.plot(pts,gomp_est_series[k][50],label=k,color=c)
    tx.fill_between(pts,gomp_est_series[k][16],gomp_est_series[k][84],alpha=.1,color=c)
    tx.fill_between(pts,gomp_est_series[k][2.5],gomp_est_series[k][97.5],alpha=.1,color=c)
    tx.plot(pts,best_gomp[k],"+",color=c,alpha=b_alpha)

    tx.legend(loc="lower right")

ex.text(.12,.92,"Gompertz",transform=ex.transAxes,fontsize=16)
bx.text(.12,.92,"SIR-Modell",transform=bx.transAxes,fontsize=16)
    
for tx in [ax,bx,cx,dx,ex,fx,rx,gx]:
    prec = 1.0 / (np.log10(tx.get_ylim()[1]) - 2.5)
    if prec < 2.0 and prec >= 0:
        tx.yaxis.set_major_formatter(
            matplotlib.ticker.FuncFormatter(_format_k(int(prec)))
        )
    tx.set_xticks([datetime.date(2020,x,y) for x,y in [(3,1),(3,15),(4,1)]])
    tx.set_xticklabels(["",datetime.date(2020,3,15),""])

#    tx.set_xlim([ts[15],ts[l+1]])

    tx.spines['top'].set_visible(False)
    tx.spines['right'].set_visible(False)

ax.set_xticks([datetime.date(2020,x,y) for x,y in [(3,1),(3,15),(4,1)]])
rx.set_xticks([datetime.date(2020,x,y) for x,y in [(3,1),(3,15),(4,1)]])
ax.set_xticklabels([datetime.date(2020,x,y) for x,y in [(3,1),(3,15),(4,1)]])
rx.set_xticklabels([datetime.date(2020,x,y) for x,y in [(3,1),(3,15),(4,1)]])

ax.set_xticks(pts,minor=True)
rx.set_xticks(pts,minor=True)
    
ax.set_title("Evolution of Parameter estimates for symptomatic cases",fontsize=14)
ax.text(pts[1],220000,"Post ex analysis based on dataset from 23.07.2020",fontsize=13,color="red")
ax.text(pts[10],180000,"Estimates on data starting 01.02.2020\nup to the date shown on the x-axis",fontsize=12)
#ax.text(pts[5],-20000,"SIR\nParameter-estimates\nright column",fontsize=10)
#ax.text(pts[0],-80000,"Gompertz\nParameter-estimates\nleft column",fontsize=10)
ax.text(pts[5],78000,r"'+' indicate samples with highest $R^2$ value (best least-square-fit)",fontsize=10)


fig.savefig("SIR-Gomp-Comp.png",dpi=300)

In [None]:
# Data from
#https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv
#Germany,51.165690999999995,10.451526,
jhu_deaths = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,3,3,7,9,11,17,24,28,44,67,84,94,123,157,206,267,342,433,533,645,775,920,1107,1275,1444,1584,1810,2016,2349,2607,2767,2736,3022,3194,3294,3804,4052,4352,4459,4586,4862,5033,5279,5575,5760,5877,5976,6126,6314,6467,6623,6736,6812,6866,6993,6993,7275,7392,7510,7549,7569,7661,7738,7861,7884,7897,7938,7962,8003,8081,8144,8203,8228,8261,8283,8309,8372,8428,8470,8504,8530,8540,8555,8563,8602,8635,8658,8673,8685,8695,8736,8752,8772,8783,8793,8801,8807,8820,8851,8875,8887,8895,8895,8899,8914,8928,8940,8965,8968,8968,8976,8990,8995,9006,9010,9020,9023,9022,9032,9046,9057,9063,9070,9071,9074,9078,9080,9087,9088,9091,9092,9094,9099,9102,9110,9120,9124,9124,9125,9131,9135,9144,9147,9154,9154,9154,9163,9179,9181,9195,9201,9202,9203,9208,9213,9217,9230,9235,9235,9236,9241,9249,9263,9266,9272,9275,9276,9281,9285,9290,9290,9299]
jhu_ts = pd.date_range(datetime.date(2020,1,21),periods=len(jhu_deaths),freq='D')

# Limit to non-zero values, dropping the front of the list
first_non_zero_index = 47
jhu_deaths1 = jhu_deaths[first_non_zero_index:]
jhu_ts1 = jhu_ts[first_non_zero_index:]
print("first non zero deaths timestamp",jhu_ts1[0])

# Alternative Deaths from Situation-report graph of actual dates of deaths
rki_deaths_per_day = [1, 1, 1, 1, 5, 2, 3, 4, 11, 13, 12, 21, 40, 24, 42, 42, 67, 76, 92, 104, 107, 114, 172, 165, 162, 204, 199, 223, 240, 240, 254, 260, 257, 240, 246, 239, 243, 245, 227, 239, 241, 191, 203, 209, 172, 194, 157, 163, 129, 141, 132, 125, 128, 121, 106, 76, 92, 78, 93, 82, 69, 54, 63, 64, 63, 42, 54, 66, 47, 29, 40, 44, 45, 47, 31, 33, 27, 30, 31, 23, 21, 18, 19, 17, 18, 24, 15, 18, 18, 14, 8, 9, 8, 6, 11, 11, 14, 9, 6, 9, 7, 7, 9, 6, 3, 4, 4, 6, 7, 7, 7, 4, 7, 8, 1]
rki_deaths = np.cumsum(rki_deaths_per_day)
rki_ts = pd.date_range(datetime.date(2020,3,8),periods=len(rki_deaths),freq='D')

sources = {"jhu_deaths":(jhu_deaths,jhu_ts),"jhu1_deaths":(jhu_deaths1,jhu_ts1),"rki_deaths":(rki_deaths,rki_ts)}
# Include Symptomatic from arcgis-Dump
rki_df = pd.read_csv("data/RKI_COVID19_200723.csv", sep=",") 

def rdate(s):
    r = None
    if "T" in s:
        r = datetime.datetime.strptime(s.split("T")[0],"%Y-%m-%d")
    if "-" in s:
        r = datetime.datetime.strptime(s.split(" ")[0],"%Y-%m-%d")
    else:
        if s[:4] == "2020":
            r = datetime.datetime.strptime(s.split(" ")[0],"%Y/%m/%d")
        else:
            r = datetime.datetime.strptime(s.split(" ")[0],"%m/%d/%Y")
    return r.date()

cols = [x for x in ["Meldedatum","Refdatum"] if x in rki_df.columns]
for col in cols:
    if col in rki_df:
        rki_df[col] = rki_df[col].apply(rdate)
        
rki_df = rki_df.filter(items=["Meldedatum","Refdatum","NeuerFall","AnzahlFall","IstErkrankungsbeginn","AnzahlTodesfall"])

ts = pd.date_range(datetime.date(2020,1,1),periods=200,freq='D')
symptomatic_timeseries = pd.date_range(datetime.date(2020,2,1),periods=150,freq='D')

symptomatic = np.zeros(len(ts))
for i,t in enumerate(ts):
    m = rki_df[rki_df["Refdatum"] == t]
    symptomatic[i] = np.sum(m[m["IstErkrankungsbeginn"] == 1]["AnzahlFall"])

cs = np.cumsum(symptomatic)
s = cs[(symptomatic_timeseries[0]-ts[0]).days:]
s = s[:len(ts)]

sources["rki_symptomatic"] = (s,symptomatic_timeseries)
data,ts = sources["rki_symptomatic"]

print("["+",".join(["%d"%x for x in data[:120]])+"]")

In [None]:
print(ts[51])
print(ts[30])
print(pts[21])
sum(data*data)

In [None]:
Dataset = {"Erster Datenpunkt":pd.DatetimeIndex([ts[0]]*len(pts)),"Letzter Datenpunkt":pts}
Dataset.update({"SIR_beta":best_sir["beta"],"SIR_gamma":best_sir["gamma"],"SIR_I0":best_sir["I0"],"SIR_N":best_sir["N"],"SIR_final":best_sir["saturation"],"SIR_R^2":best_sir["Rsq"]})
Dataset.update({"Gompertz_a":best_gomp["a"],"Gompertz_b":best_gomp["b"],"Gompertz_c":best_gomp["c"],"Gompertz_R^2":best_gomp["Rsq"]})
dset = pd.DataFrame.from_dict(Dataset)

dset.to_csv("estimates/Estimates_on_Deaths_starting_2020-03-08.csv")

In [None]:
def Rsquare(data,est):
    md = np.mean(data)
    ss_tot = np.sum(np.power(data-md,2))
    ss_res = np.sum(np.power(data-est,2))
    return 1-ss_res/ss_tot


if True:
    fig = plt.figure(figsize=(7,8))
    fig.subplots_adjust(wspace=0.33,hspace=0.27,left=0.125,right=0.93,top=0.92,bottom=0.05)
    
    ts = pd.date_range(datetime.date(2020,2,1),periods=120,freq='D')
    pts = pd.date_range(datetime.date(2020,2,1)+datetime.timedelta(days=30),periods=120,freq='D')
    print(pts[0])
    gs = fig.add_gridspec(1, 1)
    
    ax = fig.add_subplot(gs[0,0])
    
    hom = {"1.Versuch":[227646,9.92,.014,0],"2.Versuch":[106397,7.061,.107,0],"2.Ver.Offset":[106397,7.061,.107,-28]}
    for k,h in hom.items():
        h.append(gomp_f(h[0],h[1],h[2],len(ts),h[3])[0])
    
    print(sir_est_series["beta"].keys())
    
    n = 22
    nx = n+30
    print(ts[30])
    
    
    sest = sir_est_series
    gest = gomp_est_series
    
    print(sir_est.keys())
    v = sir_est[sorted(sir_est.keys())[n]]
    cs = []
    for i,beta,gamma,N,I0 in zip(range(len(v["beta"])),v["beta"],v["gamma"],v["N"],v["I0"]):
        S_t,I_t,R_t,dS = sir_f(beta,gamma,N,I0,nx)
        cs.append(-np.cumsum(dS))
    cs = np.array(cs)
    qcs = np.percentile(cs,q=(2.5,16,84,97.5),axis=0)
    
    print(gomp_est.keys())
    v = gomp_est[sorted(gomp_est.keys())[n]]
    cs = []
    for i,at,bt,ct in zip(range(len(v["a"])),v["a"],v["b"],v["c"]):
        g = gomp_f(at,bt,ct,nx,0)
        cs.append(g[0])
    cs = np.array(cs)
    print(cs.shape)
    gcs = np.percentile(cs,q=(2.5,16,84,97.5),axis=0)
    
    print(cs.shape)
    
    
    print(sest["beta"][50][n],sest["gamma"][50][n],sest["N"][50][n],sest["I0"][50][n])
    print(gest["a"][50][n],gest["b"][50][n],gest["c"][50][n])
    
    S_t,I_t,R_t,dS = sir_f(sest["beta"][50][n],sest["gamma"][50][n],sest["N"][50][n],sest["I0"][50][n],300)
    est_sir = -np.cumsum(dS)
    
    print(est_sir[-1])
    est_gomp = gomp_f(gest["a"][50][n],gest["b"][50][n],gest["c"][50][n],len(ts),0)[0]
    

    #ax.plot(sts,y,label="simple")
    ax.plot(ts[:nx],data[:nx],"+",label="data",color="black",alpha=1)
    
    hcolors = ["purple","green","orange"]
    i = 0
    for k,v in hom.items():
        a,b,c,o,d = v
    #    ax.plot(ts[:nx],d[:nx],label=r"Homburg's Gompertz (%s) $a=%.1f$k $b=%.2f$ $c=%.4f$"%(k,a/1000.,b,c),color=hcolors[i])
        i+=1
    
    prec = 1.0 / (np.log10(tx.get_ylim()[1]) - 2.5)
    ax.plot(ts[:nx],est_sir[:nx],label=r"SIR estimate $\beta=%.3f$ $\gamma=%.3f$ $N=%.1f$k $I_0=%.1f$ final$=%.1fk$"%(sest["beta"][50][n],sest["gamma"][50][n],sest["N"][50][n]/1000,sest["I0"][50][n],est_sir[-1]/1000.),color="tab:blue")
    ax.fill_between(ts[:nx],qcs[0],qcs[-1],color="tab:blue",alpha=.1)
    ax.fill_between(ts[:nx],qcs[1],qcs[-2],color="tab:blue",alpha=.1)
    
    ax.plot(ts[:nx],est_gomp[:nx],label=r"Gompertz $a=%.1f$k $b=%.2f$ $c=%.4f$"%(gest["a"][50][n]/1000.,gest["b"][50][n],gest["c"][50][n]),color="tab:red")
    ax.fill_between(ts[:nx],gcs[0],gcs[-1],color="tab:red",alpha=.1)
    ax.fill_between(ts[:nx],gcs[1],gcs[-2],color="tab:red",alpha=.1)
    ax.legend(loc="upper left")
    
    ax.set_ylabel("symptomatic cases",fontsize=16)
    ax.set_title("Symptomatic cases up to %s\nFrom Dataset 2020-07-23"%ts[nx].date(),fontsize=16)
        
    rsqsir = Rsquare(data[:nx],est_sir[:nx])
    rsqgomp = Rsquare(data[:nx],est_gomp[:nx])
    rsqhomOffset = Rsquare(data[:nx],hom["2.Ver.Offset"][-1][:nx])
    ax.text(.1,.7,r"Gompertz $R^2=%.5f$"%rsqgomp,transform=ax.transAxes,fontsize=15)
    ax.text(.1,.65,r"SIR $R^2=%.5f$"%rsqsir,transform=ax.transAxes,fontsize=15)
    
    ax.text(.1,.55,r"Homburg, Offset %d days $R^2=%.5f$"%(hom["2.Ver.Offset"][3],rsqhomOffset),transform=ax.transAxes,fontsize=15)
 
    for tx in [ax]:
        prec = 1.0 / (np.log10(tx.get_ylim()[1]) - 2.5)
        if prec < 2.0 and prec >= 0:
            tx.yaxis.set_major_formatter(
                matplotlib.ticker.FuncFormatter(_format_k(int(prec)))
            )
        tx.spines['top'].set_visible(False)
        tx.spines['right'].set_visible(False)
        
        
        lts = [datetime.date(2020,x,y) for x,y in [(2,1),(2,15),(3,1),(3,15),(3,22)]]
        if n>21:
            lts.append(datetime.date(2020,4,1))
        tx.set_xticks(lts)
        lts[-2] = ""
        tx.set_xticklabels(lts)
    fig.savefig("Est_comp2_%s.png"%ts[nx].date(),dpi=300)

    

In [None]:
r = [0.15]
f = 1.15
for j in range(10):
    r = [r[0]/f]+r+[r[-1]*f]


for alpha in r:
    for beta in r:
        print(alpha,beta)
    
