In [None]:
# Copyright 2020 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import time
import math
from ipywidgets import *
from bqplot import *
import datetime
from data_processor import DataProcessor
import bisect
from ipywidgets import Image as im
from toggle_buttons import Toggle_Buttons
import itertools
from collections import Counter
import base64

# import IPython.core.display as ipd
# ipd.display(ipd.HTML("<style>.container { width:100% !important; }</style>"))


In [None]:
# Terms of use

agree = Button(
    description="I agree",
    style={"description_width": "initial", "font_weight": "bold"},
    button_style="success",
    layout=Layout(width="100%", height="100%"),
)


In [None]:
grid_5 = GridspecLayout(10, 10)
grid_5.layout.height = "100%"
ToS = """
<h1 style="text-align: center; color: #ff8b0e">Terms of Use of JHU's data:</h1>
<h5 style="text-align: center"><a style='color: #1e90ff' href="https://github.com/CSSEGISandData/COVID-19" target="_blank"> https://github.com/CSSEGISandData/COVID-19 </a></h5>
<ol style="text-align: justify; color: #ff8b0e">
    <li>This website and its contents herein, including all data, mapping, and analysis (Website),
    copyright 2020 Johns Hopkins University, all rights reserved, is provided solely for non-profit public health,
    educational, and academic research purposes. You should not rely on this Website for medical advice or guidance.</li>
    <li>Use of the Website by commercial parties and/or in commerce is strictly prohibited. 
    Redistribution of the Website or the aggregated data set underlying the Website is strictly prohibited.</li>
    <li>When linking to the website, attribute the Website as the COVID-19 Dashboard by the Center
    for Systems Science and Engineering (CSSE) at Johns Hopkins University, or the COVID-19 Data
    Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University.</li>
    <li>The Website relies upon publicly available data from multiple sources that do not always
    agree. The Johns Hopkins University hereby disclaims any and all representations and warranties
    with respect to the Website, including accuracy, fitness for use, reliability, completeness,
    and non-infringement of third party rights.</li>
    <li>Any use of the Johns Hopkins' names, logos, trademarks, and/or trade dress in a factually
    inaccurate manner or for marketing, promotional or commercial purposes is strictly prohibited.</li>
    <li>These terms and conditions are subject to change. Your use of the Website constitutes your
    acceptance of these terms and conditions and any future modifications thereof.</li>
</ol>
<h5 style='font-weight: bold; color: #ff8b0e'>Bloomberg LP shall not be liable in any way if data use doesn't follow the Terms of Use set by JHU.</h5>
"""


In [None]:
FOLDER_IMG = "../screenshots/"
image_names = ["World_map", "Heatmap", "Rebased_graph", "Custom_graph"]
images = [open(FOLDER_IMG + im + "_zoom.PNG", "rb").read() for im in image_names]
grid_5[:5, :2] = Box(
    children=[
        im(
            value=images[0],
            format="png",
            layout=Layout(object_fit="scale-down"),
            margin="0 10 0 10",
        )
    ],
    layout=Layout(width="auto", height="auto", overflow_x="hidden"),
)
grid_5[:5, 8:] = Box(
    children=[
        im(
            value=images[2],
            format="png",
            layout=Layout(object_fit="scale-down"),
            margin="0 10 0 10",
        )
    ],
    layout=Layout(width="auto", height="auto", overflow_x="hidden"),
)
grid_5[5:, :2] = Box(
    children=[
        im(
            value=images[1],
            format="png",
            layout=Layout(object_fit="scale-down"),
            margin="0 10 0 10",
        )
    ],
    layout=Layout(width="auto", height="auto", overflow_x="hidden"),
)
grid_5[5:, 8:] = Box(
    children=[
        im(
            value=images[3],
            format="png",
            layout=Layout(object_fit="scale-down"),
            margin="0 10 0 10",
        )
    ],
    layout=Layout(width="auto", height="auto", overflow_x="hidden"),
)
grid_5[1:7, 2:8] = HTML(
    ToS,
    layout=Layout(
        width="auto",
        height="auto",
        overflow_y="scroll",
        margin="0px 40px 0px 40px",
        border="solid #ff8b0e",
    ),
)
grid_5[8:9, 4:6] = agree


In [None]:
children = [grid_5]
outer_tab = Tab(_titles={0: "Terms of Use"})
outer_tab.children = children
outer_tab.layout.height = "780px"


In [None]:
# Collect World Data


BASE_URL = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
FOLDER_WORLD = "../data/WORLD/"
URL_TESTS = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/testing/covid-testing-all-observations.csv"
dateformat = "%Y-%m-%d"


def collect_World_data(BASE_URL, FOLDER_WORLD, URL_TESTS):
    url_cases = BASE_URL + "time_series_covid19_confirmed_global.csv"
    url_deaths = BASE_URL + "time_series_covid19_deaths_global.csv"
    url_rec = BASE_URL + "time_series_covid19_recovered_global.csv"
    dataframes = {
        "Cases": pd.read_csv(url_cases, header=0),
        "Deaths": pd.read_csv(url_deaths, header=0),
        "Recovered": pd.read_csv(url_rec, header=0),
    }
    for key, df in dataframes.items():
        df.loc[df["Province/State"] == "Hong Kong", "Country/Region"] = "Hong Kong"
        df.set_index("Country/Region", inplace=True)
    pops = pd.read_csv(FOLDER_WORLD + "Population.csv", header=0).dropna()
    pops.set_index("Country Code", inplace=True)
    country_codes = pd.read_csv(
        FOLDER_WORLD + "world_map_codes.csv", header=0, index_col=0
    )
    countries_to_codes = country_codes["ISOA3"].to_dict()
    ID_to_codes = dict(zip(country_codes.ISON3, country_codes.ISOA3))

    new_names = {
        "Antigua and Barbuda": "Antigua & Barbuda",
        "Bosnia and Herzegovina": "Bosnia",
        "Cabo Verde": "Cape Verde",
        "Congo (Brazzaville)": "Congo - Brazzaville",
        "Congo (Kinshasa)": "Congo - Kinshasa",
        "Eswatini": "Swaziland",
        "Holy See": "Vatican City",
        "Korea, South": "South Korea",
        "North Macedonia": "Macedonia",
        "Saint Lucia": "St. Lucia",
        "Saint Vincent and the Grenadines": "St. Vincent & Grenadines",
        "Trinidad and Tobago": "Trinidad & Tobago",
        "US": "United States",
        "Saint Kitts and Nevis": "St. Kitts & Nevis",
        "Burma": "Myanmar",
        "Taiwan*": "Taiwan",
    }
    # normalize country names
    for df in dataframes.values():
        df.rename(index=new_names, inplace=True)

    # assign dummy codes to a few entities
    countries_to_codes["World"] = "WLD"
    countries_to_codes["Diamond Princess"] = "DPS"
    countries_to_codes["Taiwan"] = "TWN"
    countries_to_codes["West Bank and Gaza"] = "PSE"
    countries_to_codes["Kosovo"] = "XKX"
    countries_to_codes["MS Zaandam"] = "MSZ"
    ID_to_codes[158] = "TWN"
    ID_to_codes[-2] = "XKX"
    codes_to_countries = {v: k for k, v in countries_to_codes.items()}
    codes_to_ID = {v: k for k, v in ID_to_codes.items()}

    for df in dataframes.values():
        df["code"] = [
            countries_to_codes[cod] if cod in countries_to_codes.keys() else None
            for cod in df.index.values
        ]

    old_dateformat = "%m/%d/%y"
    new_index = [
        datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
        for d in dataframes["Cases"].columns.values[3:-1]
    ]
    # set index and dateformat
    for k, df in dataframes.items():
        df = df.iloc[:, 3:].groupby(["code"]).sum().transpose().reset_index()
        df["Date"] = new_index
        dataframes[k] = df.set_index("Date", drop=True).drop("index", axis=1)

    progress_bar.value = 0.3
    current_action.value = "Collecting country level Test data..."
    df_world_tests = pd.read_csv(
        URL_TESTS,
        index_col=0,
        header=0,
        usecols=["ISO code", "Date", "Cumulative total"],
    )
    df_world_tests = (
        pd.pivot_table(df_world_tests, index="Date", columns="ISO code")
        .loc[:, "Cumulative total"]
        .fillna(method="ffill")
        .replace({math.nan: 0})
    )
    for c in dataframes["Cases"].columns.values:
        if c not in df_world_tests.columns.values:
            df_world_tests[c] = 0
    STDT = dataframes["Cases"].index[0]
    ENDT = dataframes["Cases"].index[-1]
    df_world_tests = df_world_tests.loc[STDT:]

    if df_world_tests.shape[0] < len(new_index):
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_world_tests.index.values[-1]) + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_world_tests.columns.values.shape[0]
        df_world_tests = df_world_tests.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_world_tests.columns.values,
            )
        )
    elif df_world_tests.shape[0] > len(new_index):
        df_world_tests = df_world_tests.loc[:ENDT]

    dataframes["Tests"] = df_world_tests.loc[:, dataframes["Cases"].columns.values]

    for c in list(dataframes.keys()):
        dataframes[c] = dataframes[c].cummax()

    dataframes["Active Cases"] = (
        dataframes["Cases"]
        - dataframes["Deaths"].fillna(0)
        - dataframes["Recovered"].fillna(0)
    )
    datasets_World = DataProcessor(
        dataframes, pops, "World", codes_to_ID, add_world_data=True
    )  # World data processor
    progress_bar.value = 0.5
    current_action.value = (
        "Collecting US state level Cases/Deaths/Recovered/Test data..."
    )
    return datasets_World, STDT, ENDT, countries_to_codes, codes_to_countries


In [None]:
states_to_codes = {
    "Alabama": "AL",
    "Alaska": "AK",
    "American Samoa": "AS",
    "Arizona": "AZ",
    "Arkansas": "AR",
    "California": "CA",
    "Colorado": "CO",
    "Connecticut": "CT",
    "Delaware": "DE",
    "District of Columbia": "DC",
    "Florida": "FL",
    "Georgia": "GA",
    "Guam": "GU",
    "Hawaii": "HI",
    "Idaho": "ID",
    "Illinois": "IL",
    "Indiana": "IN",
    "Iowa": "IA",
    "Kansas": "KS",
    "Kentucky": "KY",
    "Louisiana": "LA",
    "Maine": "ME",
    "Maryland": "MD",
    "Massachusetts": "MA",
    "Michigan": "MI",
    "Minnesota": "MN",
    "Mississippi": "MS",
    "Missouri": "MO",
    "Montana": "MT",
    "Nebraska": "NE",
    "Nevada": "NV",
    "New Hampshire": "NH",
    "New Jersey": "NJ",
    "New Mexico": "NM",
    "New York": "NY",
    "North Carolina": "NC",
    "North Dakota": "ND",
    "Northern Mariana Islands": "MP",
    "Ohio": "OH",
    "Oklahoma": "OK",
    "Oregon": "OR",
    "Pennsylvania": "PA",
    "Puerto Rico": "PR",
    "Rhode Island": "RI",
    "South Carolina": "SC",
    "South Dakota": "SD",
    "Tennessee": "TN",
    "Texas": "TX",
    "Utah": "UT",
    "Vermont": "VT",
    "Virgin Islands": "VI",
    "Virginia": "VA",
    "Washington": "WA",
    "West Virginia": "WV",
    "Wisconsin": "WI",
    "Wyoming": "WY",
}

codes_to_states = {v: k for k, v in states_to_codes.items()}


In [None]:
FOLDER_US = "../data/USA/"
BASE_URL_US = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"
URL_TESTS_US = "https://raw.githubusercontent.com/COVID19Tracking/covid-tracking-data/master/data/states_daily_4pm_et.csv"


