In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import time
import types
sys.path.append('../')
import pandas as pd
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact, interactive, interactive_output
from traitlets import traitlets
import matplotlib.pyplot as plt
from fbprophet import Prophet
import yaml
from src.data_downloader import DATA_REPOS, download_from_repo, get_dataframes
import plotly.graph_objects as go
from plotly.graph_objs import Layout
from plotly.subplots import make_subplots
import plotly.express as px
from IPython.display import display, clear_output
with open("columns_names.yaml", 'r') as stream:
    out = yaml.load(stream)
    orig_data_columns = out['LABELS']['orig_data_columns']
    extra_data_columns = out['LABELS']['extra_data_columns']
    prov_data_columns = out['LABELS']['prov_data_columns']
    trend_labels = out['LABELS']['trend_labels']
    countries_labels = ['confirmed', 'recovered', 'deaths', 'daily_confirmed', 'daily_recovered', 'daily_deaths',
                   '%daily_confirmed', '%daily_recovered', '%daily_deaths']
daily_cols = ['daily_'+col for col in orig_data_columns]
import warnings
warnings.filterwarnings('ignore')
dest='../data'



In [2]:
df_naz, reg, prov, df_reg, df_prov, df_world_confirmed, df_world_deaths, df_world_recovered, populations, ita_populations, df_comuni_sett = get_dataframes(dest, npt_rth=5, smooth=True)
countries_columns = df_world_confirmed['Country/Region'].unique()

last available date for Italy data 2020-11-18T17:00:00
last available date for World data 11/18/20


In [3]:
# first, set the main page as a Tab container
main_tab = widgets.Tab()
main_tab.set_title(0, 'Italy')
main_tab.set_title(1, 'World')
main_tab.set_title(2, 'Data')

In [4]:
# first tab is for data download
download_tab = widgets.VBox()
download_out = widgets.Output(layout={'border': '1px solid black'})

# @download_out.on_displayed
# def check_creation_date(b=None):
#     file1 = 'dpc-covid19-ita-regioni.csv'
#     file2 = 'time_series_covid19_confirmed_global.csv'
#     file_path1 = os.path.join(dest,file1)
#     file_path2 = os.path.join(dest,file2)
#     (mode, ino, dev, nlink, uid, gid, size, atime, mtime1, ctime) = os.stat(file_path1)
#     (mode, ino, dev, nlink, uid, gid, size, atime, mtime2, ctime) = os.stat(file_path2)
#     download_out.append_display_data("{} last modified: {}".format(file1, time.ctime(mtime1)))
#     download_out.append_display_data("{} last modified: {}".format(file2, time.ctime(mtime2)))
    

update_button = widgets.Button(description='Update data from repos')
@update_button.on_click
def update_repo(b=None):
    download_out.append_display_data('downloading Italian data')
    download_from_repo(DATA_REPOS['italy']['url'], filenames=DATA_REPOS['italy']['streams'], dest=dest)
    download_out.append_display_data('downloading world data')
    download_from_repo(DATA_REPOS['world']['url'], filenames=DATA_REPOS['world']['streams'], dest=dest)  
    df_naz, reg, prov, df_reg, df_prov, df_world_confirmed, df_world_deaths, df_world_recovered, populations, ita_populations, df_comuni_sett = get_dataframes(dest, npt_rth=5, smooth=True)


download_tab.children = [widgets.HBox(children=[download_out, update_button])]

In [5]:
# second tab for italy
italy_tab = widgets.Tab()

