In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../modules'))

In [2]:
import own.data.hdi as HDI
import own.data.countries as Con

In [3]:
from bokeh.plotting import figure, ColumnDataSource, curdoc
from bokeh.io import output_file, show, output_notebook, save
from bokeh.palettes import inferno, all_palettes, Turbo6
from bokeh.models import HoverTool, Text, LabelSet, CustomJS, NumeralTickFormatter,PanTool,ResetTool,HoverTool,WheelZoomTool,SaveTool,BoxZoomTool,UndoTool,RedoTool,SaveTool,ZoomInTool,ZoomOutTool
from bokeh.transform import linear_cmap, factor_cmap
from bokeh.layouts import column, row, gridplot
from own.colors import economy_colors as color_map
import own.funny as funny
from bokeh.models.widgets import CheckboxGroup, Select
output_notebook()

In [4]:
import pandas as pd

In [5]:
data = HDI.ecdcWithHDI()

In [6]:
# color_based_on = "income_group"
color_based_on = "economy"

# regions = list(data[color_based_on].unique()[:5])
regions = list(data[color_based_on].unique())
regions.sort()

In [7]:
data['current_disp_x'] = data['HDI_value']
data['current_disp_y'] = data['deaths_total_per_capita']

In [8]:
def TOOLS():
    return [
#       HoverTool(),
        BoxZoomTool(dimensions='both'),
        PanTool(dimensions='width'),
        ZoomInTool(),
        ZoomOutTool(),
        ResetTool(),
        SaveTool(),
        UndoTool(),
        RedoTool(),
        WheelZoomTool()
    ]

In [9]:
my_hover2 = HoverTool(tooltips=[("Country", "@Country_name"),
                               ("HDI value", "@HDI_value"),
                               ("Cases per capita", "@cases_total_per_capita"),
                               ("Deaths per capita", "@deaths_total_per_capita"),
                               ("Tests per capita", "@tests_per_capita"),
                               ("Test unit", "@tests_units"),
                               ('Up to date', '@date{%F}')],
                     formatters={'@date': 'datetime'})

p2 = figure(plot_height = 300, 
            plot_width = 400,
            x_axis_label=f'Human development index (HDI) from {int(data.Year.max())}',
            y_axis_label='Total deaths per capita',
            title = "COVID-19 data (cases, deaths, tests) ",
            x_range=(0.4, 1),
            tools = TOOLS())
p2.yaxis[0].formatter.use_scientific = False
sources = {}
plots = {}
for name in (regions):
    temp_source = ColumnDataSource(data[data[color_based_on] == name])
    temp_scatter = p2.circle('current_disp_x', 
                            'current_disp_y', 
                            size = 3, 
                            source = temp_source, 
                            color = color_map[name],
#                             fill_alpha=0.1,
#                             line_alpha = 0.25,
#                             fill_color = color_map[name], 
                            legend_label = name)
    temp_labels = p2.text(x='current_disp_x', 
             y='current_disp_y', 
             text='iso3166_a2', 
             level='glyph',
             x_offset=0, 
             y_offset=0, 
             source=temp_source,
             legend_label = name,
             text_font_size='9.5pt',
#              text_alpha = 0.65,
             text_color=color_map[name],
             text_font_style = 'bold')
    
    sources[name] = temp_source
    plots[name] = {'plot' : temp_scatter, 'text' : temp_labels}
#     plots[name] = {'text' : temp_labels}



p2.legend.click_policy="hide"
p2.legend.location = "top_left"

callback_cat = CustomJS(args={'jsSources': sources, 'jsYaxis': p2.yaxis[0]},
                    code = """
                            var changedTo = cb_obj.value;
                            for (var key in jsSources){
                                console.log( key, jsSources[key]);
                                if (changedTo == "Total cases per capita") {
                                    jsSources[key].data['current_disp_y'] = jsSources[key].data['cases_total_per_capita'];
                                    jsSources[key].change.emit();
                                    jsYaxis.axis_label = 'Total cases per capita';
                                }
                                if (changedTo == "Total deaths per capita") {
                                    jsSources[key].data['current_disp_y'] = jsSources[key].data['deaths_total_per_capita'];
                                    jsSources[key].change.emit();
                                    jsYaxis.axis_label = 'Total deaths per capita';
                                }
                                if (changedTo == "Total tests per capita") {
                                    jsSources[key].data['current_disp_y'] = jsSources[key].data['tests_per_capita'];
                                    jsSources[key].change.emit();
                                    jsYaxis.axis_label = 'Total tests per capita';
                                }
                            }
                            """
                   )


callback_cat2 = CustomJS(args={'jsSources': sources, 'jsXaxis': p2.xaxis[0], 'jsXrange':p2.x_range},
                    code = """
                            var changedTo = cb_obj.value;
                            for (var key in jsSources){
                                console.log( key, jsSources[key]);
                                if (changedTo == "Human development index (HDI)") {
                                    jsSources[key].data['current_disp_x'] = jsSources[key].data['HDI_value'];
                                    jsSources[key].change.emit();
                                    jsXaxis.axis_label = 'Human development index (HDI) from 2017';
                                    jsXrange.start = 0.04;
                                    jsXrange.end = 1;
                                    jsXrange.change.emit();
                                    jsXaxis.change.emit();
                                }
                                if (changedTo == "Total tests per capita") {
                                    jsSources[key].data['current_disp_x'] = jsSources[key].data['tests_per_capita'];
                                    jsSources[key].change.emit();
                                    jsXaxis.axis_label = 'Total tests per capita';
                                    jsXrange.start = -0.025;
                                    jsXrange.end = 0.13;
                                    jsXrange.change.emit();
                                    jsXaxis.change.emit();
                                }
                            }
                            
                            """
                   )
x_categories = ['Human development index (HDI)', 'Total tests per capita']
y_categories = ['Total cases per capita', 'Total deaths per capita', 'Total tests per capita']

select1 = Select(title="Select y axis category:", value = y_categories[1], options=y_categories)
select1.js_on_change('value', callback_cat)

select2 = Select(title="Select x axis category:", value = x_categories[0], options=x_categories)
select2.js_on_change('value', callback_cat2)

p2.add_tools(my_hover2)

# for region in regions:
#     if region == '1. Developed region: G7':
#         pass
#     else:
#         plots[region]['text'].visible = False
#         plots[region]['plot'].visible = False

grid = gridplot(
    [[p2], [column(select2, select1)]],
    sizing_mode='scale_both',
    merge_tools = True,
    toolbar_options = dict(logo=None),
    toolbar_location = 'above',
)

if funny.is_save_for_later():
    funny.save_for_later('HDI', grid)

if funny.is_save_to_public():
    output_file('../docs/plots/hdi-scatter.html', mode='inline')#, mode='relative-dev', root_dir='../lib/bokeh/2.0.1')
    save(grid)

if funny.is_display():
    output_notebook()
    show(grid)