def collect_US_data(BASE_URL, FOLDER_US, URL_TESTS_US):
    url_cases_US = BASE_URL + "time_series_covid19_confirmed_US.csv"
    url_deaths_US = BASE_URL + "time_series_covid19_deaths_US.csv"

    df_cases_US = pd.read_csv(url_cases_US, header=0)
    df_deaths_US = pd.read_csv(url_deaths_US, header=0)
    old_dateformat = "%m/%d/%y"
    new_index = [
        datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
        for d in df_cases_US.columns.values[11:]
    ]
    df_cases_US_States = (
        df_cases_US.groupby(["Province_State"])
        .sum()
        .transpose()
        .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
        .reset_index()
    )
    df_cases_US_States["Date"] = new_index
    df_cases_US_States = df_cases_US_States.set_index("Date").drop("index", axis=1)

    df_deaths_US_States = (
        df_deaths_US.groupby(["Province_State"])
        .sum()
        .transpose()
        .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
    )
    pops_US_States = (
        df_deaths_US_States.loc["Population"]
        .to_frame(name="2018")
        .replace({0: math.nan})
        .dropna()
    )
    df_deaths_US_States.drop(["Population"], inplace=True)
    df_deaths_US_States["Date"] = new_index
    df_deaths_US_States = (
        df_deaths_US_States.reset_index().set_index("Date").drop("index", axis=1)
    )
    df_rec_tests_US_States = pd.read_csv(
        URL_TESTS_US,
        index_col=[0, 1],
        header=0,
        usecols=["date", "state", "totalTestResults", "recovered"],
    )
    df_rec_US_States = (
        pd.pivot_table(df_rec_tests_US_States, index="date", columns="state")
        .loc[:, "recovered"]
        .rename(columns=codes_to_states)
        .replace({math.nan: 0})
    )
    if df_rec_US_States.shape[0] > len(new_index):
        df_rec_US_States = df_rec_US_States[: len(new_index)]

    df_rec_US_States["Date"] = new_index[: df_rec_US_States.shape[0]]

    df_rec_US_States.set_index("Date", inplace=True)
    all_states = df_cases_US_States.columns.values
    for c in all_states:
        if c not in df_rec_US_States.columns.values:
            df_rec_US_States[c] = 0

    df_tests_US_States = (
        pd.pivot_table(df_rec_tests_US_States, index="date", columns="state")
        .loc[:, "totalTestResults"]
        .rename(columns=codes_to_states)
        .replace({math.nan: 0})
    )
    if df_tests_US_States.shape[0] > len(new_index):
        df_tests_US_States = df_tests_US_States[: len(new_index)]
    df_tests_US_States["Date"] = new_index[: df_rec_US_States.shape[0]]
    df_tests_US_States.set_index("Date", inplace=True)
    for c in all_states:
        if c not in df_tests_US_States.columns.values:
            df_tests_US_States[c] = 0

    if df_tests_US_States.index.values[-1] != ENDT:
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_tests_US_States.index.values[-1])
                + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_tests_US_States.columns.values.shape[0]
        df_tests_US_States = df_tests_US_States.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_tests_US_States.columns.values,
            )
        )

    if df_rec_US_States.index.values[-1] != ENDT:
        missing_indexes = [
            str(d)
            for d in np.arange(
                np.datetime64(df_rec_US_States.index.values[-1])
                + np.timedelta64(1, "D"),
                np.datetime64(ENDT) + np.timedelta64(1, "D"),
            )
        ]
        n_rows = len(missing_indexes)
        n_cols = df_rec_US_States.columns.values.shape[0]
        df_rec_US_States = df_rec_US_States.append(
            pd.DataFrame(
                [[0] * n_cols] * n_rows,
                index=missing_indexes,
                columns=df_rec_US_States.columns.values,
            )
        )

    dict_df_US_States = {
        "Cases": df_cases_US_States,
        "Deaths": df_deaths_US_States,
        "Recovered": df_rec_US_States,
        "Tests": df_tests_US_States,
    }
    for c in list(dict_df_US_States.keys()):
        dict_df_US_States[c] = dict_df_US_States[c].cummax()

    dict_df_US_States["Active Cases"] = (
        dict_df_US_States["Cases"]
        - dict_df_US_States["Deaths"].fillna(0)
        - dict_df_US_States["Recovered"].fillna(0)
    )
    states_to_codes = pd.read_csv(
        FOLDER_US + "USStatesMap_codes.csv", index_col=0, header=0
    ).to_dict()["ID"]
    datasets_US_States = DataProcessor(
        dict_df_US_States, pops_US_States, "US States", states_to_codes
    )
    progress_bar.value = 0.7
    current_action.value = "Collecting US county level Cases/Deaths data..."
    df_cases_US_counties = df_cases_US.drop(
        [
            "UID",
            "iso2",
            "iso3",
            "code3",
            "Country_Region",
            "Lat",
            "Long_",
            "Combined_Key",
        ],
        axis=1,
    )
    df_cases_US_counties.dropna(inplace=True)
    df_cases_US_counties.drop(
        df_cases_US_counties[df_cases_US_counties["Admin2"] == "Unassigned"].index,
        inplace=True,
    )
    df_cases_US_counties.rename(columns={"Admin2": "County"}, inplace=True)

    df_cases_US_counties["County_name"] = (
        df_cases_US_counties["County"] + "_" + df_cases_US_counties["Province_State"]
    )

    df_cases_US_counties.set_index("County_name", inplace=True)
    counties_to_ID = df_cases_US_counties["FIPS"].astype(int).to_dict()
    df_cases_US_counties.drop(
        ["Province_State", "FIPS", "County"], axis=1, inplace=True
    )
    df_cases_US_counties = df_cases_US_counties.transpose()

    df_cases_US_counties["Date"] = new_index
    df_cases_US_counties.set_index("Date", drop=True, inplace=True)

    df_deaths_US_counties = df_deaths_US.drop(
        [
            "UID",
            "iso2",
            "iso3",
            "code3",
            "Country_Region",
            "Lat",
            "Long_",
            "Combined_Key",
        ],
        axis=1,
    )
    df_deaths_US_counties.dropna(inplace=True)
    df_deaths_US_counties.drop(
        df_deaths_US_counties[df_deaths_US_counties["Admin2"] == "Unassigned"].index,
        inplace=True,
    )
    df_deaths_US_counties.rename(columns={"Admin2": "County"}, inplace=True)

    df_deaths_US_counties["County_name"] = (
        df_deaths_US_counties["County"] + "_" + df_deaths_US_counties["Province_State"]
    )

    df_deaths_US_counties.set_index("County_name", inplace=True)
    pops_US_counties = (
        df_deaths_US_counties.loc[:, "Population"]
        .to_frame(name="2018")
        .replace({0: math.nan})
        .dropna()
        .replace({0: math.nan})
    )
    df_deaths_US_counties.drop(
        ["Province_State", "FIPS", "County", "Population"], axis=1, inplace=True
    )
    df_deaths_US_counties = df_deaths_US_counties.transpose()

    df_deaths_US_counties["Date"] = new_index
    df_deaths_US_counties.set_index("Date", drop=True, inplace=True)

    df_rec_US_counties = pd.DataFrame(
        np.zeros(df_cases_US_counties.shape),
        index=df_cases_US_counties.index.values,
        columns=df_cases_US_counties.columns.values,
    )

    df_tests_US_counties = pd.DataFrame(
        np.zeros(df_cases_US_counties.shape),
        index=df_cases_US_counties.index.values,
        columns=df_cases_US_counties.columns.values,
    )

    dict_df_US_counties = {
        "Cases": df_cases_US_counties,
        "Deaths": df_deaths_US_counties,
        "Recovered": df_rec_US_counties,
        "Active Cases": df_cases_US_counties
        - df_deaths_US_counties.fillna(0)
        - df_rec_US_counties.fillna(0),
        "Tests": df_tests_US_counties,
    }

    for c in list(dict_df_US_counties.keys()):
        dict_df_US_counties[c] = dict_df_US_counties[c].cummax()

    datasets_US_counties = DataProcessor(
        dict_df_US_counties, pops_US_counties, "US Counties", counties_to_ID
    )

    return datasets_US_States, states_to_codes, datasets_US_counties


In [None]:
# Buttons and interactions


min_button_width_data = "100px"
min_button_width_norm = "85px"
min_button_width_type = "100px"
min_button_width_scale = "60px"
min_description_width_1 = "40px"

# Toggle buttons to choose data to color the map
data_buttons = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
)

# Toggle buttons to choose normalization of data
norm_buttons = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
)

# Toggle buttons to choose scale of plots and map
scale_buttons = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Linear",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
)

# Toggle buttons to choose between cumulative data or daily change
type_buttons = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
)

tab_1_ma_ch = Checkbox(
    description="Moving Average",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="auto"),
)
tab_1_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_1 = VBox(
    [tab_1_ma_ch, tab_1_ma_w],
    layout=Layout(
        min_width="205px",
        width="205px",
        max_width="205px",
        height="76px",
        min_height="76px",
        max_height="76px",
        overflow="hide",
    ),
)

cat_tab_1_buttons = HBox(
    [
        VBox(
            [data_buttons, type_buttons],
            layout=Layout(
                width="565px",
                height="76px",
                min_width="565px",
                max_width="565px",
                min_height="76px",
                max_height="76px",
                overflow="auto",
            ),
        ),
        VBox(
            [norm_buttons, scale_buttons],
            layout=Layout(
                width="220px",
                height="76px",
                min_width="220px",
                max_width="220px",
                min_height="76px",
                max_height="76px",
                overflow="auto",
            ),
        ),
        ma_box_tab_1,
    ],
    layout=Layout(
        width="1000px",
        height="76px",
        min_width="1000px",
        max_width="1000px",
        min_height="76px",
        max_height="76px",
        overflow="auto",
        margin="0 0 0 0",
    ),
)


In [None]:
map_buttons = Toggle_Buttons(
    options=["World", "US States", "US Counties"],
    value="World",
    description="Maps",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
    style="warning",
)
box_map = HBox(
    [map_buttons],
    layout=Layout(
        width="565px",
        height="36px",
        min_width="565px",
        max_width="565px",
        min_height="36px",
        max_height="36px",
        overflow="auto",
    ),
)


In [None]:
# Main graph

# If white theme uncomment this variable

# css_style = """
# <head>
#     <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
#     <style type="text/css" media="screen">
#         #stats_table {
#           font-family: "Trebuchet MS", Arial, Helvetica, sans-serif;
#           font-size: 12px;
#           border-collapse: collapse;
#           border-spacing: 0;
#           width: auto;
#           text-align: left;
#         }
#         #stats_table td, #stats_table th {
#           border: 1px solid #ddd;
#           padding: 0;
#         }
#         #stats_table tbody td {
#           font-size: 12px;
#           font-weight: bold;
#         }
#         #stats_table tr:nth-child(even) {
#           background-color: #f2f2f2;
#         }
#         #stats_table th {
#           padding-top: 0;
#           padding-bottom: 0;
#           text-align: center;
#           background-color: #003366;
#           color: white;
#         }
#     </style>
# </head>
# <body>
# """

css_style = ""

line_tooltip_table = """
<table id="stats_table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr style="color:#1E90FF">
    <td>{1:s}</td>
    <td>{2:}</td>
</tr>
<tr style="color:#D62728">
    <td>{3:s}</td>
    <td>{4:}</td>
</tr>
<tr style="color:#2CA02C">
    <td>{5:s}</td>
    <td>{6:}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>{7:s}</td>
    <td>{8:}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>{9:s}</td>
    <td>{10:}</td>
</tr>
</table>
"""

stats_line_tooltip_table_values = HTML()  # table showed when you hover on Lines markers


In [None]:
scale_t = DateScale(dateformat=dateformat,)
axis_t = Axis(scale=scale_t, grid_lines="none", tick_format="%m/%d", num_ticks=10)
scale_y = LinearScale()
axis_y = Axis(
    scale=scale_y, orientation="vertical", grid_lines="solid", tick_format=","
)
main_graph_ttl = "COVID-19 {} in {}"

main_graph = Figure(
    animation_duration=1000,
    legend_location="top-left",
    axes=[axis_t, axis_y],
    title=main_graph_ttl,
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 20},
    layout=Layout(width="auto"),
    legend_style={"stroke-width": 0},
)


In [None]:
curve_colors = {
    "Cases": "dodgerblue",
    "Deaths": "red",
    "Recovered": "green",
    "Active Cases": "orange",
    "Tests": "gray",
}
selected = ["Cases", "Deaths", "Recovered", "Active Cases", "Tests"]