In [6]:
prov_tab = widgets.VBox()
province_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def get_top_provinces(label, top_prov,date, show_map, show_grid):    
    with province_out:
        clear_output()
        label = [label]
        df_prov.index = pd.to_datetime(df_prov.index)
        tempdf = df_prov.loc[str(date)][['sigla_provincia','denominazione_provincia', 'lat', 'long']+ label].sort_values(by=label, 
             ascending=False)[:top_prov].set_index('sigla_provincia')

        if show_map:
            fig = px.choropleth(tempdf, 
                                geojson='https://raw.githubusercontent.com/openpolis/geojson-italy/master/geojson/limits_IT_provinces.geojson', 
                                locations='denominazione_provincia', 
                                color=label[0], 
                                color_continuous_scale='Reds',
                                featureidkey='properties.prov_name',                                   
                                range_color=(0, max(tempdf[label[0]])))
            
            fig.update_layout(showlegend=True,title='italian provinces',
                              paper_bgcolor='rgba(0,0,0,0)',font = dict(color = 'lightgray'),plot_bgcolor='rgba(0,0,0,0)')
            fig.update_geos(resolution=50, showcountries=True, countrycolor="lightgray", showland=True, 
                            showsubunits=True, subunitcolor="Blue", #showocean=True, oceancolor="lightblue"
                           )
            fig.update_geos(fitbounds="locations", visible=False)
            fig.update_layout(height=600)
        else:
            fig = px.bar(tempdf[label].reset_index(), x=label[0], y='sigla_provincia', orientation='h')
            fig.update_layout(showlegend=True,title='top {} provinces on day {}'.format(top_prov, date.strftime("%m/%d/%Y")),
                             paper_bgcolor='rgba(0,0,0,0)',font = dict(color = 'lightgray'),plot_bgcolor='rgba(0,0,0,0)')
            fig.update_xaxes(showgrid=show_grid, gridwidth=1, gridcolor='gray')
            fig.update_yaxes(showgrid=show_grid, gridwidth=1, gridcolor='gray')
        fig.show()
        
prov_widget = interactive(get_top_provinces,label=widgets.Select(options=prov_data_columns),
                            top_prov=widgets.IntSlider(min=1,max=150,step=1,value=150),
                            date=widgets.DatePicker(description='Pick a Date',value=pd.to_datetime(df_prov.index.max())),
                            show_map=widgets.Checkbox(value=True),
                            show_grid=widgets.Checkbox(value=True),
                            continuous_update=False
                            )
prov_box = widgets.HBox(prov_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))
prov_tab.children = [prov_box, province_out]

In [7]:
prov_evo_tab = widgets.VBox()
province_evo_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def get_prov_data_evolution(label,region,show_grid,cumulated_bars,log,plot_bars):
    with province_evo_out:
        clear_output()
        df_prov.index = pd.to_datetime(df_prov.index)     
        if cumulated_bars:
            label='daily_totale_casi'
            temp = df_prov.groupby('denominazione_regione').get_group(region).set_index('denominazione_provincia')[label].sort_values()
            fig = px.bar(temp.reset_index(), x=label, y='denominazione_provincia', orientation='h')
        else:
            temp = df_prov.groupby('denominazione_regione').get_group(region).groupby('denominazione_provincia')[label]
            fig = go.Figure()
            for province in temp.groups.keys():   
                if plot_bars:
                    fig.add_traces(go.Bar(x=temp.get_group(province).index, y=temp.get_group(province), name=province))
                else:                
                    fig.add_traces(go.Scatter(y=temp.get_group(province), name=province))
            if log: fig.update_layout(yaxis_type="log")  
        fig.update_layout(showlegend=True,title='province details evolution',
                             paper_bgcolor='rgba(0,0,0,0)',font = dict(color = 'lightgray'),plot_bgcolor='rgba(0,0,0,0)')
        fig.update_xaxes(showgrid=show_grid, gridwidth=1, gridcolor='lightgray')
        fig.update_yaxes(showgrid=show_grid, gridwidth=1, gridcolor='lightgray')

        fig.show()
        
prov_evo_widget = interactive(get_prov_data_evolution,
                            label=prov_data_columns, 
                            region = list(df_prov.denominazione_regione.unique()),
                            show_grid=False,
                            cumulated_bars = False,
                            log=False,
                            plot_bars=False
                            )
prov_evo_box = widgets.HBox(prov_evo_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))
prov_evo_tab.children = [prov_evo_box, province_evo_out]

