In [1]:
import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
import numpy as np
import json
from dash.dependencies import Input, Output
from plotly import graph_objs as go
from plotly.graph_objs import *
from datetime import datetime as dt
import requests
import plotly.express as px
from scipy.integrate import odeint
from scipy.optimize import minimize,curve_fit
import os
from flask import send_from_directory
global glob_data
import os

from datetime import datetime, timedelta
from copy import deepcopy

if os.path.split(os.getcwd())[-1]=='notebooks':
    os.chdir("../")

'Base path is at: '+os.path.split(os.getcwd())[-1]

'Base path is at: ads_covid-19'

In [None]:
app = dash.Dash(__name__)
server = app.server


#to get the data for the choropleth map
def get_data():
    data_table = []
    url="https://corona.lmao.ninja/v2/countries?yesterday&sort"
    data= requests.get(url)
    data=json.loads(data.text)
    for item in data:
        data_table.append([item['countryInfo']['iso3'],item['country'],item['cases'],item['recovered'],item['active'],item['deaths'],item['critical'], item['population']])
    data = pd.DataFrame(data_table,columns = ['Code','Country', 'Confirmed', 'Recovered', 'Active', 'Deaths','Critical', 'Population'])
    data = data.sort_values(by = 'Confirmed', ascending=False)
    return data

#choropleth figure
def world_status(df):

    fig = go.Figure(data=go.Choropleth(
                locations = df['Code'],
                z = df['Confirmed'],
                text = df.Active,
                colorscale = 'Oranges',
                autocolorscale=False,
                marker_line_color='darkgray',
                marker_line_width=1.5,
                colorbar_title = 'Affected',
                hovertext = df.Deaths,
                hovertemplate =df.Country + "<extra>Confirmed : %{z}<br>Active : %{text} <br>Deaths : %{hovertext}</extra>",
    ))
    fig.update_layout(
                width=1024,
                height=720,
    )

    return fig


#fetch data from api : api.covid19api.com for each country

def get_country_data(country):

    till_date_data=[]

    url=f"https://api.covid19api.com/total/country/{country}"
    requested_data= requests.get(url)
    requested_data=json.loads(requested_data.text)

    for each in requested_data:
        till_date_data.append([each['Date'][:10],each['Confirmed'],each['Recovered'],each['Active'],each['Deaths']])

    country_data = pd.DataFrame(till_date_data,columns = ['Date','Confirmed', 'Recovered', 'Active', 'Deaths',])

    data = country_data[['Confirmed','Recovered','Deaths']]
    unrepaired_data= data - data.shift(1)

    false_index_deaths = list(unrepaired_data.index[unrepaired_data['Deaths'] < 0])

    if false_index_deaths != None :
        for each in false_index_deaths:
            data.at[each,'Deaths'] = data.at[each-1,'Deaths']

    false_index_confirmed = list(unrepaired_data.index[unrepaired_data['Confirmed'] < 0])

    if false_index_confirmed != None :
        for each in false_index_confirmed:
            data.at[each,'Confirmed'] = data.at[each-1,'Confirmed']


    false_index_recovered = list(unrepaired_data.index[unrepaired_data['Recovered'] < 0])

    if false_index_recovered != None :
        for each in false_index_recovered:
            data.at[each,'Recovered'] = data.at[each-1,'Recovered']

    daily_data = data - data.shift(1)
    daily_data = daily_data.fillna(0)
    daily_data = daily_data.mask(daily_data < 0, 0)

    new_data = pd.concat([country_data[['Date']],data,daily_data], axis=1, sort=False)
    new_data.columns = ['Date', 'Total_confirmed', 'Total_recovered', 'Total_deaths', 'Daily_confirmed','Daily_recovered', 'Daily_deaths']

    return new_data

#to convert the country code to country name by matching values from different api's
def collected_data(data, country_code = 'DEU'):
    
    if country_code == 'KOR':
        return 'KOR'
        
    if country_code != "USA":
        data = np.array(data[['Code','Country']])

        for records in data:
            if records[0] == country_code:
                break

        return records[1]

    if country_code == 'USA':
        return 'United States'
        
#to fetch the total world stats
def total_status():

    url = 'https://api.covid19api.com/world/total'
    data = requests.get(url)
    total_data = json.loads(data.text)

    total_confirmed = f'{total_data["TotalConfirmed"]:,}'
    total_deaths = f"{total_data['TotalDeaths']:,}"
    total_recovered = f"{total_data['TotalRecovered']:,}"
    total_active = total_data["TotalConfirmed"] -total_data['TotalDeaths'] - total_data['TotalRecovered']
    total_active = f"{total_active:,}"

    return total_confirmed,total_recovered,total_active,total_deaths

