In [None]:
# Copyright 2020 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import time
import math
from ipywidgets import *
from bqplot import *
import datetime
from data_processor import DataProcessor
import bisect
from ipywidgets import Image as im
from toggle_buttons import Toggle_Buttons
from pretty_breaks import PrettyBreaks
import itertools
from collections import Counter
import base64

import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

# import IPython.core.display as ipd
# ipd.display(ipd.HTML("<style>.container { width:100% !important; }</style>"))


In [None]:
pretty_breaks = PrettyBreaks("../data/WORLD/area_lat.csv")

In [None]:
class Help_Page:
    def __init__(self, images_url, images_description):
        help_page = GridspecLayout(
            5, 3, layout=Layout(width="100%", height="100%", justify_items="center")
        )
        self.next_button = Button(
            description="Next",
            button_style="warning",
            layout=Layout(height="50px", margin="0px 0px 0px 10px"),
        )
        self.previous_button = Button(
            description="Previous",
            button_style="warning",
            layout=Layout(
                visibility="hidden", height="50px", margin="0px 10px 0px 0px"
            ),
        )
        self.skip_button = Button(
            description="Go to application",
            button_style="success",
            layout=Layout(height="50px", margin="0px 0px 0px 0px"),
        )
        images = [open(im, "rb").read() for im in images_url]
        self.images = images
        self.images_description = images_description
        self.index_img = 0
        self.img_container = Box(
            children=[
                im(
                    value=self.images[self.index_img],
                    format="png",
                    layout=Layout(object_fit="scale-down"),
                    margin="0 0 0 0",
                )
            ],
            layout=Layout(
                width="auto",
                height="auto",
                overflow_x="hidden",
                margin="10px 0px 10px 0px",
            ),
        )

        help_page[3, 1] = self.img_container
        self.action_container = HBox(
            [self.previous_button, self.next_button],
            layout=Layout(
                justify_content="space-between", align_items="center", width="60%"
            ),
        )
        help_page[4, 1] = self.action_container

        self.header = HTML(
            self.images_description[self.index_img],
            layout=Layout(margin="5px 15px 5px 15px",)
        )

        self.title = HTML(
            "<h1 style='text-align: center; color: #ff8b0e'> User Guide </h1>"
        )

        help_page[0, 1] = VBox([self.skip_button], layout=Layout(align_items="center"))
        help_page[1, 1] = VBox([self.title], layout=Layout(align_items="center"))
        help_page[2, 1] = Box(
                [self.header], 
                layout=Layout(
                width="75%",
                height="85px",
                border="solid #ff8b0e",
                overflow="auto",
                justify_content='center',
            ),)
        help_page.layout.grid_template_columns = "10% 80% 10%"
        help_page.layout.grid_template_rows = "50px 70px 90px auto 50px"
        self.next_button.on_click(self.next_click)
        self.previous_button.on_click(self.prev_click)
        self.help_page = help_page

    def next_click(self, *args):
        if self.index_img == len(self.images) - 2:
            self.next_button.layout.visibility = "hidden"
        if self.index_img == 0:
            self.previous_button.layout.visibility = "visible"

        self.index_img += 1
        self.header.value = self.images_description[self.index_img]
        self.img_container.children = [
            im(
                value=self.images[self.index_img],
                format="png",
                layout=Layout(object_fit="scale-down"),
                margin="0 0 0 0",
            )
        ]
        return

    def prev_click(self, *args):
        if self.index_img == len(self.images) - 1:
            self.next_button.layout.visibility = "visible"
        if self.index_img == 1:
            self.previous_button.layout.visibility = "hidden"

        self.index_img -= 1
        self.img_container.children = [
            im(
                value=self.images[self.index_img],
                format="png",
                layout=Layout(object_fit="scale-down"),
                margin="0 0 0 0",
            )
        ]
        self.header.value = self.images_description[self.index_img]
        return

    def get_page(self):
        return self.help_page


In [None]:
def help_button(tooltip):
    return Button(
        icon="info",
        tooltip=tooltip,
        button_style="success",
        layout=Layout(width="30px", height="30px", margin="0px 0px 0px 0px"),
    )


In [None]:
FOLDER_IMG = "../screenshots/user_guide/"
image_names = [
    "Tab1_height.PNG",
    "Tab1_tabs.PNG",
    "Tab1_download.PNG",
    "Tab1_typepicker.PNG",
    "Tab1_date_sel.PNG",
    "Tab1_maps.PNG",
    "Tab1_hover.PNG",
    "Tab1_legend.PNG",
    "Tab1_table.PNG",
    "Tab2_controls.PNG",
    "Tab2_2.PNG",
    "Tab2_3.PNG",
    "Tab2_4.PNG",
    "Tab2_5_1.PNG",
    "Tab2_5.PNG",
    "Tab3_1.PNG",
    "Tab3_2.PNG",
    "Tab4_1.PNG",
    "Tab4_2.PNG",
    "Tab4_3.PNG",
]
images = [FOLDER_IMG + im for im in image_names]
images_description = [
    "Use the top right slider to modify the dashboard's height to fit your screen",
    "Use the top left tabs to use different applications of the dashboard (Infection Maps/Rebased Graph/Heat Map/Custom Graph/User Guide)",
    "You can download the figure you created using the app and the corresponding data by using the top left buttons: 'Download Figure' and 'Download Data'<br/> For a larger figure click on 'Hide Controls' before.",
    "Each tab has a type picker which is a block of different sets of buttons (here Data, Type, Norm and Scale + Moving average). This type picker allows you to choose a unique combination of data and display it in the corresponding figure in the tab.",
    "Select a date for the map by using the 'Date' slider. Click Play to animate.",
    "Use maps buttons to display one of the 3 available maps (World/US States/US Counties)",
    "By hovering over a country/state/county in the map a data summary table will be displayed. By clicking on a country/state/county you can change the right-hand side graph",
    "On the right-hand side graph you can select a sub-period by using the bottom range slider and select only one of the 5 available curves by clicking on the figure legend",
    "Navigate between table and graph view using the tabs on the right-hand side",
    "In tab 2 (Rebased graph) you have 3 blocks of controls:<br/> &emsp;- Type picker like in tab 1 (red frame)<br/> &emsp;- Threshold buttons to choose which condition to apply to get the initial point (here 1000 cases) (green frame)<br/> &emsp;- Country picker in a list or map format (blue frame). To add/remove a country press Ctrl(\u2318) and click",
    'You can remove the threshold rule by clicking on the "Calendar Time" Checkbox. Select a sub-period of time by using the bottom left range slider',
    "Use predefined selections to pick countries from preset lists",
    "Use predefined scenarios for examples",
    "By clicking on the country picker buttons you can select countries/states by clicking on a world or US states map",
    'Click on "Collapse" to go back to the previous view',
    "Tab 3 (Heat Map) has a type picker and a country picker. It works the same way as the previous tabs. Each row is assigned to a country/state. Time runs horizontally.",
    '"Normalize" button divides each row by its maximum. It allows to detect the maximum value of each entity. Without normalizing the values of all countries/states are compared, ranked and coloured',
    "Tab 4 (Custom Graph) allows you to select a data type combination for both x and y axis.",
    'By clicking on "Last data point" checkbox you can display only the most recent value of each selected entity',
    "You can pan by clicking and dragging. You can zoom by using your mouse wheel",
]
style = '<p style="text-align: justify; color: white; font-size: 15px; font-weight: bold; padding: 0px; margin: 2px;">'
descriptions = [style + d + "</p>" for d in images_description]


In [None]:
hp = Help_Page(images, descriptions)
help_page = hp.get_page()

In [None]:
outer_tab = Tab(children=[help_page], _titles={0:'User guide'}, layout=Layout(height='740px'))

In [None]:
# Collect World Data


BASE_URL = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
FOLDER_WORLD = "../data/WORLD/"
URL_TESTS = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/testing/covid-testing-all-observations.csv"
dateformat = "%Y-%m-%d"


def collect_World_data(BASE_URL, FOLDER_WORLD, URL_TESTS):
    url_cases = BASE_URL + "time_series_covid19_confirmed_global.csv"
    url_deaths = BASE_URL + "time_series_covid19_deaths_global.csv"
    url_rec = BASE_URL + "time_series_covid19_recovered_global.csv"
    dataframes = {
        "Cases": pd.read_csv(url_cases, header=0),
        "Deaths": pd.read_csv(url_deaths, header=0),
        "Recovered": pd.read_csv(url_rec, header=0),
    }
    for key, df in dataframes.items():
        df.loc[df["Province/State"] == "Hong Kong", "Country/Region"] = "Hong Kong"
        df.set_index("Country/Region", inplace=True)
    pops = pd.read_csv(FOLDER_WORLD + "Population.csv", header=0).dropna()
    pops.set_index("Country Code", inplace=True)
    country_codes = pd.read_csv(
        FOLDER_WORLD + "world_map_codes.csv", header=0, index_col=0
    )
    countries_to_codes = country_codes["ISOA3"].to_dict()
    ID_to_codes = dict(zip(country_codes.ISON3, country_codes.ISOA3))

    new_names = {
        "Antigua and Barbuda": "Antigua & Barbuda",
        "Bosnia and Herzegovina": "Bosnia",
        "Cabo Verde": "Cape Verde",
        "Congo (Brazzaville)": "Congo - Brazzaville",
        "Congo (Kinshasa)": "Congo - Kinshasa",
        "Eswatini": "Swaziland",
        "Holy See": "Vatican City",
        "Korea, South": "South Korea",
        "North Macedonia": "Macedonia",
        "Saint Lucia": "St. Lucia",
        "Saint Vincent and the Grenadines": "St. Vincent & Grenadines",
        "Trinidad and Tobago": "Trinidad & Tobago",
        "US": "United States",
        "Saint Kitts and Nevis": "St. Kitts & Nevis",
        "Burma": "Myanmar",
        "Taiwan*": "Taiwan",
    }
    # normalize country names
    for df in dataframes.values():
        df.rename(index=new_names, inplace=True)

    # assign dummy codes to a few entities
    countries_to_codes["World"] = "WLD"
    countries_to_codes["Diamond Princess"] = "DPS"
    countries_to_codes["Taiwan"] = "TWN"
    countries_to_codes["West Bank and Gaza"] = "PSE"
    countries_to_codes["Kosovo"] = "XKX"
    countries_to_codes["MS Zaandam"] = "MSZ"
    ID_to_codes[158] = "TWN"
    ID_to_codes[-2] = "XKX"
    codes_to_countries = {v: k for k, v in countries_to_codes.items()}
    codes_to_ID = {v: k for k, v in ID_to_codes.items()}

    for df in dataframes.values():
        df["code"] = [
            countries_to_codes[cod] if cod in countries_to_codes.keys() else None
            for cod in df.index.values
        ]

    old_dateformat = "%m/%d/%y"
    new_index = [
        datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
        for d in dataframes["Cases"].columns.values[3:-1]
    ]
    # set index and dateformat
    for k, df in dataframes.items():
        df = df.iloc[:, 3:].groupby(["code"]).sum().transpose().reset_index()
        df["Date"] = new_index
        dataframes[k] = df.set_index("Date", drop=True).drop("index", axis=1)

    progress_bar.value = 0.3
    current_action.value = "Collecting country level Test data..."
    df_world_tests = pd.read_csv(
        URL_TESTS,
        index_col=0,
        header=0,
        usecols=["ISO code", "Date", "Cumulative total"],
    )
    df_world_tests = (
        pd.pivot_table(df_world_tests, index="Date", columns="ISO code")
        .loc[:, "Cumulative total"]
        .fillna(method="ffill")
        .replace({math.nan: 0})
    )
    for c in dataframes["Cases"].columns.values:
        if c not in df_world_tests.columns.values:
            df_world_tests[c] = 0
    STDT = dataframes["Cases"].index[0]
    ENDT = dataframes["Cases"].index[-1]
    df_world_tests = df_world_tests.loc[STDT:]

    if df_world_tests.shape[0] < len(new_index):
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_world_tests.index.values[-1]) + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_world_tests.columns.values.shape[0]
        df_world_tests = df_world_tests.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_world_tests.columns.values,
            )
        )
    elif df_world_tests.shape[0] > len(new_index):
        df_world_tests = df_world_tests.loc[:ENDT]

    dataframes["Tests"] = df_world_tests.loc[:, dataframes["Cases"].columns.values]

    for c in list(dataframes.keys()):
        dataframes[c] = dataframes[c].cummax()

    dataframes["Active Cases"] = (
        dataframes["Cases"]
        - dataframes["Deaths"].fillna(0)
        - dataframes["Recovered"].fillna(0)
    )
    datasets_World = DataProcessor(
        dataframes, pops, "World", codes_to_ID, add_world_data=True
    )  # World data processor
    progress_bar.value = 0.5
    current_action.value = (
        "Collecting US state level Cases/Deaths/Recovered/Test data..."
    )
    return datasets_World, STDT, ENDT, countries_to_codes, codes_to_countries


In [None]:
states_to_codes = {
    "Alabama": "AL",
    "Alaska": "AK",
    "American Samoa": "AS",
    "Arizona": "AZ",
    "Arkansas": "AR",
    "California": "CA",
    "Colorado": "CO",
    "Connecticut": "CT",
    "Delaware": "DE",
    "District of Columbia": "DC",
    "Florida": "FL",
    "Georgia": "GA",
    "Guam": "GU",
    "Hawaii": "HI",
    "Idaho": "ID",
    "Illinois": "IL",
    "Indiana": "IN",
    "Iowa": "IA",
    "Kansas": "KS",
    "Kentucky": "KY",
    "Louisiana": "LA",
    "Maine": "ME",
    "Maryland": "MD",
    "Massachusetts": "MA",
    "Michigan": "MI",
    "Minnesota": "MN",
    "Mississippi": "MS",
    "Missouri": "MO",
    "Montana": "MT",
    "Nebraska": "NE",
    "Nevada": "NV",
    "New Hampshire": "NH",
    "New Jersey": "NJ",
    "New Mexico": "NM",
    "New York": "NY",
    "North Carolina": "NC",
    "North Dakota": "ND",
    "Northern Mariana Islands": "MP",
    "Ohio": "OH",
    "Oklahoma": "OK",
    "Oregon": "OR",
    "Pennsylvania": "PA",
    "Puerto Rico": "PR",
    "Rhode Island": "RI",
    "South Carolina": "SC",
    "South Dakota": "SD",
    "Tennessee": "TN",
    "Texas": "TX",
    "Utah": "UT",
    "Vermont": "VT",
    "Virgin Islands": "VI",
    "Virginia": "VA",
    "Washington": "WA",
    "West Virginia": "WV",
    "Wisconsin": "WI",
    "Wyoming": "WY",
}

codes_to_states = {v: k for k, v in states_to_codes.items()}


In [None]:
FOLDER_US = "../data/USA/"
BASE_URL_US = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
# URL_TESTS_US = "https://raw.githubusercontent.com/COVID19Tracking/covid-tracking-data/master/data/states_daily_4pm_et.csv"
URL_TESTS_US = "https://api.covidtracking.com/v1/states/daily.csv"


