# Dynamic SIR for multi countries

# Prepare the data

In [None]:
## check some parameters
import os
if os.path.split(os.getcwd())[-1]=='/notebooks':
    os.chdir("../")

'Your base path is at: '+os.path.split(os.getcwd())[-1]

In [None]:
# import pandas as pd
import os
import numpy as np
from datetime import datetime
import pandas as pd
from scipy import optimize
from scipy import integrate
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="darkgrid")

mpl.rcParams['figure.figsize'] = (16, 9)
pd.set_option('display.max_rows', 10)

def store_flat_table_JH_data():
    
    data_path='../data/raw/COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
    pd_raw=pd.read_csv(data_path)
    time_idx = pd_raw.columns[4:]
    df_plot = pd.DataFrame({'date':time_idx})
    country_list=['Germany','India', 'US', 'Austria', 'France'] 
    for each in country_list:
        df_plot[each]=np.array(pd_raw[pd_raw['Country/Region']==each].iloc[:,4::].sum(axis=0))
    time_idx=[datetime.strptime( each,"%m/%d/%y") for each in df_plot.date] # convert to datetime
    time_str=[each.strftime('%Y-%m-%d') for each in time_idx] # convert back to date ISO norm (str)
    df_plot['date']=time_idx
    df_plot.to_csv('../data/processed/COVID_small_flat_table.csv',sep=';',index=False)
    df_plot.tail()
    
if __name__ == '__main__':
    store_flat_table_JH_data()

# Process the data

In [None]:
def main():
    
    def SIR_model_t(SIR,t,beta,gamma):
        S,I,R = SIR
        dS_dt = -beta*S*I/N0
        dI_dt = beta*S*I/N0-gamma*I
        dR_dt = gamma*I
        return dS_dt,dI_dt,dR_dt


    def data_slice(data,con):
        ydata = np.array(data[con])
        t = np.arange(len(ydata))
        return ydata, t

    def fit_odeint(x, beta, gamma):

        '''
        function for fitting
        '''
        I0 = ydata[0]
        S0 =N0 - I0
        R0=0
        return integrate.odeint(SIR_model_t, (S0, I0, R0), t, args=(beta, gamma))[:,1] # we only would like to get dI

    #def cal_pmts():
    df_analyse=pd.read_csv('../data/processed/COVID_small_flat_table.csv',sep=';')
    df_analyse=df_analyse.sort_values('date',ascending=True)
    df_analyse=df_analyse[35:140:]
    df_analyse= df_analyse.reset_index().drop(["index"], axis=1)
    N0 =10000000
    popt=[0.3,0.1]
    country_list= df_analyse.columns[1:]
    for each in country_list:
        ydata, t = data_slice(df_analyse,each)
        fit_odeint(t, *popt)
        popt2, pcov = optimize.curve_fit(fit_odeint, t, ydata, maxfev=1000)
        perr = np.sqrt(np.diag(pcov))
        print("For {}:".format(each))
        print('standard deviation errors : ', str(perr), ' start infect:', ydata[0])
        print("Optimal parameters: beta =", popt2[0], " and gamma = ", popt2[1])
        fitted=fit_odeint(t, *popt)
        df_analyse[each+"_fit"]= pd.DataFrame(fitted)

    df_analyse= df_analyse.reset_index()
    print(df_analyse.tail())
    df_analyse.to_csv('../data/processed/COVID_SIR_final_set.csv',sep=';',index=False)

if __name__ == '__main__':
    main()


In [None]:
df_analyse=pd.read_csv('../data/processed/COVID_SIR_final_set.csv',sep=';')
df_analyse.tail()

# Visualize the result

In [None]:
import pandas as pd
import numpy as np

import dash
dash.__version__
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output,State

import plotly.graph_objects as go
import random
import os

color_list = []
for i in range(200):
    var = '#%03x%03x%03x%03x%03x'%(random.randint(0,255),random.randint(0,255),random.randint(0,255),random.randint(0, 255),random.randint(0,255))
    color_list.append(var)

df_analyse=pd.read_csv('../data/processed/COVID_SIR_final_set.csv',sep=';')


fig = go.Figure()

app = dash.Dash()
app.layout = html.Div([

    dcc.Markdown('''
    #  Applied Data Science on COVID-19 data

    Goal of the project is to learn data science by applying a cross industry standard process,
    it covers the full walkthrough of: automated data gathering, data transformations,
    filtering and machine learning to approximating the doubling time, and
    (static) deployment of responsive dashboard.

    '''),

    dcc.Markdown('''
    ## Multi-Select Country for SIR visualization
    '''),


    dcc.Dropdown(
        id='country_drop_down',
        options=[ {'label': each,'value':each} for each in df_analyse.columns[2:7]],
        value=['India'], # which are pre-selected
        multi=True
    ),

    dcc.Graph(figure=fig, id='main_window_slope')
])



@app.callback(
    Output('main_window_slope', 'figure'),
    [Input('country_drop_down', 'value')])


def update_figure(country_list):


    v=0
    my_yaxis={'type':"log",
              'title':'New Population Infected'
          }


    traces = []
    for each in country_list:

        traces.append(dict(x=df_analyse.index,
                                y=df_analyse[each],
                                mode='markers', markers = dict(color= color_list[v]),
                                opacity=0.9,
                                name=each
                        )
                )
        traces.append(dict(x=df_analyse.index,
                        y=df_analyse[each+"_fit"],
                        mode='lines',line = dict(color=color_list[v]), opacity=1.0,name=each+'_Fit'))
        v= v+1

    return {
            'data': traces,
            'layout': dict (
                width=1280,
                height=720,

                xaxis={'title':'Timeline',
                        'tickangle':-45,
                        'nticks':20,
                        'tickfont':dict(size=14,color="#444"),
                      },

                yaxis=my_yaxis
        )
    }

app.run_server(host=os.getenv('IP', '0.0.0.0'), 
            port=int(os.getenv('PORT', 4444)),debug = True, use_reloader = False)