In [1]:
from utils.helpers import *

import ipywidgets as widgets
from ipywidgets import interact, interact_manual, fixed

import warnings
warnings.filterwarnings('ignore')

# Bayesian Modelling of Covid-19
---

Here we consider two simple models for the growth of the number of cases in various countries: exponential vs sigmoidal. 

### Exponential

An exponential model grows like:

$\alpha e^{\beta  t}$ 

and is unbounded. As t, time, increases, the number of cases grows even faster. This is a crude approximation to what happens during an outbreak, but is fairly accurate early on, especially if no meaures are implemented to slow down the contagion. This exponential turn on of a virus is what we are hoping to slow down by social distancing, hygiene measures, and other efforts. 

### Sigmoid

A sigmoid, or plateauing model, looks exponential initially before leveling off. The model looks roughly like:

$\frac{1}{1+ e^{-\beta t}}$

This is a slightly more accurate depiction of what may happen during an outbreak: there are only so many people who can become infected before eventually you reach (in theory) the entire population. 


### Comparison

The exponential curve continues to increase with no slowdown, while a sigmoidal curve will eventually plateau. By comparing these two simple models we can see if countries are mitigating the effect of Covid-19 within their borders. 


To look at different models, simply:

1. Choose a country 
2. Choose a number of days to forecast (default of 7)
3. Click `Run Interact`


Note that for some models with very few data points, the uncertainties on the curves are going to be massive! We need more time (and data) for these curves to settle in. 

In [2]:
country_list = ['Italy', 'Germany', 'China', 'Spain', 'Canada', 'US',
                    'Canada British Columbia', 'Canada Alberta', 'Canada Ontario',
                    'China Hong Kong', 'Korea South', 'Singapore']
country_list.sort()

country_widget= widgets.Dropdown(
    options=country_list,
    value='Italy',
    description='Country:',
    disabled=False,
)

