In [109]:
# Copyright 2020 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import time
import math
from ipywidgets import *
from bqplot import *
import datetime
from data_processor import DataProcessor
import bisect
from ipywidgets import Image as im
from toggle_buttons import Toggle_Buttons
from pretty_breaks import PrettyBreaks
import itertools
from collections import Counter
import base64
from bqplot.market_map import MarketMap
import ipyleaflet as ipl
import os
import json
import branca.colormap as cm
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

# import IPython.core.display as ipd
# ipd.display(ipd.HTML("<style>.container { width:100% !important; }</style>"))


In [110]:
class SelectMultipleWidget:
    def __init__(self, options, value, width, height):
        self.options = options
        self.value = value
        self.search_bar = Text(
            "",
            placeholder="Search...",
            layout=Layout(
                width="auto", height="32px", max_width="160px", margin="0px 0px 0px 0px"
            ),
        )
        self.reset_text = Button(
            icon="times",
            button_style="",
            layout=Layout(width="32px", height="32px", margin="0px 0px 0px 0px"),
        )
        self.reset_text.on_click(self.clear_search_bar)
        self.options_dict = {
            opt: Checkbox(
                description=opt,
                value=False,
                indent=False,
                layout=Layout(
                    width="auto", margin="0px 0px 0px 0px", padding="0px", height="auto"
                ),
            )
            for opt in self.options
        }
        self.options_list = [self.options_dict[opt] for opt in self.options]
        self.options_widget = VBox(self.options_list, layout=Layout(overflow="auto"))
        self.search_bar.observe(self.on_text_change, names="value")
        self.select_all = Button(
            description="Select All", layout=Layout(width="50%", height="auto")
        )
        self.select_all.on_click(self.select_all_click)
        self.clear_all = Button(
            description="Clear", layout=Layout(width="50%", height="auto")
        )
        self.clear_all.on_click(self.clear_all_click)
        self.res = GridspecLayout(
            3,
            1,
            layout=Layout(
                width=width,
                height=height,
                overflow="auto",
                border="",
                min_height="150px",
            ),
        )
        top = GridspecLayout(1, 2)
        top[0, 0] = self.search_bar
        top[0, 1] = self.reset_text
        top.layout.grid_template_columns = "auto 35px"
        top.layout.width = width
        top.layout.overflow = "hidden"
        self.res[0, 0] = top
        self.res[1, 0] = HBox([self.select_all, self.clear_all])
        self.res[2, 0] = self.options_widget
        self.res.layout.grid_template_rows = "35px 35px auto"
        self.actions = {}
        self.all_observe(self.on_check)
        self.set_value(self.value)

    def set_option(self, new_options):
        self.options = new_options
        self.value = []
        self.options_dict = {
            opt: Checkbox(
                description=opt,
                value=False,
                indent=False,
                layout=Layout(
                    width="auto", margin="0px 0px 0px 0px", padding="0px", height="auto"
                ),
            )
            for opt in self.options
        }
        self.options_list = [self.options_dict[opt] for opt in self.options]
        self.options_widget.children = self.options_list
        keys = list(self.actions.keys())
        values = list(self.actions.values())
        for k, v in zip(keys, values):
            self.all_unobserve(v)
        for k, v in zip(keys, values):
            self.all_observe(v)
        return

    def on_text_change(self, new_text):
        search_input = new_text["new"]
        if search_input == "":
            new_options = [self.options_dict[opt] for opt in self.options]
            self.reset_text.button_style = ""
        else:
            self.reset_text.button_style = "danger"
            close_matches = [
                x
                for x in list(self.options_dict.keys())
                if str.lower(search_input.strip("")) in str.lower(x)
            ]
            new_options = sorted(
                [x for x in self.options_list if x.description in close_matches],
                key=lambda x: x.value,
                reverse=True,
            )
        self.options_widget.children = new_options
        return

    def select_all_click(self, *args):
        list_del_actions = []
        keys = list(self.actions.keys())
        values = list(self.actions.values())
        for k, v in zip(keys, values):
            if k != self.on_check.__name__:
                self.all_unobserve(v)
                list_del_actions.append(v)
        for x in self.options_list[:-1]:
            x.value = True
        if self.options_list[-1].value:
            self.options_list[-1].value = False
        for v in list_del_actions:
            self.all_observe(v)
        self.options_list[-1].value = True

    def clear_all_click(self, *args):
        list_del_actions = []
        keys = list(self.actions.keys())
        values = list(self.actions.values())
        for k, v in zip(keys, values):
            if k != self.on_check.__name__:
                self.all_unobserve(v)
                list_del_actions.append(v)
        for x in self.options_list[:-1]:
            x.value = False
        if not self.options_list[-1].value:
            self.options_list[-1].value = True
        for v in list_del_actions:
            self.all_observe(v)
        self.options_list[-1].value = False

    def clear_search_bar(self, *args):
        self.search_bar.value = ""
        return

    def set_value(self, value):
        if len(value) == 0:
            self.clear_all_click()
        else:
            list_del_actions = []
            keys = list(self.actions.keys())
            values = list(self.actions.values())
            for k, v in zip(keys, values):
                if k != self.on_check.__name__:
                    self.all_unobserve(v)
                    list_del_actions.append(v)
            for x in self.options:
                if x in value:
                    self.options_dict[x].value = True
                else:
                    self.options_dict[x].value = False
            self.options_dict[value[-1]].value = False
            for v in list_del_actions:
                self.all_observe(v)
            self.options_dict[value[-1]].value = True
        self.value = value

    def get_value(self):
        return [opt.description for opt in self.options_list if opt.value]

    def on_check(self, *args):
        self.value = self.get_value()

    def all_observe(self, action):
        for chk in self.options_list:
            chk.observe(action, "value")
        self.actions[action.__name__] = action

    def all_unobserve(self, action):
        for chk in self.options_list:
            chk.unobserve(action, "value")
        self.actions.pop(action.__name__, None)


In [111]:
pretty_breaks = PrettyBreaks("../data/WORLD/area_lat.csv")

In [112]:
class Help_Page:
    def __init__(self, images_url, images_description):
        help_page = GridspecLayout(
            4, 3, layout=Layout(width="100%", height="100%", justify_items="center")
        )
        self.next_button = Button(
            description="Next",
            button_style="success",
            layout=Layout(height="50px", margin="0px 0px 0px 10px"),
        )
        self.next_button.add_class("hb")
        self.previous_button = Button(
            description="Previous",
            button_style="success",
            layout=Layout(
                visibility="hidden", height="50px", margin="0px 10px 0px 0px"
            ),
        )
        self.previous_button.add_class("hb")
        self.skip_button = Button(
            description="Go to application",
            button_style="warning",
            layout=Layout(height="50px", margin="0px 0px 0px 0px", width="auto"),
        )
        self.skip_button.add_class("hb")
        images = [open(im, "rb").read() for im in images_url]
        self.images = images
        self.images_description = images_description
        self.index_img = 0
        self.img_container = Box(
            children=[
                im(
                    value=self.images[self.index_img],
                    format="png",
                    layout=Layout(object_fit="scale-down"),
                    margin="0 0 0 0",
                )
            ],
            layout=Layout(
                width="auto",
                height="auto",
                min_height="100px",
                overflow_x="hidden",
                margin="10px 0px 10px 0px",
            ),
        )

        help_page[3, 1] = self.img_container
        self.header = HTML(
            self.images_description[self.index_img],
            layout=Layout(margin="5px 15px 5px 15px",),
        )

        self.title = HTML(
            "<h1 style='text-align: center; color: #4db254'> User Guide </h1>"
        )

        help_page[0, 1] = VBox([self.skip_button], layout=Layout(align_items="center"))
        help_page[1, 1] = VBox([self.title], layout=Layout(align_items="center"))
        description_box = Box(
            [self.header],
            layout=Layout(
                width="75%",
                height="auto",
                border="solid #4db254",
                overflow="auto",
                justify_content="center",
                max_height="145px",
            ),
        )
        help_page[2, :] = HBox(
            [self.previous_button, description_box, self.next_button],
            layout=Layout(
                width="75%",
                height="auto",
                overflow="auto",
                justify_content="center",
                align_items="center",
                min_height="70px",
            ),
        )
        help_page.layout.grid_template_columns = "10% 80% 10%"
        help_page.layout.grid_template_rows = "50px 70px auto auto"
        self.next_button.on_click(self.next_click)
        self.previous_button.on_click(self.prev_click)
        self.help_page = help_page

    def next_click(self, *args):
        if self.index_img == len(self.images) - 2:
            self.next_button.layout.visibility = "hidden"
        if self.index_img == 0:
            self.previous_button.layout.visibility = "visible"

        self.index_img += 1
        self.header.value = self.images_description[self.index_img]
        self.img_container.children = [
            im(
                value=self.images[self.index_img],
                format="png",
                layout=Layout(object_fit="scale-down"),
                margin="0 0 0 0",
            )
        ]
        return

    def prev_click(self, *args):
        if self.index_img == len(self.images) - 1:
            self.next_button.layout.visibility = "visible"
        if self.index_img == 1:
            self.previous_button.layout.visibility = "hidden"

        self.index_img -= 1
        self.img_container.children = [
            im(
                value=self.images[self.index_img],
                format="png",
                layout=Layout(object_fit="scale-down"),
                margin="0 0 0 0",
            )
        ]
        self.header.value = self.images_description[self.index_img]
        return

    def get_page(self):
        return self.help_page


In [113]:
def help_button(tooltip):
    val = (
        """<p style='padding:15px; margin:0px; line-height:1.8em; color:#ffffff; background-color:#4caf50;
        border-radius:8px; box-sizing:border-box; box-shadow:0 1px 8px rgba(0,0,0,0.5);'>"""
        + tooltip
        + """</p>"""
    )
    tooltip_widget = HTML(
        value=val,
        layout=Layout(width="auto", height="auto", overflow="hidden", margin="0 0 0 0"),
    )
    res = MarketMap(
        names=["i"],
        tooltip_widget=tooltip_widget,
        group_stroke="black",
        stroke="#4caf50",
        hovered_stroke="#4caf50",
        selected_stroke="#4caf50",
        map_margin={"top": 0, "bottom": 0, "left": 0, "right": 0},
        layout=Layout(width="36px", height="36px", margin="0px 0px 0px 0px"),
        colors=["#4caf50"],
        font_style={"font-size": "28px", "font-weight": "bold", "fill": "white"},
        min_aspect_ratio=1,
        max_aspect_ratio=1,
    )
    return res


In [114]:
FOLDER_IMG = "../screenshots/user_guide/"
image_names = [
    "Tab1_tabs.PNG",
    "Tab1_download.PNG",
    "Tab1_fullscreen.PNG",
    "Tab1_typepicker.PNG",
    "Tab1_date_sel.PNG",
    "Tab1_maps.PNG",
    "Tab1_map_control.PNG",
    "Tab1_legend.PNG",
    "Tab1_table.PNG",
    "Tab2_controls.PNG",
    "Tab2_2.PNG",
    "Tab2_3.PNG",
    "Tab2_4.PNG",
    "Tab2_5.PNG",
    "Tab3_1.PNG",
    "Tab3_2.PNG",
    "Tab4_1.PNG",
    "Tab4_2.PNG",
    "Tab4_3.PNG",
]
images = [FOLDER_IMG + im for im in image_names]
images_description = [
    "Use the top left tabs to use different applications of the dashboard \
    (Infection Maps/Time Track/Heat Map/Custom Graph/User Guide)",
    "For a larger figure click on 'Hide Controls'.<br/>You can download the\
    figure you created using the app and the corresponding data by using the\
    top right-hand side buttons: 'Download Figure' and 'Download Data'.",
    "Go full screen with one click on the full screen button",
    "Each tab has a type picker which is a block of different sets of buttons\
    (here Data, Type, Norm and Scale + Moving average). This type picker allows\
    you to choose a unique combination of data and display it in the corresponding\
    figure in the tab.",
    "Select a date for the map by using the 'Date' slider. Click Play to animate.",
    "Use maps buttons to display one of the 3 available maps (World/US States/US Counties)",
    "Use +/- buttons to zoom in and out. Go full screen with one click on the full screen\
    button. Use the search bar tool to find a location <br/>By hovering over a\
    country/state/county in the map a data summary table will be displayed.\
    By clicking on a country/state/county you can change the right-hand side graph",
    "On the right-hand side graph you can select a sub-period by using the\
    bottom range slider and select a subset of curves by clicking on the figure legend",
    "Navigate between table and graph view using the tabs on the right-hand side",
    "In tab 2 (Time Track) you have 3 blocks of controls:<br/> &emsp;- Type picker\
    like in tab 1 (green frame)<br/> &emsp;- Threshold buttons to choose which\
    condition to apply to get the initial point (here 1000 cases) (blue frame)<br/>\
    &emsp;- Country picker in a list or map format (red frame)",
    'You can remove the threshold rule by clicking on the "Calendar Time" Checkbox.\
    Select a sub-period of time by using the bottom left range slider',
    "Use predefined selections to pick countries from preset lists",
    "Use predefined scenarios for examples",
    'By clicking on the country picker buttons you can select countries/states by\
    clicking on a world or US states map. Click on "Collapse" to go back to the previous view',
    "Tab 3 (Heat Map) has a type picker and a country picker. It works the same way\
    as the previous tabs. Each row is assigned to a country/state. Time runs horizontally.",
    '"Scale by max" button divides each row by its maximum. It allows to detect the\
    maximum value of each entity. Without normalizing the values of all countries/states\
    are compared, ranked and coloured',
    "Tab 4 (Custom Graph) allows you to select a data type combination for both x and y axis.",
    'By clicking on "Last data point" checkbox you can display only the most recent\
    value of each selected entity',
    "You can pan by clicking and dragging. You can zoom by using your mouse wheel",
]
style = '<p style="text-align: justify; color: white; font-size: 15px;\
font-weight: bold; padding: 0px; margin: 2px;">'
descriptions = [style + d + "</p>" for d in images_description]


hp = Help_Page(images, descriptions)
help_page = hp.get_page()

outer_tab = Tab(children=[help_page], _titles={0: "USER GUIDE"}, layout=Layout())


In [115]:
# Collect World Data

# BASE_URL = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
FOLDER_WORLD = "../data/WORLD/"
# URL_TESTS = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/testing/covid-testing-all-observations.csv"
dateformat = "%Y-%m-%d"
# URL_vaccine = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/vaccinations/vaccinations.csv"


def collect_World_data(FOLDER_WORLD):
    data = ['Cases', 'Deaths', 'Tests', 'Vaccine']
    dataframes = {}
    for c in data:
        dataframes[c] = pd.read_csv(FOLDER_WORLD + c + '.csv', index_col=0)
    pops = pd.read_csv(FOLDER_WORLD + 'pops.csv', index_col=0)
    
    STDT = dataframes["Cases"].index[0]
    ENDT = dataframes["Cases"].index[-1]
    datasets_World = DataProcessor(
        dataframes, pops, "World", add_world_data=True
    )  # World data processor
    
    country_codes = pd.read_csv(
        FOLDER_WORLD + "world_map_codes.csv", header=0, index_col=0
    )
    countries_to_codes = country_codes["ISOA3"].to_dict()
    countries_to_codes["World"] = "WLD"
    countries_to_codes["Diamond Princess"] = "DPS"
    countries_to_codes["Taiwan"] = "TWN"
    countries_to_codes["West Bank and Gaza"] = "PSE"
    countries_to_codes["Kosovo"] = "XKX"
    countries_to_codes["MS Zaandam"] = "MSZ"
    codes_to_countries = {v: k for k, v in countries_to_codes.items()}
    
    return datasets_World, STDT, ENDT, countries_to_codes, codes_to_countries


In [116]:
states_to_codes = {
    "Alabama": "AL",
    "Alaska": "AK",
    "American Samoa": "AS",
    "Arizona": "AZ",
    "Arkansas": "AR",
    "California": "CA",
    "Colorado": "CO",
    "Connecticut": "CT",
    "Delaware": "DE",
    "District of Columbia": "DC",
    "Florida": "FL",
    "Georgia": "GA",
    "Guam": "GU",
    "Hawaii": "HI",
    "Idaho": "ID",
    "Illinois": "IL",
    "Indiana": "IN",
    "Iowa": "IA",
    "Kansas": "KS",
    "Kentucky": "KY",
    "Louisiana": "LA",
    "Maine": "ME",
    "Maryland": "MD",
    "Massachusetts": "MA",
    "Michigan": "MI",
    "Minnesota": "MN",
    "Mississippi": "MS",
    "Missouri": "MO",
    "Montana": "MT",
    "Nebraska": "NE",
    "Nevada": "NV",
    "New Hampshire": "NH",
    "New Jersey": "NJ",
    "New Mexico": "NM",
    "New York": "NY",
    "North Carolina": "NC",
    "North Dakota": "ND",
    "Northern Mariana Islands": "MP",
    "Ohio": "OH",
    "Oklahoma": "OK",
    "Oregon": "OR",
    "Pennsylvania": "PA",
    "Puerto Rico": "PR",
    "Rhode Island": "RI",
    "South Carolina": "SC",
    "South Dakota": "SD",
    "Tennessee": "TN",
    "Texas": "TX",
    "Utah": "UT",
    "Vermont": "VT",
    "Virgin Islands": "VI",
    "Virginia": "VA",
    "Washington": "WA",
    "West Virginia": "WV",
    "Wisconsin": "WI",
    "Wyoming": "WY",
}