def update_main_graph(
    datasets,
    country_name,
    normalization,
    scale,
    data_type,
    ma,
    n,
    date1,
    date2,
    axis_t,
    axis_y,
    figure,
    init=False,
):
    if init:
        curves_subset = [i for i in range(len(selected) - 1)]
    else:
        curves_subset = main_graph.marks[0].curves_subset
    line_col = [curve_colors[c] for c in selected]
    data = [
        datasets.get_ts_plot(
            country_name, data_name, normalization, scale, data_type, ma, n, STDT, ENDT,
        )
        for data_name in selected
    ]
    scale_t = DateScale(
        dateformat=dateformat,
        min=datetime.datetime.strptime(date1, dateformat),
        max=datetime.datetime.strptime(date2, dateformat),
    )
    axis_t.scale = scale_t

    if scale == "Log":
        scale_y = LogScale()
    elif scale == "Linear":
        scale_y = LinearScale()

    main_mark = Lines(
        scales={"x": scale_t, "y": scale_y},
        display_legend=True,
        labels=selected,
        colors=line_col,
        marker="circle",
        marker_size=30,
        tooltip=stats_line_tooltip_table_values,
        interactions={"hover": "tooltip"},
        curves_subset=curves_subset,
        opacities=[1] * len(selected),
    )
    main_mark.y = data
    axis_y.scale = scale_y
    figure.axes = [axis_t, axis_y]
    main_mark.x = list(
        np.arange(np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D"),)
    )
    ind1 = np.where(main_mark.x == np.datetime64(date1))[0][0]
    ind2 = np.where(main_mark.x == np.datetime64(date2))[0][0]
    main_mark.scales["y"].max = 1.1 * float(
        np.nanmax(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.scales["y"].min = 0.9 * float(
        np.nanmin(np.array(main_mark.y)[curves_subset, ind1 : ind2 + 1])
    )
    main_mark.on_hover(line_hover)
    main_mark.on_legend_hover(legend_hover)
    main_mark.on_legend_click(legend_click_line)
    main_mark.on_background_click(line_background_click)
    figure.marks = [main_mark]

    def graph_title(norm, scale, data_type, ma, n):
        res = ""
        if scale == "Log":
            res += "Log"
        if norm == "Per million":
            res += "/1M population"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res

    val_name = graph_title(normalization, scale, data_type, ma, n)
    if datasets.name == "World":
        figure.title = main_graph_ttl.format(val_name, codes_to_countries[country_name])

    elif datasets.name == "US States":
        figure.title = main_graph_ttl.format(val_name, country_name + " State")

    elif datasets.name == "US Counties":
        figure.title = main_graph_ttl.format(val_name, country_name + " County")
        
    update_link_download()
    return


In [None]:
def name_val_on_hover(data, norm, data_type, ma, n):
    res = data
    if norm == "Per million":
        res += "/1M population"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res


def line_hover(*args):
    # show table with Cases, deaths, rec and active corresponding to the
    # hovered marker

    try:
        index = args[1]["data"]["sub_index"]
        Date = str(main_graph.marks[0].x[index])
        # get value of cases deaths rec and active at Date
        values = list(
            np.nan_to_num(
                [main_graph.marks[0].y[i][index] for i in range(len(selected))]
            )
        )
        name_vals = [
            name_val_on_hover(
                data,
                norm_buttons.value,
                type_buttons.value,
                tab_1_ma_ch.value,
                tab_1_ma_w.value,
            )
            for data in selected
        ]
        if norm_buttons.value == "Values" and type_buttons.value != "Daily % change":
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,d}".format(int(values[i]))]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        else:
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,.2f}".format(values[i])]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        stats_line_tooltip_table_values.value = css_style + line_tooltip_table.format(
            *res
        )
    except:
        pass
    return


def legend_hover(*args):
    # Reduce opacity of non hovered legends

    N = len(selected)
    op = [0.3] * N
    index = args[1]["data"]["index"]
    op[index] = 1
    if (
        np.sum(main_graph.marks[0].opacities) < N
        and main_graph.marks[0].opacities[index] == 1
    ) or (index not in main_graph.marks[0].curves_subset):
        main_graph.marks[0].opacities = [1] * N
    else:
        main_graph.marks[0].opacities = op
    return