In [8]:
forecast_tab = widgets.VBox()
forecast_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def get_forecast(region,start_fit,end_fit,label,forecast_periods,smoothing):
    with forecast_out:
        clear_output()
        df = df_reg[region][label].rolling(smoothing).mean()
        y = label
        train_data = pd.DataFrame()
        train_data['ds']=pd.to_datetime(df.index)
        train_data['y']=np.log1p(df.reset_index(drop=True).values)
        train_data['floor'] = 0.
        m = Prophet(growth='linear', daily_seasonality=True, weekly_seasonality=True, yearly_seasonality=False)
        m.fit(train_data.set_index('ds').loc[start_fit:end_fit].reset_index())
        future = m.make_future_dataframe(periods=forecast_periods)
        future['floor'] = train_data['floor']
        forecast = m.predict(future)
        forecast['yhat'] = np.expm1(forecast['yhat'])
        forecast['yhat_lower'] = np.expm1(forecast['yhat_lower'])
        forecast['yhat_upper'] = np.expm1(forecast['yhat_upper'])
        train_data['y'] = np.expm1(train_data['y'])
        df = pd.merge(left=train_data, right=forecast, on='ds', how='outer').set_index('ds')
        df.index = pd.to_datetime(df.index)
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=df.index, y=df['yhat_lower'],fill=None,mode='lines',line_color='lightgrey',name='confidence_lev_down'))
        fig.add_trace(go.Scatter(x=df.index, y=df['yhat_upper'],fill='tonexty',mode='lines',line_color='lightgrey', name='confidence_lev_up'))
        fig.add_traces(go.Scatter(x=df.index, y=df['y'], name='{}'.format(y), mode='lines+markers', marker=dict(size=5)))
        fig.add_traces(go.Scatter(x=df.loc[start_fit:end_fit].index, y=df['yhat'].loc[start_fit:end_fit], line_color='goldenrod',
                                  mode='lines',name='model fit'))
        fig.add_traces(go.Scatter(x=df.loc[end_fit:].index, y=df['yhat'].loc[end_fit:], line_color='darkblue', mode='markers',
                                  marker=dict(size=2), name='forecast'))
        fig.update_layout(showlegend=True,title={'text':label.replace('_', ' ') + ' for ' + region, 'xanchor': 'left'})
                         
        fig.update_xaxes(showgrid=False, gridwidth=1, gridcolor='gray')
        fig.update_yaxes(showgrid=False, gridwidth=1, gridcolor='gray')
        fig.show()
        


forecast_widget=interactive(get_forecast, {'manual': True},
                              region=widgets.Dropdown(options=df_reg.keys(), value='Italy'), 
                              start_fit=widgets.DatePicker(value=pd.to_datetime(df_naz.index[0])), 
                              end_fit=widgets.DatePicker(value=pd.to_datetime(df_naz.index[-10])), 
                              label = widgets.Dropdown(options=trend_labels, value='nuovi_positivi'),
                              forecast_periods=50, smoothing=widgets.IntText(1),
                              continuous_update=False)