codes_to_states = {v: k for k, v in states_to_codes.items()}


In [117]:
FOLDER_US = "../data/USA/"
# BASE_URL_US = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
# URL_TESTS_US = "https://raw.githubusercontent.com/COVID19Tracking/covid-tracking-data/master/data/states_daily_4pm_et.csv"
# URL_TESTS_US = "https://api.covidtracking.com/v1/states/daily.csv"
# URL_US_States_vaccine = 'https://raw.githubusercontent.com/govex/COVID-19/master/\
# data_tables/vaccine_data/raw_data/vaccine_data_us_state_timeline.csv'

def collect_US_data(FOLDER_US):
    data = ['Cases', 'Deaths', 'Tests', 'Vaccine']
    dict_df_US_States = {}
    dict_df_US_counties = {}
    for c in data:
        dict_df_US_States[c] = pd.read_csv(FOLDER_US + c + '_States.csv', index_col=0)
        dict_df_US_counties[c] = pd.read_csv(FOLDER_US + c + '_Counties.csv', index_col=0)
    pops_US_States = pd.read_csv(FOLDER_US + 'pops_US_States.csv', index_col=0)
    pops_US_counties = pd.read_csv(FOLDER_US + 'pops_US_counties.csv', index_col=0)
    counties_to_ID = pd.read_csv(FOLDER_US + 'counties_id.csv', index_col=0).astype(int).to_dict()['FIPS']
    datasets_US_States = DataProcessor(dict_df_US_States, pops_US_States, "US States")
    datasets_US_counties = DataProcessor(
        dict_df_US_counties, pops_US_counties, "US Counties", counties_to_ID
    )

    return datasets_US_States, datasets_US_counties


In [118]:
buttons_tooltip = {
    "Data": [
        "'Choose one of the following data category : Cases, Deaths, Vaccine or Tests'",
        "'Cases'",
        "'Deaths'",
        "'Vaccine doses administered'",
        "'Tests'",
    ],
    "Norm": [
        "'Use raw numbers or normalize data by population'",
        "'Raw numbers'",
        "'Divide values by population and multiply by 1M'",
    ],
    "Type": [
        "'Choose cumulative number(Total) or today s value minus yesterday\'s (daily change) or percentage change between today\'s and yesterday\'s value (daily % change)'",
        "'Cumulative number'",
        "'today’s value minus yesterday’s'",
        "'percentage change between today’s and yesterday’s value'",
    ],
    "Scale": [
        "'Use linear or logarithmic scale on the graph'",
        "'Linear'",
        "'Logarithmic: eases comparison for large numbers'",
    ],
    "Moving average": [
        "'Click to smooth data by applying a moving average. You can select the window size after selecting this option.'"
    ],
    "Maps": ["'Select a map'", "'World map'", "'US States map'", "'US Counties map'"],
    "Scale by max": [
        "'Divide each row by its maximum. It allows to detect the maximum value of each entity. Without normalizing the values of all countries/states are compared, ranked and coloured'"
    ],
    "Last data point": ["'Display only the most recent value of each selected entity'"],
    "Calendar time": [
        "'Display time series data withtout any rebasing rule. You can select a period of time using the bottom left range slider.'"
    ],
    "Threshold": [
        "'Threshold value applied to the rebasing data. (Example: plot deaths starting from the day the selected country/state exceeded 1000 cases)'"
    ],
    "Threshold data": [
        "'Data to use for the rebasing rule'",
        "'Cases'",
        "'Deaths'",
        "'Vaccine doses administered'",
        "'Tests'",
    ],
    "Threshold norm": [
        "'Normalization to apply for the rebasing rule'",
        "'Raw numbers'",
        "'Divide values by population and multiply by 1M'",
    ],
}


In [119]:
help_tooltip = {
    "Tab 1": """- Select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale + Moving average).<br>
It will color the selected map and change the graph/tables according to your choice.<br>
- You can hover on the map to display a statistic table.<br>
- Click on a country to show its data on the right-hand side graph.<br>
- Navigate between table and graph view using the tabs on the right-hand side.<br>
- Zoom/pan on the map using the +/- buttons or click and drag.<br>
- Select a subset of curves in the right-hand side graph by clicking on the figure legend.<br>
- Select a date for the map by using the slider on the right-hand side.<br>
- If you need more information hover on buttons to show a description or go to user guide tab.""",
    "Tab 2": """- Select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale + Moving Average).<br>
- Select country/states of interest by using the list picker or map picker (single click on a country/state) or predefined selections of countries.<br>
- Display time series data by checking the 'Calendar Time' checkbox. You can select a period of time using the bottom left range slider.<br>
- Choose a rebasing rule by unchecking the 'Calendar Time' checkbox. Select the threshold value applied to the rebasing data by using the 'Threshold slider'.<br>
  Select which data and normalization to use for the rebasing rule.<br>
        Example: plot deaths/1M Population starting from the day the selected country/state exceeded 1000 cases.<br>
        Data = Deaths, Norm = Per million, Type = Total, Threshold = 1000, Threshold Data = Cases, Threshold Norm = Values.<br>
- Use one of the 4 predefined scenarios to display an example.<br>
- If you need more information hover on buttons to show a description or go to user guide tab.""",
    "Tab 3": """- Select a unique combination of data by using the type picker (3 sets of buttons on the top left: Data, Norm, Type + Moving Average).<br>
- Select country/states of interest by using the list picker or map picker (single click on a country/state) or predefined selections of countries.<br>
- Divide each row by its maximum. It allows to detect the maximum value of each entity. Without normalizing the values of all countries/states are compared, ranked and coloured.<br>
- Hover on the figure's cells to get the corresponding value.<br>
- If you need more information hover on buttons to show a description or go to the user guide tab.""",
    "Tab 4": """- For each axis, select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale).<br>
- Show only the most recent value of each selected entity by checking the 'Last Data Point' checkbox.<br>
- Moving average is applied to both axes.<br>
- Select country/states of interest by using the list picker or map picker (single click on a country/state) or predefined selections of countries.<br>
- Zoom/pan on the figure using your mouse's wheel or click and drag.<br>
- Use one of the 4 predefined scenarios to display an example.<br>
- If you need more information hover on buttons to show a description or go to user guide tab.""",
}


In [120]:
# Buttons and interactions
global_style = ""

min_button_width_data = "100px"
min_button_width_norm = "85px"
min_button_width_type = "100px"
min_button_width_scale = "60px"
min_description_width_1 = "40px"

# Toggle buttons to choose data to color the map
data_buttons = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
    key="tab1",
    tooltip_position="bottom",
)
global_style += data_buttons.all_styles
# Toggle buttons to choose normalization of data
norm_buttons = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
    key="tab1",
    tooltip_position="bottom",
)
global_style += norm_buttons.all_styles
# Toggle buttons to choose scale of plots and map
scale_buttons = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Linear",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
    key="tab1",
    tooltip_position="top",
)
global_style += scale_buttons.all_styles
# Toggle buttons to choose between cumulative data or daily change
type_buttons = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
    key="tab1",
    tooltip_position="top",
)
global_style += type_buttons.all_styles
tab_1_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
    indent=False,
)
tab_1_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px",
        width="200px",
        overflow="auto",
        visibility="hidden",
        margin="0 0 0 0",
    ),
)
ma_box_tab_1 = HBox(
    [tab_1_ma_ch, tab_1_ma_w],
    layout=Layout(
        min_width="565px",
        width="565px",
        max_width="565px",
        justify_content="space-between",
    ),
)

cat_tab_1_buttons = VBox(
    [
        VBox(
            [data_buttons, norm_buttons, type_buttons, scale_buttons,],
            layout=Layout(
                width="565px",
                height="138px",
                min_width="565px",
                max_width="565px",
                min_height="138px",
                max_height="138px",
                overflow="visible",
            ),
        ),
        ma_box_tab_1,
    ],
    layout=Layout(
        width="565px",
        height="170px",
        min_width="565px",
        max_width="565px",
        min_height="170px",
        max_height="170px",
        overflow="visible",
    ),
)


In [121]:
map_buttons = Toggle_Buttons(
    options=["World", "US States", "US Counties"],
    value="World",
    description="",
    min_button_width=min_button_width_data,
    min_description_width="",
    style="warning",
    key="tab1",
)

global_style += map_buttons.all_styles


In [122]:
# Main graph

css_style = """
<style>
    #stats-table {
        border-collapse: collapse;
        width: 200px;
    }
    #stats-table tr{
        margin: 0px;
        padding-left: 3px;
        width: 50px;
        line-height: 1.5em;
    }
    #stats-table td{
        margin: 0px;
        padding-left: 2px;
    }
</style>
"""

line_tooltip_table = """
<table id="stats-table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr style="color:#1E90FF">
    <td>{1:s}</td>
    <td>{2:}</td>
</tr>
<tr style="color:#D62728">
    <td>{3:s}</td>
    <td>{4:}</td>
</tr>
<tr style="color:#2CA02C">
    <td>{5:s}</td>
    <td>{6:}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>{7:s}</td>
    <td>{8:}</td>
</tr>
</table>
"""

stats_line_tooltip_table_values = HTML()  # table showed when you hover on Lines markers


In [123]:
scale_t = DateScale(dateformat=dateformat,)
axis_t = Axis(scale=scale_t, grid_lines="none", tick_format="%m/%d", num_ticks=10)
scale_y = LinearScale()
axis_y = Axis(
    scale=scale_y, orientation="vertical", grid_lines="solid", tick_format=","
)
main_graph_ttl = "{} in {}"

main_graph = Figure(
    animation_duration=1000,
    legend_location="top-left",
    axes=[axis_t, axis_y],
    title=main_graph_ttl,
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 20},
    layout=Layout(width="auto"),
    legend_style={"stroke-width": 0},
)


In [124]:
curve_colors = {
    "Cases": "dodgerblue",
    "Deaths": "red",
    "Vaccine": "green",
    "Tests": "#FFBB0E",
}
selected = ["Cases", "Deaths", "Vaccine", "Tests"]