num_days_widget = widgets.IntText(
    value=7,
    step=1,
    description='Number of Future Days:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

ymax_widget = widgets.IntText(
    value=-1,
    description='Y Max:',
    disabled=False
)

out = interact_manual(plot_country, country=country_widget, num_days=num_days_widget, ymax = ymax_widget)
#plot_country('Italy')

interactive(children=(Dropdown(description='Country:', index=7, options=('Canada', 'Canada Alberta', 'Canada B…

## SIR Model
---

Although a sigmoidal model will do a relatively good job of approximating the number of cases during an outbreak, we can still do better. One thing the sigmoid model does not take into account is the fact that people can recover, becoming resistant to the virus. As more people become resistant, it becomes more unlikely for a sick (infected) person to interact with and successfully infect a healthy person.

To take into account (S)usceptible, (I)nfected, and (R)esistant members of the population, we can look at a SIR model. 

The basic setup of a SIR model is as follows:

$ \frac{dS}{dt} = -\beta I(t)S(t) $

$ \frac{dI}{dt} = +\beta I(t)S(t) - \gamma I(t) $

$ \frac{dR}{dt} = \gamma I(t) $

$ S + I + R = 1$

Where we have normalized to the population of interest.

Essentially: infected people will recover over time with probability $\gamma$, while they will infect susceptible people over time with probability $\beta$. A parameter of interest is the combined ratio:

$ R_0 = \frac{\beta}{\gamma} $ 

This tells us roughly how many new people may get infected by a single person. $R_0 > 1$ implies an epidemic is taking place.

### Flatten the Curve!!!

So what can we do about all this? We can flatten that curve! We can't really change how likely we are to recover if we get sick ($\gamma$). But if society as a whole can reduce the probability of getting sick in the first place ($\beta$), then we can reduce the overall stress on hospitals, etc by reducing the total number of new cases on any given day. Too many all at once will overload the system, leading to worse care for those who need it.  

We can do this mathematically by adding a new parameter, $\delta$, that allows us to have $\beta$ decrease over time. For example, if everyone starts to practice social distancing, $\delta$ will increase. If those who can work from home do, $\delta$ will increase. If kids stay home from school, $\delta$ will increase. You get the idea!

Let's check this out ... increase $\delta$ and watch that curve flatten!

In [3]:


delta_widget = widgets.FloatSlider(
    value=0.0,
    min=0,
    max=0.02,
    step=0.005,
    description='Delta: ',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=False,
    readout_format='.1f',
)

out = interact(sir_plot_static, R0=fixed(10), gamma=fixed(1/30), delta=delta_widget)

interactive(children=(FloatSlider(value=0.0, description='Delta: ', max=0.02, readout=False, readout_format='.…

### SIR In The Real World
---

So what does this look like in the real world? We can use the same Bayesian techniques to estimate best fits for model parameters.


In [4]:
country_sir_list = [
    'Italy', 
    #'Germany', 
    #'China', 
    #'Spain', 
    #'Canada', 
    #'US',
    'Canada British Columbia', 
    #'Canada Alberta', 
    #'Canada Ontario',         
    #'China Hong Kong', 
    #'Korea South', 
    #'Singapore'
]
country_sir_list.sort()

country_sir_widget= widgets.Dropdown(
    options=country_sir_list,
    value='Canada British Columbia',
    description='Country:',
    disabled=False,
)

out = interact(sir_bayes_plot_v2, country=country_sir_widget, num_days=fixed(500), delta_case = True)

interactive(children=(Dropdown(description='Country:', options=('Canada British Columbia', 'Italy'), value='Ca…

In [35]:
def sir_bayes_plot_v2(country, num_days, delta_case=False):
    # country = 'Canada British Columbia'
    # dates, x, sus, inf = get_country_sir(country, min_cases=1)

    if delta_case:
        second_dir = 'sir_delta'
    else:
        second_dir = 'sir'

    tr_path = os.path.join('traces', country.lower().replace(' ', '_'))

    dates = joblib.load(os.path.join(tr_path, 'dates_{:s}.pkl'.format(second_dir)))
    x = joblib.load(os.path.join(tr_path, 'x_{:s}.pkl'.format(second_dir)))
    sus = joblib.load(os.path.join(tr_path, 'sus_{:s}.pkl'.format(second_dir)))
    inf = joblib.load(os.path.join(tr_path, 'inf_{:s}.pkl'.format(second_dir)))

    # sus and inf are already normalized
    # just normalize x
    x_train = x[:-1]
    x_test = x[-1:]

    sus_train = sus[:-1]
    sus_test = sus[-1:]

    inf_train = inf[:-1]
    inf_test = inf[-1:]

    # make single array
    y_train = np.hstack((sus_train.reshape(-1, 1), inf_train.reshape(-1, 1)))
    y_test = np.hstack((sus_test.reshape(-1, 1), inf_test.reshape(-1, 1)))

    y0 = [y_train[0][0], y_train[0][1]]

    last = len(x)
    extend = np.arange(last, last + num_days)
    x_updated = np.append(x, extend)
    y_updated = np.empty((x_updated.shape[0], y_train.shape[1]))

    # sir = sir_model(x_updated, y_updated, y0)
    # posterior_predictive, trace = predict_model_from_file(sir, os.path.join(tr_path, 'sir'), 1000)

    posterior_predictive = joblib.load(os.path.join(tr_path, second_dir+'_y_predict.pkl'))
    all_y = posterior_predictive['y_obs'][:, :len(x_updated), :]

    y0_array = all_y[:, :, 0]
    y1_array = all_y[:, :, 1]
    y2_array = 1 - y0_array - y1_array
    total_cases = y1_array + y2_array
    new_cases = np.gradient(total_cases, axis=1)

    y0_mean = np.nanmean(y0_array, axis=0)
    y0_std = 1 * np.nanstd(y0_array, axis=0)

    y1_mean = np.nanmean(y1_array, axis=0)
    y1_std = 1 * np.nanstd(y1_array, axis=0)

    y2_mean = np.nanmean(y2_array, axis=0)
    y2_std = 1 * np.nanstd(y2_array, axis=0)

    total_cases_mean = np.nanmean(total_cases, axis=0)
    total_cases_std = 1 * np.nanstd(total_cases, axis=0)

    new_cases_mean = np.nanmean(new_cases, axis=0)
    new_cases_std = 1 * np.nanstd(new_cases, axis=0)

    # SIR Curves
    fig, ax = plt.subplots(figsize=(16, 6))
    # plt.sca(ax[0])
    plt.fill_between(x_updated, y0_mean + y0_std, y0_mean - y0_std, alpha=0.5, color='g')
    plt.plot(x_train, sus_train, c='g', label='suseptible')
    plt.scatter(x_test, sus_test, color='g')
    plt.plot(x_updated, y0_mean, '--g', alpha=0.7)
    #plt.show()

    plt.fill_between(x_updated, y1_mean + y1_std, y1_mean - y1_std, alpha=0.5, color='r')
    plt.plot(x_train, inf_train, c='r', label='infected')
    plt.scatter(x_test, inf_test, color='r')
    plt.plot(x_updated, y1_mean, '--r', alpha=0.7)
    #plt.show()

    plt.fill_between(x_updated, y2_mean + y2_std, y2_mean - y2_std, alpha=0.5, color='b')
    plt.plot(x_train, 1 - sus_train - inf_train, c='b', label='resistant')
    plt.scatter(x_test, 1 - sus_test - inf_test, color='b')
    plt.plot(x_updated, y2_mean, '--b', alpha=0.7)

    plt.xlabel('Days')
    plt.ylim([0, 1.01])
    plt.xlim([x_updated[0], x_updated[-1]])
    plt.title('Infection Rates')
    plt.legend()
    #plt.yscale('log')
    plt.show()

    fig, ax = plt.subplots(1, 2, figsize=(16, 6))
    # Total Cases

    plt.sca(ax[0])
    plt.plot(x_updated, total_cases_mean, '--')
    plt.fill_between(x_updated, total_cases_mean + total_cases_std, total_cases_mean - total_cases_std, alpha=0.5)
    plt.xlabel('Days')
    plt.title('Total Cases')
    plt.ylim([0, 1.01])
    plt.xlim([x_updated[0], x_updated[-1]])

    # New Cases
    plt.sca(ax[1])
    plt.plot(x_updated, new_cases_mean, '--')
    plt.fill_between(x_updated, new_cases_mean + new_cases_std, new_cases_mean - new_cases_std, alpha=0.5)
    plt.xlabel('Days')
    plt.title('Number of New DAILY Cases')
    plt.ylim([0, 1.01*np.max(new_cases_mean+new_cases_std)])
    plt.xlim([x_updated[0], x_updated[-1]])
    plt.show()

    # Parameters
    trace = joblib.load(os.path.join(tr_path, second_dir+'_params.pkl'))
    if not delta_case:
        vars_list = ['R0', 'lambda', 'beta']
    else:
        vars_list = ['R0', 'lambda', 'beta', 'delta']
    fig, ax = plt.subplots(1, len(vars_list), figsize=(16, 6))
    for idx, var in enumerate(vars_list):
        plt.sca(ax[idx])
        sns.kdeplot(trace[var], shade=True)
        plt.title('{:s} = {:0.3f}'.format(var, np.mean(trace[var])))
    plt.show()