In [None]:
#@title Install arviz
# !pip3 install arviz

In [None]:
import arviz as az
import pystan
import os
# os.environ['STAN_NUM_THREADS'] = "4"
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
%matplotlib inline

## Select model

In [None]:
import MBS_epidemic_concentration_models as models
model = models.model2()
model.plotnetwork()

## Compile

In [None]:
stanrunmodel = pystan.StanModel(model_code=model.stan)

# Load data from JHU



In [None]:
url_confirmed = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
url_deaths = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv"
url_recovered = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_recovered_global.csv"

dfc = pd.read_csv(url_confirmed)
dfd = pd.read_csv(url_deaths)
dfr = pd.read_csv(url_recovered)




## Make JHU ROI DF

### Enter country 

In [None]:
#Austria,Belgium,Denmark,France,Germany,Italy,Norway,Spain,Sweden,Switzerland,United Kingdom
roi = "Italy"

In [None]:
dfc2 = dfc.loc[(dfc['Country/Region']==roi)&(pd.isnull(dfc['Province/State']))]
dfd2 = dfd.loc[(dfd['Country/Region']==roi)&(pd.isnull(dfd['Province/State']))]
dfr2 = dfr.loc[(dfr['Country/Region']==roi)&(pd.isnull(dfr['Province/State']))]


DF = df = pd.DataFrame(columns=['date','cases','recovered','deaths'])

dates = dfc.columns[4:].values

for i in range(len(dates)):
  DF.loc[i] = pd.Series({'date':dates[i],
                         'cases':dfc2[dates[i]].values[0] - (dfr2[dates[i]].values[0] + dfd2[dates[i]].values[0]),
                         'recovered':dfr2[dates[i]].values[0],
                         'deaths':dfd2[dates[i]].values[0]})


pop = {}
pop['Italy'] = 60500000
pop['United Kingdom'] = 6440000

t0 = np.where(DF["cases"].values>=10)[0][0] - 5# estimated day of first exposure? Need to make this a parameter
model.stan_data['t0'] = t0-1
print("t0 assumed to be: day "+str(t0))
plt.plot(DF["cases"],'bo', label="cases")
plt.plot(DF["recovered"],'go',label="recovered")
plt.plot(DF["deaths"],'ks',label="deaths")

plt.axvline(model.stan_data['t0'],color='k', linestyle="dashed")
plt.legend()


## Format JHU ROI data for Stan

In [None]:
#truncate time series from t0 on (initial is t0-1)
model.stan_data['ts'] = np.arange(t0,len(dates))  
model.stan_data['y'] = (DF[['cases','recovered','deaths']].to_numpy()).astype(int)[t0:,:]
model.stan_data['n_obs'] = len(dates) - t0

### Enter population manually

In [None]:
model.stan_data['n_pop'] = pop[roi] 
model.stan_data['n_scale'] = 50000


### Print data for Stan 

In [None]:
print(model.stan_data)

# Load England School 1978 Influenza data 

In [None]:
# #England 1978 influenza
# cases = [0,8,26,76,225,298,258,233,189,128,150,85,14,4]
# recovered = [0,0,0,0,9,17,105,162,176,166,150,85,47,20]
# plt.plot(cases,'bo', label="cases")
# plt.plot(recovered,'go',label="recovered")
# pop = 763
# model.stan_data['t0'] = 0
# #truncate time series from t0 on (initial is t0-1)
# model.stan_data['n_pop'] = pop 
# model.stan_data['ts'] = np.arange(1,len(cases)+1)  
# Y = np.hstack([np.c_[cases],np.c_[recovered],np.zeros((len(cases),1))]).astype(int)
# model.stan_data['y'] = Y
# model.stan_data['n_obs'] = len(cases)

# plt.plot(cases,'bo', label="cases")
# plt.plot(recovered,'go',label="recovered")

# plt.legend()

# Run Stan 

## Initialize parameters

In [None]:
# Feed in some feasible initial values to start from

# init_par = [{'theta':[0.25,0.01,0.01,0.05,.02],'S0':0.5}] 

if model.stan_data['n_theta'] == 2:
    def init_fun():
        x = {'theta':[np.random.uniform(0,5),np.random.uniform(0.2,0.4)],
              'S0':np.random.uniform()}
        return x
else:
    def init_fun():
        x = {'theta':[5*np.random.uniform()]+
             [0.01*np.random.uniform()]+
             [0.01*np.random.uniform()]+
             [0.1*np.random.uniform()]+
             [0.1*np.random.uniform()],
#              [model.stan_data['n_scale']],
              'S0':1*np.random.uniform()}
#         x = {'theta':[np.random.uniform(), np.random.uniform(),
#                       np.random.uniform(),np.random.uniform(),np.random.uniform()],
#               'S0':1*np.random.uniform()}
        return x

## Fit Stan 

In [None]:
model.stan_data['max_num_steps'] = 10000000

n_chains=1
n_warmups=500
n_iter=2000
n_thin=10

control = {'adapt_delta':0.9}
fit = stanrunmodel.sampling(data = model.stan_data,init = init_fun,control=control, chains = n_chains, warmup = n_warmups, iter = n_iter, thin=n_thin, seed=13219)



In [None]:
print(fit)

In [None]:
#https://arviz-devs.github.io/arviz/generated/arviz.plot_density
az.plot_density(fit,group='posterior',var_names=["theta","R_0"])

In [None]:
# dir(fit)
# stan_data = az.from_pystan(
#     posterior=fit,
#     posterior_predictive='y_hat',
#     observed_data="y[1,1]")

# print(stan_data)
# # az.plot_ppc(stan_data)

print(np.shape(fit.extract()['u']))
# plt.plot(np.mean(fit.extract()['y_hat'][:,:,1],0),'r', label=r'$\hat{I}$')
# plt.plot(np.mean(fit.extract()['y_hat'][:,:,2],0),'b', label=r'$\hat{R}$')
# plt.legend()
# plt.ylabel('fraction')
# plt.xlabel('days')
# # plt.plot(model.stan_data['y'][:,0],'bo', label="cases")
# # plt.plot(model.stan_data['y'][:,1],'go',label="recovered")

# plt.plot(fit.extract()['u'][-1,:,4],'r', label=r'$\hat{I}$')
# plt.plot(fit.extract()['u'][-1,:,2],'b', label=r'$\hat{R}$')
# plt.legend()
# plt.ylabel('fraction')
# plt.xlabel('days')
# plt.plot(model.stan_data['y'][:,0],'bo', label="cases")
# plt.plot(model.stan_data['y'][:,1],'go',label="recovered")
# plt.subplot(1,2,1)
labels = ['C','D','R','I','S','Z']
for i in range(5):
    plt.plot(model.stan_data['n_scale']*fit.extract()['u'][-1,:,i],label=labels[i])
plt.plot(model.stan_data['n_scale']*(1-fit.extract()['u'][-1,:,4]),label=labels[-1])
plt.legend()
plt.ylim((0,35000))

# plt.subplot(1,2,2)
# tot = DF["cases"][-1:] + DF["recovered"][-1:] + DF["deaths"][-1:]
plt.plot(DF["cases"][t0:],'bo', label="cases")
plt.plot(DF["recovered"][t0:],'go',label="recovered")
plt.plot(DF["deaths"][t0:],'ks',label="deaths")

# plt.axvline(model.stan_data['t0'],color='k', linestyle="dashed")
plt.legend()

