diff --git a/app.py b/app.py index ef1b8bc4..86008af0 100644 --- a/app.py +++ b/app.py @@ -46,7 +46,8 @@ # ----------- Figures --------------------- fig1 = make_map(df_tidy, df_tidy_fatalities) -fig2 = make_timeplot(df, df_prediction) +fig2 = make_timeplot(df, df_prediction, countries=['France', 'Italy', 'Spain']) +fig_store = make_timeplot(df, df_prediction) # ------------ Markdown text --------------- # maybe later we can break the text in several parts @@ -97,7 +98,7 @@ ], className="pure-u-1 pure-u-lg-1-2 pure-u-xl-8-24", ), - dcc.Store(id='store', data=[fig2, initial_indices]), + dcc.Store(id='store', data=[fig_store, initial_indices]), html.Div([ dash_table.DataTable( id='table', @@ -165,21 +166,12 @@ # in order to transform the app into static html pages # javascript functions are defined in assets/callbacks.js -app.clientside_callback( - ClientsideFunction( - namespace='clientside2', - function_name='get_store_data' - ), - output=Output('plot', 'figure'), - inputs=[Input('store', 'data')]) - - app.clientside_callback( ClientsideFunction( namespace='clientside', function_name='update_store_data' ), - output=Output('store', 'data'), + output=Output('plot', 'figure'), inputs=[ Input('table', "data"), Input('table', "selected_rows")], @@ -203,6 +195,5 @@ ) - if __name__ == '__main__': app.run_server(debug=debug) diff --git a/assets/callbacks.js b/assets/callbacks.js index 98576bf9..602c14d8 100644 --- a/assets/callbacks.js +++ b/assets/callbacks.js @@ -3,11 +3,6 @@ if (!window.dash_clientside) { } -window.dash_clientside.clientside2 = { - get_store_data: function(data) { - return data[0]; - } -}; window.dash_clientside.clientside3 = { update_table: function(clickdata, selecteddata, table_data, selectedrows, store) { @@ -51,27 +46,26 @@ window.dash_clientside.clientside3 = { window.dash_clientside.clientside = { - update_store_data: function(selectedrows, rows, store) { + update_store_data: function(rows, selectedrows, store) { var fig = store[0]; if (!rows) { throw "Figure data not loaded, aborting update." } - var new_fig = {...fig}; + var new_fig = {}; + new_fig['data'] = []; + new_fig['layout'] = fig['layout']; var countries = []; - for (i = 0; i < rows.length; i++) { - countries.push(selectedrows[rows[i]]["country_region"]); + for (i = 0; i < selectedrows.length; i++) { + countries.push(rows[selectedrows[i]]["country_region"]); } - for (i = 0; i < new_fig['data'].length; i++) { - var name = new_fig['data'][i]['name']; + for (i = 0; i < fig['data'].length; i++) { + var name = fig['data'][i]['name']; if (countries.includes(name) || countries.includes(name.substring(1))){ - new_fig['data'][i]['visible'] = true; - } - else{ - new_fig['data'][i]['visible'] = false; + new_fig['data'].push(fig['data'][i]); } } - return [new_fig, store[1]]; + return new_fig; } }; diff --git a/make_figures.py b/make_figures.py index 8bb4dfb3..201b18cf 100644 --- a/make_figures.py +++ b/make_figures.py @@ -57,7 +57,7 @@ def make_map(df, df_fatalities): return fig -def make_timeplot(df_measure, df_prediction): +def make_timeplot(df_measure, df_prediction, countries=None): """ Build figure showing evolution of number of cases vs. time for all countries. The visibility of traces is set to 0 so that the interactive app will @@ -82,6 +82,8 @@ def make_timeplot(df_measure, df_prediction): hovertemplate_measure = '%{meta}
%{x}
%{y:.0f} per Million' hovertemplate_prediction = '%{meta}
prediction

%{x}
%{y:.0f} per Million' for i, country in enumerate(df_measure_confirmed.columns): + if countries and country[1] not in countries: + continue fig.add_trace(go.Scatter(x=df_measure_confirmed.index, y=df_measure_confirmed[country], name=country[1], mode='markers+lines', @@ -90,7 +92,7 @@ def make_timeplot(df_measure, df_prediction): line_color=colors[i%n_colors], meta=country[1], hovertemplate=hovertemplate_measure, - visible=False)) + visible=True)) prediction = df_prediction['prediction'] upper_bound = df_prediction['upper_bound'] lower_bound = df_prediction['lower_bound'] @@ -101,6 +103,8 @@ def make_timeplot(df_measure, df_prediction): lower_bound = normalize_by_population_wide(lower_bound) lower_bound *= 1e6 for i, country in enumerate(prediction.columns): + if countries and country[1] not in countries: + continue # Do not plot predictions for a country with less than 50 cases if df_measure_confirmed[country][-1] < 50: continue @@ -112,14 +116,14 @@ def make_timeplot(df_measure, df_prediction): showlegend=False, meta=country[1], hovertemplate=hovertemplate_prediction, - visible=False)) + visible=True)) fig.add_trace(go.Scatter(x=upper_bound.index, y=upper_bound[country], name='+' + country[1], mode='lines', line_dash='dot', line_color=colors[i%n_colors], showlegend=False, - visible=False, + visible=True, hoverinfo='skip', line_width=.8)) fig.add_trace(go.Scatter(x=lower_bound.index, @@ -128,7 +132,7 @@ def make_timeplot(df_measure, df_prediction): line_dash='dot', line_color=colors[i%n_colors], showlegend=False, - visible=False, + visible=True, hoverinfo='skip', line_width=.8))