diff --git a/app.py b/app.py
index ef1b8bc4..86008af0 100644
--- a/app.py
+++ b/app.py
@@ -46,7 +46,8 @@
# ----------- Figures ---------------------
fig1 = make_map(df_tidy, df_tidy_fatalities)
-fig2 = make_timeplot(df, df_prediction)
+fig2 = make_timeplot(df, df_prediction, countries=['France', 'Italy', 'Spain'])
+fig_store = make_timeplot(df, df_prediction)
# ------------ Markdown text ---------------
# maybe later we can break the text in several parts
@@ -97,7 +98,7 @@
],
className="pure-u-1 pure-u-lg-1-2 pure-u-xl-8-24",
),
- dcc.Store(id='store', data=[fig2, initial_indices]),
+ dcc.Store(id='store', data=[fig_store, initial_indices]),
html.Div([
dash_table.DataTable(
id='table',
@@ -165,21 +166,12 @@
# in order to transform the app into static html pages
# javascript functions are defined in assets/callbacks.js
-app.clientside_callback(
- ClientsideFunction(
- namespace='clientside2',
- function_name='get_store_data'
- ),
- output=Output('plot', 'figure'),
- inputs=[Input('store', 'data')])
-
-
app.clientside_callback(
ClientsideFunction(
namespace='clientside',
function_name='update_store_data'
),
- output=Output('store', 'data'),
+ output=Output('plot', 'figure'),
inputs=[
Input('table', "data"),
Input('table', "selected_rows")],
@@ -203,6 +195,5 @@
)
-
if __name__ == '__main__':
app.run_server(debug=debug)
diff --git a/assets/callbacks.js b/assets/callbacks.js
index 98576bf9..602c14d8 100644
--- a/assets/callbacks.js
+++ b/assets/callbacks.js
@@ -3,11 +3,6 @@ if (!window.dash_clientside) {
}
-window.dash_clientside.clientside2 = {
- get_store_data: function(data) {
- return data[0];
- }
-};
window.dash_clientside.clientside3 = {
update_table: function(clickdata, selecteddata, table_data, selectedrows, store) {
@@ -51,27 +46,26 @@ window.dash_clientside.clientside3 = {
window.dash_clientside.clientside = {
- update_store_data: function(selectedrows, rows, store) {
+ update_store_data: function(rows, selectedrows, store) {
var fig = store[0];
if (!rows) {
throw "Figure data not loaded, aborting update."
}
- var new_fig = {...fig};
+ var new_fig = {};
+ new_fig['data'] = [];
+ new_fig['layout'] = fig['layout'];
var countries = [];
- for (i = 0; i < rows.length; i++) {
- countries.push(selectedrows[rows[i]]["country_region"]);
+ for (i = 0; i < selectedrows.length; i++) {
+ countries.push(rows[selectedrows[i]]["country_region"]);
}
- for (i = 0; i < new_fig['data'].length; i++) {
- var name = new_fig['data'][i]['name'];
+ for (i = 0; i < fig['data'].length; i++) {
+ var name = fig['data'][i]['name'];
if (countries.includes(name) || countries.includes(name.substring(1))){
- new_fig['data'][i]['visible'] = true;
- }
- else{
- new_fig['data'][i]['visible'] = false;
+ new_fig['data'].push(fig['data'][i]);
}
}
- return [new_fig, store[1]];
+ return new_fig;
}
};
diff --git a/make_figures.py b/make_figures.py
index 8bb4dfb3..201b18cf 100644
--- a/make_figures.py
+++ b/make_figures.py
@@ -57,7 +57,7 @@ def make_map(df, df_fatalities):
return fig
-def make_timeplot(df_measure, df_prediction):
+def make_timeplot(df_measure, df_prediction, countries=None):
"""
Build figure showing evolution of number of cases vs. time for all countries.
The visibility of traces is set to 0 so that the interactive app will
@@ -82,6 +82,8 @@ def make_timeplot(df_measure, df_prediction):
hovertemplate_measure = '%{meta}
%{x}
%{y:.0f} per Million'
hovertemplate_prediction = '%{meta}
prediction
%{x}
%{y:.0f} per Million'
for i, country in enumerate(df_measure_confirmed.columns):
+ if countries and country[1] not in countries:
+ continue
fig.add_trace(go.Scatter(x=df_measure_confirmed.index,
y=df_measure_confirmed[country],
name=country[1], mode='markers+lines',
@@ -90,7 +92,7 @@ def make_timeplot(df_measure, df_prediction):
line_color=colors[i%n_colors],
meta=country[1],
hovertemplate=hovertemplate_measure,
- visible=False))
+ visible=True))
prediction = df_prediction['prediction']
upper_bound = df_prediction['upper_bound']
lower_bound = df_prediction['lower_bound']
@@ -101,6 +103,8 @@ def make_timeplot(df_measure, df_prediction):
lower_bound = normalize_by_population_wide(lower_bound)
lower_bound *= 1e6
for i, country in enumerate(prediction.columns):
+ if countries and country[1] not in countries:
+ continue
# Do not plot predictions for a country with less than 50 cases
if df_measure_confirmed[country][-1] < 50:
continue
@@ -112,14 +116,14 @@ def make_timeplot(df_measure, df_prediction):
showlegend=False,
meta=country[1],
hovertemplate=hovertemplate_prediction,
- visible=False))
+ visible=True))
fig.add_trace(go.Scatter(x=upper_bound.index,
y=upper_bound[country],
name='+' + country[1], mode='lines',
line_dash='dot',
line_color=colors[i%n_colors],
showlegend=False,
- visible=False,
+ visible=True,
hoverinfo='skip',
line_width=.8))
fig.add_trace(go.Scatter(x=lower_bound.index,
@@ -128,7 +132,7 @@ def make_timeplot(df_measure, df_prediction):
line_dash='dot',
line_color=colors[i%n_colors],
showlegend=False,
- visible=False,
+ visible=True,
hoverinfo='skip',
line_width=.8))