def collect_US_data(BASE_URL, FOLDER_US, URL_TESTS_US):
    url_cases_US = BASE_URL + "time_series_covid19_confirmed_US.csv"
    url_deaths_US = BASE_URL + "time_series_covid19_deaths_US.csv"

    df_cases_US = pd.read_csv(url_cases_US, header=0)
    df_deaths_US = pd.read_csv(url_deaths_US, header=0)
    old_dateformat = "%m/%d/%y"
    new_index = [
        datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
        for d in df_cases_US.columns.values[11:]
    ]
    df_cases_US_States = (
        df_cases_US.groupby(["Province_State"])
        .sum()
        .transpose()
        .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
        .reset_index()
    )
    df_cases_US_States["Date"] = new_index
    df_cases_US_States = df_cases_US_States.set_index("Date").drop("index", axis=1)

    df_deaths_US_States = (
        df_deaths_US.groupby(["Province_State"])
        .sum()
        .transpose()
        .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
    )
    pops_US_States = (
        df_deaths_US_States.loc["Population"]
        .to_frame(name="2018")
        .replace({0: math.nan})
        .dropna()
    )
    df_deaths_US_States.drop(["Population"], inplace=True)
    df_deaths_US_States["Date"] = new_index
    df_deaths_US_States = (
        df_deaths_US_States.reset_index().set_index("Date").drop("index", axis=1)
    )
    df_rec_tests_US_States = pd.read_csv(
        URL_TESTS_US,
        index_col=[0, 1],
        header=0,
        usecols=["date", "state", "totalTestResults", "recovered"],
    )
    df_rec_US_States = (
        pd.pivot_table(df_rec_tests_US_States, index="date", columns="state")
        .loc[:, "recovered"]
        .rename(columns=codes_to_states)
        .replace({math.nan: 0})
    )
    if df_rec_US_States.shape[0] > len(new_index):
        df_rec_US_States = df_rec_US_States[: len(new_index)]

    df_rec_US_States["Date"] = new_index[: df_rec_US_States.shape[0]]

    df_rec_US_States.set_index("Date", inplace=True)
    all_states = df_cases_US_States.columns.values
    for c in all_states:
        if c not in df_rec_US_States.columns.values:
            df_rec_US_States[c] = 0

    df_tests_US_States = (
        pd.pivot_table(df_rec_tests_US_States, index="date", columns="state")
        .loc[:, "totalTestResults"]
        .rename(columns=codes_to_states)
        .replace({math.nan: 0})
    )
    if df_tests_US_States.shape[0] > len(new_index):
        df_tests_US_States = df_tests_US_States[: len(new_index)]
    df_tests_US_States["Date"] = new_index[: df_rec_US_States.shape[0]]
    df_tests_US_States.set_index("Date", inplace=True)
    for c in all_states:
        if c not in df_tests_US_States.columns.values:
            df_tests_US_States[c] = 0

    if df_tests_US_States.index.values[-1] != ENDT:
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_tests_US_States.index.values[-1])
                + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_tests_US_States.columns.values.shape[0]
        df_tests_US_States = df_tests_US_States.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_tests_US_States.columns.values,
            )
        )

    if df_rec_US_States.index.values[-1] != ENDT:
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_rec_US_States.index.values[-1])
                + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_rec_US_States.columns.values.shape[0]
        df_rec_US_States = df_rec_US_States.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_rec_US_States.columns.values,
            )
        )

    dict_df_US_States = {
        "Cases": df_cases_US_States,
        "Deaths": df_deaths_US_States,
        "Recovered": df_rec_US_States,
        "Tests": df_tests_US_States,
    }
    for c in list(dict_df_US_States.keys()):
        dict_df_US_States[c] = dict_df_US_States[c].cummax()

    dict_df_US_States["Active Cases"] = (
        dict_df_US_States["Cases"]
        - dict_df_US_States["Deaths"].fillna(0)
        - dict_df_US_States["Recovered"].fillna(0)
    )
    states_to_codes = pd.read_csv(
        FOLDER_US + "USStatesMap_codes.csv", index_col=0, header=0
    ).to_dict()["ID"]
    datasets_US_States = DataProcessor(
        dict_df_US_States, pops_US_States, "US States", states_to_codes
    )
    progress_bar.value = 0.7
    current_action.value = "Collecting US county level Cases/Deaths data..."
    df_cases_US_counties = df_cases_US.drop(
        [
            "UID",
            "iso2",
            "iso3",
            "code3",
            "Country_Region",
            "Lat",
            "Long_",
            "Combined_Key",
        ],
        axis=1,
    )
    df_cases_US_counties.dropna(inplace=True)
    df_cases_US_counties.drop(
        df_cases_US_counties[df_cases_US_counties["Admin2"] == "Unassigned"].index,
        inplace=True,
    )
    df_cases_US_counties.rename(columns={"Admin2": "County"}, inplace=True)

    df_cases_US_counties["County_name"] = (
        df_cases_US_counties["County"] + "_" + df_cases_US_counties["Province_State"]
    )

    df_cases_US_counties.set_index("County_name", inplace=True)
    counties_to_ID = df_cases_US_counties["FIPS"].astype(int).to_dict()
    df_cases_US_counties.drop(
        ["Province_State", "FIPS", "County"], axis=1, inplace=True
    )
    df_cases_US_counties = df_cases_US_counties.transpose()

    df_cases_US_counties["Date"] = new_index
    df_cases_US_counties.set_index("Date", drop=True, inplace=True)

    df_deaths_US_counties = df_deaths_US.drop(
        [
            "UID",
            "iso2",
            "iso3",
            "code3",
            "Country_Region",
            "Lat",
            "Long_",
            "Combined_Key",
        ],
        axis=1,
    )
    df_deaths_US_counties.dropna(inplace=True)
    df_deaths_US_counties.drop(
        df_deaths_US_counties[df_deaths_US_counties["Admin2"] == "Unassigned"].index,
        inplace=True,
    )
    df_deaths_US_counties.rename(columns={"Admin2": "County"}, inplace=True)

    df_deaths_US_counties["County_name"] = (
        df_deaths_US_counties["County"] + "_" + df_deaths_US_counties["Province_State"]
    )

    df_deaths_US_counties.set_index("County_name", inplace=True)
    pops_US_counties = (
        df_deaths_US_counties.loc[:, "Population"]
        .to_frame(name="2018")
        .replace({0: math.nan})
        .dropna()
        .replace({0: math.nan})
    )
    df_deaths_US_counties.drop(
        ["Province_State", "FIPS", "County", "Population"], axis=1, inplace=True
    )
    df_deaths_US_counties = df_deaths_US_counties.transpose()

    df_deaths_US_counties["Date"] = new_index
    df_deaths_US_counties.set_index("Date", drop=True, inplace=True)

    df_rec_US_counties = pd.DataFrame(
        np.zeros(df_cases_US_counties.shape),
        index=df_cases_US_counties.index.values,
        columns=df_cases_US_counties.columns.values,
    )

    df_tests_US_counties = pd.DataFrame(
        np.zeros(df_cases_US_counties.shape),
        index=df_cases_US_counties.index.values,
        columns=df_cases_US_counties.columns.values,
    )

    dict_df_US_counties = {
        "Cases": df_cases_US_counties,
        "Deaths": df_deaths_US_counties,
        "Recovered": df_rec_US_counties,
        "Active Cases": df_cases_US_counties
        - df_deaths_US_counties.fillna(0)
        - df_rec_US_counties.fillna(0),
        "Tests": df_tests_US_counties,
    }

    for c in list(dict_df_US_counties.keys()):
        dict_df_US_counties[c] = dict_df_US_counties[c].cummax()

    datasets_US_counties = DataProcessor(
        dict_df_US_counties, pops_US_counties, "US Counties", counties_to_ID
    )

    return datasets_US_States, states_to_codes, datasets_US_counties


In [None]:
buttons_tooltip = {
    "Data": [
        "'Choose one of the following data category : Cases, Deaths, Recovered, Active Cases or Tests'",
        "Cases",
        "Deaths",
        "Recovered",
        "Active Cases = Cases-Deaths-Recovered",
        "Tests",
    ],
    "Norm": [
        "'Use raw numbers or normalize data by population'",
        "Raw numbers",
        "Divide values by population and multiply by 1M",
    ],
    "Type": [
        "'Choose cumulative number(Total) or today s value minus yesterday\'s (daily change) or percentage change between today\'s and yesterday\'s value (daily % change)'",
        "Cumulative number",
        "today\'s value minus yesterday\'s",
        "percentage change between today\'s and yesterday\'s value",
    ],
    "Scale": [
        "'Use linear or logarithmic scale on the graph'",
        "Linear",
        "Logarithmic",
    ],
    "Moving average": [
        "'Click to smooth data by applying a moving average. You can select the window size after selecting this option.'"
    ],
    "Maps": ["'Select a map'", "World map", "US States map", "US Counties map"],
    "Normalize": [
        "'Divide each row by its maximum. It allows to detect the maximum value of each entity. Without normalizing the values of all countries/states are compared, ranked and coloured'"
    ],
    "Last data point": ["'Display only the most recent value of each selected entity'"],
    "Calendar time": [
        "'Display time series data withtout any rebasing rule. You can select a period of time using the bottom left range slider.'"
    ],
    "Threshold": [
        "'Threshold value applied to the rebasing data. (Example: plot deaths starting from the day the selected country/state exceeded 1000 cases)'"
    ],
    "Threshold data": [
        "'Data to use for the rebasing rule'",
        "Cases",
        "Deaths",
        "Recovered",
        "Tests",
    ],
    "Threshold norm": [
        "'Normalization to apply for the rebasing rule'",
        "Raw numbers",
        "Divide values by population and multiply by 1M",
    ],
}


In [None]:
help_tooltip = {
    "Tab 1": """- Select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale + Moving average).
It will color the selected map and change the graph/tables according to your choice.
- You can hover on the map to display a statistic table.
- Click on a country to show its data on the right graph.
- Navigate between table and graph view using the tabs on the right-hand side.
- Zoom/pan on the map's background using your mouse's wheel or click and drag.
- Select one of the 5 curves in the right-hand side graph by clicking on the figure legend. Click on the figure legend again if you want to show all 5 curves.
Display 4 curves(Cases, Deaths, Recovered, Active Cases) by clicking on the figure's background.
- Select a date for the map by using the slider on the right-hand side.
- If you need more information hover on buttons to show a description or go to user guide tab.""",
    "Tab 2": """- Select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale + Moving Average).
- Select country/states of interest by using the list picker (Ctrl + click to add/remove a country) or map pickers (single click on a country/state) or predefined selections of countries.
- Display time series data by checking the 'Calendar Time' checkbox. You can select a period of time using the bottom left range slider.
- Choose a rebasing rule by unchecking the 'Calendar Time' checkbox. Select the threshold value applied to the rebasing data by using the 'Threshold slider'.
  Select which data and normalization to use for the rebasing rule.
        Example: plot deaths/1M Population starting from the day the selected country/state exceeded 1000 cases.
        Data = Deaths, Norm = Per million, Type = Total, Threshold = 1000, Threshold Data = Cases, Threshold Norm = Values.
- Use one of the 4 predefined scenarios to display an example.
- If you need more information hover on buttons to show a description or go to user guide tab.""",
    "Tab 3": """- Select a unique combination of data by using the type picker (3 sets of buttons on the top left: Data, Norm, Type + Moving Average).
- Select country/states of interest by using the list picker (Ctrl + click to add/remove a country) or map pickers (single click on a country/state) or predefined selections of countries.
- Divide each row by its maximum. It allows to detect the maximum value of each entity. Without normalizing the values of all countries/states are compared, ranked and coloured.
- Hover on the figure's cells to get the corresponding value.
- If you need more information hover on buttons to show a description or go to the user guide tab.""",
    "Tab 4": """- For each axis, select a unique combination of data by using the type picker (4 sets of buttons on the top left: Data, Norm, Type, Scale).
- Show only the most recent value of each selected entity by checking the 'Last Data Point' checkbox.
- Moving average is applied to both axes.
- Select country/states of interest by using the list picker (Ctrl + click to add/remove a country) or the predefined selections of countries.
- Zoom/pan on the figure using your mouse's wheel or click and drag.
- Use one of the 4 predefined scenarios to display an example.
- If you need more information hover on buttons to show a description or go to user guide tab.""",
}


In [None]:
# Buttons and interactions


min_button_width_data = "100px"
min_button_width_norm = "85px"
min_button_width_type = "100px"
min_button_width_scale = "60px"
min_description_width_1 = "40px"

# Toggle buttons to choose data to color the map
data_buttons = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
)

# Toggle buttons to choose normalization of data
norm_buttons = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
)

# Toggle buttons to choose scale of plots and map
scale_buttons = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Linear",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    description_tooltip=buttons_tooltip["Scale"][0],
)

# Toggle buttons to choose between cumulative data or daily change
type_buttons = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
)

tab_1_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
)
tab_1_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px",
        width="200px",
        overflow="auto",
        visibility="hidden",
        margin="0 0 0 0",
    ),
)
ma_box_tab_1 = VBox(
    [tab_1_ma_ch, tab_1_ma_w],
    layout=Layout(
        min_width="205px",
        width="205px",
        max_width="205px",
        height="76px",
        min_height="76px",
        max_height="76px",
        overflow="hidden",
    ),
)

cat_tab_1_buttons = VBox(
    [data_buttons, norm_buttons, type_buttons, scale_buttons,],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="auto",
    ),
)


In [None]:
map_buttons = Toggle_Buttons(
    options=["World", "US States", "US Counties"],
    value="World",
    description="Maps",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    style="warning",
    tooltips=buttons_tooltip["Maps"][1:],
    description_tooltip=buttons_tooltip["Maps"][0],
)
box_map = HBox(
    [map_buttons],
    layout=Layout(
        width="350px",
        height="36px",
        min_width="350px",
        max_width="350px",
        min_height="36px",
        max_height="36px",
        overflow="auto",
    ),
)


In [None]:
top_right_buttons_tab1 = GridspecLayout(
    2,
    2,
    layout=Layout(
        width="1000px",
        min_width="1000px",
        max_width="1000px",
        height="138px",
        overflow="auto",
        margin="0 0 0 0",
        align_items="center",
        justify_content="space-between",
        border="",
    ),
)
left_cat_buttons = VBox(
    [ma_box_tab_1, box_map],
    layout=Layout(
        margin="0px 0px 0px 15px",
        align_items="stretch",
        justify_content="space-between",
        height="100%",
    ),
)
top_right_buttons_tab1[:, 0] = cat_tab_1_buttons
top_right_buttons_tab1[:, 1] = left_cat_buttons


In [None]:
# Main graph

css_style = ""

line_tooltip_table = """
<table id="stats_table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr style="color:#1E90FF">
    <td>{1:s}</td>
    <td>{2:}</td>
</tr>
<tr style="color:#D62728">
    <td>{3:s}</td>
    <td>{4:}</td>
</tr>
<tr style="color:#2CA02C">
    <td>{5:s}</td>
    <td>{6:}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>{7:s}</td>
    <td>{8:}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>{9:s}</td>
    <td>{10:}</td>
</tr>
</table>
"""

stats_line_tooltip_table_values = HTML()  # table showed when you hover on Lines markers


In [None]:
scale_t = DateScale(dateformat=dateformat,)
axis_t = Axis(scale=scale_t, grid_lines="none", tick_format="%m/%d", num_ticks=10)
scale_y = LinearScale()
axis_y = Axis(
    scale=scale_y, orientation="vertical", grid_lines="solid", tick_format=","
)
main_graph_ttl = "COVID-19 {} in {}"

main_graph = Figure(
    animation_duration=1000,
    legend_location="top-left",
    axes=[axis_t, axis_y],
    title=main_graph_ttl,
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 20},
    layout=Layout(width="auto"),
    legend_style={"stroke-width": 0},
)


In [None]:
curve_colors = {
    "Cases": "dodgerblue",
    "Deaths": "red",
    "Recovered": "green",
    "Active Cases": "orange",
    "Tests": "gray",
}
selected = ["Cases", "Deaths", "Recovered", "Active Cases", "Tests"]


