In [None]:
# run initial imports
from __future__ import print_function
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

from fbprophet import Prophet
from fbprophet.plot import plot_plotly

import matplotlib.pyplot as plt
import plotly.offline as py
import plotly.express as px
import plotly.graph_objs as go

from scipy.integrate import odeint
from scipy.optimize import curve_fit

from sklearn.metrics import r2_score
from sklearn.linear_model import Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

from ipywidgets import interact, interactive, fixed, interact_manual
from ipywidgets import widgets

from datetime import datetime
import os

In [None]:
NOW = datetime.now().strftime("%Y/%m/%d %H:%M")
NOW_FILE = datetime.now().strftime("%Y_%m_%d__%H%M")
pDay=3 # number of corecast days - should not exceed 5 unless you are very careful in tuning, and understand how to interpert the results
input_dir = './data/'
covid_19_national_observations = pd.read_csv(input_dir + 'global/covid_19_national_observations.csv')
covid_19_infected_observations = pd.read_csv(input_dir + 'global/covid_19_infected_observations.csv')

In [None]:
countries = covid_19_national_observations['Country/Region'].unique()
countries.sort()
#print('Available Countries:', countries)

In [None]:
def cleanStr(s):
    c = s.lower().replace(' ', '_')
    c = c.replace('*', '').replace('(', '_').replace(')', '_')
    c = c.strip()
    return c

def getDf(country):
    c = cleanStr(country)
    #setup initial data frames
    input_dir = './data/countries/' + c + '/'
    cases = pd.read_csv(input_dir + 'covid_19_'+ c +'_cases.csv')
    totals = pd.read_csv(input_dir + './covid_19_' + c + '_totals.csv')
    
    #correct date parsing on some of the JH data
    cases['Date'] = pd.to_datetime(cases['Date'])
    totals['Date'] = pd.to_datetime(totals['Date'])
    
    cases = cases.sort_values('Date')
    totals = totals.sort_values('Date')
    
    cases = cases[cases['Confirmed'] > 0]
    return cases, totals

def analyze_country(country):
    cases, totals = getDf(country)
    
    likelyTotal = int(totals['Likely Cases C86'].max())
    confirmedActive = int(totals['Active Cases'].iloc[-1])
    recovered = int(totals['Recovered'].iloc[-1])
    dead = int(totals['Death'].iloc[-1])
    #print(likelyTotal, confirmedActive)
    
    # totals overall
    totals_state = totals[['Date', 'Active Cases', 'Recovered', 'Death']].melt(id_vars=['Date'], 
            value_vars=['Active Cases', 'Death', 'Recovered'], value_name="Population", var_name='Status')
    totals_chart = px.bar(totals_state, 
        x="Date", y="Population", color="Status", title=country+" Active, Recovered and Deaths by Date as of " + str(NOW))
    totals_chart.show()
    totals_chart.write_image("images/" + NOW_FILE + "__" + cleanStr(country) + "active_recoverd_death.png")

    # active cases only
    active_chart = px.bar(totals[totals['Active Cases'] > 0], 
        x="Date", y="Active Cases", color="Death Rate", 
        title=country+" Active Cases (confirmed - recovered - dead) - Current as of: "+str(NOW), 
    )
    active_chart.show()
    
    change_chart = px.bar(totals[totals['Date'] >= '2020-02-20'], 
        x="Date", y="New Case PCT Change", color="Death Rate", 
        title=country+" New Active Case Change Rate (confirmed - recovered - dead) - Current as of: "+str(NOW), 
    )
    change_chart.show()
    
    change_chart = px.bar(totals[totals['Date'] >= '2020-02-20'], 
        x="Date", y="New Cases", color="Death Rate", 
        title=country+" New Cases - Current as of: "+str(NOW), 
    )
    change_chart.show()

def sim_country(country, pop_factor=100, b=0.3, g=1/36, t=90):
    cases, totals = getDf(country)
    
    # see here for some good parameter tuning
    N = cases['Confirmed'].max() * pop_factor
    T=t
    I0, R0 = totals['Likely Cases 1.8pct'].max(), totals['Recovered'].max()*120
    S0 = N - I0 - R0
    beta, gamma = b, g
    t = np.linspace(0, T, T)
    y0 = S0, I0, R0

    #SIR Deriv
    def deriv(y, t, N, beta, gamma):
        S, I, R = y
        dSdt = -beta * S * I / N
        dIdt = beta * S * I / N - gamma * I
        dRdt = gamma * I
        return dSdt, dIdt, dRdt

    ret = odeint(deriv, y0, t, args=(N, beta, gamma))
    S, I, R = ret.T

    D = (I/1000000)*totals['Death Rate'].mean()
    R = (R/1000000) - D

    fig = plt.figure(figsize=(24,7), facecolor='w')
    ax = fig.add_subplot(111, axisbelow=True)
    ax.plot(t, S/1000000, 'b', alpha=0.5, lw=2, label='Susceptible')
    ax.plot(t, I/1000000, 'r', alpha=0.5, lw=2, label='Infected')
    ax.plot(t, D, 'y', alpha=0.5, lw=2, label='Dead')
    ax.plot(t, R, 'g', alpha=0.5, lw=2, label='Recovered with immunity')
    ax.set_xlabel('Time /days')
    ax.set_title(country + ' - SIR+d model based on known disease characteristics and response at: '+str(NOW) + 
                '\nTotal Infected: '+'{:,.0f}'.format(int(np.max(S)))+
                '\nTotal Recovered: '+'{:,.0f}'.format(int(np.max(R)*1000000))+
                '\nTotal Dead: '+'{:,.0f}'.format(int(np.max(D)*1000000)),
                loc='left'
                )
    #ax.set_ylabel('Number (1000s)')
    #ax.set_ylim(0,1.2)
    ax.yaxis.set_tick_params(length=0)
    ax.xaxis.set_tick_params(length=0)
    ax.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax.legend()
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
        ax.spines[spine].set_visible(False)
    plt.show()