def update_main_graph(
    datasets,
    country_name,
    normalization,
    scale,
    data_type,
    ma,
    n,
    date1,
    date2,
    axis_t,
    axis_y,
    figure,
    init=False,
):
    if init:
        curves_subset = [i for i in range(len(selected) - 1)]
        opacities = [1] * (len(selected) - 1) + [0]
    else:
        curves_subset = [
            i for i in range(len(selected)) if main_graph.marks[0].opacities[i] == 1
        ]
        opacities = [i for i in main_graph.marks[0].opacities]
    line_col = [curve_colors[c] for c in selected]
    data = [
        datasets.get_ts_plot(
            country_name, data_name, normalization, scale, data_type, ma, n, STDT, ENDT,
        )
        for data_name in selected
    ]
    scale_t = DateScale(
        dateformat=dateformat,
        min=datetime.datetime.strptime(date1, dateformat),
        max=datetime.datetime.strptime(date2, dateformat),
    )
    axis_t.scale = scale_t

    if scale == "Log":
        scale_y = LogScale()
    elif scale == "Linear":
        scale_y = LinearScale()

    main_mark = Lines(
        scales={"x": scale_t, "y": scale_y},
        display_legend=True,
        labels=selected,
        colors=line_col,
        marker="circle",
        marker_size=30,
        tooltip=stats_line_tooltip_table_values,
        interactions={"hover": "tooltip"},
        opacities=opacities,
    )
    main_mark.y = data
    axis_y.scale = scale_y
    figure.axes = [axis_t, axis_y]
    main_mark.x = list(
        np.arange(np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D"),)
    )
    ind1 = np.where(main_mark.x == np.datetime64(date1))[0][0]
    ind2 = np.where(main_mark.x == np.datetime64(date2))[0][0]
    main_mark.scales["y"].max = 1.1 * float(
        np.nanmax(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.scales["y"].min = 0.9 * float(
        np.nanmin(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.on_hover(line_hover)
    main_mark.on_legend_click(legend_click_line)
    figure.marks = [main_mark]

    def graph_title(norm, scale, data_type, ma, n):
        res = "Data"
        if norm == "Per million":
            res += "/1M"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        if scale == "Log":
            res += "(log scale)"
        return res

    val_name = graph_title(normalization, scale, data_type, ma, n)
    if datasets.name == "World":
        figure.title = main_graph_ttl.format(val_name, codes_to_countries[country_name])

    elif datasets.name == "US States":
        figure.title = main_graph_ttl.format(val_name, country_name + " State")

    elif datasets.name == "US Counties":
        figure.title = main_graph_ttl.format(val_name, country_name + " County")

    update_link_download()
    return


In [125]:
def name_val_on_hover(data, norm, data_type, ma, n):
    res = data
    if norm == "Per million":
        res += "/1M"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res


def line_hover(*args):
    # show table with Cases, deaths corresponding to the
    # hovered marker

    try:
        index = args[1]["data"]["sub_index"]
        Date = str(main_graph.marks[0].x[index])
        # get value of cases deaths at Date
        values = [main_graph.marks[0].y[i][index] for i in range(len(selected))]
        name_vals = [
            name_val_on_hover(
                data,
                norm_buttons.value,
                type_buttons.value,
                tab_1_ma_ch.value,
                tab_1_ma_w.value,
            )
            for data in selected
        ]
        if norm_buttons.value == "Values" and type_buttons.value != "Daily % change":
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,.0f}".format(values[i])]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        else:
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,.2f}".format(values[i])]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        stats_line_tooltip_table_values.value = css_style + line_tooltip_table.format(
            *res
        )
    except:
        pass
    return


def legend_click_line(*args):
    N = len(selected)
    index = args[1]["data"]["index"]
    ind1, ind2 = date_axis_selector.index
    main_graph.marks[0].opacities = [
        i if ind != index else 1 - i
        for ind, i in enumerate(main_graph.marks[0].opacities)
    ]
    curves_subset = [i for i in range(N) if main_graph.marks[0].opacities[i] == 1]
    try:
        main_graph.marks[0].scales["y"].max = 1.1 * float(
            np.nanmax(np.array(main_graph.marks[0].y)[curves_subset, ind1 : ind2 + 1])
        )
        main_graph.marks[0].scales["y"].min = 0.9 * float(
            np.nanmin(np.array(main_graph.marks[0].y)[curves_subset, ind1 : ind2 + 1])
        )
    except:
        pass
    return


In [126]:
# Maps

def map_title(map_name, data, norm, data_type, date, ma, n):
    res = map_name + " COVID-19 " + data
    if norm == "Per million":
        res += " " + norm
    if data_type == "Daily change":
        res += " daily change"
    if data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res + " Map, " + date


In [127]:
table_tmpl_no_duplicate = """
<table id="stats-table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr>
    <td>{1:s}</td>
    <td>{2:,}</td>
</tr>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{3:,.0f}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{4:,.0f}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Vaccine</td>
    <td>{5:,.0f}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Tests</td>
    <td>{6:,.0f}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{7:,.0f}</td>
</tr>
</table>
"""

table_tmpl_duplicate = """
<table id="stats-table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{1:,.0f}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{2:,.0f}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Vaccine</td>
    <td>{3:,.0f}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Tests</td>
    <td>{4:,.0f}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{5:,.0f}</td>
</tr>
</tbody>
</table>
"""
stats_table = HTML()


In [128]:
def current_val_on_hover(data, norm, scale, data_type, ma, n):
    if data_type == "Total" and norm == "Values":
        return None
    else:
        res = data
        if norm == "Per million":
            res += "/1M"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res


In [129]:
red_color_scheme = [
    "#FFF0CC",
    "#ffcf76",
    "#ffa600",
    "#ff6e00",
    "#ff4417",
    "#d31522",
    "#8c000e",
    "#560410",
]
green_color_scheme = [
    "#e9f6e4",
    "#d1edc6",
    "#b9e3a8",
    "#a1da8a",
    "#89d06d",
    "#71c74f",
    "#5db53a",
    "#4e9730",
]
gyr_color_scheme = [
    "#069740",
    "#9acd32",
    "#e4d31b",
    "#ffa600",
    "#e95f11",
    "#d31522",
    "#8c000e",
    "#560410",
]

map_color_scheme = {
    "Cases": red_color_scheme,
    "Deaths": red_color_scheme,
    "Vaccine": green_color_scheme,
    "Tests": green_color_scheme,
}


def map_cuts(arr, cuts):
    return [
        cuts[bisect.bisect_right(cuts, a) - 1] if not np.isnan(a) else np.nan
        for a in arr
    ]


def cuts_to_str(cuts):
    res = []
    for v in cuts:
        if abs(v) >= 1e6:
            if v % 1e6 == 0:
                res.append(str(int(v / 1e6)) + "M")
            else:
                res.append(str(v / 1e6) + "M")
        elif abs(v) >= 1000:
            if v % 1000 == 0:
                res.append(str(int(v / 1000)) + "K")
            else:
                res.append(str(v / 1000) + "K")
        elif abs(v) >= 1:
            if v % 1 == 0:
                res.append(str(int(v)))
            else:
                res.append(str(round(v, 3)))
        else:
            res.append(str(round(v, 3)))
    return res


def cuts_to_range(cuts, pct=False):
    res = []
    pcuts = [
        "(" + cuts[i] + ")" if cuts[i][0] == "-" else cuts[i] for i in range(len(cuts))
    ]
    for i in range(len(pcuts) - 1):
        if pct:
            res.append(pcuts[i] + "% - " + pcuts[i + 1] + "%")
        else:
            res.append(pcuts[i] + " - " + pcuts[i + 1])
    if pct:
        res.append(pcuts[-1] + "%+")
    else:
        res.append(pcuts[-1] + "+")
    return res


def map_color(dataset, map_id, data, norm, data_type, ma, n, date):
    global splits
    values = (
        dataset.get_ts(None, data, norm, "Linear", data_type, ma, n).loc[date].dropna()
    )
    if data_type == "Total":
        values = values[values > 0]
    if dataset.name == "US States":
        new_index = [
            states_to_codes[c]
            for c in values.index.values
            if c in states_to_codes.keys()
        ]
        values = pd.Series(dict(zip(new_index, values.values)))
    elif dataset.name == "US Counties":
        new_index = [dataset.get_ID(c) for c in values.index.values]
        values = pd.Series(dict(zip(new_index, values.values)))
    missing_val = list(set(map_id) - set(values.index.values))
    to_keep = list(set(map_id).intersection(set(values.index.values)))
    values = values[to_keep]
    values = values.append(
        pd.Series(data=[np.nan] * len(missing_val), index=missing_val)
    )

    selected_index = dict_ID_to_map[dataset.name]
    n_breaks = 7
    if values.nunique() < 8 and values.nunique() > 0:
        n_breaks = values.nunique() - 1
    elif values.shape[0] == 0:
        map_legend.legends = {}
        dict_maps[selected_index]["map"].choro_data = values.to_dict()
        return
    try:
        if dataset.name == "World":
            weights = [2, 1, 4]
        else:
            weights = [2, 5, 0]
        q_round_cuts = pretty_breaks.breaker(values.dropna(), n_breaks, weights=weights)
        splits = q_round_cuts
        new_colormap = cm.StepColormap(
            colors=map_color_scheme[data], index=q_round_cuts
        )

        if data_type == "Daily % change":
            pct = True
        else:
            pct = False
        cuts_str_format = cuts_to_range(cuts_to_str(q_round_cuts), pct)
        map_legend.legends = dict(zip(cuts_str_format, map_color_scheme[data]))
        (
            dict_maps[selected_index]["map"].choro_data,
            dict_maps[selected_index]["map"].colormap,
        ) = (values.to_dict(), new_colormap)
    except:
        map_legend.legends = {}
        dict_maps[selected_index]["map"].choro_data = values.to_dict()
    return


In [130]:
def select_date(obj):
    # color the map with the values of the chosen date
    date = obj["new"]
    norm, scale, data_type = norm_buttons.value, scale_buttons.value, type_buttons.value
    data = data_buttons.value
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    map_color(
        datasets,
        dict_maps[selected_index]["map_id"],
        data,
        norm,
        data_type,
        ma,
        n,
        date,
    )
    main_map_title.value = main_map_title_val.format(
        map_title(dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n)
    )
    update_tables(data, norm, scale, data_type, ma, n, date)
    return


def map_click(feature, **kwargs):
    # update main graph if you click on a country or state
    try:
        if map_buttons.value == "World":
            name = feature["id"]
        else:
            name = feature["properties"]["name"]
        date = date_selector.value
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        normalization, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        date_sel_index = date_axis_selector.index
        date1 = datetime.datetime.strftime(
            datetime.datetime.strptime(
                date_axis_selector.options[date_sel_index[0]], "%m/%d/%y"
            ),
            dateformat,
        )  # str(date_axis_selector.options[date_sel_index[0]])
        date2 = datetime.datetime.strftime(
            datetime.datetime.strptime(
                date_axis_selector.options[date_sel_index[1]], "%m/%d/%y"
            ),
            dateformat,
        )  # str(date_axis_selector.options[date_sel_index[1]])
        update_main_graph(
            datasets,
            name,
            normalization,
            scale,
            data_type,
            ma,
            n,
            date1,
            date2,
            axis_t,
            axis_y,
            main_graph,
        )
    except:
        pass
    return


def map_background_click(*args):
    normalization, scale, data_type = (
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
    )
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    date_sel_index = date_axis_selector.index
    date1 = datetime.datetime.strftime(
        datetime.datetime.strptime(
            date_axis_selector.options[date_sel_index[0]], "%m/%d/%y"
        ),
        dateformat,
    )  # str(date_axis_selector.options[date_sel_index[0]])
    date2 = datetime.datetime.strftime(
        datetime.datetime.strptime(
            date_axis_selector.options[date_sel_index[1]], "%m/%d/%y"
        ),
        dateformat,
    )  # str(date_axis_selector.options[date_sel_index[1]])
    update_main_graph(
        dict_datasets[0],
        "WLD",
        normalization,
        scale,
        data_type,
        ma,
        n,
        date1,
        date2,
        axis_t,
        axis_y,
        main_graph,
    )
    return


def map_hovering(feature, **kwargs):
    # show stat table when you hover on a country or state
    _id = feature["id"]
    name = feature["properties"]["name"]

    date = date_selector.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    if selected_index > 0:
        _id = feature["properties"]["name"]
    try:
        norm, scale, data_type = (
            norm_buttons.value,
            "Linear",
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        data = data_buttons.value
        current_val = current_val_on_hover(data, norm, scale, data_type, ma, n)
        output_vals = [name]
        for c in selected:
            output_vals.append(
                datasets.get_value(_id, c, "Values", "Linear", "Total", ma, n, date)
            )
        output_vals.append(datasets.get_population(_id))
        if current_val is None:
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple(output_vals)
            )
        else:
            val = datasets.get_value(_id, data, norm, scale, data_type, ma, n, date)
            if not np.isnan(val):
                if data_type == "Daily change" and norm == "Values":
                    val = int(val)
                else:
                    val = round(val, 2)
            output_vals = [output_vals[0]] + [current_val, val] + output_vals[1:]
            stats_table.value = css_style + table_tmpl_no_duplicate.format(
                *tuple(output_vals)
            )
    except:
        try:
            population = int(datasets.get_population(_id))
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple([name] + [0] * 5 + [population])
            )
        except:
            stats_table.value = ""
    return


In [131]:
def event(*args):
    # update map color and graph if the value of the radio buttons change
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        if ma:
            tab_1_ma_w.layout.visibility = "visible"
        else:
            tab_1_ma_w.layout.visibility = "hidden"
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        map_color(
            datasets,
            dict_maps[selected_index]["map_id"],
            data,
            norm,
            data_type,
            ma,
            n,
            date,
        )
        main_map_title.value = main_map_title_val.format(
            map_title(
                dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n
            )
        )
        date_sel_index = date_axis_selector.index
        date1 = datetime.datetime.strftime(
            datetime.datetime.strptime(
                date_axis_selector.options[date_sel_index[0]], "%m/%d/%y"
            ),
            dateformat,
        )  # str(date_axis_selector.options[date_sel_index[0]])
        date2 = datetime.datetime.strftime(
            datetime.datetime.strptime(
                date_axis_selector.options[date_sel_index[1]], "%m/%d/%y"
            ),
            dateformat,
        )  # str(date_axis_selector.options[date_sel_index[1]])
        if main_graph.title.split(" ")[-1] == "State":
            update_main_graph(
                dict_datasets[1],
                main_graph.title.split(" in ")[-1][:-6],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        elif main_graph.title.split(" ")[-1] == "County":
            update_main_graph(
                dict_datasets[2],
                main_graph.title.split(" in ")[-1][:-7],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        else:
            update_main_graph(
                dict_datasets[0],
                countries_to_codes[main_graph.title.split(" in ")[-1]],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        update_tables(data, norm, scale, data_type, ma, n, date)
    return


In [132]:
play = Play(
    value=0,
    min=0,
    max=1,
    step=14,
    interval=2000,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto", height="auto", min_width="150px"),
)


def update_index_date(*args):
    date_selector.value = date_selector.options[args[0]["new"]]
    return


In [133]:
date_selector = SelectionSlider(
    options=["2020-01-22", "2020-01-23"],
    description="Date",
    disabled=False,
    value="2020-01-22",
    layout=Layout(width="100%", height="auto", min_width="200px"),
    style={"description_width": "initial"},
    description_tooltip='Select a date',
)

In [134]:
date_tab_1_box = HBox(
    [date_selector, play],
    layout=Layout(
        width="100%",
        min_width="200px",
        height="40px",
        min_height="40px",
        max_height="40px",
        overflow="auto",
    ),
)


In [135]:
date_axis_selector = SelectionRangeSlider(
    options=["2020-01-22"], index=(0, 0), layout=Layout(width="100%", height="auto"),
)


def update_zoom(*args):
    min_x, max_x = args[0]["new"]
    main_graph.axes[0].scale.min = np.datetime64(
        datetime.datetime.strftime(
            datetime.datetime.strptime(min_x, "%m/%d/%y"), dateformat
        )
    )
    main_graph.axes[0].scale.max = np.datetime64(
        datetime.datetime.strftime(
            datetime.datetime.strptime(max_x, "%m/%d/%y"), dateformat
        )
    )
    curves_subset = [
        i
        for i in range(len(main_graph.marks[0].y))
        if main_graph.marks[0].opacities[i] == 1
    ]
    main_graph.axes[1].scale.min = 0.9 * float(
        np.nanmin(
            main_graph.marks[0].y[
                curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )

    main_graph.axes[1].scale.max = 1.1 * float(
        np.nanmax(
            main_graph.marks[0].y[
                curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )
    return


In [136]:
data_buttons.add_observe(event, "value")
norm_buttons.add_observe(event, "value")
scale_buttons.add_observe(event, "value")
type_buttons.add_observe(event, "value")
tab_1_ma_ch.observe(event, "value")
tab_1_ma_w.observe(event, "value")


In [137]:
countries_geojson = json.load(open("../data/WORLD/countries.geo.json"))
us_states_geojson = json.load(open("../data/USA/us-states.json"))
us_counties_geojson = json.load(open("../data/USA/us-counties.json"))
main_map_title_val = "<p style='text-align:center; font-size:16px; font-weight:bold;\
color: white; width: 100%; overflow: auto; margin: 0px 0px 3px 0px'> {} </p>"
main_map_title = HTML()
main_map = ipl.Map(
    basemap={},
    center=[24.68, 4.39],
    zoom=2,
    zoom_snap=0.5,
    zoom_delta=0.5,
    layout=Layout(width="99%", height="99%", min_width="400px"),
)

countries_map_id = []
for i, d in enumerate(countries_geojson["features"]):
    countries_map_id.append(d["id"])

us_counties_map_id = []
us_counties_map_name = []
for i, v in enumerate(us_counties_geojson["features"]):
    us_counties_map_id.append(v["id"])
    us_counties_map_name.append(v["properties"]["name"])

us_states_map_id = []
for i, v in enumerate(us_states_geojson["features"]):
    us_states_map_id.append(v["id"])
splits = [i for i in range(8)]
colormap = cm.StepColormap(
    colors=[
        "#FFF0CC",
        "#ffcf76",
        "#ffa600",
        "#ff6e00",
        "#ff4417",
        "#d31522",
        "#8c000e",
        "#560410",
    ],
    index=splits,
)

map_legend = ipl.LegendControl(
    {
        "0-1": "#FFF0CC",
        "1-5": "#ffcf76",
        "5-10": "#ffa600",
        "10-20": "#ff6e00",
        "20-40": "#ff4417",
        "40-70": "#d31522",
        "70-100": "#8c000e",
        "100+": "#560410",
    },
    position="bottomright",
    name="",
)

main_map.add_control(map_legend)
main_map.add_control(ipl.FullScreenControl())

sc = ipl.SearchControl(
    position="topleft",
    url="https://nominatim.openstreetmap.org/search?format=json&country={s}",
    zoom=6,
)
#main_map.add_control(sc)

widget_control1 = ipl.WidgetControl(widget=map_buttons, position="topright")
main_map.add_control(widget_control1)

widget_control2 = ipl.WidgetControl(widget=stats_table, position="bottomleft")
main_map.add_control(widget_control2)


def getColor(value, colors, splits):
    if np.isnan(value):
        return "#5a5a5a"
    else:
        return colors[bisect.bisect_right(splits, value) - 1]


def style(feature, colormap, choro_data):
    global splits
    col = getColor(choro_data, map_color_scheme[data_buttons.value], splits)
    return {"fillColor": col, "fillOpacity": 1, "color": "black", "weight": 0.7}


chl_wm = ipl.Choropleth(
    geo_data=countries_geojson,
    choro_data=dict(zip(countries_map_id, [np.nan] * len(countries_map_id))),
    colormap=colormap,
    hover_style={"color": "white", "weight": 4},
    style_callback=style,
)

chl_us_states = ipl.Choropleth(
    geo_data=us_states_geojson,
    choro_data=dict(zip(us_states_map_id, [np.nan] * len(us_states_map_id))),
    colormap=colormap,
    hover_style={"color": "white", "weight": 4},
    style_callback=style,
)

chl_us_counties = ipl.Choropleth(
    geo_data=us_counties_geojson,
    choro_data=dict(zip(us_counties_map_id, [np.nan] * len(us_counties_map_id))),
    colormap=colormap,
    hover_style={"color": "white", "weight": 4},
    style_callback=style,
)

main_map.add_layer(chl_wm)

chl_wm.on_hover(map_hovering)
chl_wm.on_click(map_click)
chl_us_states.on_hover(map_hovering)
chl_us_states.on_click(map_click)
chl_us_counties.on_hover(map_hovering)
chl_us_counties.on_click(map_click)


In [138]:
dict_maps = {
    0: {
        "map": chl_wm,
        "name": "World",
        "center": [24.68, 4.39],
        "map_zoom": 2,
        "url": "https://nominatim.openstreetmap.org/search?format=json&country={s}",
        "search_zoom": 6,
        "map_id": countries_map_id,
    },
    1: {
        "map": chl_us_states,
        "name": "US States",
        "center": [37.57, -96.67],
        "map_zoom": 4,
        "url": "https://nominatim.openstreetmap.org/search?format=json&country=United States of America&state={s}",
        "search_zoom": 7,
        "map_id": us_states_map_id,
    },
    2: {
        "map": chl_us_counties,
        "name": "US Counties",
        "center": [37.57, -96.67],
        "map_zoom": 4,
        "url": "https://nominatim.openstreetmap.org/search?format=json&q={s} county",
        "search_zoom": 10,
        "map_id": us_counties_map_id,
    },
}
dict_ID_to_map = dict(zip(["World", "US States", "US Counties"], [0, 1, 2]))


In [139]:
def update_map(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        map_name = map_buttons.value
        ID = dict_ID_to_map[map_name]
        chl = dict_maps[ID]["map"]
        main_map.substitute_layer(main_map.layers[1], chl)
        main_map.center = dict_maps[ID]["center"]
        main_map.zoom = dict_maps[ID]["map_zoom"]
        sc.url = dict_maps[ID]["url"]
        sc.zoom = dict_maps[ID]["search_zoom"]
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        datasets = dict_datasets[ID]
        map_color(datasets, dict_maps[ID]["map_id"], data, norm, data_type, ma, n, date)
        main_map_title.value = main_map_title_val.format(
            map_title(map_name, data, norm, data_type, date, ma, n)
        )

    return


map_buttons.add_observe(update_map, "value")


In [140]:
# Tables

css_style_2 = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <style type="text/css" media="screen">
        #sorted_table {
          width: 100%;
          text-align: left;
          color: #ff8b0e;
          border-collapse: collapse;
        }
        #sorted_table th {
           text-align: center;
           color: #ff8b0e;
           background-color: black;
           border: none;
        } 
        #sorted_table tr:nth-child(odd){
            background-color: #424242;
        }
        #sorted_table tr:nth-child(even){
            background-color: black;
        }
        #sorted_table td{
          border: none;
          padding-left: 5px;
        }
    </style>
</head>
"""


def sorted_table_html(
    dataset,
    data,
    norm,
    scale,
    data_type,
    ma,
    n,
    date,
    ascending,
    K,
    dict_codes,
    column_name,
):
    df = dataset.get_ts_sort(
        data, norm, scale, data_type, ma, n, date, ascending, K
    ).reset_index()
    current_val = name_val_on_hover(data, norm, data_type, ma, n)

    df.rename(
        {df.columns.values[0]: column_name, date: current_val}, axis=1, inplace=True
    )
    if dict_codes:
        df[column_name] = [dict_codes[c] for c in df[column_name].values]
    df.dropna(inplace=True)
    def formatter(x, integer=True):
        if integer:
            return "{:,d}".format(int(x))
        else:
            if x >= 1000:
                return "{:,d}".format(int(x))
            else:
                return "{:,.02f}".format(x)
    if norm == "Values" and data_type != "Daily % change":
        df[current_val] = df[current_val].astype(int)
        form = lambda x: formatter(x)
    else:
        form = lambda x: formatter(x, integer=False)

    val = df.to_html(
        index=False,
        header=True,
        formatters={current_val: form},
        notebook=True,
        table_id="sorted_table",
    )
    return css_style_2 + val + """</body>"""


def sorted_tables_html(data, norm, scale, data_type, ma, n, date, ascending, K):
    value_table_1 = sorted_table_html(
        dict_datasets[0],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        codes_to_countries,
        "Countries",
    )
    value_table_2 = sorted_table_html(
        dict_datasets[1],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        None,
        "US States",
    )
    return value_table_1, value_table_2


def update_tables(data, norm, scale, data_type, ma, n, date):
    ascending = False
    K = None
    value_table_1, value_table_2 = sorted_tables_html(
        data, norm, scale, data_type, ma, n, date, ascending, K
    )

    table_1.value = value_table_1
    table_2.value = value_table_2
    return


table_1 = HTML()
table_2 = HTML()


In [141]:
# Infection Map tab

grid_1 = GridspecLayout(2, 2)
grid_1.layout.height = "99%"
grid_1.layout.width = "100%"
grid_1.layout.overflow = "auto"
table_grid = GridspecLayout(1, 2)
table_grid.layout.width = "100%"
table_grid.layout.height = "100%"
table_1.layout.width = "auto"
table_1.layout.overflow = "auto"
table_2.layout.overflow = "auto"
table_2.layout.width = "auto"
table_grid[0, 0] = table_1
table_grid[0, 1] = table_2
help_date_axis_selector = help_button(
    """- Select a period of time using this slider for the graph above</br>
    - Select a subset of curves in the right-hand side graph by clicking on the labels (Top left)"""
)
center_right_panel = VBox(
    [
        main_graph,
        HBox(
            [help_date_axis_selector, date_axis_selector],
            layout=Layout(hieght="auto", overflow="hidden"),
        ),
    ]
)
center_right_panel.layout.height = "100%"
tab_graph_table = Tab(
    _titles=dict(zip([0, 1], ["Graph", "Table"])),
    children=[center_right_panel, table_grid],
    layout=Layout(height="99%", min_height="300px"),
)
help_tab1 = help_button(help_tooltip["Tab 1"])
help_tab1.layout.margin = "0px 2px 3px 0px"
grid_1[0, 1] = VBox(
    [cat_tab_1_buttons, date_tab_1_box], layout=Layout(align_items="flex-end")
)
main_graph.layout.height = "99%"
left_grid = GridspecLayout(2, 2)
left_grid[0, 0] = help_tab1
left_grid[0, 1] = main_map_title
left_grid[1, :] = main_map
left_grid.layout.height = "99%"
left_grid.layout.width = "100%"
left_grid.layout.min_width = "400px"
left_grid.layout.grid_template_rows = "40px auto"
left_grid.layout.grid_template_columns = "40px auto"
grid_1[:, 0] = left_grid
grid_1[1, 1] = tab_graph_table
grid_1.layout.align_items = "stretch"
grid_1.layout.grid_template_rows = "210px auto"
grid_1.layout.grid_template_columns = "auto 565px"


In [142]:
# Rebased Graph Tab

# Buttons and Multiple selectors
min_description_width_2 = "140px"

rebased_graph_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
    key="tab2",
    tooltip_position="bottom",
)
global_style += rebased_graph_data_button.all_styles
rebased_graph_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
    key="tab2",
    tooltip_position="bottom",
)
global_style += rebased_graph_norm_button.all_styles
rebased_graph_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
    key="tab2",
    tooltip_position="top",
)
global_style += rebased_graph_type_button.all_styles
thr_val_slider = IntSlider(
    description="<tag title="
    + buttons_tooltip["Threshold"][0]
    + ">"
    + "Threshold"
    + "</tag>",
    value=1000,
    min=100,
    max=50000,
    step=100,
    style={"description_width": "initial"},
    layout=Layout(width="340px", visibility="hidden"),
)

plot_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Linear",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
    key="tab2",
    tooltip_position="top",
)
global_style += plot_scale_button.all_styles
min_description_width_2 = "125px"
rebased_graph_thr_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Threshold Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_2,
    tooltips=buttons_tooltip["Threshold data"][1:],
    description_tooltip=buttons_tooltip["Threshold data"][0],
    key="thrtab2",
    tooltip_position="bottom",
)
global_style += rebased_graph_thr_data_button.all_styles
rebased_graph_thr_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Threshold Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_2,
    tooltips=buttons_tooltip["Threshold norm"][1:],
    description_tooltip=buttons_tooltip["Threshold norm"][0],
    key="thrtab2",
    tooltip_position="top",
)
global_style += rebased_graph_thr_norm_button.all_styles
tab_2_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
)
tab_2_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_2 = VBox(
    [tab_2_ma_ch, tab_2_ma_w],
    layout=Layout(
        min_width="205px",
        width="205px",
        max_width="205px",
        height="74px",
        min_height="74px",
        max_height="74px",
        overflow="auto",
    ),
)


In [143]:
top_countries_names = ["United States"]
states = ["New York"]

countries_selector = SelectMultipleWidget(
    options=top_countries_names, value=[], width="auto", height="100%"
)
states_selector = SelectMultipleWidget(
    options=states, value=[], width="auto", height="100%"
)
calendar_time = Checkbox(
    description="<tag title="
    + buttons_tooltip["Calendar time"][0]
    + ">"
    + "Calendar Time"
    + "</tag>",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(width="auto", overflow="hidden"),
)

cat_tab_2_buttons = VBox(
    [
        rebased_graph_data_button,
        rebased_graph_norm_button,
        rebased_graph_type_button,
        plot_scale_button,
    ],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="visible",
    ),
)
thr_tab_2_buttons = HBox(
    [VBox([calendar_time, thr_val_slider], layout=Layout(width="auto")), ma_box_tab_2],
    layout=Layout(
        width="565px",
        height="80px",
        min_width="565px",
        max_width="565px",
        min_height="80px",
        max_height="80px",
        overflow="auto",
        border="solid #ff8b0e",
        margin="10px 0 10px 0",
    ),
)
cat_thr_tab_2_buttons = VBox(
    [rebased_graph_thr_data_button, rebased_graph_thr_norm_button,],
    layout=Layout(
        width="565px",
        height="74px",
        min_width="565px",
        max_width="565px",
        min_height="74px",
        max_height="74px",
        overflow="visible",
        visibility="hidden",
    ),
)


In [144]:
reb_top_right_buttons = VBox(
    [cat_tab_2_buttons, thr_tab_2_buttons, cat_thr_tab_2_buttons],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        overflow="auto",
        height="315px",
        min_height="315px",
        max_height="315px",
    ),
)


In [145]:
presaved_sel_options = [
    "Europe",
    "Asia",
    "North America",
    "South America",
    "Africa",
    "Oceania",
    "Top10 (Cases)",
    "Top10 (Deaths)",
    "Top10 (Cases/1M)",
    "Top10 (Deaths/1M)",
    "Top10 (Cases daily chg)",
    "Top10 (Deaths daily chg)",
]
presaved_sel_lists = dict(zip(presaved_sel_options, [None] * len(presaved_sel_options)))


def preselection(presaved_sel_options, class_name, action):
    presaved_sel_buttons = dict(
        zip(presaved_sel_options, [None] * len(presaved_sel_options))
    )
    for l in presaved_sel_options:
        if "Top10" in l:
            tltp = "Top10 most-affected countries (" + l.split("(")[-1]
        else:
            tltp = "Countries in " + l
        presaved_sel_buttons[l] = Button(description=l, button_style="", tooltip=tltp)
        presaved_sel_buttons[l].add_class(class_name)
        presaved_sel_buttons[l].on_click(action)
    return presaved_sel_buttons


def on_click_action(val):
    selection = presaved_sel_lists[val.description]
    if val._dom_classes[0] == "Rebased":
        states_selector.set_value([])
        countries_selector.set_value(selection)
    elif val._dom_classes[0] == "Heat map":
        dna_states_selector.set_value([])
        dna_countries_selector.set_value(selection)
    elif val._dom_classes[0] == "Custom":
        free_states_selector.set_value([])
        free_countries_selector.set_value(selection)
    return


In [146]:
def set_value(widget, new_val, action, unobserve=True):
    if new_val != None:
        if type(widget).__name__ == "Toggle_Buttons":
            if unobserve:
                widget.del_observe(action, "value")
                widget.set_value(new_val)
                widget.add_observe(action, "value")
            else:
                widget.set_value(new_val)
        elif type(widget).__name__ == "SelectMultipleWidget":
            if unobserve:
                widget.all_unobserve(action)
                widget.set_value(new_val)
                widget.all_observe(action)
            else:
                if sorted(list(widget.value)) == sorted(new_val):
                    widget.set_value(new_val + ["Alabama", "New York"])
                widget.set_value(new_val)
        else:
            if unobserve:
                widget.unobserve(action, "value")
                widget.value = new_val
                widget.observe(action, "value")
            else:
                widget.value = new_val
    return


def on_click_scenario(val):
    ind = outer_tab.selected_index
    sc = val.description
    for i, w in enumerate(zip(dict_tab_buttons[ind], dict_scenario[ind][sc])):
        if i == len(dict_tab_buttons[ind]) - 1:
            unobserve = False
        else:
            unobserve = True
        set_value(w[0], w[1], dict_tab_buttons_actions[ind][i], unobserve)
    return


def sc_buttons(presaved_sc, action, tooltips):
    res = dict(zip(presaved_sc, [None] * len(presaved_sc)))
    for descr, tltp in zip(presaved_sc, tooltips):
        res[descr] = Button(description=descr, button_style="", tooltip=tltp)
        res[descr].on_click(action)
    return res


In [147]:
scenario_tooltip_tab_2 = [
    "Cases/1M population daily change during the past 30days",
    "Cases (log scale) from 2020-05-25 in 8 US States",
    "Evolution of the total number of deaths since the number of deaths = 1000",
    "Cases daily change (7 days moving average) since the number of cases = 1000",
]
ps_tab2 = preselection(presaved_sel_options, "Rebased", on_click_action)
hb_tab2 = VBox(
    [
        HBox([ps_tab2[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab2_1 = sc_buttons(
    ["Scenario " + str(i) for i in range(1, 5)],
    on_click_scenario,
    scenario_tooltip_tab_2,
)
hb_tab2_1 = VBox(
    [
        HBox([ps_tab2_1[v] for v in list(ps_tab2_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab2_1.keys())), 4)
    ]
)
accordion_tab2 = Accordion(
    children=[hb_tab2, hb_tab2_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab2.set_title(0, "Predefined selections")
accordion_tab2.set_title(1, "Predefined scenarios")


In [148]:
tab_2_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="565px",
        align_items="stretch",
        min_height="160px",
        overflow="auto",
        border="solid #ff8b0e",
    ),
)
tab_2_box_selectors[0, :] = accordion_tab2
tab_2_box_selectors.layout.grid_template_rows = "80px auto"


def accordion_adapt_height(val):
    selected = val.new == 0 or val.new == 1
    index = outer_tab.selected_index
    tab_selectors = dict(
        zip([1, 2, 3], [tab_2_box_selectors, tab_3_box_selectors, tab_4_box_selectors])
    )
    if selected:
        if val.new == 0:
            tab_selectors[index].layout.grid_template_rows = "230px auto"
        else:
            tab_selectors[index].layout.grid_template_rows = "160px auto"
    else:
        tab_selectors[index].layout.grid_template_rows = "80px auto"


accordion_tab2.observe(accordion_adapt_height, "selected_index")


In [149]:
# Graph

scale_reb_t = LinearScale()
axis_reb_t = Axis(scale=scale_reb_t, grid_lines="none")
scale_reb_y = LogScale()
axis_reb_y = Axis(
    scale=scale_reb_y, orientation="vertical", grid_lines="solid", tick_format=","
)
rebased_graph = Figure(
    animation_duration=1000,
    title="Time Track",
    axes=[axis_reb_t, axis_reb_y],
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 100},
    layout=Layout(min_width="400px")
)


In [150]:
date_axis_selector_tab_2 = SelectionRangeSlider(
    options=["2020-01-22"], index=(0, 0), layout=Layout(width="auto", height="auto"),
)


def update_zoom_tab_2(*args):
    min_x, max_x = args[0]["new"]
    rebased_graph.axes[0].scale.min = np.datetime64(
        datetime.datetime.strftime(
            datetime.datetime.strptime(min_x, "%m/%d/%y"), dateformat
        )
    )
    rebased_graph.axes[0].scale.max = np.datetime64(
        datetime.datetime.strftime(
            datetime.datetime.strptime(max_x, "%m/%d/%y"), dateformat
        )
    )
    if len(rebased_graph.marks[0].y.shape) == 2:
        rebased_graph.axes[1].scale.min = 0.9 * float(
            np.nanmin(
                rebased_graph.marks[0].y[
                    :,
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.axes[1].scale.max = 1.1 * float(
            np.nanmax(
                rebased_graph.marks[0].y[
                    :,
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.marks[1].x = [
            np.datetime64(
                datetime.datetime.strftime(
                    datetime.datetime.strptime(max_x, "%m/%d/%y"), dateformat
                )
            )
        ] * len(rebased_graph.marks[1].x)
        rebased_graph.marks[1].y = rebased_graph.marks[0].y[
            :, date_axis_selector_tab_2.index[1]
        ]
    elif len(rebased_graph.marks[0].y.shape) == 1:
        rebased_graph.axes[1].scale.min = 0.9 * float(
            np.nanmin(
                rebased_graph.marks[0].y[
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.axes[1].scale.max = 1.1 * float(
            np.nanmax(
                rebased_graph.marks[0].y[
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.marks[1].x = [
            np.datetime64(
                datetime.datetime.strftime(
                    datetime.datetime.strptime(max_x, "%m/%d/%y"), dateformat
                )
            )
        ] * len(rebased_graph.marks[1].x)
        rebased_graph.marks[1].y = [
            rebased_graph.marks[0].y[date_axis_selector_tab_2.index[1]]
        ]
    (
        data_to_plot,
        data_norm,
        data_type,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
        threshold,
    ) = [v.value for v in value_buttons[:-2]]
    if calendar_time.value:
        threshold = 0
    rebased_graph.title = title_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
        date1=str(min_x),
        date2=str(max_x),
    )
    return


In [151]:
def title_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    threshold,
    plot_scale,
    thr_data,
    thr_norm,
    ma,
    n,
    date1="",
    date2="",
):
    res = ""
    res += data_to_plot
    if data_norm != "Values":
        res += "/1M"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    if plot_scale == "Log":
        res += "(log scale)"
    if threshold > 0:
        res += " since number of " + thr_data
        if thr_norm != "Values":
            res += "/1M"
        res += " = " + str(threshold)
    else:
        res += " from " + date1 + " to " + date2
    return res


def label_index_pos(data):
    list_nan = list(np.isnan(data))
    if True in list_nan:
        index = len(data) - 1
        while index >= 1 and not (
            list_nan[index] == False and list_nan[index - 1] == False
        ):
            index -= 1
        return index
    else:
        return len(data) - 1


In [152]:
line_hover_tooltip = HTML("")


def line_hover_tab2(*args):
    val = """<h4 style='color:{}; margin:0px; padding:0px; font-weight:bold'> {}</h4>"""
    index = args[1]["data"]["index"]
    color = args[0].colors[index]
    name = rebased_graph.marks[1].text[index]
    line_hover_tooltip.value = val.format(color, name)
    return


In [153]:
def persistent_colors(used_colors, used_labels, new_labels, all_colors):
    if len(used_colors) != 0:
        color_dict = dict(
            zip(
                used_labels,
                list(used_colors) * (len(used_labels) // len(used_colors) + 1),
            )
        )
    else:
        color_dict = {}
    for l in used_labels:
        if l not in new_labels:
            color_dict.pop(l, None)
    non_assigned = []
    for l in new_labels:
        if l not in used_labels:
            non_assigned.append(l)
    if len(non_assigned) > 0:
        counter_colors = Counter(color_dict.values())
        for c in all_colors:
            if c not in list(counter_colors.keys()):
                counter_colors[c] = 0
        for l in non_assigned:
            index = np.argmin(list(counter_colors.values()))
            new_color = list(counter_colors.keys())[index]
            color_dict[l] = new_color
            counter_colors[new_color] += 1
    return color_dict


def plot_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    thr,
    countries,
    states,
    plot_scale,
    thr_data,
    thr_norm,
    ma,
    n,
    calendar_time=False,
):
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    update = len(rebased_graph.marks) > 0
    if calendar_time:
        threshold = 0
    else:
        threshold = thr
    if plot_scale == "Log":
        scale_reb_y = LogScale()
    else:
        scale_reb_y = LinearScale()
    Ys = []
    Xs = []
    label_names = []
    max_len = 0
    for country_name in countries:
        yaux = dict_datasets[0].get_ts_plot(
            countries_to_codes[country_name],
            data_to_plot,
            data_norm,
            plot_scale,
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[0].get_ts(
                countries_to_codes[country_name],
                thr_data,
                thr_norm,
                "Linear",
                "Total",
                ma,
                n,
            )
            # values starts from index of threshold
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(country_name)
            max_len = max(max_len, len(Ys[-1]))

    for st in states:
        yaux = dict_datasets[1].get_ts_plot(
            st,
            data_to_plot,
            data_norm,
            "Linear",
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[1].get_ts(
                st, thr_data, thr_norm, "Linear", "Total", ma, n,
            )
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(st)
            max_len = max(max_len, len(Ys[-1]))

    label_x = []
    label_y = []
    for i in range(len(Ys)):
        # get a 2d nparray to plot the rebased graph
        index_lab = label_index_pos(Ys[i])
        label_y.append(Ys[i][index_lab])
        label_x.append(Xs[i][index_lab])
        Ys[i] = Ys[i] + [math.nan] * (max_len - len(Ys[i]))
        Xs[i] = Xs[i] + [math.nan] * (max_len - len(Xs[i]))
    axis_reb_y.scale = scale_reb_y
    if threshold > 0:
        scale_reb_t = LinearScale()
    else:  # if thr == 0 Plot the values vs calendar dates
        scale_reb_t = DateScale(
            dateformat=dateformat,
            min=datetime.datetime.strptime(dict_datasets[0].STDT, dateformat),
            max=datetime.datetime.strptime(dict_datasets[0].ENDT, dateformat),
        )
        Xs = list(
            np.arange(
                np.datetime64(
                    datetime.datetime.strptime(STDT, dateformat).strftime(dateformat)
                ),
                np.datetime64(
                    datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
                )
                + np.timedelta64(1, "D"),
            )
        )
        label_x = np.array([
                np.datetime64(ENDT)
        for i in range(len(countries) + len(states))], dtype='datetime64[D]')
    axis_reb_t.scale = scale_reb_t
    rebased_graph.axes = [axis_reb_t, axis_reb_y]
    if update:
        color_dict = persistent_colors(
            list(rebased_graph.marks[0].colors),
            list(rebased_graph.marks[1].text),
            label_names,
            all_colors,
        )
        colors = [color_dict[c] for c in label_names]
    else:
        colors = all_colors.copy()
    rebased_curves = Lines(
        scales={"x": scale_reb_t, "y": scale_reb_y},
        colors=colors,
        interactions={"hover": "tooltip"},
        tooltip=line_hover_tooltip,
    )
    rebased_curves.x = np.array(Xs)
    rebased_curves.y = np.array(Ys)
    rebased_labels = Label(
        apply_clip=False,
        scales={"x": scale_reb_t, "y": scale_reb_y},
        default_size=15,
        colors=colors,
    )
    rebased_labels.text = label_names
    rebased_labels.x = label_x
    rebased_labels.y = np.array(label_y)
    rebased_graph.marks = [rebased_curves, rebased_labels]
    rebased_graph.title = title_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
    )
    rebased_curves.on_hover(line_hover_tab2)
    if calendar_time:
        ind1, ind2 = date_axis_selector_tab_2.index
        update_zoom_tab_2({"new": [date_axis_selector_tab_2.options[ind1], date_axis_selector_tab_2.options[ind2]]})
    return


In [154]:
value_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    plot_scale_button,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    tab_2_ma_ch,
    tab_2_ma_w,
    thr_val_slider,
    countries_selector,
    states_selector,
]


def update_rebased_graph(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        (
            data_to_plot,
            data_norm,
            data_type,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            threshold,
        ) = [v.value for v in value_buttons[:-2]]
        countries, states = list(countries_selector.value), list(states_selector.value)

        if "NONE" in countries:
            countries = []
        if "NONE" in states:
            states = []
        if len(countries) + len(states) == 0:
            countries = ["United States"]

        if ma:
            tab_2_ma_w.layout.visibility = "visible"
        else:
            tab_2_ma_w.layout.visibility = "hidden"

        if calendar_time.value:
            thr_val_slider.layout.visibility = "hidden"
            cat_thr_tab_2_buttons.layout.visibility = "hidden"
            date_axis_selector_tab_2.layout.visibility = "visible"
        else:
            thr_val_slider.layout.visibility = "visible"
            cat_thr_tab_2_buttons.layout.visibility = "visible"
            date_axis_selector_tab_2.layout.visibility = "hidden"

        plot_rebased_graph(
            data_to_plot,
            data_norm,
            data_type,
            threshold,
            countries,
            states,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            calendar_time.value,
        )
        update_link_download()
    return


thr_val_slider.continuous_update = False
thr_val_slider.observe(update_rebased_graph, "value")
tab_2_ma_ch.observe(update_rebased_graph, "value")
tab_2_ma_w.observe(update_rebased_graph, "value")
calendar_time.observe(update_rebased_graph, "value")
for v in value_buttons[:-5]:
    v.add_observe(update_rebased_graph, "value")


In [155]:
grid_2 = GridspecLayout(2, 2)
grid_2.layout.overflow = "auto"
grid_2.layout.height = "100%"
grid_2.layout.width = "100%"
rebased_graph.layout.width = "auto"
rebased_graph.layout.height = "99%"
help_tab2 = help_button(help_tooltip["Tab 2"])
grid_2_left_panel = GridspecLayout(2, 2)
grid_2_left_panel[0, 0] = help_tab2
grid_2_left_panel[0, 1] = rebased_graph
grid_2_left_panel[1, :] = date_axis_selector_tab_2
grid_2_left_panel.layout.grid_template_columns = "32px auto"
grid_2_left_panel.layout.grid_template_rows = "auto 40px"
grid_2[:, 0] = grid_2_left_panel
grid_2[0, 1] = HBox([reb_top_right_buttons])
grid_2[1, 1] = HBox(
    [tab_2_box_selectors], layout=Layout(min_height="210px", overflow="auto")
)
grid_2.layout.align_items = "stretch"
grid_2.layout.grid_template_columns = "auto 590px"
grid_2.layout.grid_template_rows = "315px auto"


In [156]:
# DNA Graph tab3

dna_countries_selector = SelectMultipleWidget(top_countries_names, [], "auto", "100%")
dna_states_selector = SelectMultipleWidget(states, [], "auto", "100%")

dna_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
    key="tab3",
    tooltip_position="bottom",
)
global_style += dna_data_button.all_styles
dna_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
    key="tab3",
    tooltip_position="bottom",
)
global_style += dna_norm_button.all_styles
dna_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
    key="tab3",
    tooltip_position="top",
)
global_style += dna_type_button.all_styles
tab_3_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="100%", overflow="hidden"),
)
tab_3_norm = Checkbox(
    description="<tag title="
    + buttons_tooltip["Scale by max"][0]
    + ">"
    + "Scale by max"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="100px", width="100%", overflow="auto"),
)
tab_3_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_3 = HBox(
    [tab_3_norm, tab_3_ma_ch, tab_3_ma_w],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        overflow="auto",
        align_items="stretch",
    ),
)

cat_tab_3_buttons = VBox(
    [dna_data_button, dna_norm_button, dna_type_button],
    layout=Layout(
        width="565px",
        height="105px",
        min_width="565px",
        max_width="565px",
        min_height="105px",
        max_height="105px",
        overflow="visible",
    ),
)

dna_top_buttons = VBox(
    [cat_tab_3_buttons, ma_box_tab_3],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        height="140px",
        min_height="140px",
        max_height="140px",
        overflow="auto",
    ),
)


In [157]:
scenario_tooltip_tab_3 = [
    "Cases/1M population daily change(7 days moving average)",
    "Cases/1M population daily change(7 days moving average) (US States)",
    "Deaths/1M population daily change(7 days moving average)",
    "Cases daily change (7 days moving average) scaled by max",
]

ps_tab3 = preselection(presaved_sel_options, "Heat map", on_click_action)
hb_tab3 = VBox(
    [
        HBox([ps_tab3[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab3_1 = sc_buttons(
    ["Scenario " + str(i) for i in range(1, 5)],
    on_click_scenario,
    scenario_tooltip_tab_3,
)
hb_tab3_1 = VBox(
    [
        HBox([ps_tab3_1[v] for v in list(ps_tab3_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab3_1.keys())), 4)
    ]
)
accordion_tab3 = Accordion(
    children=[hb_tab3, hb_tab3_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab3.set_title(0, "Predefined selections")
accordion_tab3.set_title(1, "Predefined scenarios")
accordion_tab3.observe(accordion_adapt_height, "selected_index")


In [158]:
tab_3_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="565px", align_items="stretch", min_height="160px", overflow="auto"
    ),
)
tab_3_box_selectors[0, :] = accordion_tab3
tab_3_box_selectors.layout.grid_template_rows = "80px auto"


In [159]:
# GridHeatMap, DNA graph

dna_x_scale = DateScale()
dna_y_scale = OrdinalScale(padding_y=0)
dna_color_scale = ColorScale(
    colors=["white", "#ffa600", "#ff6e00", "#ff4417", "#d31522", "#8c000e", "#560410"]
)
dna_x_ax = Axis(scale=dna_x_scale, tick_format="%m/%d")
dna_y_ax = Axis(scale=dna_y_scale, orientation="vertical", side="right")
dna_color_ax = ColorAxis(scale=dna_color_scale, tick_format=",")
dna_axes = [dna_x_ax, dna_y_ax, dna_color_ax]

dna_layout = Layout(width="auto", height="99%", min_width="400px")
dna_fig_margin = {"top": 50, "bottom": 60, "left": 0, "right": 80}

dna_figure = Figure(
    axes=dna_axes,
    fig_margin=dna_fig_margin,
    layout=dna_layout,
    min_aspect_ratio=0.0,
    title="Cases",
    padding_y=0,
    animation_duration=1000,
    null_color="#808080",
)


In [160]:
def update_dna(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "NONE" in dna_countries_selector.value:
            countries = []
        else:
            countries = list(dna_countries_selector.value)

        if "NONE" in dna_states_selector.value:
            states = []
        else:
            states = list(dna_states_selector.value)

        ma, n = tab_3_ma_ch.value, tab_3_ma_w.value
        if ma:
            tab_3_ma_w.layout.visibility = "visible"
        else:
            tab_3_ma_w.layout.visibility = "hidden"

        if len(countries) + len(states) == 0:
            countries = ["United States", "World"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(states) == 0 and len(countries) >= 1:
            if len(countries) == 1:
                if "United States" in countries:
                    countries = ["World", "United States"]
                else:
                    countries = countries + ["United States"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(countries) == 0 and len(states) >= 1:
            if len(states) == 1:
                if "New York" in states:
                    states = ["New Jersey", "New York"]
                else:
                    states = states + ["New York"]
            df_DNA = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
        else:
            data_countries = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
            data_states = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
            df_DNA = pd.merge(
                data_countries,
                data_states,
                left_index=True,
                right_index=True,
                how="outer",
            )
        df_DNA = df_DNA.iloc[:, ::-1]
        title = dna_data_button.value
        if dna_norm_button.value == "Per million":
            title += "/1M"
        if dna_type_button.value != "Total":
            title += " " + dna_type_button.value
        if ma:
            title += " ({0:d}days m.a)".format(n)
        if tab_3_norm.value:
            df_DNA = 100 * df_DNA.divide(df_DNA.max(axis=0))
            title += ", (scaled by max)"

        dna_color_ax.visible = not tab_3_norm.value

        dna_figure.marks[0].color = df_DNA.values.T
        dna_figure.marks[0].row = list(df_DNA.columns.values)
        if dna_data_button.value == "Vaccine" or dna_data_button.value == "Tests":
            dna_color_scale.colors = [
                "white",
                "#C5E8B7",
                "#ABE098",
                "#83D475",
                "#57C84D",
                "#2EB62C",
                "#207f1e",
            ]
        else:
            dna_color_scale.colors = [
                "white",
                "#ffa600",
                "#ff6e00",
                "#ff4417",
                "#d31522",
                "#8c000e",
                "#560410",
            ]
        dna_figure.title = title
        update_link_download()
    return


dna_data_button.add_observe(update_dna, "value")
dna_norm_button.add_observe(update_dna, "value")
dna_type_button.add_observe(update_dna, "value")
tab_3_ma_ch.observe(update_dna, "value")
tab_3_norm.observe(update_dna, "value")
tab_3_ma_w.observe(update_dna, "value")


In [161]:
dna_tooltip_table = """
<table id="stats-table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr>
    <td>Country/State</td>
    <td>{1:s}</td>
</tr>
<tr>
    <td>{2:s}</td>
    <td>{3:,}</td>
</tr>
</table>
"""
dna_stats_tootltip = HTML()


def update_dna_tooltip(*args):
    Date = str(dna_figure.marks[0].column[int(args[1]["data"]["column_num"])])[:-19]
    Country = args[1]["data"]["row"]
    val = args[1]["data"]["color"]
    try:
        if not np.isnan(val):
            if "pop" in dna_figure.title or "%" in dna_figure.title:
                val = round(val, 1)
            else:
                val = int(val)
    except:
        val = np.nan
    dna_stats_tootltip.value = dna_tooltip_table.format(
        *tuple([Date, Country, dna_figure.title, val])
    )
    return


In [162]:
grid_3 = GridspecLayout(2, 2)
grid_3.layout.overflow = "auto"
grid_3.layout.height = "100%"
grid_3.layout.width = "100%"
help_tab3 = help_button(help_tooltip["Tab 3"])
grid_3_left_panel = GridspecLayout(1, 2)
grid_3_left_panel[0, 0] = help_tab3
grid_3_left_panel[0, 1] = dna_figure
grid_3_left_panel.layout.grid_template_columns = "32px auto"
grid_3[:, 0] = grid_3_left_panel
grid_3[0, 1] = HBox([dna_top_buttons])
grid_3[1, 1] = HBox(
    [tab_3_box_selectors], layout=Layout(min_height="210px", overflow="auto")
)
grid_3.layout.align_items = "stretch"
grid_3.layout.grid_template_columns = "auto 590px"
grid_3.layout.grid_template_rows = "140px auto"


In [163]:
# Free graph

# Buttons and interactions


min_button_width_1 = "100px"
x_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Data",
    min_button_width="100px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
    key='xtab4',
    tooltip_position='bottom',
)
global_style += x_data_button.all_styles
y_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Vaccine", "Tests"],
    value="Cases",
    description="Data",
    min_button_width="100px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
    key='ytab4',
    tooltip_position='bottom',
)
global_style += y_data_button.all_styles

x_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
    key='xtab4',
    tooltip_position='bottom',
)
global_style += x_norm_button.all_styles
y_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
    key='ytab4',
    tooltip_position='bottom',
)
global_style += y_norm_button.all_styles

x_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
    key='xtab4',
    tooltip_position='top',
)
global_style += x_scale_button.all_styles
y_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
    key='ytab4',
    tooltip_position='top',
)
global_style += y_scale_button.all_styles

x_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width="120px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
    key='xtab4',
    tooltip_position='top',
)
global_style += x_type_button.all_styles
y_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Daily change",
    description="Type",
    min_button_width="120px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
    key='ytab4',
    tooltip_position='top',
)
global_style += y_type_button.all_styles
free_checkbox = Checkbox(
    value=False,
    description="<tag title="
    + buttons_tooltip["Last data point"][0]
    + ">"
    + "Last data point"
    + "</tag>",
    disabled=False,
    style={"description_width": "initial"},
    layout=Layout(width="auto", height="auto"),
)

tab_4_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
)
tab_4_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="195px", width="195px", overflow="auto", visibility="visible"
    ),
)
ma_box_tab_4 = HBox(
    [tab_4_ma_ch, tab_4_ma_w],
    layout=Layout(
        width="100%",
        overflow="hidden",
        align_items="stretch",
        justify_content="space-between",
    ),
)


In [164]:
x_data = VBox(
    [x_data_button, x_norm_button, x_type_button, x_scale_button,],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="visible",
    ),
)

y_data = VBox(
    [y_data_button, y_norm_button, y_type_button, y_scale_button,],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="visible",
    ),
)


In [165]:
x_data_buttons = HBox(
    [
        HTML(
            "<p style='text-align:left; font-size:18px; font-weight:bold; padding: 0; margin: 0; color: #ff8b0e; margin-right: 15px'>X DATA : </p>",
            layout=Layout(width="105px"),
        ),
        x_data,
    ],
    layout=Layout(
        width="670px",
        height="145px",
        align_items="center",
        border="solid #ff8b0e",
        margin="0 0 5px 0",
        overflow="hidden",
    ),
)
y_data_buttons = HBox(
    [
        HTML(
            "<p style='text-align:left; font-size:18px; font-weight:bold; padding: 0; margin: 0; color: #ff8b0e; margin-right: 15px'>Y DATA : </p>",
            layout=Layout(width="105px"),
        ),
        y_data,
    ],
    layout=Layout(
        width="670px",
        height="145px",
        align_items="center",
        border="solid #ff8b0e",
        overflow="hidden",
    ),
)
cat_tab_4_buttons = VBox([x_data_buttons, y_data_buttons])


In [166]:
free_countries_selector = SelectMultipleWidget(
    top_countries_names, [], "auto", height="100%"
)
free_states_selector = SelectMultipleWidget(states, [], "auto", height="100%")


In [167]:
scenario_tooltip_tab_4 = [
    "Cases/1M population daily change(log scale) vs Cases/1M population (log scale) (7days moving average)",
    "Cases/1M popultation vs Tests/1M population",
    "Vaccine/1M popultation vs Cases/1M population",
    "Deaths/1M popultation vs Cases/1M population (last data point)",
]
ps_tab4 = preselection(presaved_sel_options, "Custom", on_click_action)
hb_tab4 = VBox(
    [
        HBox([ps_tab4[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab4_1 = sc_buttons(
    ["Scenario " + str(i) for i in range(1, 5)],
    on_click_scenario,
    scenario_tooltip_tab_4,
)
hb_tab4_1 = VBox(
    [
        HBox([ps_tab4_1[v] for v in list(ps_tab4_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab4_1.keys())), 4)
    ]
)
accordion_tab4 = Accordion(
    children=[hb_tab4, hb_tab4_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab4.set_title(0, "Predefined selections")
accordion_tab4.set_title(1, "Predefined scenarios")
accordion_tab4.observe(accordion_adapt_height, "selected_index")


In [168]:
tab_4_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="670px", align_items="stretch", min_height="160px", overflow="auto"
    ),
)
tab_4_box_selectors[0, :] = accordion_tab4
tab_4_box_selectors.layout.grid_template_rows = "80px auto"


In [169]:
countries_selector_box_free = VBox(
    [HTML('<p style="text-align:center">Countries</p>'), free_countries_selector.res],
    layout=Layout(width="100%", height="100%"),
)
states_selector_box_free = VBox(
    [HTML('<p style="text-align:center">US States</p>'), free_states_selector.res],
    layout=Layout(width="100%", height="100%"),
)


In [170]:
free_scatter_sc_x = LinearScale()
free_scatter_sc_y = LinearScale()
free_scatter_ax_x = Axis(scale=free_scatter_sc_x, tick_format=",")
free_scatter_ax_y = Axis(
    scale=free_scatter_sc_y, orientation="vertical", tick_format=","
)
free_pz = PanZoom(scales={"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]})
free_scatter_fig = Figure(
    axes=[free_scatter_ax_x, free_scatter_ax_y],
    interaction=free_pz,
    layout=Layout(width="auto", height="99%", min_width="400px"),
    fig_margin={"top": 50, "left": 60, "right": 100, "bottom": 50},
)


In [171]:
def last_index_nan(x_li, y_li):
    N = len(x_li)
    i = N - 1
    while i >= 0 and (np.isnan(x_li[i]) or np.isnan(y_li[i])):
        i -= 1
    return i


def clean_line_plot(xs, ys):
    for i in range(len(xs)):
        for j in range(len(xs[i])):
            if np.isnan(xs[i][j]) or np.isnan(ys[i][j]):
                xs[i][j] = np.nan
                ys[i][j] = np.nan
    return xs, ys


def free_scatter(
    selected_countries, datas, norms, scales, data_types, ma, n, last_point=True
):
    global free_mark_x, free_mark_y
    Ys = []
    Xs = []
    label_names = []
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    if len(free_scatter_fig.marks) > 0:
        update = True
        new_labels = [codes_to_countries[c] for c in selected_countries[0]] + list(
            selected_countries[1]
        )
        color_dict = persistent_colors(
            list(free_scatter_fig.marks[0].colors),
            list(free_scatter_fig.marks[1].text),
            new_labels,
            all_colors,
        )
    else:
        update = False
        colors = all_colors.copy()

    if scales[0] == "Log":
        free_scatter_sc_x = LogScale()
    elif scales[0] == "Linear":
        free_scatter_sc_x = LinearScale()

    if scales[1] == "Log":
        free_scatter_sc_y = LogScale()
    elif scales[1] == "Linear":
        free_scatter_sc_y = LinearScale()

    for i in range(len(selected_countries)):
        for c in selected_countries[i]:
            if last_point:
                x = dict_datasets[i].get_value(
                    c,
                    datas[0],
                    norms[0],
                    scales[0],
                    data_types[0],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                y = dict_datasets[i].get_value(
                    c,
                    datas[1],
                    norms[1],
                    scales[1],
                    data_types[1],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                if not (np.isnan(x) or np.isnan(y)):
                    Xs.append(x)
                    Ys.append(y)
                    if dict_datasets[i].name == "World":
                        label_names.append(codes_to_countries[c])
                    else:
                        label_names.append(c)
            else:
                Xs.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[0], norms[0], scales[0], data_types[0], ma, n)
                        .values
                    )
                )
                Ys.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[1], norms[1], scales[1], data_types[1], ma, n)
                        .values
                    )
                )
                if dict_datasets[i].name == "World":
                    label_names.append(codes_to_countries[c])
                else:
                    label_names.append(c)
            if update:
                if dict_datasets[i].name == "World":
                    colors.append(color_dict[codes_to_countries[c]])
                else:
                    colors.append(color_dict[c])
    free_scatter_ax_x.scale = free_scatter_sc_x
    free_scatter_ax_y.scale = free_scatter_sc_y
    free_scatter_fig.axes = [free_scatter_ax_x, free_scatter_ax_y]

    free_labels = Label(
        apply_clip=False,
        scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
        default_size=15,
        x_offset=4,
        colors=colors,
    )
    free_labels.text = label_names
    if last_point:
        free_labels.x = np.array(Xs)
        free_labels.y = np.array(Ys)

        free_mark = Scatter(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            colors=colors,
            marker_size=30,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)

    else:
        Xs, Ys = clean_line_plot(Xs, Ys)
        free_labels.x = np.array(
            [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )
        free_labels.y = np.array(
            [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )

        free_mark = Lines(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            marker_size=30,
            colors=colors,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)
        free_mark_x = Xs
        free_mark_y = Ys

    free_scatter_fig.marks = [free_mark, free_labels]

    def title_free_scatter(data, norm, scale, data_type, ma, n):
        res = ""
        res += " " + data
        if norm == "Per million":
            res += "/1M"
        if data_type != "Total":
            res += " " + data_type.lower()
        if ma:
            res += " ({0:d}days m.a)".format(n)
        if scale == "Log":
            res += "(log scale)"
        return res

    free_scatter_fig.title = (
        title_free_scatter(datas[1], norms[1], scales[1], data_types[1], False, None)
        + " vs "
        + title_free_scatter(datas[0], norms[0], scales[0], data_types[0], ma, n)
        + ", "
        + ENDT
    )
    free_pz.scales = {"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]}
    return


In [172]:
def update_free_scatter_fig(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "NONE" in free_countries_selector.value:
            countries = []
        else:
            countries = [countries_to_codes[c] for c in free_countries_selector.value]

        if "NONE" in free_states_selector.value:
            states = []
        else:
            states = free_states_selector.value

        if len(countries) + len(states) == 0:
            countries = ["USA"]

        ma, n = tab_4_ma_ch.value, tab_4_ma_w.value
        if ma:
            tab_4_ma_w.layout.visibility = "visible"
        else:
            tab_4_ma_w.layout.visibility = "hidden"

        selected_countries = [countries, states]
        datas = [x_data_button.value, y_data_button.value]
        norms = [x_norm_button.value, y_norm_button.value]
        scales = [x_scale_button.value, y_scale_button.value]
        data_types = [x_type_button.value, y_type_button.value]
        last_point = free_checkbox.value
        free_scatter(
            selected_countries, datas, norms, scales, data_types, ma, n, last_point
        )
        if last_point:
            free_play.layout.visibility = "hidden"
        else:
            free_play.layout.visibility = "visible"
            
        update_link_download()
    return


x_data_button.add_observe(update_free_scatter_fig, "value")
y_data_button.add_observe(update_free_scatter_fig, "value")
x_norm_button.add_observe(update_free_scatter_fig, "value")
y_norm_button.add_observe(update_free_scatter_fig, "value")
x_scale_button.add_observe(update_free_scatter_fig, "value")
y_scale_button.add_observe(update_free_scatter_fig, "value")
x_type_button.add_observe(update_free_scatter_fig, "value")
y_type_button.add_observe(update_free_scatter_fig, "value")
free_checkbox.observe(update_free_scatter_fig, "value")
tab_4_ma_ch.observe(update_free_scatter_fig, "value")
tab_4_ma_w.observe(update_free_scatter_fig, "value")


In [173]:
free_play = Play(
    value=0,
    min=0,
    max=1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto"),
)


def update_free_play(*args):
    global free_mark_x, free_mark_y
    index = args[0]["new"]
    date = date_selector.options[args[0]["new"]]
    Xs = np.array(free_mark_x)[:, : index + 1]
    Ys = np.array(free_mark_y)[:, : index + 1]
    Xs, Ys = clean_line_plot(list(Xs), list(Ys))
    free_scatter_fig.marks[1].x = np.array(
        [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[1].y = np.array(
        [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[0].x = np.array(Xs)
    free_scatter_fig.marks[0].y = np.array(Ys)
    free_scatter_fig.title = free_scatter_fig.title[:-10] + str(date)
    return


help_tab4 = help_button(help_tooltip["Tab 4"])
tab_4_top_right_buttons = VBox(
    [HBox(
                    [free_checkbox, free_play],
                    layout=Layout(justify_content="space-between", align_items='stretch'),
                ),
            ],
    layout=Layout(height="auto", width="auto"),)


In [174]:
grid_4 = GridspecLayout(2, 2)
grid_4.layout.overflow = "auto"
grid_4.layout.width = "100%"
grid_4.layout.height = "100%"
tab_4_top_right_buttons.layout.margin = "0px 3px 0px 5px"
grid_4_left_panel = GridspecLayout(1, 2)
grid_4_left_panel[0, 0] = help_tab4
grid_4_left_panel[0, 1] = free_scatter_fig
grid_4_left_panel.layout.grid_template_columns = "32px auto"
grid_4_left_panel.layout.height = "99%"
grid_4_left_panel.layout.overflow = "auto"
grid_4[:, 0] = grid_4_left_panel
grid_4[0, 1] = Box(
    [
        VBox(
            [tab_4_top_right_buttons, cat_tab_4_buttons, ma_box_tab_4],
            layout=Layout(
                height="360px", min_height="360px", max_height="360px", overflow="auto"
            ),
        )
    ]
)
grid_4[1, 1] = HBox(
    [tab_4_box_selectors], layout=Layout(min_height="210px", overflow="auto")
)
grid_4.layout.align_items = "stretch"
grid_4.layout.grid_template_columns = "auto 695px"
grid_4.layout.grid_template_rows = "360px auto"


In [175]:
tab_2_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    plot_scale_button,
    thr_val_slider,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    calendar_time,
    date_axis_selector_tab_2,
    tab_2_ma_ch,
    tab_2_ma_w,
    countries_selector,
    states_selector,
]
tab_2_buttons_actions = (
    [update_rebased_graph] * 8 + [update_zoom_tab_2] + [update_rebased_graph] * 4
)

tab_3_buttons = [
    dna_data_button,
    dna_norm_button,
    dna_type_button,
    tab_3_ma_ch,
    tab_3_ma_w,
    tab_3_norm,
    dna_countries_selector,
    dna_states_selector,
]
tab_3_buttons_actions = [update_dna] * len(tab_3_buttons)

tab_4_buttons = [
    x_data_button,
    x_norm_button,
    x_scale_button,
    x_type_button,
    y_data_button,
    y_norm_button,
    y_scale_button,
    y_type_button,
    free_checkbox,
    tab_4_ma_ch,
    tab_4_ma_w,
    free_countries_selector,
    free_states_selector,
]
tab_4_buttons_actions = [update_free_scatter_fig] * len(tab_4_buttons)

dict_tab_buttons = dict(zip([1, 2, 3], [tab_2_buttons, tab_3_buttons, tab_4_buttons]))
dict_tab_buttons_actions = dict(
    zip(
        [1, 2, 3], [tab_2_buttons_actions, tab_3_buttons_actions, tab_4_buttons_actions]
    )
)


In [176]:
def sync_selector_map(selector, chl, names_to_codes, map_id):
    values = selector.get_value()
    id_values = []
    for v in values:
        if v in names_to_codes.keys() and names_to_codes[v] in map_id:
            id_values.append(names_to_codes[v])
    new_data = dict(zip(id_values, [1] * len(id_values)))
    for v in map_id:
        if v not in id_values:
            new_data[v] = np.nan
    chl.choro_data = new_data
    return


class CountryPicker:
    def map_change(self, val):
        if val["new"] == True:
            map_name = val["owner"].description
            chl, center, zoom = self.dict_button_map[map_name]
            self.map_picker.substitute_layer(self.map_picker.layers[1], chl)
            self.map_picker.center = center
            self.map_picker.zoom = zoom
        return

    def style_picker(self, feature, colormap, choro_data):
        if np.isnan(choro_data):
            return {
                "fillColor": "#5a5a5a",
                "fillOpacity": 1,
                "color": "black",
                "weight": 0.7,
            }
        else:
            return {
                "fillColor": "red",
                "fillOpacity": 1,
                "color": "black",
                "weight": 0.7,
            }

    def map_click(self, feature, **kwargs):
        try:
            name = feature["properties"]["name"]
            _id = feature["id"]
            if (
                self.map_buttons.value == "World"
                and name in self.class_countries_selector.options
            ):
                new_data = {}
                new_values = []
                for k, v in self.chl_wm.choro_data.items():
                    new_data[k] = v
                    if not np.isnan(v) and k != _id:
                        new_values.append(k)
                if np.isnan(new_data[_id]):
                    new_data[_id] = 1
                    new_values.append(_id)
                else:
                    new_data[_id] = np.nan
                self.chl_wm.choro_data = new_data
                self.class_countries_selector.set_value(
                    [self.codes_to_countries[v] for v in new_values]
                )

            elif (
                self.map_buttons.value == "US States"
                and name in self.class_states_selector.options
            ):
                new_data = {}
                new_values = []
                for k, v in self.chl_us_states.choro_data.items():
                    new_data[k] = v
                    if not np.isnan(v) and k != _id:
                        new_values.append(k)

                if np.isnan(new_data[_id]):
                    new_data[_id] = 1
                    new_values.append(_id)
                else:
                    new_data[_id] = np.nan
                self.chl_us_states.choro_data = new_data
                self.class_states_selector.set_value(
                    [self.codes_to_states[v] for v in new_values]
                )

        except:
            pass

    def map_hover(self, feature, **kwargs):
        try:
            name = feature["properties"]["name"]
            self.tooltip.value = name
        except:
            self.tooltip.value = ""

    def button_update(self, val):
        if val["new"] == True:
            descr = val["owner"].description
            if descr == "List Picker":
                self.collapse_button.layout.visibility = "hidden"
                self.grid_picker[0, 1] = self.box_grid
                self.outer_grid[:, 0] = self.left_panel
                self.outer_grid[1, 1] = self.tab_box_selector
                self.box_grid.layout.grid_template_columns = "45% 5% 45%"
                self.outer_grid.layout.grid_template_rows = self.grid_template_rows
                hide_button.value = False
            else:
                self.collapse_button.layout.visibility = "visible"
                self.grid_picker[0, 1] = self.map_picker
                if self.hide:
                    self.outer_grid[0, 0] = HTML()
                    row_template = self.grid_template_rows
                else:
                    self.outer_grid[0, 0] = self.left_panel
                    row_template = "30% 70%"
                self.outer_grid[1, :] = self.tab_box_selector
                self.outer_grid.layout.grid_template_rows = row_template
                self.tab_box_selector.layout.width = "100%"

            self.outer_grid.layout.grid_template_columns = self.grid_template_columns
            self.grid_picker.layout.grid_template_columns = "110px auto"

    def __init__(
        self,
        countries_selector,
        states_selector,
        codes_to_countries,
        countries_to_codes,
        codes_to_states,
        states_to_codes,
        outer_grid,
        left_panel,
        grid_template_columns,
        grid_template_rows,
        tab_box_selector,
        hide=True,
    ):
        global global_style, styling, countries_geojson, countries_map_id, us_states_geojson
        self.class_countries_selector = countries_selector
        self.countries_selector = countries_selector.res
        self.class_states_selector = states_selector
        self.states_selector = states_selector.res
        self.codes_to_countries = codes_to_countries
        self.countries_to_codes = countries_to_codes
        self.codes_to_states = codes_to_states
        self.states_to_codes = states_to_codes
        self.outer_grid = outer_grid
        self.left_panel = left_panel
        self.grid_template_columns = grid_template_columns
        self.grid_template_rows = grid_template_rows
        self.tab_box_selector = tab_box_selector
        self.hide = hide
        self.tooltip = HTML()
        widget_control = ipl.WidgetControl(widget=self.tooltip, position="bottomleft")

        self.map_buttons = Toggle_Buttons(
            options=["World", "US States"],
            value="World",
            description="",
            min_button_width="100px",
            min_description_width="",
            style="warning",
            key="none",
        )
        widget_control2 = ipl.WidgetControl(
            widget=self.map_buttons, position="topright"
        )
        self.map_buttons.add_observe(self.map_change, "value")
        colormap = cm.StepColormap(colors=["black", "red"], index=[0, 1])

        self.chl_wm = ipl.Choropleth(
            geo_data=countries_geojson,
            choro_data=dict(zip(countries_map_id, [np.nan] * len(countries_map_id))),
            colormap=colormap,
            hover_style={"color": "white", "weight": 4},
            style_callback=self.style_picker,
        )

        self.chl_us_states = ipl.Choropleth(
            geo_data=us_states_geojson,
            choro_data=dict(zip(us_states_map_id, [np.nan] * len(us_states_map_id))),
            colormap=colormap,
            hover_style={"color": "white", "weight": 4},
            style_callback=self.style_picker,
        )

        self.map_picker = ipl.Map(
            basemap={},
            center=[24.68, 4.39],
            zoom=2,
            layout=Layout(width="99%", height="auto"),
        )
        self.map_picker.add_layer(self.chl_wm)
        self.map_picker.add_control(widget_control)
        self.map_picker.add_control(widget_control2)

        self.chl_wm.on_click(self.map_click)
        self.chl_us_states.on_click(self.map_click)
        self.chl_wm.on_hover(self.map_hover)
        self.chl_us_states.on_hover(self.map_hover)

        countries_selector.all_observe(
            lambda *args: sync_selector_map(
                countries_selector,
                self.chl_wm,
                self.countries_to_codes,
                countries_map_id,
            )
        )
        states_selector.all_observe(
            lambda *args: sync_selector_map(
                states_selector,
                self.chl_us_states,
                self.states_to_codes,
                us_states_map_id,
            )
        )
        countries_box = VBox(
            [
                HTML(
                    '<p style="text-align:center; padding:0px; margin:0px">Countries</p>'
                ),
                self.countries_selector,
            ],
            layout=Layout(width="100%"),
        )
        states_box = VBox(
            [
                HTML(
                    '<p style="text-align:center; padding:0px; margin:0px">US States</p>'
                ),
                self.states_selector,
            ],
            layout=Layout(width="100%"),
        )
        box_grid = GridspecLayout(1, 3)
        box_grid.layout = Layout(width="100%")
        box_grid[0, 0] = countries_box
        box_grid[0, 2] = states_box

        self.box_grid = box_grid

        self.pick_buttons = Toggle_Buttons(
            options=["List Picker", "Map Picker"],
            value="List Picker",
            description="",
            min_button_width="100px",
            min_description_width="0px",
            horizontal=False,
            button_width="auto",
            style="warning",
            key="none",
        )
        self.collapse_button = Button(
            description="Collapse",
            button_style="danger",
            layout=Layout(width="auto", visibility="hidden", margin="10px 0px 0px 0px"),
        )
        grid_picker = GridspecLayout(1, 2)
        grid_picker.layout = Layout(width="100%", height="100%", min_height="60px")
        grid_picker[0, 0] = VBox([self.pick_buttons, self.collapse_button])
        grid_picker[0, 1] = self.box_grid
        grid_picker.layout.grid_template_columns = "110px auto"
        self.grid_picker = grid_picker
        self.box_grid.layout.grid_template_columns = "45% 5% 45%"
        self.dict_button_map = {
            "List Picker": self.box_grid,
            "World": [self.chl_wm, [24.68, 4.39], 2],
            "US States": [self.chl_us_states, [37.57, -96.67], 4],
        }
        self.pick_buttons.add_observe(self.button_update, "value")
        self.collapse_button.on_click(
            lambda x: self.pick_buttons.set_value("List Picker")
        )


In [177]:
def initialize_dashboard(dict_datasets):
    global top_countries_names, states, presaved_sel_lists, dict_scenario

    # Initialize grid_1

    date_selector.options = dict_datasets[0].get_index("Cases")
    date_selector.value = dict_datasets[0].ENDT

    date_axis_selector.options = [datetime.datetime.strftime(x, "%m/%d/%y") for x in pd.date_range(dict_datasets[0].STDT, dict_datasets[0].ENDT).to_list()]

    date_axis_selector.index = (0, dict_datasets[0].get_len() - 1)

    play.max = dict_datasets[0].get_len() - 1
    play.value = dict_datasets[0].get_len() - 1

    map_color(
        dict_datasets[0],
        countries_map_id,
        data_buttons.value,
        norm_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    main_map_title.value = main_map_title_val.format(
        map_title(
            dict_datasets[0].name,
            data_buttons.value,
            norm_buttons.value,
            type_buttons.value,
            date_selector.value,
            tab_1_ma_ch.value,
            tab_1_ma_w.value,
        )
    )

    update_tables(
        data_buttons.value,
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    update_main_graph(
        dict_datasets[0],
        "USA",
        "Values",
        "Linear",
        "Total",
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].STDT,
        dict_datasets[0].ENDT,
        axis_t,
        axis_y,
        main_graph,
        init=True,
    )

    # initialize grid_2
    df = pd.read_csv(
        FOLDER_WORLD + "continents_restricted.csv", sep=";", index_col=1
    ).fillna("NA")
    presaved_sel_lists["Europe"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "EU"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Africa"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "AF"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Asia"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "AS"].index.values
        if c in codes_to_countries.keys()
        and c in dict_datasets[0].get_columns("Cases")
        and c != "EGY"
    ]
    presaved_sel_lists["North America"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "NA"].index.values
        if c in codes_to_countries.keys()
        and c in dict_datasets[0].get_columns("Cases")
        and c != "EGY"
    ]
    presaved_sel_lists["South America"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "SA"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Oceania"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "OC"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    top_countries_codes = (
        dict_datasets[0]
        .dict_data["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:]
        .index.values
    )
    top_countries_names = [codes_to_countries[c] for c in top_countries_codes]
    states = list(
        dict_datasets[1]
        .dict_data["Cases"]
        .loc[dict_datasets[1].ENDT, np.array(list(states_to_codes.keys()))]
        .sort_values(ascending=False)
        .index.values
    )
    presaved_sel_lists["Top10 (Cases)"] = [
        codes_to_countries[c]
        for c in top_countries_codes[:10]
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Top10 (Deaths)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Cases/1M)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data_pop["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Deaths/1M)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data_pop["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Cases daily chg)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .get_value(None, "Cases", "Values", "Linear", "Daily change", False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Deaths daily chg)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .get_value(None, "Deaths", "Values", "Linear", "Daily change", False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    date_axis_selector_tab_2.options = date_axis_selector.options
    date_axis_selector_tab_2.index = (0, dict_datasets[0].get_len() - 1)
    countries_options = list(np.sort(top_countries_names))
    states_options = list(np.sort(states))
    countries_selector.set_option(["World"] + countries_options)
    states_selector.set_option(states_options)

    cp_tab_2 = CountryPicker(
        countries_selector,
        states_selector,
        codes_to_countries,
        countries_to_codes,
        codes_to_states,
        states_to_codes,
        grid_2,
        grid_2_left_panel,
        "auto 590px",
        "315px auto",
        tab_2_box_selectors,
        hide=False,
    )
    tab_2_box_selectors[1, :] = cp_tab_2.grid_picker
    countries_selector.set_value(top_countries_names[:6])
    states_selector.set_value(states[:2])
    countries_selector.all_observe(update_rebased_graph)
    states_selector.all_observe(update_rebased_graph)
    tab_2_box_selectors.layout.grid_template_rows = "80px auto"
    plot_rebased_graph(
        "Cases",
        "Values",
        "Total",
        1000,
        top_countries_names[:6],
        states[:2],
        "Linear",
        "Cases",
        "Values",
        tab_2_ma_ch.value,
        tab_2_ma_w.value,
        calendar_time.value,
    )

    # initialize grid_3

    dna_countries_selector.set_option(["World"] + countries_options)
    dna_states_selector.set_option(states_options)
    cp_tab_3 = CountryPicker(
        dna_countries_selector,
        dna_states_selector,
        codes_to_countries,
        countries_to_codes,
        codes_to_states,
        states_to_codes,
        grid_3,
        grid_3_left_panel,
        "auto 590px",
        "140px auto",
        tab_3_box_selectors,
    )
    tab_3_box_selectors[1, :] = cp_tab_3.grid_picker
    dna_countries_selector.set_value(top_countries_names[1:20])
    dna_states_selector.set_value(states[:4])
    dna_countries_selector.all_observe(update_dna)
    dna_states_selector.all_observe(update_dna)

    data_countries = (
        dict_datasets[0]
        .get_ts(
            [countries_to_codes[c] for c in dna_countries_selector.value],
            dna_data_button.value,
            dna_norm_button.value,
            "Linear",
            dna_type_button.value,
            tab_3_ma_ch.value,
            tab_3_ma_ch.value,
        )
        .rename(columns=codes_to_countries)
    )
    data_states = dict_datasets[1].get_ts(
        dna_states_selector.value,
        dna_data_button.value,
        dna_norm_button.value,
        "Linear",
        dna_type_button.value,
        tab_3_ma_ch.value,
        tab_3_ma_ch.value,
    )
    df_DNA = pd.merge(
        data_countries, data_states, left_index=True, right_index=True, how="outer"
    ).iloc[:, ::-1]

    column_dna = pd.to_datetime(df_DNA.index.values)
    row_dna = list(df_DNA.columns.values)
    dna_heat_map = GridHeatMap(
        color=df_DNA.values.T,
        column=column_dna,
        row=row_dna,
        column_align="start",
        scales={"column": dna_x_scale, "row": dna_y_scale, "color": dna_color_scale},
        stroke="white",
    )
    dna_figure.marks = [dna_heat_map]
    dna_heat_map.tooltip = dna_stats_tootltip
    dna_heat_map.on_hover(update_dna_tooltip)
    tab_3_box_selectors.layout.grid_template_rows = "80px auto"

    # initialize grid_4

    free_countries_selector.set_option(["World"] + countries_options)
    free_states_selector.set_option(states_options)
    cp_tab_4 = CountryPicker(
        free_countries_selector,
        free_states_selector,
        codes_to_countries,
        countries_to_codes,
        codes_to_states,
        states_to_codes,
        grid_4,
        grid_4_left_panel,
        "auto 695px",
        "360px auto",
        tab_4_box_selectors,
        hide=False,
    )
    tab_4_box_selectors[1, :] = cp_tab_4.grid_picker
    free_countries_selector.set_value(
        [
            "United States",
            "France",
            "World",
            "Italy",
            "China",
            "Germany",
            "Russia",
            "Brazil",
        ]
    )
    free_states_selector.set_value(["New York", "California"])
    free_countries_selector.all_observe(update_free_scatter_fig)
    free_states_selector.all_observe(update_free_scatter_fig)

    free_play.max = dict_datasets[0].get_len() - 1
    free_play.value = dict_datasets[0].get_len() - 1
    selected_countries = [
        ["USA", "FRA", "WLD", "ITA", "CHN", "DEU", "RUS", "BRA"],
        ["New York", "California"],
    ]
    datas = ["Cases", "Cases"]
    norms = ["Values", "Values"]
    scales = ["Log", "Log"]
    data_types = ["Total", "Daily change"]
    last_point = False
    free_scatter(
        selected_countries,
        datas,
        norms,
        scales,
        data_types,
        tab_4_ma_ch.value,
        tab_4_ma_w.value,
        last_point,
    )
    date_axis_selector.observe(update_zoom, "value")
    date_axis_selector.continuous_update = False
    date_axis_selector_tab_2.observe(update_zoom_tab_2, "value")
    date_axis_selector_tab_2.continuous_update = False
    play.observe(update_index_date, "value")
    date_selector.observe(select_date, "value")
    free_play.observe(update_free_play, "value")
    tab_4_box_selectors.layout.grid_template_rows = "80px auto"
    download_button.layout.visibility = "visible"
    hide_button.layout.visibility = "visible"
    update_link_download()
    dict_scenario = {
        1: {
            "Scenario 1": [
                "Cases",
                "Per million",
                "Daily change",
                "Linear",
                None,
                None,
                None,
                True,
                [np.datetime64(ENDT) - np.timedelta64(30), np.datetime64(ENDT)],
                True,
                7,
                presaved_sel_lists["Top10 (Cases)"],
                [],
            ],
            "Scenario 2": [
                "Cases",
                "Values",
                "Total",
                "Log",
                None,
                None,
                None,
                True,
                [np.datetime64("2020-05-25"), np.datetime64(ENDT)],
                False,
                None,
                [],
                [
                    "California",
                    "Texas",
                    "Florida",
                    "Georgia",
                    "Arizona",
                    "North Carolina",
                    "Louisiana",
                    "Alabama",
                ],
            ],
            "Scenario 3": [
                "Deaths",
                "Values",
                "Total",
                "Linear",
                1000,
                "Deaths",
                "Values",
                False,
                None,
                False,
                None,
                presaved_sel_lists["Top10 (Deaths)"],
                [],
            ],
            "Scenario 4": [
                "Cases",
                "Values",
                "Daily change",
                "Linear",
                1000,
                "Cases",
                "Values",
                False,
                None,
                True,
                7,
                [
                    "Japan",
                    "Germany",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "South Korea",
                    "Spain",
                    "Belgium",
                    "Netherlands",
                ],
                [],
            ],
        },
        2: {
            "Scenario 3": [
                "Deaths",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                [
                    "United States",
                    "China",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Korea",
                    "Spain",
                    "Australia",
                    "Mexico",
                    "Indonesia",
                    "Netherlands",
                    "Saudi Arabia",
                    "Turkey",
                    "Switzerland",
                ],
                [],
            ],
            "Scenario 2": [
                "Cases",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                [],
                dna_states_selector.options,
            ],
            "Scenario 1": [
                "Cases",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                [
                    "United States",
                    "China",
                    "Hong Kong",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Korea",
                    "Spain",
                    "Australia",
                    "Mexico",
                    "Indonesia",
                    "Netherlands",
                    "Saudi Arabia",
                    "Turkey",
                    "Switzerland",
                ],
                [],
            ],
            "Scenario 4": [
                "Cases",
                "Values",
                "Daily change",
                True,
                7,
                True,
                [
                    "China",
                    "Hong Kong",
                    "South Korea",
                    "Iran",
                    "Australia",
                    "Italy",
                    "Germany",
                    "Spain",
                    "United Kingdom",
                    "France",
                    "Turkey",
                    "Japan",
                    "Canada",
                    "Russia",
                    "Saudi Arabia",
                    "United States",
                    "South Africa",
                    "Brazil",
                    "Colombia",
                    "India",
                    "World",
                ],
                [],
            ],
        },
        3: {
            "Scenario 1": [
                "Cases",
                "Per million",
                "Log",
                "Total",
                "Cases",
                "Per million",
                "Log",
                "Daily change",
                False,
                True,
                7,
                [
                    "United States",
                    "Brazil",
                    "France",
                    "Germany",
                    "Russia",
                    "China",
                    "South Korea",
                    "Japan",
                    "Hong Kong",
                    "United Kingdom",
                    "Spain",
                    "Italy",
                ],
                [],
            ],
            "Scenario 2": [
                "Tests",
                "Per million",
                "Linear",
                "Total",
                "Cases",
                "Per million",
                "Linear",
                "Total",
                False,
                True,
                7,
                [
                    "World",
                    "United States",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Africa",
                    "Iran",
                    "Spain",
                ],
                [],
            ],
            "Scenario 3": [
                "Cases",
                "Per million",
                "Linear",
                "Total",
                "Vaccine",
                "Per million",
                "Linear",
                "Total",
                False,
                True,
                7,
                [
                    "World",
                    "United States",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Israel",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "Spain",
                ],
                [],
            ],
            "Scenario 4": [
                "Cases",
                "Per million",
                "Linear",
                "Total",
                "Deaths",
                "Per million",
                "Linear",
                "Total",
                True,
                False,
                None,
                [
                    "Austria",
                    "Azerbaijan",
                    "Belarus",
                    "Belgium",
                    "Bosnia",
                    "Bulgaria",
                    "Croatia",
                    "Cyprus",
                    "Czechia",
                    "Denmark",
                    "Estonia",
                    "Finland",
                    "France",
                    "Georgia",
                    "Germany",
                    "Greece",
                    "Hungary",
                    "Iceland",
                    "Ireland",
                    "Italy",
                    "Kosovo",
                    "Latvia",
                    "Lithuania",
                    "Macedonia",
                    "Malta",
                    "Moldova",
                    "Monaco",
                    "Montenegro",
                    "Netherlands",
                    "Norway",
                    "Poland",
                    "Portugal",
                    "Romania",
                    "Russia",
                    "Serbia",
                    "Slovakia",
                    "Slovenia",
                    "Spain",
                    "Sweden",
                    "Switzerland",
                    "Turkey",
                    "Ukraine",
                    "United Kingdom",
                ],
                [],
            ],
        },
    }
    return


In [178]:
def skip_click(*args):
    global STDT, ENDT, countries_to_codes, codes_to_countries, dict_datasets, progress_bar, current_action
    progress_bar = FloatProgress(
        value=0.1,
        min=0,
        max=1.0,
        step=0.01,
        description="Loading data:",
        orientation="horizontal",
    )
    current_action = HTML("Collecting country level Cases/Deaths data...")
    outer_tab.children = [
        VBox(
            [HTML("ONE MOMENT PLEASE..."), progress_bar, current_action],
            layout=Layout(width="auto"),
        )
    ]
    outer_tab._titles = {0: "Downloading..."}
    (
        datasets_World,
        STDT,
        ENDT,
        countries_to_codes,
        codes_to_countries,
    ) = collect_World_data(FOLDER_WORLD)
    datasets_US_States, datasets_US_counties = collect_US_data(FOLDER_US)
    dict_datasets = {0: datasets_World, 1: datasets_US_States, 2: datasets_US_counties}
    initialize_dashboard(dict_datasets)
    progress_bar = 0.98
    current_action.value = "Create Dashboard..."
    tab_contents = [
        "INFECTION MAPS",
        "TIME TRACK",
        "HEAT MAP",
        "CUSTOM GRAPH",
        "USER GUIDE",
        "DATA SOURCES",
    ]
    hp.skip_button.layout.visibility = "hidden"
    help_page.layout.grid_template_rows = "0px 70px auto auto"
    children = [
        grid_1,
        grid_2,
        grid_3,
        grid_4,
        help_page,
        data_sources_tab
    ]
    outer_tab._titles = dict(zip(np.arange(0, 6), tab_contents))
    outer_tab.children = children
    outer_tab.selected_index = 0


hp.skip_button.on_click(skip_click)


In [179]:
visible_template = {
    1: ["210px auto", "auto 565px"],
    2: ["315px auto", "auto 590px"],
    3: ["140px auto", "auto 590px"],
    4: ["360px auto", "auto 695px"],
}

hidden_template = {
    1: ["0px auto", "auto 565px"],
    2: ["0px auto", "auto 0px"],
    3: ["0px auto", "auto 0px"],
    4: ["0px auto", "auto 0px"],
}
grids = dict(zip(np.arange(1, 5), [grid_1, grid_2, grid_3, grid_4]))


def hide_buttons(*args):
    new_value = args[0]["new"]
    if new_value:
        hide_button.button_style = "danger"
        hide_button.description = "Show Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = hidden_template[i][0]
            grids[i].layout.grid_template_columns = hidden_template[i][1]
    else:
        hide_button.button_style = "warning"
        hide_button.description = "Hide Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = visible_template[i][0]
            grids[i].layout.grid_template_columns = visible_template[i][1]
    return


In [180]:
hide_button = ToggleButton(
    description="Hide Controls",
    value=False,
    button_style="warning",
    layout=Layout(
        min_width="150px",
        width="150px",
        max_width="150px",
        min_height="30px",
        visibility="hidden",
    ),
)

hide_button.observe(hide_buttons, "value")


In [181]:
download_button = Button(
    description="Download Figure",
    button_style="",
    layout=Layout(
        min_width="150px",
        width="150px",
        max_width="150px",
        min_height="30px",
        visibility="hidden",
    ),
)


In [182]:
def download_fig(*args):
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        main_graph.save_png(main_graph.title.replace('.', ''))
    elif selected_index == 1:
        rebased_graph.save_png("Time_Track" + rebased_graph.title.replace('.', ''))
    elif selected_index == 2:
        dna_figure.save_png("Heat_map_" + dna_figure.title.replace('.', ''))
    elif selected_index == 3:
        free_scatter_fig.save_png(free_scatter_fig.title.replace('.', ''))
    return


download_button.on_click(download_fig)


In [183]:
def download_data():
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        data_csv = pd.DataFrame(
            data=main_graph.marks[0].y,
            index=main_graph.marks[0].labels,
            columns=main_graph.marks[0].x,
        ).transpose()
        title = (
            main_graph.title.replace(".", "").replace("/", "Per").replace("Log ", "")
            + "_data.csv"
        )
    elif selected_index == 1:
        if rebased_graph.marks[1].text.shape[0] == 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[0])
            data = rebased_graph.marks[0].y.reshape(1, -1)
        elif rebased_graph.marks[1].text.shape[0] > 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[1])
            data = rebased_graph.marks[0].y
        if calendar_time.value:
            data_csv = pd.DataFrame(
                data=data,
                index=rebased_graph.marks[1].text,
                columns=rebased_graph.marks[0].x,
            ).transpose()
        else:
            data_csv = pd.DataFrame(
                data=data, index=rebased_graph.marks[1].text, columns=cols
            ).transpose()
        title = (
            "Time_Track_"
            + rebased_graph.title.replace(".", "")
            .replace("/", "Per")
            .replace("Log ", "")
            + "_data.csv"
        )
    elif selected_index == 2:
        data_csv = (
            pd.DataFrame(
                data=dna_figure.marks[0].color,
                index=dna_figure.marks[0].row,
                columns=dna_figure.marks[0].column,
            )
            .iloc[::-1]
            .transpose()
        )
        title = (
            "Heat_map_"
            + dna_figure.title.replace(".", "").replace("/", "Per")
            + "_data.csv"
        )
    elif selected_index == 3:
        arr = []
        ind0 = []
        ind1 = []
        Y_label = free_scatter_fig.title.split("vs")[0].strip().replace("Log ", "")
        X_label = (
            free_scatter_fig.title.split("vs")[1]
            .split(",")[0]
            .strip()
            .replace("Log ", "")
        )
        if tab_4_ma_ch.value:
            Y_label += " (" + X_label.split("(")[1]
        for i in range(free_scatter_fig.marks[1].text.shape[0]):
            arr.append(free_scatter_fig.marks[0].x[i])
            arr.append(free_scatter_fig.marks[0].y[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind1.append(X_label)
            ind1.append(Y_label)
        ind = [np.array(ind0), np.array(ind1)]
        if free_checkbox.value:
            col = [ENDT]
        else:
            col = np.arange(
                np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D")
            )
        data_csv = pd.DataFrame(data=arr, index=ind, columns=col).transpose()
        title = (
            free_scatter_fig.title.replace(".", "")
            .replace("/", "Per")
            .replace("Log ", "")
            + "_data.csv"
        )
    return data_csv, title


download_link = '<a style="background-color: #424242; border: none; color: white; padding: 1px 22px; visibility: {visibility}; text-align: center; text-decoration: none; display: inline-block; font-size: 13px; margin: 0px; cursor: pointer" download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">Download Data</a>'
download_data_button = HTML(
    download_link.format(visibility="hidden", payload="", filename="")
)


def update_link_download():
    try:
        df, filename = download_data()
        csv = df.to_csv()
        b64 = base64.b64encode(csv.encode())
        payload = b64.decode()
        download_data_button.value = download_link.format(
            visibility="visible", payload=payload, filename=filename
        )
    except:
        pass
    return


def hide_show_top_buttons(val):
    sel_index = val["new"]
    if sel_index == 4:
        download_data_button.value = download_link.format(
            visibility="hidden", payload="", filename=""
        )
        download_button.layout.visibility = "hidden"
        hide_button.layout.visibility = "hidden"
    else:
        download_button.layout.visibility = "visible"
        hide_button.layout.visibility = "visible"
    return


outer_tab.observe(lambda x: update_link_download(), "selected_index")
outer_tab.observe(hide_show_top_buttons, "selected_index")


In [184]:
import IPython.display

js = """
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons"> 
</head>
<body>
<style>
.round-button {
    display:block;
    width:25px;
    height:30px;
    line-height:15px;
    border: 2px solid #f5f5f5;
    border-radius: 50%;
    color: #ff8b0e;
    text-align:center;
    text-decoration:none;
    background-color: #606060;
    box-shadow: 0 0 3px gray;
    font-size:14px;
    font-weight:bold;
    margin: 0px 0px 3px 0px;
}
.round-button:hover {
    background: white;
}
</style>
<button class="round-button", onclick="fullscreen();"><i class="material-icons">fullscreen</i></button>
"""


In [185]:
grid_header_1 = GridspecLayout(1, 3)
grid_header_1.layout.width = "100%"
grid_header_1.layout.height = "auto"
grid_header_1.layout.margin = "0px 0px 0px 0px"
header_controls = HBox(
    [hide_button, download_button, download_data_button],
    layout=Layout(
        align_items="stretch",
        justify_content="flex-end",
        margin="0px 0px 0px 0px",
        overflow="visible",
    ),
)
grid_header_1[0, 0] = HTML(js)
grid_header_1[0, 2] = header_controls
header_1 = HBox(
    [
        HTML(
            value="<h1 style='color: #ff8b0e; padding:0px; margin:0px; text-align:center; justify-content:center; font-weight: bold; font-size: 34px'>(BETA version) Bloomberg COVID-19 Dashboard</h1><h3 style='color: #ff8b0e; padding:0px; margin:4px 0px 0px 0px; text-align:center; font-size: 18px'>Bloomberg Quant Research</h3>",
            layout=Layout(width="800px", overflow="visible"),
        )
    ],
    layout=Layout(width="100%", height="auto", justify_content="center"),
)
grid_header_1.layout.align_items = "stretch"
grid_header_1.layout.grid_template_columns = "50px auto 450px"
grid_header_1.layout.overflow = "visible"
grid_header = VBox([header_1, grid_header_1], layout=Layout(overflow="visible"))


In [186]:
global_style += """
<style>
    .bqplot > svg .g_legend text.legendtext{
        opacity: 1 !important;
    }
    .bqplot > svg .g_legend path.line{
        opacity: 1 !important;
    }
    .bqplot > svg .g_legend path.dot{
        opacity: 1 !important;
    }
    .outer-tab-class li div{
          color: #ff8b0e;
          text-align: center;
          margin: 0px;
          padding: 0px;
          font-size: 14px;
          font-weight: bold;
    }
    .widget-slider .ui-slider .ui-slider-handle{
          background-color: #ff8b0e;
    }
    .hb {
        font-weight: bold;
        font-size: 18px;
        border-radius: 8px;
    }
    .hbb {
        font-weight: bold;
        font-size: 18px;
        border-radius: 20px;
        overflow: visible;
        position: relative;
    } 
    ::-webkit-scrollbar {
      width: 12px;
      border: 5px solid #606060;
      border-radius: 20px;
    }

    ::-webkit-scrollbar-thumb {
      background-color: #ff8b0e;
      background-clip: padding-box;
      border: 3px solid #606060;
      border-radius: 20px;
    }

    ::-webkit-scrollbar-track {
      background-color: #606060;
      border-radius: 20px;
    }
    .leaflet-control-legend.leaflet-control p {
        margin: 0px;
    }
    svg.leaflet-zoom-animated{
        background-color: black;
    }
    
    .leaflet-control-fullscreen-button a{
        background-color: white;
    }
    
</style>
"""
styling = HTML(global_style, layout=Layout(visibility="hidden", display="none"))


In [187]:
data_source = """<html><body><p style='color: #ff8b0e; padding:0; margin:0'> Data Sources (Scroll for more): - JHU CSSE COVID-19 Data, licensed by Johns Hopkins University under the <a href='https://creativecommons.org/licenses/by/4.0/' target="_blank" style='color:#1E90FF'>Creative Commons Attribution 4.0 International CC BY 4.0</a>. <a href='https://github.com/CSSEGISandData/COVID-19' target="_blank" style='color:#1E90FF'> https://github.com/CSSEGISandData/COVID-19 </a> 
             </br>- COVID Tracking Data, licensed by Covid Tracking Project under the Apache License 2.0 <a href='https://github.com/COVID19Tracking/covid-tracking-data' target="_blank" style='color:#1E90FF'> https://github.com/COVID19Tracking/covid-tracking-data </a>
             </br>- COVID-19-Data/Testing and COVID-19-Data/Vaccinations, licensed by OWID under the <a href='https://creativecommons.org/licenses/by/4.0/' target="_blank" style='color:#1E90FF'>Creative Commons Attribution 4.0 International CC BY 4.0</a>. <a href='https://github.com/owid/covid-19-data' target="_blank" style='color:#1E90FF'> https://github.com/owid/covid-19-data </a>
             </br>- govex/COVID-19, licensed by JHU under the MIT License <a href='https://github.com/govex/COVID-19' target="_blank" style='color:#1E90FF'> https://github.com/govex/COVID-19 </a></p></body></html>"""
data_sources_tab = HTML(data_source, layout=Layout(width='auto'))
grid_footer = GridspecLayout(1, 1, layout=Layout(width="100%", align_items="center"))
grid_footer[0, 0] = HBox(
    [
        HTML(
            data_source,
            layout=Layout(
                width="100%",
                height="30px",
                max_height="30px",
                min_height="30px",
                overflow="auto",
            ),
        )
    ],
    layout=Layout(width="auto", height="auto", overflow="auto"),
)
outer_tab.layout.height = "calc(100vh - 205px)"
outer_tab.add_class("outer-tab-class")
hide_button.add_class("hbb")

Dashboard = GridspecLayout(4, 1)
Dashboard[0, 0] = grid_header
Dashboard[1, 0] = outer_tab
Dashboard[2, 0] = grid_footer
Dashboard[3, 0] = styling
Dashboard.layout.width = "100%"
Dashboard.layout.height = "99%"
Dashboard.layout.overflow = "hidden"
Dashboard.layout.grid_template_rows = "110px auto 35px 0px"
IPython.display.display(Dashboard)
IPython.display.Javascript(
    """var elem = document.documentElement;
var state = 0;
function openFullscreen() {
  if (elem.requestFullscreen) {
    elem.requestFullscreen();
  } else if (elem.mozRequestFullScreen) { /* Firefox */
    elem.mozRequestFullScreen();
  } else if (elem.webkitRequestFullscreen) { /* Chrome, Safari & Opera */
    elem.webkitRequestFullscreen();
  } else if (elem.msRequestFullscreen) { /* IE/Edge */
    elem.msRequestFullscreen();
  }
}

function closeFullscreen() {
  if (document.exitFullscreen) {
    document.exitFullscreen();
  } else if (document.mozCancelFullScreen) {
    document.mozCancelFullScreen();
  } else if (document.webkitExitFullscreen) {
    document.webkitExitFullscreen();
  } else if (document.msExitFullscreen) {
    document.msExitFullscreen();
  }
}

function fullscreen(){
    if (state == 0){
        openFullscreen();
        state = 1;
    } else if (state == 1){
        closeFullscreen();
        state = 0;
    }
}"""
)


GridspecLayout(children=(VBox(children=(HBox(children=(HTML(value="<h1 style='color: #ff8b0e; padding:0px; mar…

<IPython.core.display.Javascript object>