def update_main_graph(
    datasets,
    country_name,
    normalization,
    scale,
    data_type,
    ma,
    n,
    date1,
    date2,
    axis_t,
    axis_y,
    figure,
    init=False,
):
    if init:
        curves_subset = [i for i in range(len(selected) - 1)]
    else:
        curves_subset = main_graph.marks[0].curves_subset
    line_col = [curve_colors[c] for c in selected]
    data = [
        datasets.get_ts_plot(
            country_name, data_name, normalization, scale, data_type, ma, n, STDT, ENDT,
        )
        for data_name in selected
    ]
    scale_t = DateScale(
        dateformat=dateformat,
        min=datetime.datetime.strptime(date1, dateformat),
        max=datetime.datetime.strptime(date2, dateformat),
    )
    axis_t.scale = scale_t

    if scale == "Log":
        scale_y = LogScale()
    elif scale == "Linear":
        scale_y = LinearScale()

    main_mark = Lines(
        scales={"x": scale_t, "y": scale_y},
        display_legend=True,
        labels=selected,
        colors=line_col,
        marker="circle",
        marker_size=30,
        tooltip=stats_line_tooltip_table_values,
        interactions={"hover": "tooltip"},
        curves_subset=curves_subset,
        opacities=[1] * len(selected),
    )
    main_mark.y = data
    axis_y.scale = scale_y
    figure.axes = [axis_t, axis_y]
    main_mark.x = list(
        np.arange(np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D"),)
    )
    ind1 = np.where(main_mark.x == np.datetime64(date1))[0][0]
    ind2 = np.where(main_mark.x == np.datetime64(date2))[0][0]
    main_mark.scales["y"].max = 1.1 * float(
        np.nanmax(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.scales["y"].min = 0.9 * float(
        np.nanmin(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.on_hover(line_hover)
    main_mark.on_legend_hover(legend_hover)
    main_mark.on_legend_click(legend_click_line)
    main_mark.on_background_click(line_background_click)
    figure.marks = [main_mark]

    def graph_title(norm, scale, data_type, ma, n):
        res = ""
        if scale == "Log":
            res += "Log"
        if norm == "Per million":
            res += "/1M population"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res

    val_name = graph_title(normalization, scale, data_type, ma, n)
    if datasets.name == "World":
        figure.title = main_graph_ttl.format(val_name, codes_to_countries[country_name])

    elif datasets.name == "US States":
        figure.title = main_graph_ttl.format(val_name, country_name + " State")

    elif datasets.name == "US Counties":
        figure.title = main_graph_ttl.format(val_name, country_name + " County")

    update_link_download()
    return


In [None]:
def name_val_on_hover(data, norm, data_type, ma, n):
    res = data
    if norm == "Per million":
        res += "/1M population"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res


def line_hover(*args):
    # show table with Cases, deaths, rec and active corresponding to the
    # hovered marker

    try:
        index = args[1]["data"]["sub_index"]
        Date = str(main_graph.marks[0].x[index])
        # get value of cases deaths rec and active at Date
        values = list(
            np.nan_to_num(
                [main_graph.marks[0].y[i][index] for i in range(len(selected))]
            )
        )
        name_vals = [
            name_val_on_hover(
                data,
                norm_buttons.value,
                type_buttons.value,
                tab_1_ma_ch.value,
                tab_1_ma_w.value,
            )
            for data in selected
        ]
        if norm_buttons.value == "Values" and type_buttons.value != "Daily % change":
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,d}".format(int(values[i]))]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        else:
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,.2f}".format(values[i])]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        stats_line_tooltip_table_values.value = css_style + line_tooltip_table.format(
            *res
        )
    except:
        pass
    return


def legend_hover(*args):
    # Reduce opacity of non hovered legends

    N = len(selected)
    op = [0.3] * N
    index = args[1]["data"]["index"]
    op[index] = 1
    if (
        np.sum(main_graph.marks[0].opacities) < N
        and main_graph.marks[0].opacities[index] == 1
    ) or (index not in main_graph.marks[0].curves_subset):
        main_graph.marks[0].opacities = [1] * N
    else:
        main_graph.marks[0].opacities = op
    return


def legend_click_line(*args):
    # Remove curve if you click on legend
    # if it's the last visible curve add all the others

    N = len(selected)
    index = args[1]["data"]["index"]
    ind1, ind2 = date_axis_selector.index
    if index in main_graph.marks[0].curves_subset:
        if len(main_graph.marks[0].curves_subset) == 1:
            main_graph.marks[0].curves_subset = [i for i in range(N)]
        else:
            main_graph.marks[0].curves_subset = [index]
    else:
        main_graph.marks[0].curves_subset = main_graph.marks[0].curves_subset + [index]

    main_graph.marks[0].scales["y"].max = 1.1 * float(
        np.nanmax(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    main_graph.marks[0].scales["y"].min = 0.9 * float(
        np.nanmin(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    return


def line_background_click(*args):
    main_graph.marks[0].curves_subset = [0, 1, 2, 3]
    ind1, ind2 = date_axis_selector.index
    main_graph.marks[0].scales["y"].max = 1.1 * float(
        np.nanmax(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    main_graph.marks[0].scales["y"].min = 0.9 * float(
        np.nanmin(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    return


def bqplot_legend_curvesubset_bug():
    # bug in bqplot Lines curvesubset method
    curves_subset = main_graph.marks[0].curves_subset
    if len(curves_subset) == 1:
        curves_subset_1 = [
            curves_subset[0] + 1 if curves_subset[0] + 1 < 5 else curves_subset[0] - 1
        ]
    else:
        curves_subset_1 = [curves_subset[0]]
    main_graph.marks[0].curves_subset = curves_subset_1
    main_graph.marks[0].curves_subset = curves_subset
    return


In [None]:
# Maps

def map_title(map_name, data, norm, data_type, date, ma, n):
    res = map_name + " COVID-19 " + data
    if norm == "Per million":
        res += " " + norm
    if data_type == "Daily change":
        res += " daily change"
    if data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res + " Map, " + date


In [None]:
table_tmpl_no_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr>
    <td>{1:s}</td>
    <td>{2:,}</td>
</tr>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{3:,d}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{4:,d}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Recovered</td>
    <td>{5:,d}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Active Cases</td>
    <td>{6:,d}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>Tests</td>
    <td>{7:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{8:,d}</td>
</tr>
</table>
"""

table_tmpl_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{1:,d}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{2:,d}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Recovered</td>
    <td>{3:,d}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Active Cases</td>
    <td>{4:,d}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>Tests</td>
    <td>{5:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{6:,d}</td>
</tr>
</tbody>
</table>
"""
stats_table = HTML()


In [None]:
def current_val_on_hover(data, norm, scale, data_type, ma, n):
    if data_type == "Total" and norm == "Values":
        return None
    else:
        res = data
        if norm == "Per million":
            res += "/1M population"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res


In [None]:
red_color_scheme = [
    "#FFF0CC",
    "#ffcf76",
    "#ffa600",
    "#ff6e00",
    "#ff4417",
    "#d31522",
    "#8c000e",
    "#560410",
]
green_color_scheme = [
    "#e9f6e4",
    "#d1edc6",
    "#b9e3a8",
    "#a1da8a",
    "#89d06d",
    "#71c74f",
    "#5db53a",
    "#4e9730",
]
gyr_color_scheme = [
    "#069740",
    "#9acd32",
    "#e4d31b",
    "#ffa600",
    "#e95f11",
    "#d31522",
    "#8c000e",
    "#560410",
]

map_color_scheme = {
    "Cases": red_color_scheme,
    "Deaths": red_color_scheme,
    "Recovered": green_color_scheme,
    "Active Cases": gyr_color_scheme,
    "Tests": green_color_scheme,
}


def map_cuts(arr, cuts):
    return [
        cuts[bisect.bisect_right(cuts, a) - 1] if not np.isnan(a) else np.nan
        for a in arr
    ]


def cuts_to_str(cuts):
    res = []
    for v in cuts:
        if abs(v) >= 1e6:
            if v % 1e6 == 0:
                res.append(str(int(v / 1e6)) + "M")
            else:
                res.append(str(v / 1e6) + "M")
        elif abs(v) >= 1000:
            if v % 1000 == 0:
                res.append(str(int(v / 1000)) + "K")
            else:
                res.append(str(v / 1000) + "K")
        elif abs(v) >= 1:
            if v % 1 == 0:
                res.append(str(int(v)))
            else:
                res.append(str(round(v, 3)))
        else:
            res.append(str(round(v, 3)))
    return res


def cuts_to_range(cuts, pct=False):
    res = []
    pcuts = [
        "(" + cuts[i] + ")" if cuts[i][0] == "-" else cuts[i] for i in range(len(cuts))
    ]
    for i in range(len(pcuts) - 1):
        if pct:
            res.append(pcuts[i] + "% - " + pcuts[i + 1] + "%")
        else:
            res.append(pcuts[i] + " - " + pcuts[i + 1])
    if pct:
        res.append(pcuts[-1] + "%+")
    else:
        res.append(pcuts[-1] + "+")
    return res


def map_color(dataset, data, norm, scale, data_type, ma, n, date):
    IDs = []
    colors = []
    values = (
        dataset.get_ts(None, data, norm, scale, data_type, ma, n).loc[date].dropna()
    )
    if data_type == "Total" and data != "Active Cases":
        values = values[values > 0]
    to_drop = []
    IDs = []
    for c in values.index.values:
        if c in dataset.name_to_ID.keys():
            IDs.append(dataset.get_ID(c))
        else:
            to_drop.append(c)
    values.drop(to_drop, inplace=True)
    selected_index = dict_ID_to_map[dataset.name]
    n_breaks = 7
    if values.nunique() < 8 and values.nunique() > 0:
        n_breaks = values.nunique() - 1
    elif values.shape[0] == 0:
        dict_maps[selected_index]["scale"].domain = ["None"]
        dict_maps[selected_index]["scale"].colors = map_color_scheme[data]
        color = {-999: "None"}
        return color
    try:
        if dataset.name == "World":
            weights = [2, 1, 4]
        else:
            weights = [2, 5, 0]
        q_round_cuts = pretty_breaks.breaker(values, n_breaks, weights=weights)
        # print(pretty_breaks.pretty_count(values, q_round_cuts, dataset.name == 'World'))
        if data_type == "Daily % change":
            pct = True
        else:
            pct = False
        cuts_str_format = cuts_to_range(cuts_to_str(q_round_cuts), pct)
        str_cuts = dict(zip(q_round_cuts, cuts_str_format))
        colors = map_cuts(values.values, q_round_cuts)
        colors_to_range = [str_cuts[c] if not np.isnan(c) else np.nan for c in colors]
        color = dict(zip(IDs, colors_to_range))
        dict_maps[selected_index]["scale"].domain = cuts_str_format
        dict_maps[selected_index]["scale"].colors = map_color_scheme[data][
            : n_breaks + 1
        ]
    except:
        dict_maps[selected_index]["scale"].domain = ["None"]
        dict_maps[selected_index]["scale"].colors = map_color_scheme[data]
        color = {-999: "None"}

    return color


def select_date(obj):
    # color the map with the values of the chosen date
    date = obj["new"]
    norm, scale, data_type = norm_buttons.value, scale_buttons.value, type_buttons.value
    data = data_buttons.value
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
    main_map.title = map_title(
        dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n
    )
    dict_maps[selected_index]["map"].color = color
    update_tables(data, norm, scale, data_type, ma, n, date)
    return


def map_click(obj, value):
    # update main graph if you click on a country or state
    try:
        _id = value["data"]["id"]
        date = date_selector.value
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        name = datasets.get_name(_id)
        normalization, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        date_sel_index = date_axis_selector.index
        date1 = str(date_axis_selector.options[date_sel_index[0]])
        date2 = str(date_axis_selector.options[date_sel_index[1]])
        update_main_graph(
            datasets,
            name,
            normalization,
            scale,
            data_type,
            ma,
            n,
            date1,
            date2,
            axis_t,
            axis_y,
            main_graph,
        )
        bqplot_legend_curvesubset_bug()
    except:
        pass
    return


def map_background_click(*args):
    normalization, scale, data_type = (
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
    )
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    date_sel_index = date_axis_selector.index
    date1 = str(date_axis_selector.options[date_sel_index[0]])
    date2 = str(date_axis_selector.options[date_sel_index[1]])
    update_main_graph(
        dict_datasets[0],
        "WLD",
        normalization,
        scale,
        data_type,
        ma,
        n,
        date1,
        date2,
        axis_t,
        axis_y,
        main_graph,
    )
    bqplot_legend_curvesubset_bug()
    return


def map_hovering(obj, value):
    # show stat table when you hover on a country or state
    _id = value["data"]["id"]
    date = date_selector.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    try:
        name = datasets.get_name(_id)
        if selected_index == 0:
            name_table = codes_to_countries[name]
        elif selected_index >= 1:
            name_table = name
    except:
        stats_table.value = ""
    try:
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        data = data_buttons.value
        current_val = current_val_on_hover(data, norm, scale, data_type, ma, n)
        output_vals = [name_table]
        for c in selected:
            output_vals.append(
                int(
                    datasets.get_value(
                        name, c, "Values", "Linear", "Total", ma, n, date
                    )
                )
            )
        output_vals.append(int(datasets.get_population(name)))
        if current_val is None:
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple(output_vals)
            )
        else:
            val = float(
                datasets.get_value(name, data, norm, scale, data_type, ma, n, date)
            )
            if data_type == "Daily change" and norm == "Values":
                val = int(val)
            else:
                val = round(val, 2)
            output_vals = [output_vals[0]] + [current_val, val] + output_vals[1:]
            stats_table.value = css_style + table_tmpl_no_duplicate.format(
                *tuple(output_vals)
            )
    except:
        try:
            population = int(datasets.get_population(name))
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple([name_table] + [0] * 5 + [population])
            )
        except:
            stats_table.value = ""
    main_map.tooltip = stats_table
    return


In [None]:
def event(*args):
    # update map color and graph if the value of the radio buttons change
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        if ma:
            tab_1_ma_w.layout.visibility = "visible"
        else:
            tab_1_ma_w.layout.visibility = "hidden"
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
        main_map.title = map_title(
            dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n
        )
        dict_maps[selected_index]["map"].color = color
        date_sel_index = date_axis_selector.index
        date1 = str(date_axis_selector.options[date_sel_index[0]])
        date2 = str(date_axis_selector.options[date_sel_index[1]])
        if main_graph.title.split(" ")[-1] == "State":
            update_main_graph(
                dict_datasets[1],
                main_graph.title.split(" in ")[-1][:-6],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        elif main_graph.title.split(" ")[-1] == "County":
            update_main_graph(
                dict_datasets[2],
                main_graph.title.split(" in ")[-1][:-7],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        else:
            update_main_graph(
                dict_datasets[0],
                countries_to_codes[main_graph.title.split(" in ")[-1]],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        bqplot_legend_curvesubset_bug()
        update_tables(data, norm, scale, data_type, ma, n, date)
    return


In [None]:
play = Play(
    value=0,
    min=0,
    max=1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto", height="auto", min_width="150px"),
)


def update_index_date(*args):
    date_selector.value = date_selector.options[args[0]["new"]]
    return


In [None]:
date_selector = SelectionSlider(
    options=["2020-01-22", "2020-01-23"],
    description="Date",
    disabled=False,
    value="2020-01-22",
    layout=Layout(width="100%", height="auto", min_width="200px"),
    style={"description_width": "initial"},
    description_tooltip='Select a date',
)

In [None]:
date_tab_1_box = HBox(
    [date_selector, play],
    layout=Layout(
        width="100%",
        min_width="200px",
        height="40px",
        min_height="40px",
        max_height="40px",
        overflow="auto",
    ),
)


In [None]:
date_axis_selector = SelectionRangeSlider(
    options=["2020-01-22"], index=(0, 0), layout=Layout(width="100%", height="auto"),
)


def update_zoom(*args):
    min_x, max_x = args[0]["new"]
    main_graph.axes[0].scale.min = min_x
    main_graph.axes[0].scale.max = max_x

    main_graph.axes[1].scale.min = 0.9 * float(
        np.nanmin(
            main_graph.marks[0].y[
                main_graph.marks[0].curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )

    main_graph.axes[1].scale.max = 1.1 * float(
        np.nanmax(
            main_graph.marks[0].y[
                main_graph.marks[0].curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )
    return


In [None]:
data_buttons.add_observe(event, "value")
norm_buttons.add_observe(event, "value")
scale_buttons.add_observe(event, "value")
type_buttons.add_observe(event, "value")
tab_1_ma_ch.observe(event, "value")
tab_1_ma_w.observe(event, "value")


In [None]:
# World Map

sc_geo = Mercator(center=(-20, 50))
sc_c1 = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis = ColorAxis(scale=sc_c1)
map_tt = stats_table

wm = Map(
    map_data=topo_load("map_data/WorldMap.json"),
    scales={"projection": sc_geo, "color": sc_c1},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)


In [None]:
wm.on_hover(map_hovering)
wm.on_element_click(map_click)
wm.on_background_click(map_background_click)


In [None]:
# US States Map

sc_geo_US = AlbersUSA()
sc_c1_US = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis_US = ColorAxis(scale=sc_c1_US)
wm_US = Map(
    map_data=topo_load("map_data/USStatesMap.json"),
    scales={"projection": sc_geo_US, "color": sc_c1_US},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)


In [None]:
wm_US.on_hover(map_hovering)
wm_US.on_element_click(map_click)


In [None]:
sc_geo_US_counties = AlbersUSA()
sc_c1_US_counties = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis_US_counties = ColorAxis(scale=sc_c1_US_counties)

wm_US_counties = Map(
    map_data=topo_load("map_data/USCountiesMap.json"),
    scales={"projection": sc_geo_US_counties, "color": sc_c1_US_counties},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)

dict_maps = {
    0: {
        "map": wm,
        "name": "World",
        "projection": sc_geo,
        "scale": sc_c1,
        "axis": caxis,
    },
    1: {
        "map": wm_US,
        "name": "US States",
        "projection": sc_geo_US,
        "scale": sc_c1_US,
        "axis": caxis_US,
    },
    2: {
        "map": wm_US_counties,
        "name": "US Counties",
        "projection": sc_geo_US_counties,
        "scale": sc_c1_US_counties,
        "axis": caxis_US_counties,
    },
}
dict_ID_to_map = dict(zip(["World", "US States", "US Counties"], [0, 1, 2]))

wm_US_counties.on_hover(map_hovering)
wm_US_counties.on_element_click(map_click)


In [None]:
main_map = Figure(
    marks=[wm],
    axes=[caxis],
    title="",
    fig_margin={"top": 50, "bottom": 50, "left": 0, "right": 0},
    layout=Layout(width="99%", height="99%"),
)


In [None]:
def update_map(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        map_name = map_buttons.value
        ID = dict_ID_to_map[map_name]
        mark = dict_maps[ID]["map"]
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        datasets = dict_datasets[ID]
        color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
        main_map.title = map_title(map_name, data, norm, data_type, date, ma, n)
        mark.color = color
        main_map.axes = [dict_maps[ID]["axis"]]
        main_map.marks = [mark]

    return


map_buttons.add_observe(update_map, "value")


In [None]:
# Tables

css_style_2 = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <style type="text/css" media="screen">
        #sorted_table {
          width: 100%;
          text-align: left;
          color: #ff8b0e;
        }
        #sorted_table th {
           text-align: left;
           color: #ff8b0e;
        }
    </style>
</head>
"""


def sorted_table_html(
    dataset,
    data,
    norm,
    scale,
    data_type,
    ma,
    n,
    date,
    ascending,
    K,
    dict_codes,
    column_name,
):
    df = dataset.get_ts_sort(
        data, norm, scale, data_type, ma, n, date, ascending, K
    ).reset_index()
    current_val = name_val_on_hover(data, norm, data_type, ma, n)

    df.rename(
        {df.columns.values[0]: column_name, date: current_val}, axis=1, inplace=True
    )
    if dict_codes:
        df[column_name] = [dict_codes[c] for c in df[column_name].values]
    df.dropna(inplace=True)
    if norm == "Values" and data_type != "Daily % change":
        df[current_val] = df[current_val].astype(int)
        formatters = {current_val: lambda x: "{:,d}".format(x)}
    else:
        formatters = {current_val: lambda x: "{:,.02f}".format(x)}

    val = df.to_html(
        index=False,
        header=True,
        formatters=formatters,
        notebook=True,
        table_id="sorted_table",
    )
    return css_style_2 + val + """</body>"""


def sorted_tables_html(data, norm, scale, data_type, ma, n, date, ascending, K):
    value_table_1 = sorted_table_html(
        dict_datasets[0],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        codes_to_countries,
        "Countries",
    )
    value_table_2 = sorted_table_html(
        dict_datasets[1],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        None,
        "US States",
    )
    return value_table_1, value_table_2


def update_tables(data, norm, scale, data_type, ma, n, date):
    ascending = False
    K = None
    value_table_1, value_table_2 = sorted_tables_html(
        data, norm, scale, data_type, ma, n, date, ascending, K
    )

    table_1.value = value_table_1
    table_2.value = value_table_2
    return


table_1 = HTML()
table_2 = HTML()


In [None]:
# Infection Map tab

grid_1 = GridspecLayout(2, 2)
grid_1.layout.height = "99%"
grid_1.layout.width = "100%"
grid_1.layout.overflow = "auto"
table_grid = GridspecLayout(1, 2)
table_grid.layout.width = "100%"
table_grid.layout.height = "100%"
table_1.layout.width = "auto"
table_1.layout.overflow = "auto"
table_2.layout.overflow = "auto"
table_2.layout.width = "auto"
table_grid[0, 0] = table_1
table_grid[0, 1] = table_2
help_date_axis_selector = help_button(
    "Select a period of time using this slider for the graph above"
)
center_right_panel = VBox(
    [
        main_graph,
        HBox(
            [help_date_axis_selector, date_axis_selector],
            layout=Layout(hieght="auto", overflow="hidden"),
        ),
    ]
)
center_right_panel.layout.height = "100%"
tab_graph_table = Tab(
    _titles=dict(zip([0, 1], ["Graph", "Table"])),
    children=[center_right_panel, table_grid],
    layout=Layout(height="99%", min_height="400px"),
)
grid_1[0, 0] = HBox([top_right_buttons_tab1])
help_tab1 = help_button(help_tooltip["Tab 1"])
help_tab1.layout.margin = "0px 2px 3px 0px"
grid_1[0, 1] = VBox([help_tab1, date_tab_1_box], layout=Layout(align_items="flex-end"))
main_graph.layout.height = "99%"
grid_1[1, 0] = main_map
grid_1[1, 1] = tab_graph_table
grid_1.layout.align_items = "stretch"
grid_1.layout.grid_template_rows = "155px auto"
grid_1.layout.grid_template_columns = "60% auto"


In [None]:
# Rebased Graph Tab

# Buttons and Multiple selectors
min_description_width_2 = "140px"

rebased_graph_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
)

rebased_graph_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
)

rebased_graph_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
)

thr_val_slider = IntSlider(
    description="<tag title="
    + buttons_tooltip["Threshold"][0]
    + ">"
    + "Threshold"
    + "</tag>",
    value=1000,
    min=10,
    max=5000,
    step=100,
    style={"description_width": "initial"},
    layout=Layout(width="340px", visibility="hidden"),
)

plot_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
)
min_description_width_2 = "125px"
rebased_graph_thr_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Tests"],
    value="Cases",
    description="Threshold Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_2,
    tooltips=buttons_tooltip["Threshold data"][1:],
    description_tooltip=buttons_tooltip["Threshold data"][0],
)

rebased_graph_thr_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Threshold Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_2,
    tooltips=buttons_tooltip["Threshold norm"][1:],
    description_tooltip=buttons_tooltip["Threshold norm"][0],
)

tab_2_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
)
tab_2_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_2 = VBox(
    [tab_2_ma_ch, tab_2_ma_w],
    layout=Layout(
        min_width="205px",
        width="205px",
        max_width="205px",
        height="74px",
        min_height="74px",
        max_height="74px",
        overflow="auto",
    ),
)


In [None]:
top_countries_names = ["United States"]
states = ["New York"]

countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

calendar_time = Checkbox(
    description="<tag title="
    + buttons_tooltip["Calendar time"][0]
    + ">"
    + "Calendar Time"
    + "</tag>",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(width="auto", overflow="hidden"),
)

cat_tab_2_buttons = VBox(
    [
        rebased_graph_data_button,
        rebased_graph_norm_button,
        rebased_graph_type_button,
        plot_scale_button,
    ],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="auto",
    ),
)
thr_tab_2_buttons = HBox(
    [VBox([thr_val_slider, calendar_time], layout=Layout(width="auto")), ma_box_tab_2],
    layout=Layout(
        width="565px",
        height="80px",
        min_width="565px",
        max_width="565px",
        min_height="80px",
        max_height="80px",
        overflow="auto",
        border="solid #ff8b0e",
        margin="10px 0 10px 0",
    ),
)
cat_thr_tab_2_buttons = VBox(
    [rebased_graph_thr_data_button, rebased_graph_thr_norm_button,],
    layout=Layout(
        width="565px",
        height="74px",
        min_width="565px",
        max_width="565px",
        min_height="74px",
        max_height="74px",
        overflow="auto",
        visibility="hidden",
    ),
)


In [None]:
reb_top_right_buttons = VBox(
    [cat_tab_2_buttons, thr_tab_2_buttons, cat_thr_tab_2_buttons],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        overflow="auto",
        height="315px",
        min_height="315px",
        max_height="315px",
    ),
)


In [None]:
presaved_sel_options = [
    "Europe",
    "Asia",
    "North America",
    "South America",
    "Africa",
    "Oceania",
    "Top10 (Cases)",
    "Top10 (Deaths)",
    "Top10 (Cases/1M)",
    "Top10 (Deaths/1M)",
    "Top10 (Cases daily chg)",
    "Top10 (Deaths daily chg)",
]
presaved_sel_lists = dict(
    zip(
        presaved_sel_options, [None] * len(presaved_sel_options)
    )
)


def preselection(presaved_sel_options, class_name, action):
    presaved_sel_buttons = dict(
        zip(presaved_sel_options, [None] * len(presaved_sel_options))
    )
    for l in presaved_sel_options:
        if "Top10" in l:
            tltp = "Top10 most-affected countries (" + l.split("(")[-1]
        else:
            tltp = "Countries in " + l
        presaved_sel_buttons[l] = Button(description=l, button_style="", tooltip=tltp)
        presaved_sel_buttons[l].add_class(class_name)
        presaved_sel_buttons[l].on_click(action)
    return presaved_sel_buttons


def on_click_action(val):
    selection = presaved_sel_lists[val.description]
    if val._dom_classes[0] == "Rebased":
        states_selector.value = ["None"]
        countries_selector.value = selection
    elif val._dom_classes[0] == "Heat map":
        dna_states_selector.value = ["None"]
        dna_countries_selector.value = selection
    elif val._dom_classes[0] == "Custom":
        free_states_selector.value = ["None"]
        free_countries_selector.value = selection
    return


In [None]:
def set_value(widget, new_val, action, unobserve=True):
    if new_val != None:
        if type(widget).__name__ == "Toggle_Buttons":
            if unobserve:
                widget.del_observe(action, "value")
                widget.set_value(new_val)
                widget.add_observe(action, "value")
            else:
                widget.set_value(new_val)
        else:
            if unobserve:
                widget.unobserve(action, "value")
                widget.value = new_val
                widget.observe(action, "value")
            else:
                if sorted(list(widget.value)) == sorted(new_val):
                    widget.value = new_val + ["None", "New York"]
                widget.value = new_val
    return


def on_click_scenario(val):
    ind = outer_tab.selected_index
    sc = val.description
    for i, w in enumerate(zip(dict_tab_buttons[ind], dict_scenario[ind][sc])):
        if i == len(dict_tab_buttons[ind]) - 1:
            unobserve = False
        else:
            unobserve = True
        set_value(w[0], w[1], dict_tab_buttons_actions[ind][i], unobserve)
    return


def sc_buttons(presaved_sc, action):
    presaved_sc = dict(zip(presaved_sc, [None] * len(presaved_sc)))
    for l in presaved_sc:
        presaved_sc[l] = Button(description=l, button_style="")
        presaved_sc[l].on_click(action)
    return presaved_sc


In [None]:
ps_tab2 = preselection(presaved_sel_options, "Rebased", on_click_action)
hb_tab2 = VBox(
    [
        HBox([ps_tab2[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab2_1 = sc_buttons(["Scenario " + str(i) for i in range(1, 5)], on_click_scenario)
hb_tab2_1 = VBox(
    [
        HBox([ps_tab2_1[v] for v in list(ps_tab2_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab2_1.keys())), 4)
    ]
)
accordion_tab2 = Accordion(
    children=[hb_tab2, hb_tab2_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab2.set_title(0, "Predefined selections")
accordion_tab2.set_title(1, "Predefined scenarios")


In [None]:
tab_2_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="555px",
        align_items="stretch",
        min_height="60px",
        overflow="auto",
        border="solid #ff8b0e",
    ),
)
tab_2_box_selectors[0, :] = accordion_tab2
tab_2_box_selectors.layout.grid_template_rows = "80px auto"


def accordion_adapt_height(val):
    selected = val.new == 0 or val.new == 1
    index = outer_tab.selected_index
    tab_selectors = dict(zip([1, 2, 3], [tab_2_box_selectors, tab_3_box_selectors, tab_4_box_selectors]))
    if selected:
        if val.new == 0:
            tab_selectors[index].layout.grid_template_rows = "230px auto"
        else:
            tab_selectors[index].layout.grid_template_rows = "160px auto"
    else:
        tab_selectors[index].layout.grid_template_rows = "80px auto"
accordion_tab2.observe(accordion_adapt_height, "selected_index")


In [None]:
# Graph

scale_reb_t = LinearScale()
axis_reb_t = Axis(scale=scale_reb_t, grid_lines="none")
scale_reb_y = LogScale()
axis_reb_y = Axis(
    scale=scale_reb_y, orientation="vertical", grid_lines="solid", tick_format=","
)
rebased_graph = Figure(
    animation_duration=1000,
    title="Rebased Graph",
    legend_location="bottom-right",
    axes=[axis_reb_t, axis_reb_y],
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 100},
)


In [None]:
date_axis_selector_tab_2 = SelectionRangeSlider(
    options=["2020-01-22"], index=(0, 0), layout=Layout(width="auto", height="auto"),
)


def update_zoom_tab_2(*args):
    min_x, max_x = args[0]["new"]
    rebased_graph.axes[0].scale.min = min_x
    rebased_graph.axes[0].scale.max = max_x
    if len(rebased_graph.marks[0].y.shape) == 2:
        rebased_graph.axes[1].scale.min = 0.9 * float(
            np.nanmin(
                rebased_graph.marks[0].y[
                    :,
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.axes[1].scale.max = 1.1 * float(
            np.nanmax(
                rebased_graph.marks[0].y[
                    :,
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.marks[1].x = [max_x] * len(rebased_graph.marks[1].x)
        rebased_graph.marks[1].y = rebased_graph.marks[0].y[
            :, date_axis_selector_tab_2.index[1]
        ]
    elif len(rebased_graph.marks[0].y.shape) == 1:
        rebased_graph.axes[1].scale.min = 0.9 * float(
            np.nanmin(
                rebased_graph.marks[0].y[
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.axes[1].scale.max = 1.1 * float(
            np.nanmax(
                rebased_graph.marks[0].y[
                    date_axis_selector_tab_2.index[0] : date_axis_selector_tab_2.index[
                        1
                    ]
                    + 1,
                ]
            )
        )

        rebased_graph.marks[1].x = [max_x] * len(rebased_graph.marks[1].x)
        rebased_graph.marks[1].y = [
            rebased_graph.marks[0].y[date_axis_selector_tab_2.index[1]]
        ]
    (
        data_to_plot,
        data_norm,
        data_type,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
        threshold,
    ) = [v.value for v in value_buttons[:-2]]
    if calendar_time.value:
        threshold = 0
    rebased_graph.title = title_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
        date1=str(min_x),
        date2=str(max_x),
    )
    return


In [None]:
def title_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    threshold,
    plot_scale,
    thr_data,
    thr_norm,
    ma,
    n,
    date1="",
    date2="",
):
    res = ""
    if plot_scale == "Log":
        res += plot_scale
    res += " " + data_to_plot
    if data_norm != "Values":
        res += "/1M population"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    if threshold > 0:
        res += " since number of " + thr_data
        if thr_norm != "Values":
            res += "/1M population"
        res += " = " + str(threshold)
    else:
        res += " from " + date1 + " to " + date2
    return res


def label_index_pos(data):
    list_nan = list(np.isnan(data))
    if True in list_nan:
        index = len(data) - 1
        while index >= 1 and not (
            list_nan[index] == False and list_nan[index - 1] == False
        ):
            index -= 1
        return index
    else:
        return len(data) - 1


In [None]:
def persistent_colors(used_colors, used_labels, new_labels, all_colors):
    if len(used_colors) != 0:
        color_dict = dict(
            zip(used_labels, list(used_colors) * (len(used_labels) // len(used_colors) + 1))
        )
    else:
        color_dict = {}
    for l in used_labels:
        if l not in new_labels:
            color_dict.pop(l, None)
    non_assigned = []
    for l in new_labels:
        if l not in used_labels:
            non_assigned.append(l)
    if len(non_assigned) > 0:
        counter_colors = Counter(color_dict.values())
        for c in all_colors:
            if c not in list(counter_colors.keys()):
                counter_colors[c] = 0
        for l in non_assigned:
            index = np.argmin(list(counter_colors.values()))
            new_color = list(counter_colors.keys())[index]
            color_dict[l] = new_color
            counter_colors[new_color] += 1
    return color_dict


def plot_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    thr,
    countries,
    states,
    plot_scale,
    thr_data,
    thr_norm,
    ma,
    n,
    calendar_time=False,
):
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    update = len(rebased_graph.marks) > 0
    if calendar_time:
        threshold = 0
    else:
        threshold = thr
    if plot_scale == "Log":
        scale_reb_y = LogScale()
    else:
        scale_reb_y = LinearScale()
    Ys = []
    Xs = []
    label_names = []
    max_len = 0
    for country_name in countries:
        yaux = dict_datasets[0].get_ts_plot(
            countries_to_codes[country_name],
            data_to_plot,
            data_norm,
            plot_scale,
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[0].get_ts(
                countries_to_codes[country_name],
                thr_data,
                thr_norm,
                "Linear",
                "Total",
                ma,
                n,
            )
            # values starts from index of threshold
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(country_name)
            max_len = max(max_len, len(Ys[-1]))

    for st in states:
        yaux = dict_datasets[1].get_ts_plot(
            st,
            data_to_plot,
            data_norm,
            "Linear",
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[1].get_ts(
                st, thr_data, thr_norm, "Linear", "Total", ma, n,
            )
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(st)
            max_len = max(max_len, len(Ys[-1]))

    label_x = []
    label_y = []
    for i in range(len(Ys)):
        # get a 2d nparray to plot the rebased graph
        index_lab = label_index_pos(Ys[i])
        label_y.append(Ys[i][index_lab])
        label_x.append(Xs[i][index_lab])
        Ys[i] = Ys[i] + [math.nan] * (max_len - len(Ys[i]))
        Xs[i] = Xs[i] + [math.nan] * (max_len - len(Xs[i]))
    axis_reb_y.scale = scale_reb_y
    if threshold > 0:
        scale_reb_t = LinearScale()
    else:  # if thr == 0 Plot the values vs calendar dates
        scale_reb_t = DateScale(
            dateformat=dateformat,
            min=datetime.datetime.strptime(STDT, dateformat),
            max=datetime.datetime.strptime(ENDT, dateformat),
        )
        Xs = list(
            np.arange(
                np.datetime64(
                    datetime.datetime.strptime(STDT, dateformat).strftime(dateformat)
                ),
                np.datetime64(
                    datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
                )
                + np.timedelta64(1, "D"),
            )
        )
        label_x = [
            np.datetime64(
                datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
            )
        ] * (len(countries) + len(states))
    axis_reb_t.scale = scale_reb_t
    rebased_graph.axes = [axis_reb_t, axis_reb_y]
    if update:
        color_dict = persistent_colors(
            list(rebased_graph.marks[0].colors),
            list(rebased_graph.marks[1].text),
            label_names,
            all_colors,
        )
        colors = [color_dict[c] for c in label_names]
    else:
        colors = all_colors.copy()
    rebased_curves = Lines(scales={"x": scale_reb_t, "y": scale_reb_y}, colors=colors)
    rebased_curves.x = np.array(Xs)
    rebased_curves.y = np.array(Ys)
    rebased_labels = Label(
        apply_clip=False,
        scales={"x": scale_reb_t, "y": scale_reb_y},
        default_size=15,
        colors=colors,
    )
    rebased_labels.text = label_names
    rebased_labels.x = np.array(label_x)
    rebased_labels.y = np.array(label_y)
    rebased_graph.marks = [rebased_curves, rebased_labels]
    rebased_graph.title = title_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
    )
    if calendar_time:
        ind1, ind2 = date_axis_selector_tab_2.index
        min_x = date_axis_selector_tab_2.options[ind1]
        max_x = date_axis_selector_tab_2.options[ind2]
        update_zoom_tab_2({"new": [min_x, max_x]})
    return


In [None]:
value_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    plot_scale_button,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    tab_2_ma_ch,
    tab_2_ma_w,
    thr_val_slider,
    countries_selector,
    states_selector,
]


def update_rebased_graph(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        (
            data_to_plot,
            data_norm,
            data_type,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            threshold,
        ) = [v.value for v in value_buttons[:-2]]
        countries, states = list(countries_selector.value), list(states_selector.value)

        if "None" in countries:
            countries = []
        if "None" in states:
            states = []
        if len(countries) + len(states) == 0:
            countries = ["United States"]

        if ma:
            tab_2_ma_w.layout.visibility = "visible"
        else:
            tab_2_ma_w.layout.visibility = "hidden"

        if calendar_time.value:
            thr_val_slider.layout.visibility = "hidden"
            cat_thr_tab_2_buttons.layout.visibility = "hidden"
            date_axis_selector_tab_2.layout.visibility = "visible"
        else:
            thr_val_slider.layout.visibility = "visible"
            cat_thr_tab_2_buttons.layout.visibility = "visible"
            date_axis_selector_tab_2.layout.visibility = "hidden"

        plot_rebased_graph(
            data_to_plot,
            data_norm,
            data_type,
            threshold,
            countries,
            states,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            calendar_time.value,
        )
        update_link_download()
    return


thr_val_slider.continuous_update = False
thr_val_slider.observe(update_rebased_graph, "value")
tab_2_ma_ch.observe(update_rebased_graph, "value")
tab_2_ma_w.observe(update_rebased_graph, "value")
calendar_time.observe(update_rebased_graph, "value")
for v in value_buttons[:-5]:
    v.add_observe(update_rebased_graph, "value")


In [None]:
grid_2 = GridspecLayout(2, 2)
grid_2.layout.overflow = "hidden"
grid_2.layout.height = "100%"
grid_2.layout.width = "100%"
rebased_graph.layout.width = "auto"
rebased_graph.layout.height = "99%"
help_tab2 = help_button(help_tooltip["Tab 2"])
grid_2_left_panel = GridspecLayout(2, 2)
grid_2_left_panel[0, 0] = help_tab2
grid_2_left_panel[0, 1] = rebased_graph
grid_2_left_panel[1, :] = date_axis_selector_tab_2
grid_2_left_panel.layout.grid_template_columns = "32px auto"
grid_2_left_panel.layout.grid_template_rows = "auto 40px"
grid_2[:, 0] = grid_2_left_panel
grid_2[0, 1] = HBox([reb_top_right_buttons])
grid_2[1, 1] = tab_2_box_selectors
grid_2.layout.align_items = "stretch"
grid_2.layout.grid_template_columns = "auto 590px"
grid_2.layout.grid_template_rows = "315px auto"


In [None]:
# DNA Graph tab3

dna_countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

dna_states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)


dna_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
)

dna_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
)

dna_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
)

tab_3_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="100%", overflow="hidden"),
)
tab_3_norm = Checkbox(
    description="<tag title="
    + buttons_tooltip["Normalize"][0]
    + ">"
    + "Normalize"
    + "</tag>",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="100px", width="100%", overflow="auto"),
)
tab_3_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_3 = HBox(
    [tab_3_ma_ch, tab_3_norm, tab_3_ma_w],
    layout=Layout(
        width="555px",
        min_width="555px",
        max_width="555px",
        overflow="auto",
        align_items="stretch",
    ),
)

cat_tab_3_buttons = VBox(
    [dna_data_button, dna_norm_button, dna_type_button],
    layout=Layout(
        width="565px",
        height="105px",
        min_width="565px",
        max_width="565px",
        min_height="105px",
        max_height="105px",
        overflow="auto",
    ),
)

dna_top_buttons = VBox(
    [cat_tab_3_buttons, ma_box_tab_3],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        height="140px",
        min_height="140px",
        max_height="140px",
        overflow="auto",
    ),
)


In [None]:
ps_tab3 = preselection(presaved_sel_options, "Heat map", on_click_action)
hb_tab3 = VBox(
    [
        HBox([ps_tab3[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab3_1 = sc_buttons(["Scenario " + str(i) for i in range(1, 5)], on_click_scenario)
hb_tab3_1 = VBox(
    [
        HBox([ps_tab3_1[v] for v in list(ps_tab3_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab3_1.keys())), 4)
    ]
)
accordion_tab3 = Accordion(
    children=[hb_tab3, hb_tab3_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab3.set_title(0, "Predefined selections")
accordion_tab3.set_title(1, "Predefined scenarios")
accordion_tab3.observe(accordion_adapt_height, "selected_index")


In [None]:
tab_3_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="555px", align_items="stretch", min_height="60px", overflow="auto"
    ),
)
tab_3_box_selectors[0, :] = accordion_tab3
tab_3_box_selectors.layout.grid_template_rows = "80px auto"


In [None]:
# GridHeatMap, DNA graph

dna_x_scale = DateScale()
dna_y_scale = OrdinalScale(padding_y=0)
dna_color_scale = ColorScale(
    colors=["white", "#ffa600", "#ff6e00", "#ff4417", "#d31522", "#8c000e", "#560410"]
)
dna_x_ax = Axis(scale=dna_x_scale, tick_format="%m/%d")
dna_y_ax = Axis(scale=dna_y_scale, orientation="vertical", side="right")
dna_color_ax = ColorAxis(scale=dna_color_scale, tick_format=",")
dna_axes = [dna_x_ax, dna_y_ax, dna_color_ax]

dna_layout = Layout(width="auto", height="99%")
dna_fig_margin = {"top": 50, "bottom": 60, "left": 0, "right": 80}

dna_figure = Figure(
    axes=dna_axes,
    fig_margin=dna_fig_margin,
    layout=dna_layout,
    min_aspect_ratio=0.0,
    title="Cases",
    padding_y=0,
    animation_duration=1000,
    null_color="#808080",
)


In [None]:
def update_dna(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "None" in dna_countries_selector.value:
            countries = []
        else:
            countries = list(dna_countries_selector.value)

        if "None" in dna_states_selector.value:
            states = []
        else:
            states = list(dna_states_selector.value)

        ma, n = tab_3_ma_ch.value, tab_3_ma_w.value
        if ma:
            tab_3_ma_w.layout.visibility = "visible"
        else:
            tab_3_ma_w.layout.visibility = "hidden"

        if len(countries) + len(states) == 0:
            countries = ["United States", "World"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(states) == 0 and len(countries) >= 1:
            if len(countries) == 1:
                if "United States" in countries:
                    countries = ["World", "United States"]
                else:
                    countries = countries + ["United States"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(countries) == 0 and len(states) >= 1:
            if len(states) == 1:
                if "New York" in states:
                    states = ["New Jersey", "New York"]
                else:
                    states = states + ["New York"]
            df_DNA = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
        else:
            data_countries = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
            data_states = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
            df_DNA = pd.merge(
                data_countries,
                data_states,
                left_index=True,
                right_index=True,
                how="outer",
            )
        df_DNA = df_DNA.iloc[:, ::-1]
        title = dna_data_button.value
        if dna_norm_button.value == "Per million":
            title += "/1M population"
        if dna_type_button.value != "Total":
            title += " " + dna_type_button.value
        if ma:
            title += " ({0:d}days m.a)".format(n)
        if tab_3_norm.value:
            df_DNA = 100 * df_DNA.divide(df_DNA.max(axis=0))
            title += ", (normalized)"

        dna_color_ax.visible = not tab_3_norm.value

        dna_figure.marks[0].color = df_DNA.values.T
        dna_figure.marks[0].row = list(df_DNA.columns.values)
        if dna_data_button.value == "Recovered" or dna_data_button.value == "Tests":
            dna_color_scale.colors = [
                "white",
                "#C5E8B7",
                "#ABE098",
                "#83D475",
                "#57C84D",
                "#2EB62C",
                "#207f1e",
            ]
        elif dna_data_button.value == "Active Cases":
            dna_color_scale.colors = [
                "#069740",
                "#9acd32",
                "#e4d31b",
                "#ffa600",
                "#d31522",
                "#8c000e",
                "#560410",
            ]
        else:
            dna_color_scale.colors = [
                "white",
                "#ffa600",
                "#ff6e00",
                "#ff4417",
                "#d31522",
                "#8c000e",
                "#560410",
            ]
        dna_figure.title = title
        update_link_download()
    return


dna_data_button.add_observe(update_dna, "value")
dna_norm_button.add_observe(update_dna, "value")
dna_type_button.add_observe(update_dna, "value")
tab_3_ma_ch.observe(update_dna, "value")
tab_3_norm.observe(update_dna, "value")
tab_3_ma_w.observe(update_dna, "value")


In [None]:
dna_tooltip_table = """
<table id="stats_table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr>
    <td>Country/State</td>
    <td>{1:s}</td>
</tr>
<tr>
    <td>{2:s}</td>
    <td>{3:,}</td>
</tr>
</table>
"""
dna_stats_tootltip = HTML()


def update_dna_tooltip(*args):
    Date = str(dna_figure.marks[0].column[int(args[1]["data"]["column_num"])])[:-19]
    Country = args[1]["data"]["row"]
    val = args[1]["data"]["color"]
    try:
        if "pop" in dna_figure.title or "%" in dna_figure.title:
            val = round(val, 1)
        else:
            val = int(val)
    except:
        val = 0
    dna_stats_tootltip.value = dna_tooltip_table.format(
        *tuple([Date, Country, dna_figure.title, val])
    )
    return


In [None]:
grid_3 = GridspecLayout(2, 2)
grid_3.layout.overflow = "hidden"
grid_3.layout.height = "100%"
grid_3.layout.width = "100%"
help_tab3 = help_button(help_tooltip["Tab 3"])
grid_3_left_panel = GridspecLayout(1, 2)
grid_3_left_panel[0, 0] = help_tab3
grid_3_left_panel[0, 1] = dna_figure
grid_3_left_panel.layout.grid_template_columns = "32px auto"
grid_3[:, 0] = grid_3_left_panel
grid_3[0, 1] = HBox([dna_top_buttons])
grid_3[1, 1] = HBox([tab_3_box_selectors])
grid_3.layout.align_items = "stretch"
grid_3.layout.grid_template_columns = "auto 590px"
grid_3.layout.grid_template_rows = "140px auto"


In [None]:
# Free graph

# Buttons and interactions


min_button_width_1 = "100px"
x_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width="100px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
)

y_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width="100px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Data"][1:],
    description_tooltip=buttons_tooltip["Data"][0],
)


x_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
)

y_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Norm"][1:],
    description_tooltip=buttons_tooltip["Norm"][0],
)


x_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
)

y_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Scale"][1:],
    description_tooltip=buttons_tooltip["Scale"][0],
)


x_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width="120px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
)

y_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Daily change",
    description="Type",
    min_button_width="120px",
    min_description_width=min_description_width_1,
    tooltips=buttons_tooltip["Type"][1:],
    description_tooltip=buttons_tooltip["Type"][0],
)

free_checkbox = Checkbox(
    value=False,
    description="<tag title="
    + buttons_tooltip["Last data point"][0]
    + ">"
    + "Last data point"
    + "</tag>",
    disabled=False,
    style={"description_width": "initial"},
    layout=Layout(width="auto", height="auto"),
)

tab_4_ma_ch = Checkbox(
    description="<tag title="
    + buttons_tooltip["Moving average"][0]
    + ">"
    + "Moving Average"
    + "</tag>",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="hidden"),
)
tab_4_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="195px", width="195px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_4 = HBox(
    [tab_4_ma_ch, tab_4_ma_w],
    layout=Layout(
        width="100%",
        overflow="auto",
        align_items="stretch",
        justify_content="space-between",
    ),
)


In [None]:
x_data = VBox(
    [x_data_button, x_norm_button, x_type_button, x_scale_button,],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="auto",
    ),
)

y_data = VBox(
    [y_data_button, y_norm_button, y_type_button, y_scale_button,],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="auto",
    ),
)


In [None]:
x_data_buttons = HBox(
    [
        HTML(
            "<p style='text-align:left; font-size:18px; font-weight:bold; padding: 0; margin: 0; color: #ff8b0e; margin-right: 15px'>X DATA : </p>"
        ),
        x_data,
    ],
    layout=Layout(
        width="670px",
        height="145px",
        align_items="center",
        border="solid #ff8b0e",
        margin="0 0 5px 0",
    ),
)
y_data_buttons = HBox(
    [
        HTML(
            "<p style='text-align:left; font-size:18px; font-weight:bold; padding: 0; margin: 0; color: #ff8b0e; margin-right: 15px'>Y DATA : </p>"
        ),
        y_data,
    ],
    layout=Layout(
        width="670px", height="145px", align_items="center", border="solid #ff8b0e"
    ),
)
cat_tab_4_buttons = VBox([x_data_buttons, y_data_buttons])


In [None]:
free_countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

free_states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)


In [None]:
ps_tab4 = preselection(presaved_sel_options, "Custom", on_click_action)
hb_tab4 = VBox(
    [
        HBox([ps_tab4[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
ps_tab4_1 = sc_buttons(["Scenario " + str(i) for i in range(1, 5)], on_click_scenario)
hb_tab4_1 = VBox(
    [
        HBox([ps_tab4_1[v] for v in list(ps_tab4_1.keys())[i : i + 4]])
        for i in range(0, len(list(ps_tab4_1.keys())), 4)
    ]
)
accordion_tab4 = Accordion(
    children=[hb_tab4, hb_tab4_1],
    selected_index=None,
    layout=Layout(width="auto", height="auto"),
)
accordion_tab4.set_title(0, "Predefined selections")
accordion_tab4.set_title(1, "Predefined scenarios")
accordion_tab4.observe(accordion_adapt_height, "selected_index")


In [None]:
tab_4_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="670px", align_items="stretch", min_height="60px", overflow="auto"
    ),
)
tab_4_box_selectors[0, :] = accordion_tab4
tab_4_box_selectors.layout.grid_template_rows = "80px auto"

In [None]:
countries_selector_box_free = VBox(
    [HTML('<p style="text-align:center">Countries</p>'), free_countries_selector],
    layout=Layout(width="100%", height="100%"),
)
states_selector_box_free = VBox(
    [HTML('<p style="text-align:center">US States</p>'), free_states_selector],
    layout=Layout(width="100%", height="100%"),
)

In [None]:
free_scatter_sc_x = LinearScale()
free_scatter_sc_y = LinearScale()
free_scatter_ax_x = Axis(scale=free_scatter_sc_x, tick_format=",")
free_scatter_ax_y = Axis(
    scale=free_scatter_sc_y, orientation="vertical", tick_format=","
)
free_pz = PanZoom(scales={"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]})
free_scatter_fig = Figure(
    axes=[free_scatter_ax_x, free_scatter_ax_y],
    interaction=free_pz,
    layout=Layout(width="auto", height="99%"),
    fig_margin={"top": 50, "left": 60, "right": 100, "bottom": 50},
)


In [None]:
def last_index_nan(x_li, y_li):
    N = len(x_li)
    i = N - 1
    while i >= 0 and (np.isnan(x_li[i]) or np.isnan(y_li[i])):
        i -= 1
    return i


def clean_line_plot(xs, ys):
    for i in range(len(xs)):
        for j in range(len(xs[i])):
            if np.isnan(xs[i][j]) or np.isnan(ys[i][j]):
                xs[i][j] = np.nan
                ys[i][j] = np.nan
    return xs, ys


def free_scatter(
    selected_countries, datas, norms, scales, data_types, ma, n, last_point=True
):
    global free_mark_x, free_mark_y
    Ys = []
    Xs = []
    label_names = []
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    if len(free_scatter_fig.marks) > 0:
        update = True
        new_labels = [codes_to_countries[c] for c in selected_countries[0]] + list(
            selected_countries[1]
        )
        color_dict = persistent_colors(
            list(free_scatter_fig.marks[0].colors),
            list(free_scatter_fig.marks[1].text),
            new_labels,
            all_colors,
        )
    else:
        update = False
        colors = all_colors.copy()

    if scales[0] == "Log":
        free_scatter_sc_x = LogScale()
    elif scales[0] == "Linear":
        free_scatter_sc_x = LinearScale()

    if scales[1] == "Log":
        free_scatter_sc_y = LogScale()
    elif scales[1] == "Linear":
        free_scatter_sc_y = LinearScale()

    for i in range(len(selected_countries)):
        for c in selected_countries[i]:
            if last_point:
                x = dict_datasets[i].get_value(
                    c,
                    datas[0],
                    norms[0],
                    scales[0],
                    data_types[0],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                y = dict_datasets[i].get_value(
                    c,
                    datas[1],
                    norms[1],
                    scales[1],
                    data_types[1],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                if not (np.isnan(x) or np.isnan(y)):
                    Xs.append(x)
                    Ys.append(y)
                    if dict_datasets[i].name == "World":
                        label_names.append(codes_to_countries[c])
                    else:
                        label_names.append(c)
            else:
                Xs.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[0], norms[0], scales[0], data_types[0], ma, n)
                        .values
                    )
                )
                Ys.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[1], norms[1], scales[1], data_types[1], ma, n)
                        .values
                    )
                )
                if dict_datasets[i].name == "World":
                    label_names.append(codes_to_countries[c])
                else:
                    label_names.append(c)
            if update:
                if dict_datasets[i].name == "World":
                    colors.append(color_dict[codes_to_countries[c]])
                else:
                    colors.append(color_dict[c])
    free_scatter_ax_x.scale = free_scatter_sc_x
    free_scatter_ax_y.scale = free_scatter_sc_y
    free_scatter_fig.axes = [free_scatter_ax_x, free_scatter_ax_y]

    free_labels = Label(
        apply_clip=False,
        scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
        default_size=15,
        x_offset=4,
        colors=colors,
    )
    free_labels.text = label_names
    if last_point:
        free_labels.x = np.array(Xs)
        free_labels.y = np.array(Ys)

        free_mark = Scatter(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            colors=colors,
            marker_size=30,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)

    else:
        Xs, Ys = clean_line_plot(Xs, Ys)
        free_labels.x = np.array(
            [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )
        free_labels.y = np.array(
            [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )

        free_mark = Lines(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            marker_size=30,
            colors=colors,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)
        free_mark_x = Xs
        free_mark_y = Ys

    free_scatter_fig.marks = [free_mark, free_labels]

    def title_free_scatter(data, norm, scale, data_type, ma, n):
        res = ""
        if scale == "Log":
            res += scale
        res += " " + data
        if norm == "Per million":
            res += "/1M population"
        if data_type != "Total":
            res += " " + data_type.lower()
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res

    free_scatter_fig.title = (
        title_free_scatter(datas[1], norms[1], scales[1], data_types[1], False, None)
        + " vs "
        + title_free_scatter(datas[0], norms[0], scales[0], data_types[0], ma, n)
        + ", "
        + ENDT
    )
    free_pz.scales = {"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]}
    return


In [None]:
def update_free_scatter_fig(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "None" in free_countries_selector.value:
            countries = []
        else:
            countries = [countries_to_codes[c] for c in free_countries_selector.value]

        if "None" in free_states_selector.value:
            states = []
        else:
            states = free_states_selector.value

        if len(countries) + len(states) == 0:
            countries = ["USA"]

        ma, n = tab_4_ma_ch.value, tab_4_ma_w.value
        if ma:
            tab_4_ma_w.layout.visibility = "visible"
        else:
            tab_4_ma_w.layout.visibility = "hidden"

        selected_countries = [countries, states]
        datas = [x_data_button.value, y_data_button.value]
        norms = [x_norm_button.value, y_norm_button.value]
        scales = [x_scale_button.value, y_scale_button.value]
        data_types = [x_type_button.value, y_type_button.value]
        last_point = free_checkbox.value
        free_scatter(
            selected_countries, datas, norms, scales, data_types, ma, n, last_point
        )
        if last_point:
            free_play.layout.visibility = "hidden"
        else:
            free_play.layout.visibility = "visible"
            
        update_link_download()
    return


x_data_button.add_observe(update_free_scatter_fig, "value")
y_data_button.add_observe(update_free_scatter_fig, "value")
x_norm_button.add_observe(update_free_scatter_fig, "value")
y_norm_button.add_observe(update_free_scatter_fig, "value")
x_scale_button.add_observe(update_free_scatter_fig, "value")
y_scale_button.add_observe(update_free_scatter_fig, "value")
x_type_button.add_observe(update_free_scatter_fig, "value")
y_type_button.add_observe(update_free_scatter_fig, "value")
free_checkbox.observe(update_free_scatter_fig, "value")
tab_4_ma_ch.observe(update_free_scatter_fig, "value")
tab_4_ma_w.observe(update_free_scatter_fig, "value")


In [None]:
free_play = Play(
    value=0,
    min=0,
    max=1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto"),
)


def update_free_play(*args):
    global free_mark_x, free_mark_y
    index = args[0]["new"]
    date = date_selector.options[args[0]["new"]]
    Xs = np.array(free_mark_x)[:, : index + 1]
    Ys = np.array(free_mark_y)[:, : index + 1]
    Xs, Ys = clean_line_plot(list(Xs), list(Ys))
    free_scatter_fig.marks[1].x = np.array(
        [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[1].y = np.array(
        [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[0].x = np.array(Xs)
    free_scatter_fig.marks[0].y = np.array(Ys)
    free_scatter_fig.title = free_scatter_fig.title[:-10] + str(date)
    return


help_tab4 = help_button(help_tooltip["Tab 4"])
tab_4_top_right_buttons = VBox(
    [HBox(
                    [free_checkbox, free_play],
                    layout=Layout(justify_content="space-between", align_items='stretch'),
                ),
            ],
    layout=Layout(height="auto", width="auto"),)


In [None]:
grid_4 = GridspecLayout(2, 2)
grid_4.layout.overflow = "hidden"
grid_4.layout.width = "100%"
grid_4.layout.height = "100%"
tab_4_top_right_buttons.layout.margin = "0px 3px 0px 5px"
grid_4_left_panel = GridspecLayout(1, 2)
grid_4_left_panel[0, 0] = help_tab4
grid_4_left_panel[0, 1] = free_scatter_fig
grid_4_left_panel.layout.grid_template_columns = "32px auto"
grid_4_left_panel.layout.height = "99%"
grid_4_left_panel.layout.overflow = "hidden"
grid_4[:, 0] = grid_4_left_panel
grid_4[0, 1] = Box(
    [
        VBox(
            [tab_4_top_right_buttons, cat_tab_4_buttons, ma_box_tab_4],
            layout=Layout(
                height="360px", min_height="360px", max_height="360px", overflow="auto"
            ),
        )
    ]
)
grid_4[1, 1] = HBox([tab_4_box_selectors])
grid_4.layout.align_items = "stretch"
grid_4.layout.grid_template_columns = "auto 690px"
grid_4.layout.grid_template_rows = "360px auto"


In [None]:
tab_2_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    plot_scale_button,
    thr_val_slider,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    calendar_time,
    date_axis_selector_tab_2,
    tab_2_ma_ch,
    tab_2_ma_w,
    countries_selector,
    states_selector,
]
tab_2_buttons_actions = (
    [update_rebased_graph] * 8 + [update_zoom_tab_2] + [update_rebased_graph] * 4
)

tab_3_buttons = [
    dna_data_button,
    dna_norm_button,
    dna_type_button,
    tab_3_ma_ch,
    tab_3_ma_w,
    tab_3_norm,
    dna_countries_selector,
    dna_states_selector,
]
tab_3_buttons_actions = [update_dna] * len(tab_3_buttons)

tab_4_buttons = [
    x_data_button,
    x_norm_button,
    x_scale_button,
    x_type_button,
    y_data_button,
    y_norm_button,
    y_scale_button,
    y_type_button,
    free_checkbox,
    tab_4_ma_ch,
    tab_4_ma_w,
    free_countries_selector,
    free_states_selector,
]
tab_4_buttons_actions = [update_free_scatter_fig] * len(tab_4_buttons)

dict_tab_buttons = dict(zip([1, 2, 3], [tab_2_buttons, tab_3_buttons, tab_4_buttons]))
dict_tab_buttons_actions = dict(
    zip(
        [1, 2, 3], [tab_2_buttons_actions, tab_3_buttons_actions, tab_4_buttons_actions]
    )
)


In [None]:
def map_pick(map_mark, country_id, selector, dataset, dict_names=None):
    try:
        keys = list(map_mark.color.keys())
        if country_id in keys:
            keys.remove(country_id)
            res = dict(zip(keys, ["X"] * len(keys)))
            if len(keys) == 0:
                res = {-10: ""}
        else:
            keys.append(country_id)
            if -10 in keys:
                keys.remove(-10)
            res = dict(zip(keys, ["X"] * len(keys)))
        map_mark.color = res
        if len(keys) == 0 or -10 in keys:
            selector.value = ["None"]
        else:
            if dict_names:
                selector.value = [dict_names[dataset.get_name(c)] for c in keys]
            else:
                selector.value = [dataset.get_name(c) for c in keys]
    except:
        pass


def sync_selector_map(selector, map_mark, dataset, dict_names=None):
    values = selector.value
    res = {}
    for v in values:
        if v == "None":
            res = {-10: ""}
            break
        else:
            try:
                if dict_names:
                    country_id = dataset.get_ID(dict_names[v])
                else:
                    country_id = dataset.get_ID(v)
                res[country_id] = "X"
            except:
                pass
    map_mark.color = res
    return


class CountryPicker:
    def country_map_pick_update(self, obj, val):
        try:
            country_id = val["data"]["id"]
            if self.country_dataset.get_name(
                country_id
            ) in self.country_dataset.get_columns("Cases"):
                map_pick(
                    self.wm,
                    country_id,
                    self.countries_selector,
                    self.country_dataset,
                    self.codes_to_countries,
                )
        except:
            pass

    def states_map_pick_update(self, obj, val):
        try:
            states_id = val["data"]["id"]
            map_pick(self.wm_US, states_id, self.states_selector, self.states_dataset)
        except:
            pass

    def wld_hover(self, obj, val):
        try:
            country_id = val["data"]["id"]
            country_name = self.codes_to_countries[
                self.country_dataset.get_name(country_id)
            ]
            self.tooltip.value = country_name
        except:
            self.tooltip.value = ""

    def us_states_hover(self, obj, val):
        try:
            state_id = val["data"]["id"]
            state_name = self.states_dataset.get_name(state_id)
            self.tooltip.value = state_name
        except:
            self.tooltip.value = ""

    def button_update(self, val):
        if val["new"] == True:
            descr = val["owner"].description
            box = self.dict_button_map[descr]
            if descr == "List Picker":
                self.collapse_button.layout.visibility = "hidden"
                self.grid_picker[0, 1] = box
                self.outer_grid[:, 0] = self.left_panel
                self.outer_grid[1, 1] = self.tab_box_selector
                self.box_grid.layout.grid_template_columns = "45% 5% 45%"
                self.outer_grid.layout.grid_template_rows = self.grid_template_rows
                hide_button.value = False
            else:
                self.collapse_button.layout.visibility = "visible"
                self.map_picker.marks = [box]
                self.grid_picker[0, 1] = self.map_picker
                if self.hide:
                    self.outer_grid[0, 0] = HTML()
                    row_template = self.grid_template_rows
                else:
                    self.outer_grid[0, 0] = self.left_panel
                    row_template = "30% 70%"
                self.outer_grid[1, :] = self.tab_box_selector
                self.outer_grid.layout.grid_template_rows = row_template
                self.tab_box_selector.layout.width = "100%"
                
            self.outer_grid.layout.grid_template_columns = self.grid_template_columns
            self.grid_picker.layout.grid_template_columns = "110px auto"

    def __init__(
        self,
        countries_selector,
        states_selector,
        country_dataset,
        states_dataset,
        codes_to_countries,
        countries_to_codes,
        outer_grid,
        left_panel,
        grid_template_columns,
        grid_template_rows,
        tab_box_selector,
        hide=True,
    ):
        self.countries_selector = countries_selector
        self.states_selector = states_selector
        self.country_dataset = country_dataset
        self.states_dataset = states_dataset
        self.codes_to_countries = codes_to_countries
        self.countries_to_codes = countries_to_codes
        self.outer_grid = outer_grid
        self.left_panel = left_panel
        self.grid_template_columns = grid_template_columns
        self.grid_template_rows = grid_template_rows
        self.tab_box_selector = tab_box_selector
        self.hide = hide
        sc_geo = Mercator(center=(-20, 70), scale_factor=150)
        sc_c1 = OrdinalColorScale(domain=["X"], colors=["red"],)
        self.tooltip = HTML()
        caxis = ColorAxis(scale=sc_c1, visible=False)

        self.wm = Map(
            map_data=topo_load("map_data/WorldMap.json"),
            scales={"projection": sc_geo, "color": sc_c1},
            color={},
            tooltip=self.tooltip,
            colors={"default_color": "Grey"},
            hovered_styles={"hovered_fill": "White"},
        )

        sc_geo_US = AlbersUSA()
        sc_c1_US = OrdinalColorScale(domain=["X"], colors=["red"],)
        caxis_US = ColorAxis(scale=sc_c1_US, visible=False)
        self.wm_US = Map(
            map_data=topo_load("map_data/USStatesMap.json"),
            scales={"projection": sc_geo_US, "color": sc_c1_US},
            color={},
            tooltip=self.tooltip,
            colors={"default_color": "Grey"},
            hovered_styles={"hovered_fill": "White"},
        )
        self.map_picker = Figure(
            marks=[self.wm],
            axes=[caxis],
            layout=Layout(width="99%", height="auto"),
            fig_margin={"top": 5, "bottom": 5, "left": 0, "right": 0},
        )
        self.wm.on_element_click(self.country_map_pick_update)
        self.wm_US.on_element_click(self.states_map_pick_update)
        self.wm.on_hover(self.wld_hover)
        self.wm_US.on_hover(self.us_states_hover)
        self.countries_selector.observe(
            lambda *args: sync_selector_map(
                self.countries_selector,
                self.wm,
                self.country_dataset,
                self.countries_to_codes,
            ),
            "value",
        )
        self.states_selector.observe(
            lambda *args: sync_selector_map(
                self.states_selector, self.wm_US, self.states_dataset
            ),
            "value",
        )
        countries_box = VBox(
            [HTML('<p style="text-align:center">Countries</p>'), countries_selector],
            layout=Layout(width="100%"),
        )
        states_box = VBox(
            [HTML('<p style="text-align:center">US States</p>'), states_selector],
            layout=Layout(width="100%"),
        )
        box_grid = GridspecLayout(1, 3)
        box_grid.layout = Layout(width="100%")
        box_grid[0, 0] = countries_box
        box_grid[0, 2] = states_box

        self.box_grid = box_grid

        self.pick_buttons = Toggle_Buttons(
            options=["List Picker", "World Map", "US Map"],
            value="List Picker",
            description="",
            min_button_width="100px",
            min_description_width="0px",
            horizontal=False,
            button_width="auto",
            style="warning",
            tooltips=["List Picker", "World map picker", "US States map picker"],
            description_tooltip=""
        )
        self.collapse_button = Button(
            description="Collapse",
            button_style="danger",
            layout=Layout(width="auto", visibility="hidden", margin="20px 0px 0px 0px"),
        )
        grid_picker = GridspecLayout(1, 2)
        grid_picker.layout = Layout(width="100%", height="100%")
        grid_picker[0, 0] = VBox([self.pick_buttons, self.collapse_button])
        grid_picker[0, 1] = self.box_grid
        grid_picker.layout.grid_template_columns = "110px auto"
        self.grid_picker = grid_picker
        self.box_grid.layout.grid_template_columns = "45% 5% 45%"
        self.dict_button_map = {
            "List Picker": self.box_grid,
            "World Map": self.wm,
            "US Map": self.wm_US,
        }
        self.pick_buttons.add_observe(self.button_update, "value")
        self.collapse_button.on_click(
            lambda x: self.pick_buttons.set_value("List Picker")
        )


In [None]:
def initialize_dashboard(dict_datasets):
    global top_countries_names, states, presaved_sel_lists, dict_scenario

    # Initialize grid_1

    date_selector.options = dict_datasets[0].get_index("Cases")
    date_selector.value = dict_datasets[0].ENDT

    date_axis_selector.options = list(
        np.arange(
            np.datetime64(
                datetime.datetime.strptime(dict_datasets[0].STDT, dateformat).strftime(
                    dateformat
                )
            ),
            np.datetime64(
                datetime.datetime.strptime(dict_datasets[0].ENDT, dateformat).strftime(
                    dateformat
                )
            )
            + np.timedelta64(1, "D"),
        )
    )

    date_axis_selector.index = (0, dict_datasets[0].get_len() - 1)

    play.max = dict_datasets[0].get_len() - 1
    play.value = dict_datasets[0].get_len() - 1

    wm.color = map_color(
        dict_datasets[0],
        data_buttons.value,
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    main_map.title = map_title(
        "World",
        "Cases",
        "Values",
        "Total",
        dict_datasets[0].ENDT,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
    )

    update_tables(
        data_buttons.value,
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    update_main_graph(
        dict_datasets[0],
        "USA",
        "Values",
        "Linear",
        "Total",
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].STDT,
        dict_datasets[0].ENDT,
        axis_t,
        axis_y,
        main_graph,
        init=True,
    )

    # initialize grid_2
    df = pd.read_csv(
        FOLDER_WORLD + "continents_restricted.csv", sep=";", index_col=1
    ).fillna("NA")
    presaved_sel_lists["Europe"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "EU"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Africa"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "AF"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Asia"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "AS"].index.values
        if c in codes_to_countries.keys()
        and c in dict_datasets[0].get_columns("Cases")
        and c != "EGY"
    ]
    presaved_sel_lists["North America"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "NA"].index.values
        if c in codes_to_countries.keys()
        and c in dict_datasets[0].get_columns("Cases")
        and c != "EGY"
    ]
    presaved_sel_lists["South America"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "SA"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Oceania"] = [
        codes_to_countries[c]
        for c in df[df["CC"] == "OC"].index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    top_countries_codes = (
        dict_datasets[0]
        .dict_data["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:]
        .index.values
    )
    top_countries_names = [codes_to_countries[c] for c in top_countries_codes]
    states = list(
        dict_datasets[1]
        .dict_data["Cases"]
        .loc[dict_datasets[1].ENDT, np.array(list(states_to_codes.keys()))]
        .sort_values(ascending=False)
        .index.values
    )
    presaved_sel_lists["Top10 (Cases)"] = [
        codes_to_countries[c]
        for c in top_countries_codes[:10]
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]
    presaved_sel_lists["Top10 (Deaths)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Cases/1M)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data_pop["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Deaths/1M)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .dict_data_pop["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Cases daily chg)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .get_value(None, "Cases", "Values", "Linear", "Daily change", False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    presaved_sel_lists["Top10 (Deaths daily chg)"] = [
        codes_to_countries[c]
        for c in dict_datasets[0]
        .get_value(None, "Deaths", "Values", "Linear", "Daily change", False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns("Cases")
    ]

    date_axis_selector_tab_2.options = date_axis_selector.options
    date_axis_selector_tab_2.index = (0, dict_datasets[0].get_len() - 1)
    countries_options = list(np.sort(top_countries_names))
    states_options = list(np.sort(states))
    countries_selector.options = ["None", "World"] + countries_options
    states_selector.options = ["None"] + states_options
    cp_tab_2 = CountryPicker(
        countries_selector,
        states_selector,
        dict_datasets[0],
        dict_datasets[1],
        codes_to_countries,
        countries_to_codes,
        grid_2,
        grid_2_left_panel,
        "auto 590px",
        "315px auto",
        tab_2_box_selectors,
        hide=False,
    )
    tab_2_box_selectors[1, :] = cp_tab_2.grid_picker
    countries_selector.value = top_countries_names[:6]
    states_selector.value = states[:2]
    countries_selector.observe(update_rebased_graph, "value")
    states_selector.observe(update_rebased_graph, "value")
    tab_2_box_selectors.layout.grid_template_rows = "80px auto"
    plot_rebased_graph(
        "Cases",
        "Values",
        "Total",
        1000,
        top_countries_names[:6],
        states[:2],
        "Log",
        "Cases",
        "Values",
        tab_2_ma_ch.value,
        tab_2_ma_w.value,
        calendar_time.value,
    )

    # initialize grid_3

    dna_countries_selector.options = ["None", "World"] + countries_options
    dna_states_selector.options = ["None"] + states_options
    cp_tab_3 = CountryPicker(
        dna_countries_selector,
        dna_states_selector,
        dict_datasets[0],
        dict_datasets[1],
        codes_to_countries,
        countries_to_codes,
        grid_3,
        grid_3_left_panel,
        "auto 590px",
        "140px auto",
        tab_3_box_selectors,
    )
    tab_3_box_selectors[1, :] = cp_tab_3.grid_picker
    dna_countries_selector.value = top_countries_names[1:20]
    dna_states_selector.value = states[:4]
    dna_countries_selector.observe(update_dna, "value")
    dna_states_selector.observe(update_dna, "value")

    data_countries = (
        dict_datasets[0]
        .get_ts(
            [countries_to_codes[c] for c in dna_countries_selector.value],
            dna_data_button.value,
            dna_norm_button.value,
            "Linear",
            dna_type_button.value,
            tab_3_ma_ch.value,
            tab_3_ma_ch.value,
        )
        .rename(columns=codes_to_countries)
    )
    data_states = dict_datasets[1].get_ts(
        dna_states_selector.value,
        dna_data_button.value,
        dna_norm_button.value,
        "Linear",
        dna_type_button.value,
        tab_3_ma_ch.value,
        tab_3_ma_ch.value,
    )
    df_DNA = pd.merge(
        data_countries, data_states, left_index=True, right_index=True, how="outer"
    ).iloc[:, ::-1]

    column_dna = pd.to_datetime(df_DNA.index.values)
    row_dna = list(df_DNA.columns.values)
    dna_heat_map = GridHeatMap(
        color=df_DNA.values.T,
        column=column_dna,
        row=row_dna,
        column_align="start",
        scales={"column": dna_x_scale, "row": dna_y_scale, "color": dna_color_scale},
        stroke="white",
    )
    dna_figure.marks = [dna_heat_map]
    dna_heat_map.tooltip = dna_stats_tootltip
    dna_heat_map.on_hover(update_dna_tooltip)
    tab_3_box_selectors.layout.grid_template_rows = "80px auto"

    # initialize grid_4

    free_countries_selector.options = ["None", "World"] + countries_options
    free_states_selector.options = ["None"] + states_options
    cp_tab_4 = CountryPicker(
        free_countries_selector,
        free_states_selector,
        dict_datasets[0],
        dict_datasets[1],
        codes_to_countries,
        countries_to_codes,
        grid_4,
        grid_4_left_panel,
        "auto 690px",
        "360px auto",
        tab_4_box_selectors,
        hide=False,
    )
    tab_4_box_selectors[1, :] = cp_tab_4.grid_picker
    free_countries_selector.value = [
        "United States",
        "France",
        "World",
        "Italy",
        "China",
        "Germany",
        "Russia",
        "Brazil",
    ]
    free_states_selector.value = ["New York", "California"]
    free_countries_selector.observe(update_free_scatter_fig, "value")
    free_states_selector.observe(update_free_scatter_fig, "value")

    free_play.max = dict_datasets[0].get_len() - 1
    free_play.value = dict_datasets[0].get_len() - 1
    selected_countries = [
        ["USA", "FRA", "WLD", "ITA", "CHN", "DEU", "RUS", "BRA"],
        ["New York", "California"],
    ]
    datas = ["Cases", "Cases"]
    norms = ["Values", "Values"]
    scales = ["Log", "Log"]
    data_types = ["Total", "Daily change"]
    last_point = False
    free_scatter(
        selected_countries,
        datas,
        norms,
        scales,
        data_types,
        tab_4_ma_ch.value,
        tab_4_ma_w.value,
        last_point,
    )
    date_axis_selector.observe(update_zoom, "value")
    date_axis_selector.continuous_update = False
    date_axis_selector_tab_2.observe(update_zoom_tab_2, "value")
    date_axis_selector_tab_2.continuous_update = False
    play.observe(update_index_date, "value")
    date_selector.observe(select_date, "value")
    free_play.observe(update_free_play, "value")
    tab_4_box_selectors.layout.grid_template_rows = "80px auto"
    download_button.layout.visibility = "visible"
    hide_button.layout.visibility = "visible"
    update_link_download()
    dict_scenario = {
        1: {
            "Scenario 1": [
                "Cases",
                "Per million",
                "Daily change",
                "Linear",
                None,
                None,
                None,
                True,
                [np.datetime64(ENDT) - np.timedelta64(30), np.datetime64(ENDT)],
                True,
                7,
                presaved_sel_lists["Top10 (Cases)"],
                ["None"],
            ],
            "Scenario 2": [
                "Cases",
                "Values",
                "Total",
                "Log",
                None,
                None,
                None,
                True,
                [np.datetime64("2020-05-25"), np.datetime64(ENDT)],
                False,
                None,
                ["None"],
                [
                    "California",
                    "Texas",
                    "Florida",
                    "Georgia",
                    "Arizona",
                    "North Carolina",
                    "Louisiana",
                    "Alabama",
                ],
            ],
            "Scenario 3": [
                "Deaths",
                "Values",
                "Total",
                "Linear",
                1000,
                "Deaths",
                "Values",
                False,
                None,
                False,
                None,
                presaved_sel_lists["Top10 (Deaths)"],
                ["None"],
            ],
            "Scenario 4": [
                "Cases",
                "Values",
                "Daily change",
                "Linear",
                1000,
                "Cases",
                "Values",
                False,
                None,
                True,
                7,
                [
                    "Japan",
                    "Germany",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "South Korea",
                    "Spain",
                    "Belgium",
                    "Netherlands",
                ],
                ["None"],
            ],
        },
        2: {
            "Scenario 3": [
                "Deaths",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                [
                    "United States",
                    "China",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Korea",
                    "Spain",
                    "Australia",
                    "Mexico",
                    "Indonesia",
                    "Netherlands",
                    "Saudi Arabia",
                    "Turkey",
                    "Switzerland",
                ],
                ["None"],
            ],
            "Scenario 2": [
                "Cases",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                ["None"],
                dna_states_selector.options[1:],
            ],
            "Scenario 1": [
                "Cases",
                "Per million",
                "Daily change",
                True,
                7,
                False,
                [
                    "United States",
                    "China",
                    "Hong Kong",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Korea",
                    "Spain",
                    "Australia",
                    "Mexico",
                    "Indonesia",
                    "Netherlands",
                    "Saudi Arabia",
                    "Turkey",
                    "Switzerland",
                ],
                ["None"],
            ],
            "Scenario 4": [
                "Cases",
                "Values",
                "Daily change",
                True,
                7,
                True,
                [
                    "China",
                    "Hong Kong",
                    "South Korea",
                    "Iran",
                    "Australia",
                    "Italy",
                    "Germany",
                    "Spain",
                    "United Kingdom",
                    "France",
                    "Turkey",
                    "Japan",
                    "Canada",
                    "Russia",
                    "Saudi Arabia",
                    "United States",
                    "South Africa",
                    "Brazil",
                    "Colombia",
                    "India",
                    "World",
                ],
                ["None"],
            ],
        },
        3: {
            "Scenario 1": [
                "Cases",
                "Per million",
                "Log",
                "Total",
                "Cases",
                "Per million",
                "Log",
                "Daily change",
                False,
                True,
                7,
                [
                    "United States",
                    "Brazil",
                    "France",
                    "Germany",
                    "Russia",
                    "China",
                    "South Korea",
                    "Japan",
                    "Hong Kong",
                    "United Kingdom",
                    "Spain",
                    "Italy",
                ],
                ["None"],
            ],
            "Scenario 2": [
                "Tests",
                "Per million",
                "Linear",
                "Total",
                "Cases",
                "Per million",
                "Linear",
                "Total",
                False,
                True,
                7,
                [
                    "World",
                    "United States",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Africa",
                    "Iran",
                    "Spain",
                ],
                ["None"],
            ],
            "Scenario 3": [
                "Cases",
                "Per million",
                "Linear",
                "Total",
                "Deaths",
                "Per million",
                "Linear",
                "Total",
                False,
                False,
                None,
                [
                    "World",
                    "United States",
                    "Japan",
                    "Germany",
                    "India",
                    "United Kingdom",
                    "France",
                    "Italy",
                    "Brazil",
                    "Canada",
                    "Russia",
                    "South Africa",
                    "Iran",
                    "Spain",
                ],
                ["None"],
            ],
            "Scenario 4": [
                "Cases",
                "Per million",
                "Linear",
                "Total",
                "Deaths",
                "Per million",
                "Linear",
                "Total",
                True,
                False,
                None,
                [
                    "Austria",
                    "Azerbaijan",
                    "Belarus",
                    "Belgium",
                    "Bosnia",
                    "Bulgaria",
                    "Croatia",
                    "Cyprus",
                    "Czechia",
                    "Denmark",
                    "Estonia",
                    "Finland",
                    "France",
                    "Georgia",
                    "Germany",
                    "Greece",
                    "Hungary",
                    "Iceland",
                    "Ireland",
                    "Italy",
                    "Kosovo",
                    "Latvia",
                    "Lithuania",
                    "Macedonia",
                    "Malta",
                    "Moldova",
                    "Monaco",
                    "Montenegro",
                    "Netherlands",
                    "Norway",
                    "Poland",
                    "Portugal",
                    "Romania",
                    "Russia",
                    "Serbia",
                    "Slovakia",
                    "Slovenia",
                    "Spain",
                    "Sweden",
                    "Switzerland",
                    "Turkey",
                    "Ukraine",
                    "United Kingdom",
                ],
                ["None"],
            ],
        },
    }
    return


In [None]:
def skip_click(*args):
    global STDT, ENDT, countries_to_codes, codes_to_countries, states_to_codes, dict_datasets, progress_bar, current_action
    progress_bar = FloatProgress(
        value=0.1,
        min=0,
        max=1.0,
        step=0.01,
        description="Loading data:",
        orientation="horizontal",
    )
    current_action = HTML("Collecting country level Cases/Deaths/Recovered data...")
    outer_tab.children = [
        VBox(
            [HTML("ONE MOMENT PLEASE..."), progress_bar, current_action],
            layout=Layout(width="auto"),
        )
    ]
    outer_tab._titles = {0: "Downloading..."}
    (
        datasets_World,
        STDT,
        ENDT,
        countries_to_codes,
        codes_to_countries,
    ) = collect_World_data(BASE_URL, FOLDER_WORLD, URL_TESTS)
    datasets_US_States, states_to_codes, datasets_US_counties = collect_US_data(
        BASE_URL_US, FOLDER_US, URL_TESTS_US
    )
    dict_datasets = {0: datasets_World, 1: datasets_US_States, 2: datasets_US_counties}
    initialize_dashboard(dict_datasets)
    progress_bar = 0.98
    current_action.value = "Create Dashboard..."
    tab_contents = [
        "Infection Maps",
        "Rebased Graph",
        "Heat Map",
        "Custom Graph",
        "User Guide",
    ]
    hp.skip_button.layout.visibility = "hidden"
    help_page.layout.grid_template_rows = "0px 70px 90px auto 50px"
    children = [
        grid_1,
        grid_2,
        grid_3,
        grid_4,
        help_page,
    ]
    outer_tab._titles = dict(zip(np.arange(0, 5), tab_contents))
    outer_tab.children = children
    outer_tab.selected_index = 0


hp.skip_button.on_click(skip_click)


In [None]:
visible_template = {
    1: ["155px auto", "60% auto"],
    2: ["315px auto", "auto 590px"],
    3: ["140px auto", "auto 590px"],
    4: ["360px auto", "auto 690px"],
}

hidden_template = {
    1: ["0px auto", "60% auto"],
    2: ["0px auto", "auto 0px"],
    3: ["0px auto", "auto 0px"],
    4: ["0px auto", "auto 0px"],
}
grids = dict(zip(np.arange(1, 5), [grid_1, grid_2, grid_3, grid_4]))


def hide_buttons(*args):
    new_value = args[0]["new"]
    if new_value:
        hide_button.button_style = "danger"
        hide_button.description = "Show Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = hidden_template[i][0]
            grids[i].layout.grid_template_columns = hidden_template[i][1]
    else:
        hide_button.button_style = ""
        hide_button.description = "Hide Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = visible_template[i][0]
            grids[i].layout.grid_template_columns = visible_template[i][1]
    return


In [None]:
hide_button = ToggleButton(
    description="Hide Controls",
    value=False,
    button_style="",
    layout=Layout(
        min_width="150px", width="150px", max_width="150px", min_height="30px",
        visibility='hidden',
    ),
)
hide_button.observe(hide_buttons, "value")


In [None]:
download_button = Button(
    description="Download Figure",
    button_style="",
    layout=Layout(
        min_width="150px",
        width="150px",
        max_width="150px",
        min_height="30px",
        visibility="hidden",
    ),
)


In [None]:
def download_fig(*args):
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        main_graph.save_png(main_graph.title.replace('.', ''))
    elif selected_index == 1:
        rebased_graph.save_png("Rebased_graph_" + rebased_graph.title.replace('.', ''))
    elif selected_index == 2:
        dna_figure.save_png("Heat_map_" + dna_figure.title.replace('.', ''))
    elif selected_index == 3:
        free_scatter_fig.save_png(free_scatter_fig.title.replace('.', ''))
    return


download_button.on_click(download_fig)


In [None]:
def download_data():
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        data_csv = pd.DataFrame(
            data=main_graph.marks[0].y,
            index=main_graph.marks[0].labels,
            columns=main_graph.marks[0].x,
        )
        title = (
            main_graph.title.replace(".", "").replace("/", "Per").replace("Log ", "")
            + "_data.csv"
        )
    elif selected_index == 1:
        if rebased_graph.marks[1].text.shape[0] == 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[0])
            data = rebased_graph.marks[0].y.reshape(1, -1)
        elif rebased_graph.marks[1].text.shape[0] > 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[1])
            data = rebased_graph.marks[0].y
        if calendar_time.value:
            data_csv = pd.DataFrame(
                data=data,
                index=rebased_graph.marks[1].text,
                columns=rebased_graph.marks[0].x,
            )
        else:
            data_csv = pd.DataFrame(
                data=data, index=rebased_graph.marks[1].text, columns=cols
            )
        title = (
            "Rebased_graph_"
            + rebased_graph.title.replace(".", "")
            .replace("/", "Per")
            .replace("Log ", "")
            + "_data.csv"
        )
    elif selected_index == 2:
        data_csv = pd.DataFrame(
            data=dna_figure.marks[0].color,
            index=dna_figure.marks[0].row,
            columns=dna_figure.marks[0].column,
        ).iloc[::-1]
        title = (
            "Heat_map_"
            + dna_figure.title.replace(".", "").replace("/", "Per")
            + "_data.csv"
        )
    elif selected_index == 3:
        arr = []
        ind0 = []
        ind1 = []
        Y_label = free_scatter_fig.title.split("vs")[0].strip().replace("Log ", "")
        X_label = (
            free_scatter_fig.title.split("vs")[1]
            .split(",")[0]
            .strip()
            .replace("Log ", "")
        )
        if tab_4_ma_ch.value:
            Y_label += " (" + X_label.split("(")[1]
        for i in range(free_scatter_fig.marks[1].text.shape[0]):
            arr.append(free_scatter_fig.marks[0].x[i])
            arr.append(free_scatter_fig.marks[0].y[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind1.append(X_label)
            ind1.append(Y_label)
        ind = [np.array(ind0), np.array(ind1)]
        if free_checkbox.value:
            col = [ENDT]
        else:
            col = np.arange(
                np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D")
            )
        data_csv = pd.DataFrame(data=arr, index=ind, columns=col)
        title = (
            free_scatter_fig.title.replace(".", "")
            .replace("/", "Per")
            .replace("Log ", "")
            + "_data.csv"
        )
    return data_csv, title


download_link = '<a style="background-color: #424242; border: none; color: white; padding: 1px 22px; visibility: {visibility}; text-align: center; text-decoration: none; display: inline-block; font-size: 13px; margin: 0px; cursor: pointer" download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">Download Data</a>'
download_data_button = HTML(
    download_link.format(visibility="hidden", payload="", filename="")
)


def update_link_download():
    try:
        df, filename = download_data()
        csv = df.to_csv()
        b64 = base64.b64encode(csv.encode())
        payload = b64.decode()
        download_data_button.value = download_link.format(
            visibility="visible", payload=payload, filename=filename
        )
    except:
        pass
    return

def hide_show_top_buttons(val):
    sel_index = val['new']
    if sel_index == 4:
        download_data_button.value = download_link.format(
            visibility="hidden", payload="", filename=""
        )
        download_button.layout.visibility = 'hidden'
        hide_button.layout.visibility = 'hidden'
    else:
        download_button.layout.visibility = 'visible'
        hide_button.layout.visibility = 'visible'
    return

outer_tab.observe(lambda x: update_link_download(), "selected_index")
outer_tab.observe(hide_show_top_buttons, "selected_index")

In [None]:
zoom_slider = IntSlider(
    value=740,
    min=600,
    max=1500,
    step=10,
    description="Dashboard's height (in pixels)",
    style={"description_width": "initial"},
    layout=Layout(width="auto", min_width="300px", max_width="500px"),
)


def update_height(*args):
    outer_tab.layout.height = str(args[0]["new"]) + "px"
    return


zoom_slider.continuous_update = True
zoom_slider.observe(update_height, "value")


In [None]:
grid_header = GridspecLayout(1, 3)
grid_header.layout.width = "100%"
grid_header.layout.height = "auto"
grid_header.layout.margin = "0px 0px 0px 0px"
header_controls = VBox(
    [zoom_slider, HBox([hide_button, download_button, download_data_button])],
    layout=Layout(
        align_items="stretch", justify_content="flex-end", margin="0px 0px 0px 0px"
    ),
)
grid_header[0, 2] = header_controls
grid_header[0, 1] = HTML(
    value="<h1 style='color: #ff8b0e; padding:0px; margin:0px; text-align:center; justify-content:center; font-weight: bold; font-size: 38px'>Bloomberg Covid Dashboard</h1><h3 style='color: #ff8b0e; padding:0px; margin:4px 0px 0px 0px; text-align:center; font-size: 18px'>Bloomberg Quant Research</h3>",
    layout=Layout(width="540px"),
)
grid_header.layout.justify_items = "center"
grid_header.layout.grid_template_columns = "auto auto 450px"
grid_header.layout.overflow = "auto"


In [None]:
data_source = """<html><body><p style='color: #ff8b0e; padding:0; margin:0'> Data Sources : - JHU CSSE COVID-19 Data, licensed by Johns Hopkins University under the <a href='https://creativecommons.org/licenses/by/4.0/' target="_blank" style='color:#1E90FF'>Creative Commons Attribution 4.0 International CC BY 4.0</a>. <a href='https://github.com/CSSEGISandData/COVID-19' target="_blank" style='color:#1E90FF'> https://github.com/CSSEGISandData/COVID-19 </a> 
             </br>- COVID Tracking Data, licensed by Covid Tracking Project under the Apache License 2.0 <a href='https://github.com/COVID19Tracking/covid-tracking-data' target="_blank" style='color:#1E90FF'> https://github.com/COVID19Tracking/covid-tracking-data </a>
             </br>- Covid-19-Data/Testing, licensed by OWID under the <a href='https://creativecommons.org/licenses/by/4.0/' target="_blank" style='color:#1E90FF'>Creative Commons Attribution 4.0 International CC BY 4.0</a>. <a href='https://github.com/owid/covid-19-data' target="_blank" style='color:#1E90FF'> https://github.com/owid/covid-19-data </a></p></body></html>"""
grid_footer = GridspecLayout(1, 1, layout=Layout(width="100%", align_items="center"))
grid_footer[0, 0] = HBox(
    [
        HTML(
            data_source,
            layout=Layout(
                width="100%",
                height="30px",
                max_height="30px",
                min_height="30px",
                overflow="auto",
            ),
        )
    ],
    layout=Layout(width="auto", height="auto", overflow="auto"),
)
Dashboard = VBox(
    [grid_header, outer_tab, grid_footer,], layout=Layout(align_items="stretch")
)


In [None]:
Dashboard