def legend_click_line(*args):
    # Remove curve if you click on legend
    # if it's the last visible curve add all the others

    N = len(selected)
    index = args[1]["data"]["index"]
    ind1, ind2 = date_axis_selector.index
    if index in main_graph.marks[0].curves_subset:
        if len(main_graph.marks[0].curves_subset) == 1:
            main_graph.marks[0].curves_subset = [i for i in range(N)]
        else:
            main_graph.marks[0].curves_subset = [index]
    else:
        main_graph.marks[0].curves_subset = main_graph.marks[0].curves_subset + [index]

    main_graph.marks[0].scales["y"].max = 1.1 * float(
        np.nanmax(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    main_graph.marks[0].scales["y"].min = 0.9 * float(
        np.nanmin(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    return


def line_background_click(*args):
    main_graph.marks[0].curves_subset = [0, 1, 2, 3]
    ind1, ind2 = date_axis_selector.index
    main_graph.marks[0].scales["y"].max = 1.1 * float(
        np.nanmax(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    main_graph.marks[0].scales["y"].min = 0.9 * float(
        np.nanmin(
            np.array(main_graph.marks[0].y)[
                main_graph.marks[0].curves_subset, ind1 : ind2 + 1
            ]
        )
    )
    return


def bqplot_legend_curvesubset_bug():
    # bug in bqplot Lines curvesubset method
    curves_subset = main_graph.marks[0].curves_subset
    if len(curves_subset) == 1:
        curves_subset_1 = [
            curves_subset[0] + 1 if curves_subset[0] + 1 < 5 else curves_subset[0] - 1
        ]
    else:
        curves_subset_1 = [curves_subset[0]]
    main_graph.marks[0].curves_subset = curves_subset_1
    main_graph.marks[0].curves_subset = curves_subset
    return


In [None]:
# Maps

def map_title(map_name, data, norm, data_type, date, ma, n):
    res = map_name + " COVID-19 " + data
    if norm == "Per million":
        res += " " + norm
    if data_type == "Daily change":
        res += " daily change"
    if data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    return res + " Map, " + date


In [None]:
table_tmpl_no_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr>
    <td>{1:s}</td>
    <td>{2:,}</td>
</tr>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{3:,d}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{4:,d}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Recovered</td>
    <td>{5:,d}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Active Cases</td>
    <td>{6:,d}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>Tests</td>
    <td>{7:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{8:,d}</td>
</tr>
</table>
"""

table_tmpl_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr style="color:#1E90FF">
    <td>Cases</td>
    <td>{1:,d}</td>
</tr>
<tr style="color:#D62728">
    <td>Deaths</td>
    <td>{2:,d}</td>
</tr>
<tr style="color:#2CA02C">
    <td>Recovered</td>
    <td>{3:,d}</td>
</tr>
<tr style="color:#FFBB0E">
    <td>Active Cases</td>
    <td>{4:,d}</td>
</tr>
<tr style="color:#7F7F7F">
    <td>Tests</td>
    <td>{5:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{6:,d}</td>
</tr>
</tbody>
</table>
"""
stats_table = HTML()


In [None]:
def current_val_on_hover(data, norm, scale, data_type, ma, n):
    if data_type == "Total" and norm == "Values":
        return None
    else:
        res = data
        if norm == "Per million":
            res += "/1M population"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res


In [None]:
# a = [-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
a = [-5, -2, -1, 0, 1, 2, 5]
p = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
a_ps = np.sort(list(set([i[0] * 10 ** i[1] for i in itertools.product(a, p)])))
red_color_scheme = [
            "#FFF0CC",
            "#ffa600",
            "#ff6e00",
            "#ff4417",
            "#d31522",
            "#8c000e",
            "#560410",
        ]
green_color_scheme = [
            "white",
            "#C5E8B7",
            "#ABE098",
            "#83D475",
            "#57C84D",
            "#2EB62C",
            "#207f1e",
        ]
gyr_color_scheme = [
            "#069740",
            "#9acd32",
            "#e4d31b",
            "#ffa600",
            "#d31522",
            "#8c000e",
            "#560410",
        ]

map_color_scheme = {'Cases': red_color_scheme,
                    'Deaths': red_color_scheme,
                    'Recovered': green_color_scheme,
                    'Active Cases': gyr_color_scheme,
                    'Tests': green_color_scheme,
                   }

def my_round(v, how="down"):
    assert how in ["nearest", "up", "down"]

    def get_upper_lower(v):
        lower = np.nan
        upper = np.nan
        for val in a_ps:
            if v >= val:
                lower = val
            elif v <= val:
                upper = val
                return lower, upper
        return lower, upper

    lower, upper = get_upper_lower(v)

    if how == "down":
        return lower
    if how == "up":
        return upper
    if how == "nearest":
        return lower if (v - lower) < (upper - v) else upper


my_round_vec = np.vectorize(my_round)


def expand_roud_quantiles(q_round_cuts):
    res = [q_round_cuts[0]]
    for i in range(1, len(q_round_cuts)):
        if q_round_cuts[i] == res[i - 1]:
            res.append(a_ps[bisect.bisect_right(a_ps, q_round_cuts[i])])
        elif q_round_cuts[i] < res[i - 1]:
            res.append(a_ps[bisect.bisect_right(a_ps, res[i - 1])])
        else:
            res.append(q_round_cuts[i])
    return res


def map_cuts(arr, cuts):
    return [
        cuts[bisect.bisect_right(cuts, a) - 1] if not np.isnan(a) else np.nan
        for a in arr
    ]


def cuts_to_str(cuts):
    res = []
    for v in cuts:
        if abs(v) >= 1e6:
            res.append(str(int(v / 1e6)) + "M")
        elif abs(v) >= 1000:
            res.append(str(int(v / 1000)) + "K")
        elif abs(v) >= 1:
            res.append(str(int(v)))
        elif abs(v) >= 0.1:
            res.append(str(round(v, 1)))
        elif abs(v) >= 0.01:
            res.append(str(round(v, 2)))
        else:
            res.append(str(round(v, 3)))
    return res


def cuts_to_range(cuts, pct=False):
    res = []
    pcuts = ['('+cuts[i]+')' if cuts[i][0]=='-' else cuts[i] for i in range(len(cuts))]
    for i in range(len(pcuts) - 1):
        if pct:
            res.append(pcuts[i] + "% - " + pcuts[i + 1] + "%")
        else:
            res.append(pcuts[i] + " - " + pcuts[i + 1])
    if pct:
        res.append(pcuts[-1] + "%+")
    else: 
        res.append(pcuts[-1] + "+")
    return res


def map_color(dataset, data, norm, scale, data_type, ma, n, date):
    IDs = []
    colors = []
    values = dataset.get_ts(None, data, norm, scale, data_type, ma, n).loc[date]
    for country_name in dataset.get_columns(data):
        try:
            IDs.append(dataset.get_ID(country_name))
            value = values[country_name]
            if value == 0:
                colors.append(math.nan)
            else:
                colors.append(value)
        except:
            if len(IDs) > len(colors):
                IDs = IDs[:-1]
    q_cuts = np.nanquantile(colors, np.array([0, 0.2, 0.4, 0.55, 0.7, 0.85, 0.95]))# np.linspace(0, 1, 8)[:-1])
    q_round_cuts = expand_roud_quantiles(my_round_vec(q_cuts))
    if data_type == "Daily % change":
        pct = True
    else:
        pct = False
    cuts_str_format = cuts_to_range(cuts_to_str(q_round_cuts), pct)
    str_cuts = dict(zip(q_round_cuts, cuts_str_format))
    colors = map_cuts(colors, q_round_cuts)
    colors_to_range = [str_cuts[c] if not np.isnan(c) else np.nan for c in colors]
    color = dict(zip(IDs, colors_to_range))
    selected_index = dict_ID_to_map[dataset.name]
    dict_maps[selected_index]["scale"].domain = cuts_str_format
    dict_maps[selected_index]["scale"].colors = map_color_scheme[data]
    return color


def select_date(obj):
    # color the map with the values of the chosen date
    date = obj["new"]
    norm, scale, data_type = norm_buttons.value, scale_buttons.value, type_buttons.value
    data = data_buttons.value
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
    main_map.title = map_title(
        dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n
    )
    dict_maps[selected_index]["map"].color = color
    update_tables(data, norm, scale, data_type, ma, n, date)
    return


def map_click(obj, value):
    # update main graph if you click on a country or state
    try:
        _id = value["data"]["id"]
        date = date_selector.value
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        name = datasets.get_name(_id)
        normalization, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        date_sel_index = date_axis_selector.index
        date1 = str(date_axis_selector.options[date_sel_index[0]])
        date2 = str(date_axis_selector.options[date_sel_index[1]])
        update_main_graph(
            datasets,
            name,
            normalization,
            scale,
            data_type,
            ma,
            n,
            date1,
            date2,
            axis_t,
            axis_y,
            main_graph,
        )
        bqplot_legend_curvesubset_bug()
    except:
        pass
    return


def map_background_click(*args):
    normalization, scale, data_type = (
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
    )
    ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
    date_sel_index = date_axis_selector.index
    date1 = str(date_axis_selector.options[date_sel_index[0]])
    date2 = str(date_axis_selector.options[date_sel_index[1]])
    update_main_graph(
        dict_datasets[0],
        "WLD",
        normalization,
        scale,
        data_type,
        ma,
        n,
        date1,
        date2,
        axis_t,
        axis_y,
        main_graph,
    )
    bqplot_legend_curvesubset_bug()
    return


def map_hovering(obj, value):
    # show stat table when you hover on a country or state
    _id = value["data"]["id"]
    date = date_selector.value
    selected_index = dict_ID_to_map[map_buttons.value]
    datasets = dict_datasets[selected_index]
    try:
        name = datasets.get_name(_id)
        if selected_index == 0:
            name_table = codes_to_countries[name]
        elif selected_index >= 1:
            name_table = name
    except:
        stats_table.value = ""
    try:
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        data = data_buttons.value
        current_val = current_val_on_hover(data, norm, scale, data_type, ma, n)
        output_vals = [name_table]
        for c in selected:
            output_vals.append(
                int(
                    datasets.get_value(
                        name, c, "Values", "Linear", "Total", ma, n, date
                    )
                )
            )
        output_vals.append(int(datasets.get_population(name)))
        if current_val is None:
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple(output_vals)
            )
        else:
            val = float(
                datasets.get_value(name, data, norm, scale, data_type, ma, n, date)
            )
            if data_type == "Daily change" and norm == "Values":
                val = int(val)
            else:
                val = round(val, 2)
            output_vals = [output_vals[0]] + [current_val, val] + output_vals[1:]
            stats_table.value = css_style + table_tmpl_no_duplicate.format(
                *tuple(output_vals)
            )
    except:
        try:
            population = int(datasets.get_population(name))
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple([name_table] + [0] * 5 + [population])
            )
        except:
            stats_table.value = ""
    main_map.tooltip = stats_table
    return


In [None]:
def event(*args):
    # update map color and graph if the value of the radio buttons change
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        if ma:
            tab_1_ma_w.layout.visibility = "visible"
        else:
            tab_1_ma_w.layout.visibility = "hidden"
        selected_index = dict_ID_to_map[map_buttons.value]
        datasets = dict_datasets[selected_index]
        color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
        main_map.title = map_title(
            dict_maps[selected_index]["name"], data, norm, data_type, date, ma, n
        )
        dict_maps[selected_index]["map"].color = color
        date_sel_index = date_axis_selector.index
        date1 = str(date_axis_selector.options[date_sel_index[0]])
        date2 = str(date_axis_selector.options[date_sel_index[1]])
        if main_graph.title.split(" ")[-1] == "State":
            update_main_graph(
                dict_datasets[1],
                main_graph.title.split(" in ")[-1][:-6],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        elif main_graph.title.split(" ")[-1] == "County":
            update_main_graph(
                dict_datasets[2],
                main_graph.title.split(" in ")[-1][:-7],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        else:
            update_main_graph(
                dict_datasets[0],
                countries_to_codes[main_graph.title.split(" in ")[-1]],
                norm,
                scale,
                data_type,
                ma,
                n,
                date1,
                date2,
                axis_t,
                axis_y,
                main_graph,
            )
        bqplot_legend_curvesubset_bug()
        update_tables(data, norm, scale, data_type, ma, n, date)
    return


In [None]:
play = Play(
    value=0,
    min=0,
    max=1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto", height="auto", min_width="150px"),
)


def update_index_date(*args):
    date_selector.value = date_selector.options[args[0]["new"]]
    return


In [None]:
date_selector = SelectionSlider(
    options=["2020-01-22", "2020-01-23"],
    description="Date",
    disabled=False,
    value="2020-01-22",
    layout=Layout(width="100%", height="auto", min_width="200px"),
    style={"description_width": "initial"},
)


In [None]:
date_tab_1_box = HBox(
    [date_selector, play],
    layout=Layout(
        width="100%",
        min_width="200px",
        height="40px",
        min_height="40px",
        max_height="40px",
        overflow="auto",
    ),
)


In [None]:
date_axis_selector = SelectionRangeSlider(
    options=["2020-01-22"], index=(0, 0), layout=Layout(width="auto", height="auto"),
)


def update_zoom(*args):
    min_x, max_x = args[0]["new"]
    main_graph.axes[0].scale.min = min_x
    main_graph.axes[0].scale.max = max_x

    main_graph.axes[1].scale.min = 0.9*float(
        np.nanmin(
            main_graph.marks[0].y[
                main_graph.marks[0].curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )

    main_graph.axes[1].scale.max = 1.1*float(
        np.nanmax(
            main_graph.marks[0].y[
                main_graph.marks[0].curves_subset,
                date_axis_selector.index[0] : date_axis_selector.index[1] + 1,
            ]
        )
    )
    return


In [None]:
data_buttons.add_observe(event, "value")
norm_buttons.add_observe(event, "value")
scale_buttons.add_observe(event, "value")
type_buttons.add_observe(event, "value")
tab_1_ma_ch.observe(event, "value")
tab_1_ma_w.observe(event, "value")


In [None]:
# World Map

sc_geo = Mercator(center=(-20, 50))
sc_c1 = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis = ColorAxis(scale=sc_c1)
map_tt = stats_table

wm = Map(
    map_data=topo_load("map_data/WorldMap.json"),
    scales={"projection": sc_geo, "color": sc_c1},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)


In [None]:
wm.on_hover(map_hovering)
wm.on_element_click(map_click)
wm.on_background_click(map_background_click)


In [None]:
# US States Map

sc_geo_US = AlbersUSA()
sc_c1_US = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis_US = ColorAxis(scale=sc_c1_US)
wm_US = Map(
    map_data=topo_load("map_data/USStatesMap.json"),
    scales={"projection": sc_geo_US, "color": sc_c1_US},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)


In [None]:
wm_US.on_hover(map_hovering)
wm_US.on_element_click(map_click)


In [None]:
sc_geo_US_counties = AlbersUSA()
sc_c1_US_counties = OrdinalColorScale(
    colors=red_color_scheme,
)
caxis_US_counties = ColorAxis(scale=sc_c1_US_counties)

wm_US_counties = Map(
    map_data=topo_load("map_data/USCountiesMap.json"),
    scales={"projection": sc_geo_US_counties, "color": sc_c1_US_counties},
    colors={"default_color": "Grey"},
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)

dict_maps = {
    0: {
        "map": wm,
        "name": "World",
        "projection": sc_geo,
        "scale": sc_c1,
        "axis": caxis,
    },
    1: {
        "map": wm_US,
        "name": "US States",
        "projection": sc_geo_US,
        "scale": sc_c1_US,
        "axis": caxis_US,
    },
    2: {
        "map": wm_US_counties,
        "name": "US Counties",
        "projection": sc_geo_US_counties,
        "scale": sc_c1_US_counties,
        "axis": caxis_US_counties,
    },
}
dict_ID_to_map = dict(zip(["World", "US States", "US Counties"], [0, 1, 2]))

wm_US_counties.on_hover(map_hovering)
wm_US_counties.on_element_click(map_click)


In [None]:
main_map = Figure(
    marks=[wm],
    axes=[caxis],
    title="",
    fig_margin={"top": 50, "bottom": 50, "left": 0, "right": 0},
    layout=Layout(width="99%", height="99%"),
)


In [None]:
def update_map(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        map_name = map_buttons.value
        ID = dict_ID_to_map[map_name]
        mark = dict_maps[ID]["map"]
        date = date_selector.value
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        ma, n = tab_1_ma_ch.value, tab_1_ma_w.value
        datasets = dict_datasets[ID]
        color = map_color(datasets, data, norm, scale, data_type, ma, n, date)
        main_map.title = map_title(map_name, data, norm, data_type, date, ma, n)
        mark.color = color
        main_map.axes = [dict_maps[ID]["axis"]]
        main_map.marks = [mark]

    return


map_buttons.add_observe(update_map, "value")


In [None]:
# Tables

# Uncomment if using white theme
# css_style_2 = """
# <head>
#     <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
#     <style type="text/css" media="screen">
#         #sorted_table {
#           font-family: "Trebuchet MS", Arial, Helvetica, sans-serif;
#           font-size: 12px;
#           border-collapse: collapse;
#           border-spacing: 0;
#           width: 100%;
#           text-align: left;
#         }
#         #sorted_table td, #sorted_table th {
#           border: 1px solid #ddd;
#           padding: 0;
#         }
#         #sorted_table tbody td {
#           font-size: 12px;
#           font-weight: bold;
#         }
#         #sorted_table tr:nth-child(even) {
#           background-color: #f2f2f2;
#         }
#         #sorted_table th {
#           padding-top: 1px;
#           padding-bottom: 1px;
#           text-align: center;
#           background-color: #003366;
#           color: white;
#         }
#     </style>
# </head>
# <body>
# """

css_style_2 = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <style type="text/css" media="screen">
        #sorted_table {
          width: 100%;
          text-align: left;
          color: #ff8b0e;
        }
        #sorted_table th {
           text-align: left;
           color: #ff8b0e;
        }
    </style>
</head>
"""


def sorted_table_html(
    dataset,
    data,
    norm,
    scale,
    data_type,
    ma,
    n,
    date,
    ascending,
    K,
    dict_codes,
    column_name,
):
    df = dataset.get_ts_sort(
        data, norm, scale, data_type, ma, n, date, ascending, K
    ).reset_index()
    current_val = name_val_on_hover(data, norm, data_type, ma, n)

    df.rename(
        {df.columns.values[0]: column_name, date: current_val}, axis=1, inplace=True
    )
    if dict_codes:
        df[column_name] = [dict_codes[c] for c in df[column_name].values]
    df.dropna(inplace=True)
    if norm == "Values" and data_type != "Daily % change":
        df[current_val] = df[current_val].astype(int)
        formatters = {current_val: lambda x: "{:,d}".format(x)}
    else:
        formatters = {current_val: lambda x: "{:,.02f}".format(x)}

    val = df.to_html(
        index=False,
        header=True,
        formatters=formatters,
        notebook=True,
        table_id="sorted_table",
    )
    return css_style_2 + val + """</body>"""


def sorted_tables_html(data, norm, scale, data_type, ma, n, date, ascending, K):
    value_table_1 = sorted_table_html(
        dict_datasets[0],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        codes_to_countries,
        "Countries",
    )
    value_table_2 = sorted_table_html(
        dict_datasets[1],
        data,
        norm,
        scale,
        data_type,
        ma,
        n,
        date,
        ascending,
        K,
        None,
        "US States",
    )
    return value_table_1, value_table_2


def update_tables(data, norm, scale, data_type, ma, n, date):
    ascending = False
    K = None
    value_table_1, value_table_2 = sorted_tables_html(
        data, norm, scale, data_type, ma, n, date, ascending, K
    )

    table_1.value = value_table_1
    table_2.value = value_table_2
    return


table_1 = HTML()
table_2 = HTML()


In [None]:
# Infection Map tab

grid_1 = GridspecLayout(2, 2)
grid_1.layout.height = "99%"
grid_1.layout.width = "100%"
grid_1.layout.overflow = "auto"
table_grid = GridspecLayout(1, 2)
table_grid.layout.width = "100%"
table_grid.layout.height = "100%"
table_1.layout.width = "auto"
table_1.layout.overflow = "auto"
table_2.layout.overflow = "auto"
table_2.layout.width = "auto"
table_grid[0, 0] = table_1
table_grid[0, 1] = table_2
center_right_panel = VBox([main_graph, date_axis_selector])
center_right_panel.layout.height = "100%"
tab_graph_table = Tab(
    _titles=dict(zip([0, 1], ["Graph", "Table"])),
    children=[center_right_panel, table_grid],
    layout=Layout(height="99%", min_height="400px"),
)
grid_1[0, 0] = HBox([cat_tab_1_buttons])
grid_1[0, 1] = HBox([date_tab_1_box])
main_graph.layout.height = "99%"
grid_1[1, 0] = VBox([box_map, main_map])
grid_1[1, 1] = tab_graph_table
grid_1.layout.align_items = "stretch"
grid_1.layout.grid_template_rows = "95px auto"
grid_1.layout.grid_template_columns = "60% auto"


In [None]:
# Rebased Graph Tab

# Buttons and Multiple selectors
min_description_width_2 = "140px"

rebased_graph_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
)

rebased_graph_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
)

rebased_graph_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_type,
    min_description_width=min_description_width_1,
)

thr_val_slider = IntSlider(
    description="Threshold",
    value=1000,
    min=10,
    max=5000,
    step=100,
    style={"description_width": "initial"},
    layout=Layout(width="340px", visibility="hidden"),
)

plot_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
)
min_description_width_2 = "125px"
rebased_graph_thr_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Tests"],
    value="Cases",
    description="Threshold Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_2,
)

rebased_graph_thr_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Threshold Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_2,
)

tab_2_ma_ch = Checkbox(
    description="Moving Average",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="auto"),
)
tab_2_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_2 = VBox(
    [tab_2_ma_ch, tab_2_ma_w],
    layout=Layout(
        min_width="205px",
        width="205px",
        max_width="205px",
        height="74px",
        min_height="74px",
        max_height="74px",
        overflow="auto",
    ),
)


In [None]:
top_countries_names = ["United States"]
states = ["New York"]

countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

calendar_time = Checkbox(
    description="Calendar Time",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(width="auto", overflow="auto"),
)

cat_tab_2_buttons = VBox(
    [
        rebased_graph_data_button,
        rebased_graph_norm_button,
        rebased_graph_type_button,
        plot_scale_button,
    ],
    layout=Layout(
        width="565px",
        height="138px",
        min_width="565px",
        max_width="565px",
        min_height="138px",
        max_height="138px",
        overflow="auto",
    ),
)
thr_tab_2_buttons = HBox(
    [VBox([thr_val_slider, calendar_time], layout=Layout(width="auto")), ma_box_tab_2],
    layout=Layout(
        width="565px",
        height="80px",
        min_width="565px",
        max_width="565px",
        min_height="80px",
        max_height="80px",
        overflow="auto",
        border="solid #ff8b0e",
        margin="10px 0 10px 0",
    ),
)
cat_thr_tab_2_buttons = VBox(
    [rebased_graph_thr_data_button, rebased_graph_thr_norm_button,],
    layout=Layout(
        width="565px",
        height="74px",
        min_width="565px",
        max_width="565px",
        min_height="74px",
        max_height="74px",
        overflow="auto",
        visibility="hidden",
    ),
)


In [None]:
reb_top_right_buttons = VBox(
    [cat_tab_2_buttons, thr_tab_2_buttons, cat_thr_tab_2_buttons],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        overflow="auto",
        height="315px",
        min_height="315px",
        max_height="315px",
    ),
)


In [None]:
df = pd.read_csv(FOLDER_WORLD + "continents.csv", sep=";", index_col=1).fillna("NA")
europe = list(df[df["CC"] == "EU"].index.values)
africa = list(df[df["CC"] == "AF"].index.values)
asia = list(df[df["CC"] == "AS"].index.values)
asia.remove("EGY")
na = list(df[df["CC"] == "NA"].index.values)
sa = list(df[df["CC"] == "SA"].index.values)
oc = list(df[df["CC"] == "OC"].index.values)


In [None]:
presaved_sel_options = [
    "Europe",
    "Asia",
    "North America",
    "South America",
    "Africa",
    "Oceania",
    "Top10 (Cases)",
    "Top10 (Deaths)",
    "Top10 (Cases/1M)",
    "Top10 (Deaths/1M)",
    "Top10 (Cases daily chg)",
    "Top10 (Deaths daily chg)",
]
presaved_sel_lists = dict(
    zip(presaved_sel_options, [europe, asia, na, sa, africa, oc, [], [], [], []])
)


def preselection(presaved_sel_options, class_name, action):
    presaved_sel_buttons = dict(
        zip(presaved_sel_options, [None] * len(presaved_sel_options))
    )
    for l in presaved_sel_options:
        if "Top10" in l:
            tltp = "Top10 most-affected countries (" + l.split("(")[-1]
        else:
            tltp = "Countries in " + l
        presaved_sel_buttons[l] = Button(description=l, button_style="", tooltip=tltp)
        presaved_sel_buttons[l].add_class(class_name)
        presaved_sel_buttons[l].on_click(action)
    return presaved_sel_buttons


def on_click_action(val):
    selection = []
    for c in presaved_sel_lists[val.description]:
        if c in codes_to_countries.keys() and c in dict_datasets[0].get_columns(
            "Cases"
        ):
            selection.append(codes_to_countries[c])
    if val._dom_classes[0] == "Rebased":
        states_selector.value = ["None"]
        countries_selector.value = selection
    elif val._dom_classes[0] == "Heat map":
        dna_states_selector.value = ["None"]
        dna_countries_selector.value = selection
    elif val._dom_classes[0] == "Custom":
        free_states_selector.value = ["None"]
        free_countries_selector.value = selection
    return


In [None]:
ps_tab2 = preselection(presaved_sel_options, "Rebased", on_click_action)
hb_tab2 = VBox(
    [
        HBox([ps_tab2[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
accordion_tab2 = Accordion(
    children=[hb_tab2], selected_index=None, layout=Layout(width="auto", height="auto")
)
accordion_tab2.set_title(0, "Predefined selections")


In [None]:
reb_countries_selector_title = HTML('<p style="text-align:center">Countries</p>')
reb_states_selector_title = HTML('<p style="text-align:center">US States</p>')
countries_selector_box = VBox(
    [reb_countries_selector_title, countries_selector], layout=Layout(width="100%")
)
states_selector_box = VBox(
    [reb_states_selector_title, states_selector], layout=Layout(width="100%")
)
tab_2_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="555px",
        align_items="stretch",
        min_height="60px",
        overflow="auto",
        border="solid #ff8b0e",
    ),
)
tab_2_box_selectors[0, :] = accordion_tab2
tab_2_box_selectors[1, 1] = countries_selector_box
tab_2_box_selectors[1, 3] = states_selector_box
tab_2_box_selectors.layout.grid_template_rows = "50px auto"


def accordion_adapt_height(val):
    selected = val.new == 0
    index = outer_tab.selected_index
    if index == 1:
        if selected:
            tab_2_box_selectors.layout.grid_template_rows = "175px auto"
        else:
            tab_2_box_selectors.layout.grid_template_rows = "50px auto"
    elif index == 2:
        if selected:
            tab_3_box_selectors.layout.grid_template_rows = "175px auto"
        else:
            tab_3_box_selectors.layout.grid_template_rows = "50px auto"
    elif index == 3:
        if selected:
            tab_4_box_selectors.layout.grid_template_rows = "250px auto"
        else:
            tab_4_box_selectors.layout.grid_template_rows = "60px auto"


accordion_tab2.observe(accordion_adapt_height, "selected_index")


In [None]:
# Graph

scale_reb_t = LinearScale()
axis_reb_t = Axis(scale=scale_reb_t, grid_lines="none")
scale_reb_y = LogScale()
axis_reb_y = Axis(
    scale=scale_reb_y, orientation="vertical", grid_lines="solid", tick_format=","
)
reb_pz = PanZoom(scales={"x": [scale_reb_t], "y": [scale_reb_y]})
rebased_graph = Figure(
    animation_duration=1000,
    title="Rebased Graph",
    legend_location="bottom-right",
    interaction=reb_pz,
    axes=[axis_reb_t, axis_reb_y],
    fig_margin={"top": 50, "bottom": 50, "left": 65, "right": 100},
)


In [None]:
def title_rebased_graph(
    data_to_plot, data_norm, data_type, threshold, plot_scale, thr_data, thr_norm, ma, n
):
    res = ""
    if plot_scale == "Log":
        res += plot_scale
    res += " " + data_to_plot
    if data_norm != "Values":
        res += "/1M population"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if ma:
        res += " ({0:d}days m.a)".format(n)
    if threshold > 0:
        res += " since number of " + thr_data
        if thr_norm != "Values":
            res += "/1M population"
        res += " = " + str(threshold)
    else:
        res += " since " + STDT
    return res


def label_index_pos(data):
    list_nan = list(np.isnan(data))
    if True in list_nan:
        index = len(data) - 1
        while index >= 1 and not (
            list_nan[index] == False and list_nan[index - 1] == False
        ):
            index -= 1
        return index
    else:
        return len(data) - 1


In [None]:
def persistent_colors(used_colors, used_labels, new_labels, all_colors):
    color_dict = dict(
        zip(used_labels, list(used_colors) * (len(used_labels) // len(used_colors) + 1))
    )
    non_assigned = []
    for l in new_labels:
        if l not in used_labels:
            non_assigned.append(l)
    if len(non_assigned) > 0:
        counter_colors = Counter(color_dict.values())
        for c in all_colors:
            if c not in list(counter_colors.keys()):
                counter_colors[c] = 0
        for l in non_assigned:
            index = np.argmin(list(counter_colors.values()))
            new_color = list(counter_colors.keys())[index]
            color_dict[l] = new_color
            counter_colors[new_color] += 1
    return color_dict


def plot_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    thr,
    countries,
    states,
    plot_scale,
    thr_data,
    thr_norm,
    ma,
    n,
    calendar_time=False,
):
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    if len(rebased_graph.marks) > 0:
        update = True
        color_dict = persistent_colors(
            list(rebased_graph.marks[0].colors),
            list(rebased_graph.marks[1].text),
            countries + states,
            all_colors,
        )
    else:
        update = False
        colors = all_colors.copy()

    if calendar_time:
        threshold = 0
    else:
        threshold = thr
    if plot_scale == "Log":
        scale_reb_y = LogScale()
    else:
        scale_reb_y = LinearScale()
    Ys = []
    Xs = []
    label_names = []
    max_len = 0
    for country_name in countries:
        yaux = dict_datasets[0].get_ts_plot(
            countries_to_codes[country_name],
            data_to_plot,
            data_norm,
            plot_scale,
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[0].get_ts(
                countries_to_codes[country_name],
                thr_data,
                thr_norm,
                "Linear",
                "Total",
                ma,
                n,
            )
            # values starts from index of threshold
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(country_name)
            max_len = max(max_len, len(Ys[-1]))
        if update:
            colors.append(color_dict[country_name])
    for st in states:
        yaux = dict_datasets[1].get_ts_plot(
            st,
            data_to_plot,
            data_norm,
            "Linear",
            data_type,
            ma,
            n,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[1].get_ts(
                st, thr_data, thr_norm, "Linear", "Total", ma, n,
            )
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(st)
            max_len = max(max_len, len(Ys[-1]))
        if update:
            colors.append(color_dict[st])
    label_x = []
    label_y = []
    for i in range(len(Ys)):
        # get a 2d nparray to plot the rebased graph
        index_lab = label_index_pos(Ys[i])
        label_y.append(Ys[i][index_lab])
        label_x.append(Xs[i][index_lab])
        Ys[i] = Ys[i] + [math.nan] * (max_len - len(Ys[i]))
        Xs[i] = Xs[i] + [math.nan] * (max_len - len(Xs[i]))
    axis_reb_y.scale = scale_reb_y
    if threshold > 0:
        scale_reb_t = LinearScale()
    else:  # if thr == 0 Plot the values vs calendar dates
        scale_reb_t = DateScale(
            dateformat=dateformat,
            min=datetime.datetime.strptime(STDT, dateformat),
            max=datetime.datetime.strptime(ENDT, dateformat),
        )
        Xs = list(
            np.arange(
                np.datetime64(
                    datetime.datetime.strptime(STDT, dateformat).strftime(dateformat)
                ),
                np.datetime64(
                    datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
                )
                + np.timedelta64(1, "D"),
            )
        )
        label_x = [
            np.datetime64(
                datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
            )
        ] * (len(countries) + len(states))
    axis_reb_t.scale = scale_reb_t
    rebased_graph.axes = [axis_reb_t, axis_reb_y]
    rebased_curves = Lines(scales={"x": scale_reb_t, "y": scale_reb_y}, colors=colors)
    rebased_curves.x = np.array(Xs)
    rebased_curves.y = np.array(Ys)
    rebased_labels = Label(
        apply_clip=False,
        scales={"x": scale_reb_t, "y": scale_reb_y},
        default_size=15,
        colors=colors,
    )
    rebased_labels.text = label_names
    rebased_labels.x = np.array(label_x)
    rebased_labels.y = np.array(label_y)
    rebased_graph.marks = [rebased_curves, rebased_labels]
    rebased_graph.title = title_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        plot_scale,
        thr_data,
        thr_norm,
        ma,
        n,
    )
    reb_pz.scales = {"x": [scale_reb_t], "y": [scale_reb_y]}
    return


In [None]:
value_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    plot_scale_button,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    tab_2_ma_ch,
    tab_2_ma_w,
    thr_val_slider,
    countries_selector,
    states_selector,
]


def update_rebased_graph(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        (
            data_to_plot,
            data_norm,
            data_type,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            threshold,
        ) = [v.value for v in value_buttons[:-2]]
        countries, states = list(countries_selector.value), list(states_selector.value)

        if "None" in countries:
            countries = []
        if "None" in states:
            states = []
        if len(countries) + len(states) == 0:
            countries = ["United States"]

        if ma:
            tab_2_ma_w.layout.visibility = "visible"
        else:
            tab_2_ma_w.layout.visibility = "hidden"

        if calendar_time.value:
            thr_val_slider.layout.visibility = "hidden"
            cat_thr_tab_2_buttons.layout.visibility = "hidden"
        else:
            thr_val_slider.layout.visibility = "visible"
            cat_thr_tab_2_buttons.layout.visibility = "visible"

        plot_rebased_graph(
            data_to_plot,
            data_norm,
            data_type,
            threshold,
            countries,
            states,
            plot_scale,
            thr_data,
            thr_norm,
            ma,
            n,
            calendar_time.value,
        )
        update_link_download()
    return


thr_val_slider.continuous_update = False
thr_val_slider.observe(update_rebased_graph, "value")
tab_2_ma_ch.observe(update_rebased_graph, "value")
tab_2_ma_w.observe(update_rebased_graph, "value")
calendar_time.observe(update_rebased_graph, "value")
for v in value_buttons[:-5]:
    v.add_observe(update_rebased_graph, "value")


In [None]:
grid_2 = GridspecLayout(2, 2)
grid_2.layout.overflow = "hidden"
grid_2.layout.height = "100%"
grid_2.layout.width = "100%"
rebased_graph.layout.width = "100%"
rebased_graph.layout.height = "99%"
grid_2[:, 0] = rebased_graph
grid_2[0, 1] = HBox([reb_top_right_buttons])
grid_2[1, 1] = HBox([tab_2_box_selectors])
grid_2.layout.align_items = "stretch"
grid_2.layout.grid_template_columns = "auto 566px"
grid_2.layout.grid_template_rows = "315px auto"


In [None]:
# DNA Graph tab3

dna_countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

dna_states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)


dna_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_data,
    min_description_width=min_description_width_1,
)

dna_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
)

dna_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
)

tab_3_ma_ch = Checkbox(
    description="Moving Average",
    value=False,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="100%", overflow="auto"),
)
tab_3_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (in days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="200px", width="200px", overflow="auto", visibility="hidden"
    ),
)
ma_box_tab_3 = HBox(
    [tab_3_ma_ch, tab_3_ma_w],
    layout=Layout(
        width="555px",
        min_width="555px",
        max_width="555px",
        overflow="auto",
        align_items="stretch",
    ),
)

cat_tab_3_buttons = VBox(
    [dna_data_button, dna_norm_button, dna_type_button],
    layout=Layout(
        width="565px",
        height="105px",
        min_width="565px",
        max_width="565px",
        min_height="105px",
        max_height="105px",
        overflow="auto",
    ),
)

dna_top_buttons = VBox(
    [cat_tab_3_buttons, ma_box_tab_3],
    layout=Layout(
        width="565px",
        min_width="565px",
        max_width="565px",
        height="140px",
        min_height="140px",
        max_height="140px",
        overflow="auto",
    ),
)


In [None]:
ps_tab3 = preselection(presaved_sel_options, "Heat map", on_click_action)
hb_tab3 = VBox(
    [
        HBox([ps_tab3[v] for v in presaved_sel_options[i : i + 4]])
        for i in range(0, len(presaved_sel_options), 4)
    ]
)
accordion_tab3 = Accordion(
    children=[hb_tab3], selected_index=None, layout=Layout(width="auto", height="auto")
)
accordion_tab3.set_title(0, "Predefined selections")
accordion_tab3.observe(accordion_adapt_height, "selected_index")


In [None]:
countries_selector_box_dna = VBox(
    [HTML('<p style="text-align:center">Countries</p>'), dna_countries_selector],
    layout=Layout(width="100%"),
)
states_selector_box_dna = VBox(
    [HTML('<p style="text-align:center">US States</p>'), dna_states_selector],
    layout=Layout(width="100%"),
)
tab_3_box_selectors = GridspecLayout(
    2,
    5,
    layout=Layout(
        width="555px", align_items="stretch", min_height="60px", overflow="auto"
    ),
)
tab_3_box_selectors[0, :] = accordion_tab3
tab_3_box_selectors[1, 1] = countries_selector_box_dna
tab_3_box_selectors[1, 3] = states_selector_box_dna
tab_3_box_selectors.layout.grid_template_rows = "50px auto"


In [None]:
# GridHeatMap, DNA graph

dna_x_scale = DateScale()
dna_y_scale = OrdinalScale(padding_y=0)
dna_color_scale = ColorScale(
    colors=["white", "#ffa600", "#ff6e00", "#ff4417", "#d31522", "#8c000e", "#560410"]
)
dna_x_ax = Axis(scale=dna_x_scale, tick_format="%m/%d")
dna_y_ax = Axis(scale=dna_y_scale, orientation="vertical", side="right")
dna_color_ax = ColorAxis(scale=dna_color_scale, tick_format=",")
dna_axes = [dna_x_ax, dna_y_ax, dna_color_ax]

dna_layout = Layout(width="100%", height="99%")
dna_fig_margin = {"top": 60, "bottom": 60, "left": 0, "right": 80}

dna_figure = Figure(
    axes=dna_axes,
    fig_margin=dna_fig_margin,
    layout=dna_layout,
    min_aspect_ratio=0.0,
    title="Cases",
    padding_y=0,
    animation_duration=1000,
    null_color="#808080",
)


In [None]:
def update_dna(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "None" in dna_countries_selector.value:
            countries = []
        else:
            countries = list(dna_countries_selector.value)

        if "None" in dna_states_selector.value:
            states = []
        else:
            states = list(dna_states_selector.value)

        ma, n = tab_3_ma_ch.value, tab_3_ma_w.value
        if ma:
            tab_3_ma_w.layout.visibility = "visible"
        else:
            tab_3_ma_w.layout.visibility = "hidden"

        if len(countries) + len(states) == 0:
            countries = ["United States", "World"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(states) == 0 and len(countries) >= 1:
            if len(countries) == 1:
                if "United States" in countries:
                    countries = ["World", "United States"]
                else:
                    countries = countries + ["United States"]
            df_DNA = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
        elif len(countries) == 0 and len(states) >= 1:
            if len(states) == 1:
                if "New York" in states:
                    states = ["New Jersey", "New York"]
                else:
                    states = states + ["New York"]
            df_DNA = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
        else:
            data_countries = (
                dict_datasets[0]
                .get_ts(
                    [countries_to_codes[c] for c in countries],
                    dna_data_button.value,
                    dna_norm_button.value,
                    "Linear",
                    dna_type_button.value,
                    ma,
                    n,
                )
                .rename(columns=codes_to_countries)
            )
            data_states = dict_datasets[1].get_ts(
                states,
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
                ma,
                n,
            )
            df_DNA = pd.merge(
                data_countries,
                data_states,
                left_index=True,
                right_index=True,
                how="outer",
            )
        df_DNA = df_DNA.iloc[:, ::-1]
        title = dna_data_button.value
        if dna_norm_button.value == "Per million":
            title += "/1M population"
        if dna_type_button.value != "Total":
            title += " " + dna_type_button.value
        if ma:
            title += " ({0:d}days m.a)".format(n)
        dna_figure.marks[0].color = df_DNA.values.T
        dna_figure.marks[0].row = list(df_DNA.columns.values)
        if dna_data_button.value == "Recovered" or dna_data_button.value == "Tests":
            dna_color_scale.colors = [
                "white",
                "#C5E8B7",
                "#ABE098",
                "#83D475",
                "#57C84D",
                "#2EB62C",
                "#207f1e",
            ]
        elif dna_data_button.value == "Active Cases":
            dna_color_scale.colors = [
                "#069740",
                "#9acd32",
                "#e4d31b",
                "#ffa600",
                "#d31522",
                "#8c000e",
                "#560410",
            ]
        else:
            dna_color_scale.colors = [
                "white",
                "#ffa600",
                "#ff6e00",
                "#ff4417",
                "#d31522",
                "#8c000e",
                "#560410",
            ]
        dna_figure.title = title
        update_link_download()
    return


dna_data_button.add_observe(update_dna, "value")
dna_norm_button.add_observe(update_dna, "value")
dna_type_button.add_observe(update_dna, "value")
tab_3_ma_ch.observe(update_dna, "value")
tab_3_ma_w.observe(update_dna, "value")


In [None]:
dna_tooltip_table = """
<table id="stats_table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr>
    <td>Country/State</td>
    <td>{1:s}</td>
</tr>
<tr>
    <td>{2:s}</td>
    <td>{3:,}</td>
</tr>
</table>
"""
dna_stats_tootltip = HTML()


def update_dna_tooltip(*args):
    Date = str(dna_figure.marks[0].column[int(args[1]["data"]["column_num"])])[:-19]
    Country = args[1]["data"]["row"]
    val = args[1]["data"]["color"]
    try:
        if "pop" in dna_figure.title or "%" in dna_figure.title:
            val = round(val, 1)
        else:
            val = int(val)
    except:
        val = 0
    dna_stats_tootltip.value = dna_tooltip_table.format(
        *tuple([Date, Country, dna_figure.title, val])
    )
    return


In [None]:
grid_3 = GridspecLayout(2, 2)
grid_3.layout.overflow = "hidden"
grid_3.layout.height = "100%"
grid_3.layout.width = "100%"
grid_3[:, 0] = dna_figure
grid_3[0, 1] = HBox([dna_top_buttons])
grid_3[1, 1] = HBox([tab_3_box_selectors])
grid_3.layout.align_items = "stretch"
grid_3.layout.grid_template_columns = "auto 566px"
grid_3.layout.grid_template_rows = "140px auto"


In [None]:
# Free graph

# Buttons and interactions


min_button_width_1 = "100px"
x_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width="100px",
    min_description_width=min_description_width_1,
)

y_data_button = Toggle_Buttons(
    options=["Cases", "Deaths", "Recovered", "Active Cases", "Tests"],
    value="Cases",
    description="Data",
    min_button_width=min_button_width_1,
    min_description_width=min_description_width_1,
    horizontal=False,
    button_width="auto",
)


x_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_norm,
    min_description_width=min_description_width_1,
)

y_norm_button = Toggle_Buttons(
    options=["Values", "Per million"],
    value="Values",
    description="Norm",
    min_button_width=min_button_width_1,
    min_description_width=min_description_width_1,
    horizontal=False,
    button_width="auto",
)


x_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_scale,
    min_description_width=min_description_width_1,
)

y_scale_button = Toggle_Buttons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    min_button_width=min_button_width_1,
    min_description_width=min_description_width_1,
    horizontal=False,
    button_width="auto",
)


x_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    min_button_width="120px",
    min_description_width=min_description_width_1,
)

y_type_button = Toggle_Buttons(
    options=["Total", "Daily change", "Daily % change"],
    value="Daily change",
    description="Type",
    min_button_width=min_button_width_1,
    min_description_width=min_description_width_1,
    horizontal=False,
    button_width="auto",
)

free_checkbox = Checkbox(
    value=False,
    description="Last data point",
    disabled=False,
    style={"description_width": "initial"},
    layout=Layout(width="auto", height="auto"),
)

tab_4_ma_ch = Checkbox(
    description="Moving Average",
    value=True,
    style={"description_width": "initial"},
    layout=Layout(min_width="200px", width="200px", overflow="auto"),
)
tab_4_ma_w = BoundedIntText(
    value=7,
    min=1,
    max=14,
    description="Window size (days)",
    style={"description_width": "initial"},
    layout=Layout(
        min_width="195px", width="195px", overflow="auto", visibility="visible"
    ),
)
ma_box_tab_4 = VBox(
    [tab_4_ma_ch, tab_4_ma_w],
    layout=Layout(min_width="205px", width="205px", overflow="auto"),
)

y_data = VBox(
    [
        VBox(
            [
                HTML(
                    "<p style='text-align:center; font-size:20px; font-weight:bold; padding: 0; margin:0; color: #ff8b0e'>Y DATA :</p>"
                ),
                y_data_button,
                y_norm_button,
                y_scale_button,
                y_type_button,
            ],
            layout=Layout(
                width="130px",
                min_width="130px",
                max_width="130px",
                height="580px",
                min_height="580px",
                max_height="580px",
                overflow="auto",
                align_items="stretch",
                border="solid #ff8b0e",
            ),
        )
    ],
    layout=Layout(width="auto", height="auto"),
)
x_data = HBox(
    [
        HBox(
            [
                HTML(
                    "<p style='text-align:left; font-size:20px; font-weight:bold; padding: 0; margin: 0; color: #ff8b0e; margin-right: 15px'>X DATA : </p>"
                ),
                x_data_button,
                x_norm_button,
                x_scale_button,
                x_type_button,
            ],
            layout=Layout(
                width="1500px",
                min_width="1500px",
                max_width="1500px",
                height="40px",
                min_height="40px",
                max_height="40px",
                overflow="auto",
                align_items="stretch",
                border="solid #ff8b0e",
            ),
        )
    ],
    layout=Layout(width="auto", height="auto"),
)


In [None]:
free_countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)

free_states_selector = SelectMultiple(
    options=["None"] + states,
    value=["None"],
    layout=Layout(width="auto", height="100%", min_height="50px"),
)


In [None]:
ps_tab4 = preselection(presaved_sel_options, "Custom", on_click_action)
hb_tab4 = VBox([ps_tab4[v] for v in presaved_sel_options])
accordion_tab4 = Accordion(
    children=[hb_tab4], selected_index=None, layout=Layout(width="auto", height="auto")
)
accordion_tab4.set_title(0, "Predefined selections")
accordion_tab4.observe(accordion_adapt_height, "selected_index")


In [None]:
countries_selector_box_free = VBox(
    [HTML('<p style="text-align:center">Countries</p>'), free_countries_selector],
    layout=Layout(width="100%", height="100%"),
)
states_selector_box_free = VBox(
    [HTML('<p style="text-align:center">US States</p>'), free_states_selector],
    layout=Layout(width="100%", height="100%"),
)
tab_4_box_selectors = GridspecLayout(
    3,
    1,
    layout=Layout(
        width="205px",
        min_width="205px",
        max_width="205px",
        height="100%",
        min_height="100px",
        overflow="auto",
    ),
)
tab_4_box_selectors[0, 0] = accordion_tab4
tab_4_box_selectors[1, 0] = countries_selector_box_free
tab_4_box_selectors[2, 0] = states_selector_box_free
tab_4_box_selectors.layout.grid_template_rows = "60px auto"


In [None]:
free_scatter_sc_x = LinearScale()
free_scatter_sc_y = LinearScale()
free_scatter_ax_x = Axis(scale=free_scatter_sc_x, tick_format=",")
free_scatter_ax_y = Axis(
    scale=free_scatter_sc_y, orientation="vertical", tick_format=","
)
free_pz = PanZoom(scales={"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]})
free_scatter_fig = Figure(
    axes=[free_scatter_ax_x, free_scatter_ax_y],
    interaction=free_pz,
    layout=Layout(width="100%", height="100%"),
    fig_margin={"top": 50, "left": 60, "right": 100, "bottom": 40},
)


In [None]:
def last_index_nan(x_li, y_li):
    N = len(x_li)
    i = N - 1
    while i >= 0 and (np.isnan(x_li[i]) or np.isnan(y_li[i])):
        i -= 1
    return i


def clean_line_plot(xs, ys):
    for i in range(len(xs)):
        for j in range(len(xs[i])):
            if np.isnan(xs[i][j]) or np.isnan(ys[i][j]):
                xs[i][j] = np.nan
                ys[i][j] = np.nan
    return xs, ys


def free_scatter(
    selected_countries, datas, norms, scales, data_types, ma, n, last_point=True
):
    global free_mark_x, free_mark_y
    Ys = []
    Xs = []
    label_names = []
    colors = []
    all_colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]
    col_index = None
    if len(free_scatter_fig.marks) > 0:
        update = True
        new_labels = [codes_to_countries[c] for c in selected_countries[0]] + list(
            selected_countries[1]
        )
        color_dict = persistent_colors(
            list(free_scatter_fig.marks[0].colors),
            list(free_scatter_fig.marks[1].text),
            new_labels,
            all_colors,
        )
    else:
        update = False
        colors = all_colors.copy()

    if scales[0] == "Log":
        free_scatter_sc_x = LogScale()
    elif scales[0] == "Linear":
        free_scatter_sc_x = LinearScale()

    if scales[1] == "Log":
        free_scatter_sc_y = LogScale()
    elif scales[1] == "Linear":
        free_scatter_sc_y = LinearScale()

    for i in range(len(selected_countries)):
        for c in selected_countries[i]:
            if last_point:
                x = dict_datasets[i].get_value(
                    c,
                    datas[0],
                    norms[0],
                    scales[0],
                    data_types[0],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                y = dict_datasets[i].get_value(
                    c,
                    datas[1],
                    norms[1],
                    scales[1],
                    data_types[1],
                    ma,
                    n,
                    dict_datasets[i].ENDT,
                )
                if not (np.isnan(x) or np.isnan(y)):
                    Xs.append(x)
                    Ys.append(y)
                    if dict_datasets[i].name == "World":
                        label_names.append(codes_to_countries[c])
                    else:
                        label_names.append(c)
            else:
                Xs.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[0], norms[0], scales[0], data_types[0], ma, n)
                        .values
                    )
                )
                Ys.append(
                    list(
                        dict_datasets[i]
                        .get_ts(c, datas[1], norms[1], scales[1], data_types[1], ma, n)
                        .values
                    )
                )
                if dict_datasets[i].name == "World":
                    label_names.append(codes_to_countries[c])
                else:
                    label_names.append(c)
            if update:
                if dict_datasets[i].name == "World":
                    colors.append(color_dict[codes_to_countries[c]])
                else:
                    colors.append(color_dict[c])
    free_scatter_ax_x.scale = free_scatter_sc_x
    free_scatter_ax_y.scale = free_scatter_sc_y
    free_scatter_fig.axes = [free_scatter_ax_x, free_scatter_ax_y]

    free_labels = Label(
        apply_clip=False,
        scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
        default_size=15,
        x_offset=4,
        colors=colors,
    )
    free_labels.text = label_names
    if last_point:
        free_labels.x = np.array(Xs)
        free_labels.y = np.array(Ys)

        free_mark = Scatter(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            colors=colors,
            marker_size=30,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)

    else:
        Xs, Ys = clean_line_plot(Xs, Ys)
        free_labels.x = np.array(
            [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )
        free_labels.y = np.array(
            [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
        )

        free_mark = Lines(
            marker="circle",
            stroke="black",
            scales={"x": free_scatter_sc_x, "y": free_scatter_sc_y},
            marker_size=30,
            colors=colors,
        )
        free_mark.x = np.array(Xs)
        free_mark.y = np.array(Ys)
        free_mark_x = Xs
        free_mark_y = Ys

    free_scatter_fig.marks = [free_mark, free_labels]

    def title_free_scatter(data, norm, scale, data_type, ma, n):
        res = ""
        if scale == "Log":
            res += scale
        res += " " + data
        if norm == "Per million":
            res += "/1M population"
        if data_type != "Total":
            res += " " + data_type.lower()
        if ma:
            res += " ({0:d}days m.a)".format(n)
        return res

    free_scatter_fig.title = (
        title_free_scatter(datas[1], norms[1], scales[1], data_types[1], False, None)
        + " vs "
        + title_free_scatter(datas[0], norms[0], scales[0], data_types[0], ma, n)
        + ", "
        + ENDT
    )
    free_pz.scales = {"x": [free_scatter_sc_x], "y": [free_scatter_sc_y]}
    return


In [None]:
def update_free_scatter_fig(*args):
    if (
        type(args[0]["owner"]).__name__ == "ToggleButton" and args[0]["new"] == True
    ) or (type(args[0]["owner"]).__name__ != "ToggleButton"):
        if "None" in free_countries_selector.value:
            countries = []
        else:
            countries = [countries_to_codes[c] for c in free_countries_selector.value]

        if "None" in free_states_selector.value:
            states = []
        else:
            states = free_states_selector.value

        if len(countries) + len(states) == 0:
            countries = ["USA"]

        ma, n = tab_4_ma_ch.value, tab_4_ma_w.value
        if ma:
            tab_4_ma_w.layout.visibility = "visible"
        else:
            tab_4_ma_w.layout.visibility = "hidden"

        selected_countries = [countries, states]
        datas = [x_data_button.value, y_data_button.value]
        norms = [x_norm_button.value, y_norm_button.value]
        scales = [x_scale_button.value, y_scale_button.value]
        data_types = [x_type_button.value, y_type_button.value]
        last_point = free_checkbox.value
        free_scatter(
            selected_countries, datas, norms, scales, data_types, ma, n, last_point
        )
        if last_point:
            free_play.layout.visibility = "hidden"
        else:
            free_play.layout.visibility = "visible"
            
        update_link_download()
    return


x_data_button.add_observe(update_free_scatter_fig, "value")
y_data_button.add_observe(update_free_scatter_fig, "value")
x_norm_button.add_observe(update_free_scatter_fig, "value")
y_norm_button.add_observe(update_free_scatter_fig, "value")
x_scale_button.add_observe(update_free_scatter_fig, "value")
y_scale_button.add_observe(update_free_scatter_fig, "value")
x_type_button.add_observe(update_free_scatter_fig, "value")
y_type_button.add_observe(update_free_scatter_fig, "value")
free_checkbox.observe(update_free_scatter_fig, "value")
tab_4_ma_ch.observe(update_free_scatter_fig, "value")
tab_4_ma_w.observe(update_free_scatter_fig, "value")


In [None]:
free_play = Play(
    value=0,
    min=0,
    max=1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto"),
)


def update_free_play(*args):
    global free_mark_x, free_mark_y
    index = args[0]["new"]
    date = date_selector.options[args[0]["new"]]
    Xs = np.array(free_mark_x)[:, : index + 1]
    Ys = np.array(free_mark_y)[:, : index + 1]
    Xs, Ys = clean_line_plot(list(Xs), list(Ys))
    free_scatter_fig.marks[1].x = np.array(
        [x_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[1].y = np.array(
        [y_li[last_index_nan(x_li, y_li)] for x_li, y_li in zip(Xs, Ys)]
    )
    free_scatter_fig.marks[0].x = np.array(Xs)
    free_scatter_fig.marks[0].y = np.array(Ys)
    free_scatter_fig.title = free_scatter_fig.title[:-10] + str(date)
    return


tab_4_top_right_buttons = VBox(
    [
        VBox(
            [free_checkbox, free_play, ma_box_tab_4],
            layout=Layout(
                width="205px",
                min_width="205px",
                max_width="205px",
                height="130px",
                min_height="130px",
                max_height="130px",
                overflow="auto",
                align_items="stretch",
            ),
        )
    ],
    layout=Layout(height="auto", width="auto"),
)


In [None]:
grid_4 = GridspecLayout(3, 3)
grid_4.layout.overflow = "hidden"
grid_4.layout.width = "100%"
grid_4.layout.height = "100%"
grid_4[0:2, 0] = y_data
grid_4[2, :] = x_data
grid_4[0, 2] = tab_4_top_right_buttons
grid_4[1, 2] = HBox([tab_4_box_selectors])
grid_4[0:2, 1] = free_scatter_fig
grid_4.layout.align_items = "stretch"
grid_4.layout.grid_template_columns = "135px auto 210px"
grid_4.layout.grid_template_rows = "135px auto 60px"


In [None]:
def initialize_dashboard(dict_datasets):
    global top_countries_names, states, presaved_sel_lists

    # Initialize grid_1

    date_selector.options = dict_datasets[0].get_index("Cases")
    date_selector.value = dict_datasets[0].ENDT

    date_axis_selector.options = list(
        np.arange(
            np.datetime64(
                datetime.datetime.strptime(dict_datasets[0].STDT, dateformat).strftime(
                    dateformat
                )
            ),
            np.datetime64(
                datetime.datetime.strptime(dict_datasets[0].ENDT, dateformat).strftime(
                    dateformat
                )
            )
            + np.timedelta64(1, "D"),
        )
    )

    date_axis_selector.index = (0, dict_datasets[0].get_len() - 1)

    play.max = dict_datasets[0].get_len() - 1
    play.value = dict_datasets[0].get_len() - 1

    wm.color = map_color(
        dict_datasets[0],
        data_buttons.value,
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    main_map.title = map_title(
        "World",
        "Cases",
        "Values",
        "Total",
        dict_datasets[0].ENDT,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
    )

    update_tables(
        data_buttons.value,
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].ENDT,
    )

    update_main_graph(
        dict_datasets[0],
        "USA",
        "Values",
        "Linear",
        "Total",
        tab_1_ma_ch.value,
        tab_1_ma_w.value,
        dict_datasets[0].STDT,
        dict_datasets[0].ENDT,
        axis_t,
        axis_y,
        main_graph,
        init=True,
    )

    # initialize grid_2

    top_countries_codes = (
        dict_datasets[0]
        .dict_data["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:]
        .index.values
    )
    top_countries_names = [codes_to_countries[c] for c in top_countries_codes]
    states = list(
        dict_datasets[1]
        .dict_data["Cases"]
        .loc[dict_datasets[1].ENDT, np.array(list(states_to_codes.keys()))]
        .sort_values(ascending=False)
        .index.values
    )
    presaved_sel_lists["Top10 (Cases)"] = top_countries_codes[:10]
    presaved_sel_lists["Top10 (Deaths)"] = (
        dict_datasets[0]
        .dict_data["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
    )
    presaved_sel_lists["Top10 (Cases/1M)"] = (
        dict_datasets[0]
        .dict_data_pop["Cases"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
    )
    presaved_sel_lists["Top10 (Deaths/1M)"] = (
        dict_datasets[0]
        .dict_data_pop["Deaths"]
        .loc[ENDT]
        .sort_values(ascending=False)[1:11]
        .index.values
    )
    presaved_sel_lists["Top10 (Cases daily chg)"] = (
        dict_datasets[0].get_value(None, 'Cases', 'Values', 'Linear', 'Daily change', False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
    )
    presaved_sel_lists["Top10 (Deaths daily chg)"] = (
        dict_datasets[0].get_value(None, 'Deaths', 'Values', 'Linear', 'Daily change', False, 7, ENDT)
        .sort_values(ascending=False)[1:11]
        .index.values
    )
    countries_options = list(np.sort(top_countries_names))
    states_options = list(np.sort(states))
    countries_selector.options = ["None", "World"] + countries_options
    states_selector.options = ["None"] + states_options
    countries_selector.value = top_countries_names[:6]
    states_selector.value = states[:2]
    countries_selector.observe(update_rebased_graph, "value")
    states_selector.observe(update_rebased_graph, "value")
    tab_2_box_selectors.layout.grid_template_columns = "15% 27% 16% 27% 15%"
    plot_rebased_graph(
        "Cases",
        "Values",
        "Total",
        1000,
        top_countries_names[:6],
        states[:2],
        "Log",
        "Cases",
        "Values",
        tab_2_ma_ch.value,
        tab_2_ma_w.value,
        calendar_time.value,
    )

    # initialize grid_3

    dna_countries_selector.options = ["None", "World"] + countries_options
    dna_states_selector.options = ["None"] + states_options
    dna_countries_selector.value = top_countries_names[1:20]
    dna_states_selector.value = states[:4]
    dna_countries_selector.observe(update_dna, "value")
    dna_states_selector.observe(update_dna, "value")

    data_countries = (
        dict_datasets[0]
        .get_ts(
            [countries_to_codes[c] for c in dna_countries_selector.value],
            dna_data_button.value,
            dna_norm_button.value,
            "Linear",
            dna_type_button.value,
            tab_3_ma_ch.value,
            tab_3_ma_ch.value,
        )
        .rename(columns=codes_to_countries)
    )
    data_states = dict_datasets[1].get_ts(
        dna_states_selector.value,
        dna_data_button.value,
        dna_norm_button.value,
        "Linear",
        dna_type_button.value,
        tab_3_ma_ch.value,
        tab_3_ma_ch.value,
    )
    df_DNA = pd.merge(
        data_countries, data_states, left_index=True, right_index=True, how="outer"
    ).iloc[:, ::-1]

    column_dna = pd.to_datetime(df_DNA.index.values)
    row_dna = list(df_DNA.columns.values)
    dna_heat_map = GridHeatMap(
        color=df_DNA.values.T,
        column=column_dna,
        row=row_dna,
        column_align="start",
        scales={"column": dna_x_scale, "row": dna_y_scale, "color": dna_color_scale},
        stroke="white",
    )
    dna_figure.marks = [dna_heat_map]
    dna_heat_map.tooltip = dna_stats_tootltip
    dna_heat_map.on_hover(update_dna_tooltip)
    tab_3_box_selectors.layout.grid_template_columns = "15% 27% 16% 27% 15%"

    # initialize grid_4
    free_countries_selector.options = ["None", "World"] + countries_options
    free_states_selector.options = ["None"] + states_options
    free_countries_selector.value = [
        "United States",
        "France",
        "World",
        "Italy",
        "China",
        "Germany",
        "Russia",
        "Brazil",
    ]
    free_states_selector.value = ["New York", "California"]
    free_countries_selector.observe(update_free_scatter_fig, "value")
    free_states_selector.observe(update_free_scatter_fig, "value")

    free_play.max = dict_datasets[0].get_len() - 1
    free_play.value = dict_datasets[0].get_len() - 1
    selected_countries = [
        ["USA", "FRA", "WLD", "ITA", "CHN", "DEU", "RUS", "BRA"],
        ["New York", "California"],
    ]
    datas = ["Cases", "Cases"]
    norms = ["Values", "Values"]
    scales = ["Log", "Log"]
    data_types = ["Total", "Daily change"]
    last_point = False
    free_scatter(
        selected_countries,
        datas,
        norms,
        scales,
        data_types,
        tab_4_ma_ch.value,
        tab_4_ma_w.value,
        last_point,
    )
    date_axis_selector.observe(update_zoom, "value")
    date_axis_selector.continuous_update = False
    play.observe(update_index_date, "value")
    date_selector.observe(select_date, "value")
    free_play.observe(update_free_play, "value")
    download_button.layout.visibility = "visible"
    hide_button.layout.visibility = "visible"
    update_link_download()
    return


In [None]:
def accept_ToS(*args):
    global STDT, ENDT, countries_to_codes, codes_to_countries, states_to_codes, dict_datasets, progress_bar, current_action
    progress_bar = FloatProgress(
        value=0.1,
        min=0,
        max=1.0,
        step=0.01,
        description="Loading data:",
        orientation="horizontal",
    )
    current_action = HTML("Collecting country level Cases/Deaths/Recovered data...")
    outer_tab.children = [
        VBox(
            [HTML("ONE MOMENT PLEASE..."), progress_bar, current_action],
            layout=Layout(width="auto"),
        )
    ]
    (
        datasets_World,
        STDT,
        ENDT,
        countries_to_codes,
        codes_to_countries,
    ) = collect_World_data(BASE_URL, FOLDER_WORLD, URL_TESTS)
    datasets_US_States, states_to_codes, datasets_US_counties = collect_US_data(
        BASE_URL_US, FOLDER_US, URL_TESTS_US
    )
    dict_datasets = {0: datasets_World, 1: datasets_US_States, 2: datasets_US_counties}
    initialize_dashboard(dict_datasets)
    progress_bar = 0.98
    current_action.value = "Create Dashboard..."
    tab_contents = [
        "Infection Maps",
        "Rebased Graph",
        "Heat Map",
        "Custom Graph",
    ]
    children = [
        grid_1,
        grid_2,
        grid_3,
        grid_4,
    ]
    outer_tab._titles = dict(zip(np.arange(0, 4), tab_contents))
    outer_tab.children = children
    outer_tab.selected_index = 0


agree.on_click(accept_ToS)


In [None]:
visible_template = {
    1: ["95px auto", "60% auto"],
    2: ["315px auto", "auto 566px"],
    3: ["140px auto", "auto 566px"],
    4: ["135px auto 60px", "135px auto 210px"],
}

hidden_template = {
    1: ["0px auto", "60% auto"],
    2: ["0px auto", "auto 0px"],
    3: ["0px auto", "auto 0px"],
    4: ["0px auto 0px", "0px auto 0px"],
}
grids = dict(zip(np.arange(1, 5), [grid_1, grid_2, grid_3, grid_4]))


def hide_buttons(*args):
    new_value = args[0]["new"]
    if new_value:
        hide_button.button_style = "danger"
        hide_button.description = "Show Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = hidden_template[i][0]
            grids[i].layout.grid_template_columns = hidden_template[i][1]
    else:
        hide_button.button_style = ""
        hide_button.description = "Hide Controls"
        for i in range(1, 5):
            grids[i].layout.grid_template_rows = visible_template[i][0]
            grids[i].layout.grid_template_columns = visible_template[i][1]
    return


In [None]:
hide_button = ToggleButton(
    description="Hide Controls",
    value=False,
    button_style="",
    layout=Layout(
        min_width="150px", width="150px", max_width="150px", min_height="30px",
        visibility='hidden',
    ),
)
hide_button.observe(hide_buttons, "value")


In [None]:
download_button = Button(
    description="Download Figure",
    button_style="",
    layout=Layout(
        min_width="150px",
        width="150px",
        max_width="150px",
        min_height="30px",
        visibility="hidden",
    ),
)


In [None]:
def download_fig(*args):
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        main_graph.save_png(main_graph.title.replace('.', ''))
    elif selected_index == 1:
        rebased_graph.save_png("Rebased_graph_" + rebased_graph.title.replace('.', ''))
    elif selected_index == 2:
        dna_figure.save_png("Heat_map_" + dna_figure.title.replace('.', ''))
    elif selected_index == 3:
        free_scatter_fig.save_png(free_scatter_fig.title.replace('.', ''))
    return


download_button.on_click(download_fig)


In [None]:
def download_data():
    selected_index = outer_tab.selected_index
    if selected_index == 0:
        data_csv = pd.DataFrame(
            data=main_graph.marks[0].y,
            index=main_graph.marks[0].labels,
            columns=main_graph.marks[0].x,
        )
        title = main_graph.title.replace(".", "").replace("/", "Per").replace("Log ", '') + "_data.csv"
    elif selected_index == 1:
        if rebased_graph.marks[1].text.shape[0] == 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[0])
            data = rebased_graph.marks[0].y.reshape(1, -1)
        elif rebased_graph.marks[1].text.shape[0] > 1:
            cols = np.arange(0, rebased_graph.marks[0].y.shape[1])
            data = rebased_graph.marks[0].y
        if calendar_time.value:
            data_csv = pd.DataFrame(
                data=data,
                index=rebased_graph.marks[1].text,
                columns=rebased_graph.marks[0].x,
            )
        else:
            data_csv = pd.DataFrame(
                data=data, index=rebased_graph.marks[1].text, columns=cols
            )
        title = (
            "Rebased_graph_"
            + rebased_graph.title.replace(".", "").replace("/", "Per").replace("Log ", '')
            + "_data.csv"
        )
    elif selected_index == 2:
        data_csv = pd.DataFrame(
            data=dna_figure.marks[0].color,
            index=dna_figure.marks[0].row,
            columns=dna_figure.marks[0].column,
        ).iloc[::-1]
        title = (
            "Heat_map_"
            + dna_figure.title.replace(".", "").replace("/", "Per")
            + "_data.csv"
        )
    elif selected_index == 3:
        arr = []
        ind0 = []
        ind1 = []
        Y_label = free_scatter_fig.title.split("vs")[0].strip().replace("Log ", '')
        X_label = free_scatter_fig.title.split("vs")[1].split(",")[0].strip().replace("Log ", '')
        if tab_4_ma_ch.value:
            Y_label += " (" + X_label.split("(")[1]
        for i in range(free_scatter_fig.marks[1].text.shape[0]):
            arr.append(free_scatter_fig.marks[0].x[i])
            arr.append(free_scatter_fig.marks[0].y[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind0.append(free_scatter_fig.marks[1].text[i])
            ind1.append(X_label)
            ind1.append(Y_label)
        ind = [np.array(ind0), np.array(ind1)]
        if free_checkbox.value:
            col = [ENDT]
        else:
            col = np.arange(
                np.datetime64(STDT), np.datetime64(ENDT) + np.timedelta64(1, "D")
            )
        data_csv = pd.DataFrame(data=arr, index=ind, columns=col)
        title = (
            free_scatter_fig.title.replace(".", "").replace("/", "Per").replace("Log ", '') + "_data.csv"
        )
    return data_csv, title


download_link = '<a style="background-color: #424242; border: none; color: white; padding: 1px 22px; visibility: {visibility}; text-align: center; text-decoration: none; display: inline-block; font-size: 13px; margin: 0px; cursor: pointer" download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">Download Data</a>'
download_data_button = HTML(
    download_link.format(visibility="hidden", payload="", filename="")
)


def update_link_download():
    df, filename = download_data()
    csv = df.to_csv()
    b64 = base64.b64encode(csv.encode())
    payload = b64.decode()
    download_data_button.value = download_link.format(
        visibility="visible", payload=payload, filename=filename
    )
    return


outer_tab.observe(lambda x: update_link_download(), "selected_index")


In [None]:
zoom_slider = IntSlider(
    value=780,
    min=600,
    max=1500,
    step=10,
    description="Dashboard's height (in px)",
    style={"description_width": "initial"},
    layout=Layout(width="auto", min_width="300px", max_width="500px"),
)


def update_height(*args):
    outer_tab.layout.height = str(args[0]["new"]) + "px"
    return


zoom_slider.continuous_update = True
zoom_slider.observe(update_height, "value")


In [None]:
grid_header = GridspecLayout(2, 5)
grid_header.layout.width = "100%"
grid_header.layout.height = "auto"

grid_header[1, 0] = hide_button
grid_header[1, 1] = download_button
grid_header[1, 2] = download_data_button
grid_header[1, 4] = zoom_slider
grid_header[0, :] = HTML(
    value="<p style='color: #ff8b0e; padding:0; margin:0; text-align:center; font-weight: bold; font-size: 42px'>COVIZ</p>"
)
grid_header.layout.align_items = "stretch"
grid_header.layout.grid_template_columns = "155px 155px 155px auto 450px"
grid_header.layout.overflow = "auto"


In [None]:
logo = """
<p style='color: #ff8b0e; font-weight: bold; font-size: 16px; padding:0; margin:0'> Bloomberg Quant Research </p>
"""
data_source = """<html><body><p style='color: #ff8b0e; padding:0; margin:0'> Data Source : - JHU <a href='https://github.com/CSSEGISandData/COVID-19' target="_blank" style='color:#1E90FF'> https://github.com/CSSEGISandData/COVID-19 </a> 
             - COVID Tracking Project <a href='https://github.com/COVID19Tracking/covid-tracking-data' target="_blank" style='color:#1E90FF'> https://github.com/COVID19Tracking/covid-tracking-data </a>
             - Our World in Data <a href='https://github.com/owid/covid-19-data' target="_blank" style='color:#1E90FF'> https://github.com/owid/covid-19-data </a></p></body></html>"""
grid_footer = GridspecLayout(1, 2, layout=Layout(width="100%", align_items="center"))
grid_footer[0, 0] = HBox(
    [
        HTML(
            data_source,
            layout=Layout(
                width="100%",
                height="30px",
                max_height="30px",
                min_height="30px",
                overflow="auto",
            ),
        )
    ],
    layout=Layout(width="auto", height="auto", overflow="auto"),
)
grid_footer[0, 1] = HTML(logo, layout=Layout(width="229px"))
grid_footer.layout.grid_template_columns = "auto 235px"
Dashboard = VBox(
    [grid_header, outer_tab, grid_footer,], layout=Layout(align_items="stretch")
)


In [None]:
Dashboard