glob_data = get_data()
glob_data = glob_data.dropna()
comparision_countries_list = glob_data.sort_values('Confirmed',ascending = False)
comparision_countries_list = comparision_countries_list[0:187]
sir_simulation_countries_list = comparision_countries_list[0:187]
confirmed, recovered, active, deaths = total_status()


#app layout
app.layout = html.Div(children=[

   # title for the application
    html.Div('Applied Data Science on COVID-19 data with SIR Simulations',style = {'textAlign':'center',
    'backgroundColor': '#f5f5f5',
    'color': '#777',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

    # table to display world status like confirmed cases, Recovered  cases, deaths and active cases
    html.Div([

        html.Table (

            html.Tr([

                html.Td([html.Div('total',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '19px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                html.H2(confirmed)],style = {
                    'fontFamily' : 'Arial, Helvetica, sans-serif',
                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
                html.Td([html.Div('Recovered',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '19px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                html.H2(recovered, style = {'color' : '#3CB371' })],style = {
                    'fontFamily' : 'Arial, Helvetica, sans-serif',
                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
                html.Td([html.Div('Active',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '19px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                html.H2(active, style = {'color' : '#696969' })],style = {
                    'fontFamily' : 'Arial, Helvetica, sans-serif',
                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
                html.Td([html.Div('Deaths',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '19px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                html.H2(deaths, style = {'color' : '#B22222' })],style = {
                    'fontFamily' : 'Arial, Helvetica, sans-serif',
                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                })
            ]

            )
        ,style = { 'width' : "100%", 'textAlign' :'center'}),

    # to display choropleth global map
        html.Table(

            html.Tr([

                html.Td(dcc.Graph(figure = world_status(glob_data),id = 'map'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),

    #to display stats of single country like country name,  recovered cases, Active cases, Confirmed cases and deaths
                html.Td([
                    html.Div(id = 'country_name',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '25px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                html.Div('Confirmed',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '16px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div(id = 'final_cases',style = {'textAlign':'center',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div('Recovered',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '16px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div(id = 'final_recovered',style = {'textAlign':'center',
    'color' : '#3CB371',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div('Active',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '16px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div(id = 'final_active',style = {'textAlign':'center',
    'color' : '#696969',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                   html.Div('Deaths',style = {'textAlign':'center',
    'color': '#777',
    'font-size': '16px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

                    html.Div(id = 'final_deaths',style = {'textAlign':'center',
    'color' : '#B22222',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'})
                    ]
                ,style = { 'width': '550px',
                    'textAlign': 'center',
                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',})]

                )
        ,style = {'width' : '100%'}),

        #to display individual graphs for daily confirmed cases, daily deaths, recoveries etc.,
         html.Table(

            html.Tr([

                html.Td(dcc.Graph(id = 'deaths'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
                
                html.Td(dcc.Graph(id = 'total_recovered'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
                
                html.Td(dcc.Graph(id = 'total_cases'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                })
                ])
             
        ,style = {'width' : '100%'}),

        html.Table(

            html.Tr([

                html.Td(dcc.Graph(id = 'daily_deaths'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),

                html.Td(dcc.Graph(id = 'recovered'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                }),
            
                html.Td(dcc.Graph(id = 'new_cases'), style = {

                    'borderRadius': '5px',
                    'backgroundColor': '#f9f9f9',
                    'margin': '10px',
                    'padding': '15px',
                    'position': 'relative',
                    'boxShadow': '2px 2px 2px lightgrey',
                })]

                )
        ,style = {'width' : '100%'}),
        
    # graph to show SIR dynamic simulations

    html.Div('SIR SIMULATIONS',style = {'textAlign':'center',
    'backgroundColor': '#f5f5f5',
    'color': '#777',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

    html.Div(
            dcc.Dropdown(id = 'simulation_countries',
        options=[{'label': country_name, 'value': country_code} for country_name,country_code in zip(sir_simulation_countries_list["Country"],sir_simulation_countries_list["Code"]) ],
        value="DEU",
    )
        ),

    html.Div(
        dcc.Graph(id = "SIR_simulations")
    ),

    ]),
    
    # graph to compare different countries stats

        html.Div('countrywise comparision',style = {'textAlign':'center',
    'backgroundColor': '#f5f5f5',
    'color': '#777',
    'font-size': '23px',
    'textTransform': 'uppercase',
    'lineHeight': '40px',
    'fontFamily': 'roboto condensed,sans-serif',
    'display' : 'block'}),

        html.Table([
            html.Tr([
                html.Td(dcc.Dropdown(id = 'comparision_countries_dd',style = {'textAlign':'left'},
        options=[{'label': country_name, 'value': country_code} for country_name,country_code in zip(comparision_countries_list["Country"],comparision_countries_list["Code"]) ],
        value=["DEU", "IND"],
        multi = True
    )),

                html.Td(dcc.RadioItems(
               id = 'comparision_countries_radio',
               options = [ {'label': 'Confirmed', 'value': 'Confirmed'},
                {'label': 'Recovered', 'value': 'Recovered'},
                {'label': 'Deaths', 'value': 'Deaths'}],
                value='Confirmed',
    ))
            ])
        ],style = {'width': '100%','textAlign':'center'}),
    html.Div(dcc.Graph(id = 'comparision_output')),
    html.Footer(" Course: Enterprise Data Science ",style = {'textAlign':'center'})
])

# callback to update different elements
@app.callback(
    [Output('deaths', 'figure'),
    Output('new_cases', 'figure'),
    Output('recovered', 'figure'),
    Output('daily_deaths', 'figure'),
    Output('total_cases','figure'),
    Output('total_recovered','figure'),
    Output('final_cases','children'),
    Output('final_recovered','children'),
    Output('final_active','children'),
    Output('final_deaths','children'),
    Output('country_name','children')
    ],

    [Input('map', 'clickData')])

def update_data(clickData):
    data = str(clickData)
    index = data.find("location") +12
    country_code = data[index : index+3]
    if country_code == "":
        country = "Germany"
    else:
        country = collected_data(glob_data,country_code)

    try:
        new_data = get_country_data(country)
    except:
        new_data = get_country_data('Germany')
        country = 'Germany'

    new_cases = px.bar(new_data, x="Date",y="Daily_confirmed",width = 470, height = 400)
    new_cases.update_layout(title_text = " Daily New Cases" ,title_x=0.5 )
    new_cases.update_traces(marker_color = 'black')

    recovered = px.bar(new_data, x="Date",y="Daily_recovered",width = 470, height = 400)
    recovered.update_layout(title_text = " Daily New Recoveries" ,title_x=0.5)
    recovered.update_traces(marker_color = '#3CB371')

    daily_deaths = px.bar(new_data, x="Date",y="Daily_deaths",width = 470, height = 400 )
    daily_deaths.update_layout(title_text = "Daily Deaths" ,title_x=0.5)
    daily_deaths.update_traces(marker_color='#FF0000')

    deaths = px.line(new_data, x="Date",y="Total_deaths",width = 470, height = 400)
    deaths.update_layout(title_text ="Total Deaths" ,title_x=0.5)
    deaths.update_traces(line_color='#FF0000')

    total_cases = px.line(new_data, x="Date",y="Total_confirmed",width = 470, height = 400)
    total_cases.update_layout(title_text = " Total cases" ,title_x=0.5 )
    total_cases.update_traces(line_color = 'black')

    total_recovered = px.line(new_data, x="Date",y="Total_recovered",width = 470, height = 400)
    total_recovered.update_layout(title_text = " Total Recovered" ,title_x=0.5  )
    total_recovered.update_traces(line_color = '#3CB371')

    if country_code != "USA":
        side_data  = glob_data[glob_data.Country == country ]

    else:
        side_data = glob_data[glob_data.Country == 'USA']

    final_cases = int(side_data['Confirmed'])

    final_recovered = int(side_data['Recovered'])

    final_deaths = int(side_data['Deaths'])

    final_active = int(side_data['Active'])

    final_cases = f'{final_cases:,}'

    final_recovered = f'{final_recovered:,}'

    final_active = f'{final_active:,}'

    final_deaths = f'{final_deaths:,}'

    return deaths,new_cases,recovered,daily_deaths,total_cases,total_recovered,final_cases,final_recovered,final_active,final_deaths,country


#callback function to calculate different countries stats comparision
@app.callback(
    Output('comparision_output','figure'),
    [Input('comparision_countries_dd', 'value'),
    Input('comparision_countries_radio','value')]
)
def countries_comparision_charts(comparision_countries_dd, comparision_countries_radio) :
    comparision_countries_data = []

    for each in comparision_countries_dd:
        country = collected_data(glob_data,each)
        comparision_countries_data.append(get_country_data(country))

    if comparision_countries_radio == 'Confirmed':
        fig = go.Figure()
        for country_data_point,country in zip(comparision_countries_data,comparision_countries_dd):
            fig.add_traces( go.Scatter(x= country_data_point['Date'], y =country_data_point['Total_confirmed'], mode='lines+markers', name = country))
        fig.update_layout(
         paper_bgcolor="#f9f9f9",
         height = 800
        )
        return fig

    elif comparision_countries_radio == 'Recovered' :
        fig = go.Figure()
        for country_data_point,country in zip(comparision_countries_data,comparision_countries_dd):
            fig.add_traces( go.Scatter(x= country_data_point['Date'], y =country_data_point['Total_recovered'], mode='lines+markers', name = country))
        fig.update_layout(
         paper_bgcolor="#f9f9f9",
         height = 800
        )
        return fig

    else:
        fig = go.Figure()
        for country_data_point,country in zip(comparision_countries_data,comparision_countries_dd):
            fig.add_traces( go.Scatter(x= country_data_point['Date'], y =country_data_point['Total_deaths'], mode='lines+markers',name = country))
        fig.update_layout(
         paper_bgcolor="#f9f9f9",
         height = 800
        )
        return fig

#callback functin to calculate SIR dynammic simulations and show in the graph
@app.callback(
    Output('SIR_simulations','figure'),
    [Input('simulation_countries', 'value')]
)

def sir_simulations (value):

    country = collected_data(glob_data,value)
    data = get_country_data(country)
    data_size = 8
    t = np.arange(data_size)
    N = glob_data[glob_data['Code'] == value]['Population'].values[0]

    def SIR(y, t, beta, gamma):
        S = y[0]
        I = y[1]
        R = y[2]
        return -beta*S*I/N, (beta*S*I)/N-(gamma*I), gamma*I

    def fit_odeint(t,beta, gamma):
        return odeint(SIR,(s_0,i_0,r_0), t, args = (beta,gamma))[:,1]

    def loss(point, data, s_0, i_0, r_0):
        predict = fit_odeint(t, *point)
        l1 = np.sqrt(np.mean((predict - data)**2))
        return l1

    predicted_simulations = []

    for i in range(len(data)-data_size):
        if i%data_size == 0:
            j = i
            train = list(data['Total_confirmed'][i:i+data_size])
            i_0 = train[0]
            r_0 = data ['Total_recovered'].values[i]
            s_0 = N - i_0 - r_0
            params, cerr = curve_fit(fit_odeint,t, train)
            optimal = minimize(loss, params, args=(train, s_0, i_0, r_0))
            beta,gamma = optimal.x
            predict = list(fit_odeint(t,beta,gamma))
            predicted_simulations.extend(predict)

    train = list(data['Total_confirmed'][-data_size:])
    i_0 = train[0]
    r_0 = data ['Total_recovered'].values[-data_size]
    s_0 = N - i_0 - r_0
    params, cerr = curve_fit(fit_odeint, t, train)
    optimal = minimize(loss, params, args=(train, s_0, i_0, r_0))
    beta,gamma = optimal.x
    predict = list(fit_odeint(np.arange(data_size + 7), beta, gamma))
    predicted_simulations.extend(predict[j-i-8:])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=data["Date"], y=data['Total_confirmed'],
                        mode='lines+markers',
                        name='Actual'))
    
    ## insert extra dates here
    dates = data["Date"].values.tolist()
    last_date = datetime.strptime(dates[-1], "%Y-%m-%d")
    for _ in range (7):
        last_date += timedelta(days=1)
        dates.append(last_date.strftime("%Y-%m-%d"))
    
    fig.add_bar(x = dates[:len(predicted_simulations)], y=predicted_simulations, name = "Simulated")    
    fig.update_layout(height = 700)
    return fig

#application tab title
app.title = 'COVID-19 Dashboard(SIR)'


#application favicion
app._favicon = "fevicon.ico"

if __name__ == "__main__":
    app.run_server(debug=True,use_reloader=False,host='127.0.0.1',port=8085)


Dash is running on http://127.0.0.1:8085/

 in production, use a production WSGI server like gunicorn instead.

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: on



Covariance of the parameters could not be estimated


Covariance of the parameters could not be estimated