forecast_widget.children[-2].description = 'Launch Forecast'
forecast_box = widgets.HBox(forecast_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

forecast_tab.children = [forecast_box, forecast_out]

In [9]:
evo_tab = widgets.VBox()

evo_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def plt_region(regions,labels,log,relative_dates,cases_per_mln_people,plot_bars,show_grid,aggregate,apply_ma):    
    with evo_out:
        clear_output()
        if len(labels) == 0:
            labels = ['Rth'] 
        labels = list(labels)
        if len(regions) == 0:
            regions = ['Italy']
        regions = list(regions)  
        fig = go.Figure()
        mult = 1.
        for item in labels:
            if aggregate:
                if cases_per_mln_people: 
                    mult = 1e06/ita_populations.loc[regions, 'Popolazione'].sum()                
                temp = df_reg[regions[0]][item].copy() 
                for region in regions[1:]:
                    temp = temp.add(df_reg[region][item])
                if item == 'Rth':
                    if (len(regions[1:])) > 0:                        
                        temp = temp/len(regions[1:])
                temp = pd.DataFrame(temp)
                if relative_dates: temp = temp.loc[~(temp[item]==0)].reset_index(drop=True).iloc[:-1] 
                temp = temp.rolling(apply_ma).mean().shift(-0)
                if plot_bars:
                    fig.add_traces(go.Bar(x=temp.index, y=temp[item]*mult, name=item+'_'+'-'.join(regions)))
                else:
                    fig.add_traces(go.Scatter(x=temp.index, y=temp[item]*mult, name=item+'_'+'-'.join(regions)))

            else:
                for region in regions:
                    if cases_per_mln_people: 
                        mult = 1e06/ita_populations.loc[region, 'Popolazione']
                    df_reg[region].index = pd.to_datetime(df_reg[region].index)
                    temp = df_reg[region].copy() 
                    temp[item] = temp[item].rolling(apply_ma).mean().shift(-0)
                    if relative_dates: temp = temp.loc[~(temp[item]==0)].reset_index(drop=True).iloc[:-1] 
                        
                    if plot_bars:
                        fig.add_traces(go.Bar(x=temp.index, y=temp[item]*mult, name=item+'_'+region))
                    else:
                        fig.add_traces(go.Scatter(x=temp.index, y=temp[item]*mult, name=item+'_'+region))
        fig.update_layout(showlegend=True,
                          legend_orientation="h",
                          title='Regional Evolution')
        if log: fig.update_layout(yaxis_type="log")
        fig.show()
        
evo_widget = interactive(plt_region,
                         regions = widgets.SelectMultiple(description="regions",
                                                          options=list(df_reg.keys()), value=['Italy']), 
                         labels = widgets.SelectMultiple(description="fields",
                                                         options=orig_data_columns+extra_data_columns+daily_cols, value=['nuovi_positivi']),
                         log=False, 
                         relative_dates=False, 
                         cases_per_mln_people=False, 
                         plot_bars=True, 
                         show_grid=False,
                         aggregate=False, 
                         apply_ma=widgets.IntText(value=1, min_value=1, max_value=150))
evo_box = widgets.HBox(evo_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

evo_tab.children = [evo_box, evo_out]

In [10]:
ratio_tab = widgets.VBox()

ratio_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def plt_ratio(regions,labels1, labels2, log,cases_per_mln_people,plot_bars,show_grid,apply_ma):    
    with ratio_out:
        clear_output()
        if len(regions) == 0:
            regions = ['Italy']
        regions = list(regions)  
        fig = go.Figure()
        mult = 1.        
        for region in regions:
            if cases_per_mln_people:
                mult = 1e06/ita_populations.loc[region, 'Popolazione']
            df_reg[region].index = pd.to_datetime(df_reg[region].index)
            temp = df_reg[region].copy() 
            temp[labels1] = temp[labels1].rolling(apply_ma).mean().shift(-0)
            temp[labels2] = temp[labels2].rolling(apply_ma).mean().shift(-0)
            temp = temp[[labels1,labels2]]
#             if relative_dates: temp = temp.loc[~(temp==0)].reset_index(drop=True).iloc[:-1] 

            if plot_bars:
                fig.add_traces(go.Bar(x=temp.index, y=temp[labels1]*mult/temp[labels2], name=labels1+'/'+labels2+'_'+region))
            else:
                fig.add_traces(go.Scatter(x=temp.index, y=temp[labels1]*mult/temp[labels2], name=labels1+'/'+labels2+'_'+region))
        fig.update_layout(showlegend=True,
                          legend_orientation="h",
                          title='Ratio: '+labels1+'/'+labels2)
        if log: fig.update_layout(yaxis_type="log")
        fig.show()
        
ratio_widget = interactive(plt_ratio,
                         regions = widgets.SelectMultiple(description="regions",
                                                          options=list(df_reg.keys()), value=['Italy']), 
                         labels1 = widgets.Select(description="numerator",
                                   options=orig_data_columns+daily_cols),
                         labels2 = widgets.Select(description="denominator",
                                   options=orig_data_columns+daily_cols),

                         log=False,
#                          relative_dates=False, 
                         cases_per_mln_people=False, 
                         plot_bars=False, 
                         show_grid=False,
#                          aggregate=False, 
                         apply_ma=widgets.IntText(value=1, min_value=1, max_value=150))
ratio_box = widgets.HBox(ratio_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

ratio_tab.children = [ratio_box, ratio_out]

In [11]:
daily_tab = widgets.VBox()

daily_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def get_values_for_day(regions,labels,date,cases_per_mln_people):
    with daily_out:
        clear_output()
        data_columns = orig_data_columns+extra_data_columns
        if len(regions) == 0:
            regions = ['Italy']
        regions = list(regions)    
        if len(labels) == 0:
            labels = [item for item in data_columns if ('daily' in item) & ('%' not in item)]
        labels = list(labels)
        mult = 1.
        fig = go.Figure()
        for region in regions:    
            if cases_per_mln_people: 
                mult = 1e06/ita_populations.loc[region, 'Popolazione']
            for item in labels: 
                df_reg[region].index = pd.to_datetime(df_reg[region].index)
            fig.add_traces(go.Bar(y=labels, x=df_reg[region][labels].loc[date]*mult, name=region, orientation='h'))
            fig.update_layout(showlegend=True,title='day ' + str(date.strftime("%m/%d/%Y")))
        fig.show()

default_regions = pd.DataFrame({reg: df_reg[reg]['Rth'].iloc[-1] for reg in df_reg.keys()}, index=[0]).T.sort_values(by=0, ascending=False).iloc[:10].index
    
daily_widget = interactive(get_values_for_day,
                           regions = widgets.SelectMultiple(description="regions",
                                                            options=list(df_reg.keys()), value=list(default_regions)),
                           labels = widgets.SelectMultiple(description="data",
                                                           options=orig_data_columns+extra_data_columns+daily_cols, 
                                                           value=['Rth']),
                           date=widgets.DatePicker(description='Pick a Date',
                                                   value=pd.to_datetime(df_prov.index.max())),
                           cases_per_mln_people=False, 
                           )
daily_box = widgets.HBox(daily_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

daily_tab.children = [daily_box, daily_out]

In [12]:
italy_tab.children = [prov_tab, prov_evo_tab, daily_tab, evo_tab, forecast_tab, ratio_tab]
italy_tab.set_title(0, 'provinces')
italy_tab.set_title(1, 'prov evolution')
italy_tab.set_title(2, 'daily cases')
italy_tab.set_title(3, 'evolution')
italy_tab.set_title(4, 'forecast')
italy_tab.set_title(5, 'ratio')

In [13]:
world_tab = widgets.Tab()

In [14]:
world_daily_tab = widgets.VBox()

world_daily_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })


def get_top_countries(labels,top_prov, date,show_grid, show_map):
    with world_daily_out:
        clear_output()
        datecols = df_world_confirmed.columns.difference(['Province/State','Country/Region','Lat','Long', 'pop'])
        df_geo = df_world_confirmed[df_world_confirmed['Province/State'].isna()].groupby('Country/Region').first()
        df = {}
        df['confirmed'] = df_world_confirmed.groupby('Country/Region').agg('sum')[datecols].T
        df['recovered'] = df_world_recovered.groupby('Country/Region').agg('sum')[datecols].T
        df['deaths'] = df_world_deaths.groupby('Country/Region').agg('sum')[datecols].T
        df['daily_deaths'] = df['deaths'].diff()
        df['daily_confirmed'] = df['confirmed'].diff()
        df['daily_recovered'] = df['recovered'].diff()
        df['%daily_deaths'] = df['deaths'].diff()/df['deaths'].shift()
        df['%daily_confirmed'] = df['confirmed'].diff()/df['confirmed'].shift()
        df['%daily_recovered'] = df['recovered'].diff()/df['recovered'].shift()
        for item in df.keys():
            df[item].index = pd.to_datetime(df[item].index)     
        mult = 1.
        fig = go.Figure()        
        tempdf = df[labels].loc[date.strftime("%Y-%m-%d")].T
        tempdf['lat'] = tempdf.index.map(df_geo['Lat'])
        tempdf['long'] = tempdf.index.map(df_geo['Long'])
        tempdf.columns = [labels, 'lat', 'long']
        tempdf = tempdf.sort_values(by=labels, ascending=False)[:top_prov]        
        if show_map:
            fig = px.scatter_mapbox(tempdf.reset_index(), 
                        lat='lat', lon='long', color=labels, size=labels, 
                        labels = labels,
                        hover_name='Country/Region',
                        zoom=0,  height=800,
                        mapbox_style="open-street-map",                        
                        title='top {} countries on day {} for {}'.format(top_prov, date.strftime("%m/%d/%Y"), labels),
               )
        else:
            fig = px.bar(tempdf.reset_index(), x=labels, y='Country/Region', orientation='h')
            fig.update_layout(showlegend=True,title='top {} countries on day {}'.format(top_prov, date.strftime("%m/%d/%Y")))

        fig.show()

world_daily_widget = interactive(
    get_top_countries,
    labels=widgets.Select(description="countries", options = countries_labels),
    top_prov=widgets.IntSlider(min=1,max=80,step=1,value=10), 
    date=widgets.DatePicker(description='Pick a Date',
                value=pd.to_datetime(pd.to_datetime([item for item in df_world_confirmed.columns if '/' in item][-1]))),
    show_grid=False, 
    show_map=True)

world_daily_box = widgets.HBox(world_daily_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

world_daily_tab.children = [widgets.VBox([world_daily_box, world_daily_out])]

In [15]:
world_evo_tab = widgets.VBox()

world_evo_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def world_comparison(regions,labels,log, relative_dates, cases_per_mln_people, plot_bars, show_grid,aggregate, 
                     hide_below=widgets.BoundedIntText(value=0,min=0,max=100, description='min cases')):    
    with world_evo_out:
        clear_output()
        if len(labels) == 0:
            labels = countries_labels[:1]
        labels = list(labels)
        if len(regions) == 0:
            regions = ['Italy', 'France', 'Spain', 'Germany', 'United Kingdom', 'US']
        regions = list(regions)
        df = {}
        df['confirmed'] = df_world_confirmed.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
        df['recovered'] = df_world_recovered.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T        
        df['deaths'] = df_world_deaths.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
        df['daily_deaths'] = df['deaths'].diff()
        df['daily_confirmed'] = df['confirmed'].diff()
        df['daily_recovered'] = df['recovered'].diff()
        df['%daily_deaths'] = df['deaths'].diff()/df['deaths'].shift()
        df['%daily_confirmed'] = df['confirmed'].diff()/df['confirmed'].shift()
        df['%daily_recovered'] = df['recovered'].diff()/df['recovered'].shift()
        
        mult = 1.
        fig = go.Figure()
        for item in labels:
            if aggregate:
                if cases_per_mln_people: 
                    mult = 1e06/populations.loc[regions].sum()
                temp = df[item][regions[0]].copy()        
                for region in regions[1:]:                    
                    temp = temp.add(df[item][region])
                temp.index = pd.to_datetime(temp.index)
                if relative_dates: temp = temp.loc[(temp>hide_below)].reset_index(drop=True).iloc[:-1] 

                if plot_bars:
                    fig.add_traces(go.Bar(x=temp.index, y=temp*mult, name=item+'_'+'-'.join(regions)))
                else:
                    fig.add_traces(go.Scatter(x=temp.index, y=temp*mult, name=item+'_'+'-'.join(regions)))
                fig.update_layout(legend_orientation="h")
            else:
                for region in regions:            
                    temp = df[item][region]
                    temp.index = pd.to_datetime(temp.index)
                    if cases_per_mln_people: 
                        mult = 1e06/populations.loc[region]
                    if relative_dates: temp = temp.loc[(temp>hide_below)].reset_index(drop=True).iloc[:-1] 
                    if plot_bars:
                        fig.add_traces(go.Bar(x=temp.index, y=temp*mult, name=item+'_'+region))
                    else:
                        fig.add_traces(go.Scatter(x=temp.index, y=temp*mult, name=item+'_'+region))
        if log: fig.update_layout(yaxis_type="log")
        fig.update_xaxes(showgrid=show_grid, gridwidth=1)#, gridcolor='lightgray')
        fig.update_yaxes(showgrid=show_grid, gridwidth=1)#, gridcolor='lightgray')
        fig.show()
        
world_evo_widget = interactive(
    world_comparison,
    regions = widgets.SelectMultiple(description="regions",options=countries_columns), 
    labels = widgets.SelectMultiple(description="data",options=countries_labels),
    log=False, 
    relative_dates=False, 
    cases_per_mln_people=False,
    plot_bars=False, 
    show_grid=False,
    aggregate=False,
    hide_below=widgets.BoundedIntText(value=0,min=0,max=100, description='min cases'))

world_evo_box = widgets.HBox(world_evo_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

world_evo_tab.children = [widgets.VBox([world_evo_box, world_evo_out])]

In [16]:
region = 'Italy'
label = 'confirmed'
df = {}
df['confirmed'] = df_world_confirmed.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
df['recovered'] = df_world_recovered.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T        
df['deaths'] = df_world_deaths.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
df['daily_deaths'] = df['deaths'].diff()
df['daily_confirmed'] = df['confirmed'].diff()
df['daily_recovered'] = df['recovered'].diff()
df['%daily_deaths'] = df['deaths'].diff()/df['deaths'].shift()
df['%daily_confirmed'] = df['confirmed'].diff()/df['confirmed'].shift()
df['%daily_recovered'] = df['recovered'].diff()/df['recovered'].shift()

df = df[label][region].rolling(30).mean()

In [17]:
world_forecast_tab = widgets.VBox()
world_forecast_out = widgets.Output(layout = {
            'width': '100%',
            'height': '600px',
            'border': '1px solid black'
        })

def get_world_forecast(region,start_fit,end_fit,label,forecast_periods,smoothing):
    with world_forecast_out:
        clear_output()
        df = {}
        df['confirmed'] = df_world_confirmed.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
        df['recovered'] = df_world_recovered.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T        
        df['deaths'] = df_world_deaths.copy().groupby('Country/Region').sum().drop(['Lat', 'Long', 'pop'], 1).T
        df['daily_deaths'] = df['deaths'].diff()
        df['daily_confirmed'] = df['confirmed'].diff()
        df['daily_recovered'] = df['recovered'].diff()
        df['%daily_deaths'] = df['deaths'].diff()/df['deaths'].shift()
        df['%daily_confirmed'] = df['confirmed'].diff()/df['confirmed'].shift()
        df['%daily_recovered'] = df['recovered'].diff()/df['recovered'].shift()
        
        df = df[label][region].rolling(smoothing).mean()
        
        y = label
        train_data = pd.DataFrame()
        train_data['ds']=pd.to_datetime(df.index)
        train_data['y']=np.log1p(df.reset_index(drop=True).values)
        train_data['floor'] = 0.
        m = Prophet(growth='linear', daily_seasonality=True, weekly_seasonality=True, yearly_seasonality=False)
        m.fit(train_data.set_index('ds').loc[start_fit:end_fit].reset_index())
        future = m.make_future_dataframe(periods=forecast_periods)
        future['floor'] = train_data['floor']
        forecast = m.predict(future)
        forecast['yhat'] = np.expm1(forecast['yhat'])
        forecast['yhat_lower'] = np.expm1(forecast['yhat_lower'])
        forecast['yhat_upper'] = np.expm1(forecast['yhat_upper'])
        train_data['y'] = np.expm1(train_data['y'])
        df = pd.merge(left=train_data, right=forecast, on='ds', how='outer').set_index('ds')
        df.index = pd.to_datetime(df.index)
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=df.index, y=df['yhat_lower'],fill=None,mode='lines',line_color='lightgrey',name='confidence_lev_down'))
        fig.add_trace(go.Scatter(x=df.index, y=df['yhat_upper'],fill='tonexty',mode='lines',line_color='lightgrey', name='confidence_lev_up'))
        fig.add_traces(go.Scatter(x=df.index, y=df['y'], name='{}'.format(y), mode='lines+markers', marker=dict(size=5)))
        fig.add_traces(go.Scatter(x=df.loc[start_fit:end_fit].index, y=df['yhat'].loc[start_fit:end_fit], line_color='goldenrod',
                                  mode='lines',name='model fit'))
        fig.add_traces(go.Scatter(x=df.loc[end_fit:].index, y=df['yhat'].loc[end_fit:], line_color='darkblue', mode='markers',
                                  marker=dict(size=2), name='forecast'))
        fig.update_layout(showlegend=True,title={'text':label.replace('_', ' ') + ' for ' + region, 'xanchor': 'left'})
                         
        fig.update_xaxes(showgrid=False, gridwidth=1, gridcolor='gray')
        fig.update_yaxes(showgrid=False, gridwidth=1, gridcolor='gray')
        fig.show()
        


world_forecast_widget=interactive(get_world_forecast, {'manual': True},
                              region=widgets.Dropdown(options=countries_columns, value='Italy'), 
                              start_fit=widgets.DatePicker(value=pd.to_datetime(df_naz.index[0])), 
                              end_fit=widgets.DatePicker(value=pd.to_datetime(df_naz.index[-10])), 
                              label = widgets.Dropdown(options=countries_labels, value='confirmed'),
                              forecast_periods=50, smoothing=widgets.IntText(1),
                              continuous_update=False)
world_forecast_widget.children[-2].description = 'Launch Forecast'
world_forecast_box = widgets.HBox(world_forecast_widget.children[:-1], layout = widgets.Layout(flex_flow='row wrap'))

world_forecast_tab.children = [world_forecast_box, world_forecast_out]

In [18]:
world_tab.children = [world_daily_tab, world_evo_tab, world_forecast_tab]
world_tab.set_title(0, title='daily cases')
world_tab.set_title(1, title='country comparison')
world_tab.set_title(2, title='forecast')

In [19]:
# add all components to main_tab page
main_tab.children=[italy_tab, world_tab, download_tab]

In [20]:
main_tab

Tab(children=(Tab(children=(VBox(children=(HBox(children=(Select(description='label', options=('totale_casi', …