def pred_country(country, t=90, infectivity_factor=180, gMethod='linear', disp=True):
    cases, totals = getDf(country)
    
    # now let's run the forecast with fbprophet
    fb_df = totals[['Date', 'Active Cases']].copy()
    fb_df = fb_df.sort_values('Date').reset_index(drop=True)
    fb_df.columns = ['ds','y']
    #fb_df['cap'] = totals['Death'] * infectivity_factor
    fb_df['floor'] = 0
    #print(fb_df)

    m = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=False, growth=gMethod)
    m.fit(fb_df)
    future = m.make_future_dataframe(periods=t)
    #future['cap'] = totals['Death'].max() * infectivity_factor
    future['floor'] = 0
    forecast = m.predict(future)
    py.init_notebook_mode()

    fig = plot_plotly(m, forecast, xlabel='Date', ylabel='Active Cases', uncertainty=True, figsize=(1100,600))  # This returns a plotly Figure
    fig.update_layout(title='Active '+country+' COVID-19 Cases and Forecast ('+str(t)+' day) as of' + str(NOW))
    
    c = cleanStr(country)
    os.makedirs("./images/" + c + "/" + NOW_FILE, exist_ok=True)
    fig.write_image("./images/" + c + "/" + NOW_FILE + "/" + NOW_FILE + "__" + c + "_forecast_" + str(t) + "_day.png")
    
    if disp:
        py.iplot(fig)

def pred_province(country, province, t=90, infectivity_factor=180, gMethod='linear', disp=True):
    cases, totals = getDf(country)
    if len(cases['Province/State'].unique()) > 1:
        cases = cases[cases['Province/State'] == province]
    
    #print('Total records so far: ', totals.shape[0])
    # now let's run the forecast with fbprophet
    fb_df = cases[['Date', 'Confirmed']].copy()
    fb_df = fb_df.sort_values('Date').reset_index(drop=True)
    fb_df.columns = ['ds','y']
    #fb_df['cap'] = totals['Death'] * infectivity_factor
    fb_df['floor'] = 0
    #print(fb_df)

    m = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=False, growth=gMethod)
    m.fit(fb_df)
    future = m.make_future_dataframe(periods=t)
    #future['cap'] = totals['Death'].max() * infectivity_factor
    future['floor'] = 0
    forecast = m.predict(future)
    py.init_notebook_mode()

    fig = plot_plotly(m, forecast, xlabel='Date', ylabel='Confirmed Cases', uncertainty=True, figsize=(1100,600))  # This returns a plotly Figure
    fig.update_layout(title=province+', '+country+' Confirmed COVID-19 Cases and Forecast ('+str(t)+' day) as of' + str(NOW))
    
    c = cleanStr(country)
    p = cleanStr(province)
    os.makedirs("./images/" + c + "/" + NOW_FILE, exist_ok=True)
    fig.write_image("./images/" + c + "/" + NOW_FILE + "/" + NOW_FILE + "__" + c + "_" + p + "_forecast_" + str(t) + "_day.png")
    
    if disp:
        py.iplot(fig)

In [None]:
def f(Country):
    analyze_country(Country)
    
interactive_plot = interactive(f, Country=countries)
output = interactive_plot.children[-1]
output.layout.height = '800px'
output.layout.width = '1200px'
interactive_plot

In [None]:
def p(Country, Days):
    pred_country(Country, Days, infectivity_factor=180, gMethod='linear')
    
interactive_plot = interactive(p, Country=countries, Days=widgets.IntSlider(min=3, max=14, step=1, value=5))
output = interactive_plot.children[-1]
output.layout.height = '600px'
output.layout.width = '1200px'
interactive_plot

In [None]:
countryW = widgets.Dropdown(options = countries)
stateW = widgets.Dropdown()
daysW = widgets.IntSlider(min=3, max=14, step=1, value=5)

@interact(Country=countryW, State=stateW, Days=daysW)
def updateState(Country, State, Days):
    cases, totals = getDf(Country)
    sts = cases['Province/State'].unique()
    sts.sort()
    stateW.options = sts
    pred_province(Country, State, Days, infectivity_factor=180, gMethod='linear')

In [None]:
def s(Country, Susceptible_Population_Factor, Infectivity, Infective_Days, Days):
    sim_country(Country, Susceptible_Population_Factor, Infectivity, 1/Infective_Days, Days)
    
interactive_plot = interactive(s, Country=countries,
                              Susceptible_Population_Factor=widgets.IntSlider(min=10, max=1000000, step=1, value=200),
                              Infectivity=widgets.FloatSlider(min=0.001, max=4.000, step=0.0001, value=0.201),
                              Infective_Days=widgets.IntSlider(min=2, max=48, step=1, value=36),
                              Days=widgets.IntSlider(min=30, max=365*4, step=1, value=180))
output = interactive_plot.children[-1]
output.layout.height = '600px'
output.layout.width = '1200px'
interactive_plot

In [None]:
button = widgets.Button(description="Generate all 3 Day Province/State")
output = widgets.Output()

display(button, output)

def on_button_clicked(b):
    with output:
        for country in countries:
            print("Generating forecasts for: " + country)
            pred_country(country, 3, infectivity_factor=180, gMethod='linear', disp=False)
            cases, totals = getDf(country)
            sts = cases['Province/State'].unique()
            sts.sort()
            for st in sts:
                print("\t" + st)
                pred_province(country, st, 3, infectivity_factor=180, gMethod='linear', disp=False)
            
            
button.on_click(on_button_clicked)