In [None]:
import numpy as np
import pandas as pd
import time
import math
from ipywidgets import *
from bqplot import *
import datetime
from data_processor import DataProcessor
import bisect

In [None]:
# Collect World Data

BASE_URL = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/"

url_cases = BASE_URL + "time_series_covid19_confirmed_global.csv"
url_deaths = BASE_URL + "time_series_covid19_deaths_global.csv"
url_rec = BASE_URL + "time_series_covid19_recovered_global.csv"
# you can use the following cached data
FOLDER_WORLD = "DATA/WORLD/"
# url_cases = FOLDER_WORLD + 'time_series_covid19_confirmed_global.csv'
# url_deaths = FOLDER_WORLD + 'time_series_covid19_deaths_global.csv'
# url_rec = FOLDER_WORLD + 'time_series_covid19_recovered_global.csv'
dataframes = {
    "Cases": pd.read_csv(url_cases, header=0, index_col=1),
    "Deaths": pd.read_csv(url_deaths, header=0, index_col=1),
    "Recovered": pd.read_csv(url_rec, header=0, index_col=1),
}

pops = pd.read_csv(FOLDER_WORLD + "Population.csv", header=0).dropna()
pops.set_index("Country Code", inplace=True)

country_codes = pd.read_csv(FOLDER_WORLD + "world_map_codes.csv", header=0, index_col=0)


In [None]:
countries_to_codes = country_codes['ISOA3'].to_dict()
ID_to_codes = dict(zip(country_codes.ISON3, country_codes.ISOA3))

In [None]:
# normalize country names

new_names = {
    "Antigua and Barbuda": "Antigua & Barbuda",
    "Bosnia and Herzegovina": "Bosnia",
    "Cabo Verde": "Cape Verde",
    "Congo (Brazzaville)": "Congo - Brazzaville",
    "Congo (Kinshasa)": "Congo - Kinshasa",
    "Eswatini": "Swaziland",
    "Holy See": "Vatican City",
    "Korea, South": "South Korea",
    "North Macedonia": "Macedonia",
    "Saint Lucia": "St. Lucia",
    "Saint Vincent and the Grenadines": "St. Vincent & Grenadines",
    "Trinidad and Tobago": "Trinidad & Tobago",
    "US": "United States",
    "Saint Kitts and Nevis": "St. Kitts & Nevis",
    "Burma": "Myanmar",
    "Taiwan*": "Taiwan",
}
for df in dataframes.values():
    df.rename(index=new_names, inplace=True)

In [None]:
# assign dummy codes to a few entities

countries_to_codes["Diamond Princess"] = "DPS"
countries_to_codes["Taiwan"] = "TWN"
countries_to_codes["West Bank and Gaza"] = "PSE"
countries_to_codes["Kosovo"] = "XKX"
countries_to_codes["MS Zaandam"] = "MSZ"
ID_to_codes[158] = "TWN"
ID_to_codes[-2] = "XKX"
codes_to_countries = {v: k for k, v in countries_to_codes.items()}
codes_to_ID = {v: k for k, v in ID_to_codes.items()}

In [None]:
for df in dataframes.values():
    df["code"] = [
        countries_to_codes[cod] if cod in countries_to_codes.keys() else None
        for cod in df.index.values
    ]

In [None]:
old_dateformat = "%m/%d/%y"
dateformat = "%Y-%m-%d"
new_index = [
    datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
    for d in dataframes["Cases"].columns.values[3:-1]
]
# set index and dateformat
for k, df in dataframes.items():
    df = df.iloc[:, 3:].groupby(["code"]).sum().transpose().reset_index()
    df["Date"] = new_index
    dataframes[k] = df.set_index("Date", drop=True).drop("index", axis=1)

In [None]:
STDT = dataframes['Cases'].index[0]
ENDT = dataframes['Cases'].index[-1]

In [None]:
dataframes["Active Cases"] = (
    dataframes["Cases"]
    - dataframes["Deaths"].fillna(0)
    - dataframes["Recovered"].fillna(0)
)
datasets_World = DataProcessor(
    dataframes, pops, "World", codes_to_ID
)  # World data processor

In [None]:
# Collect US Data

url_cases_US = BASE_URL + "time_series_covid19_confirmed_US.csv"
url_deaths_US = BASE_URL + "time_series_covid19_deaths_US.csv"
# you can use the following cached data
FOLDER_US = "DATA/USA/"
# url_cases_US = FOLDER_US + 'time_series_covid19_confirmed_US.csv'
# url_deaths_US = FOLDER_US + 'time_series_covid19_deaths_US.csv'
df_cases_US = pd.read_csv(url_cases_US, header=0)
df_deaths_US = pd.read_csv(url_deaths_US, header=0)

In [None]:
new_index = [
    datetime.datetime.strptime(d, old_dateformat).strftime(dateformat)
    for d in df_cases_US.columns.values[11:]
]
df_cases_US_States = (
    df_cases_US.groupby(["Province_State"])
    .sum()
    .transpose()
    .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
    .reset_index()
)
df_cases_US_States["Date"] = new_index
df_cases_US_States = df_cases_US_States.set_index("Date").drop("index", axis=1)

df_deaths_US_States = (
    df_deaths_US.groupby(["Province_State"])
    .sum()
    .transpose()
    .drop(["UID", "code3", "FIPS", "Lat", "Long_"])
)
pops_US_States = (
    df_deaths_US_States.loc["Population"]
    .to_frame(name="2018")
    .replace({0: math.nan})
    .dropna()
)
df_deaths_US_States.drop(["Population"], inplace=True)
df_deaths_US_States["Date"] = new_index
df_deaths_US_States = (
    df_deaths_US_States.reset_index().set_index("Date").drop("index", axis=1)
)

In [None]:
df_rec_US_States = pd.DataFrame(
    np.zeros(df_cases_US_States.shape),
    index=df_cases_US_States.index.values,
    columns=df_cases_US_States.columns.values,
)

dict_df_US_States = {
    "Cases": df_cases_US_States,
    "Deaths": df_deaths_US_States,
    "Recovered": df_rec_US_States,
    "Active Cases": df_cases_US_States
    - df_deaths_US_States.fillna(0)
    - df_rec_US_States.fillna(0),
}

states_to_codes = pd.read_csv(
    FOLDER_US + "USStatesMap_codes.csv", index_col=0, header=0
).to_dict()["ID"]
datasets_US_States = DataProcessor(
    dict_df_US_States, pops_US_States, "USStates", states_to_codes
)

In [None]:
dict_datasets = {0: datasets_World, 1: datasets_US_States}

In [None]:
# Buttons and interactions

style = {"description_width": "60px", "font_weight": "bold", "button_width": "auto"}
style_2 = {"description_width": "60px", "font_weight": "bold", "button_width": "100px"}

# Toggle buttons to choose data to color the map
data_buttons = ToggleButtons(
    options=["Cases", "Deaths", "Recovered", "Active Cases"],
    value="Cases",
    description="Data",
    style=style,
    button_style="",
    layout=Layout(width="auto", height="auto"),
)

# Toggle buttons to choose normalization of data
norm_buttons = ToggleButtons(
    options=["Values", "Per 1MPop"],
    value="Values",
    description="Norm",
    style=style,
    button_style="",
    layout=Layout(width="auto", height="auto"),
)

# Toggle buttons to choose scale of plots and map
scale_buttons = ToggleButtons(
    options=["Linear", "Log"],
    value="Linear",
    description="Scale",
    style=style,
    button_style="",
    layout=Layout(width="auto", height="auto"),
)

# Toggle buttons to choose between cumulative data or daily change
type_buttons = ToggleButtons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    style=style,
    button_style="",
    layout=Layout(width="auto", height="auto"),
)


In [None]:
# Main graph

css_style = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <style type="text/css" media="screen">
        #stats_table {
          font-family: "Trebuchet MS", Arial, Helvetica, sans-serif;
          font-size: 12px;
          border-collapse: collapse;
          width: auto;
          text-align: left;
        }
        #stats_table td, #stats_table th {
          border: 1px solid #ddd;
          padding: 0px 0px;
        }
        #stats_table tbody td {
          font-size: 12px;
          font-weight: bold;
        }
        #stats_table tr:nth-child(even) {
          background-color: #f2f2f2;
        }
        #stats_table th {
          padding-top: 12px;
          padding-bottom: 12px;
          text-align: center;
          background-color: #003366;
          color: white;
        }
    </style>
</head>
<body>
"""

line_tooltip_table = """
<table id="stats_table">
<tr>
    <td>Date</td>
    <td>{0:}</td>
</tr>
<tr>
    <td>{1:s}</td>
    <td>{2:}</td>
</tr>
<tr>
    <td>{3:s}</td>
    <td>{4:}</td>
</tr>
<tr>
    <td>{5:s}</td>
    <td>{6:}</td>
</tr>
<tr>
    <td>{7:s}</td>
    <td>{8:}</td>
</tr>
</table>
"""

stats_line_tooltip_table_values = HTML()  # table showed when you hover on Lines markers


In [None]:
scale_t = DateScale(
    dateformat=dateformat,
    min=datetime.datetime.strptime(STDT, dateformat),
    max=datetime.datetime.strptime(ENDT, dateformat),
)
axis_t = Axis(scale=scale_t, grid_lines="none")
scale_y = LinearScale()
axis_y = Axis(
    scale=scale_y, orientation="vertical", grid_lines="solid", tick_format=","
)
main_graph_ttl = "COVID-19 in {}"

main_graph = Figure(
    animation_duration=1000,
    legend_location="top-left",
    axes=[axis_t, axis_y],
    title=main_graph_ttl,
    fig_margin={"top": 50, "bottom": 50, "left": 55, "right": 0},
    layout=Layout(width="auto"),
)


In [None]:
curve_colors = {
    "Cases": "dodgerblue",
    "Deaths": "red",
    "Recovered": "green",
    "Active Cases": "orange",
}


def update_main_graph(
    datasets,
    country_name,
    normalization,
    scale,
    data_type,
    date1,
    date2,
    axis_t,
    axis_y,
    figure,
):
    selected = ["Cases", "Deaths", "Recovered", "Active Cases"]
    line_col = [curve_colors[c] for c in selected]
    data = [
        datasets.get_ts_plot(
            country_name, data_name, normalization, scale, data_type, date1, date2
        )
        for data_name in selected
    ]
    scale_t = DateScale(
        dateformat=dateformat,
        min=datetime.datetime.strptime(date1, dateformat),
        max=datetime.datetime.strptime(date2, dateformat),
    )
    axis_t.scale = scale_t

    if scale == "Log":
        scale_y = LogScale()
        # data = [np.where(d <= 0, math.nan, d) for d in data]
    elif scale == "Linear":
        scale_y = LinearScale()

    main_mark = Lines(
        scales={"x": scale_t, "y": scale_y},
        display_legend=True,
        labels=selected,
        colors=line_col,
        marker="circle",
        marker_size=30,
        tooltip=stats_line_tooltip_table_values,
        interactions={"hover": "tooltip"},
        curves_subset=[i for i in range(4)],
        opacities=[1] * 4,
    )
    date_axis_selector.index = (0, datasets.get_len() - 1)
    main_mark.y = data
    axis_y.scale = scale_y
    figure.axes = [axis_t, axis_y]
    main_mark.x = list(
        np.arange(
            np.datetime64(
                datetime.datetime.strptime(date1, dateformat).strftime(dateformat)
            ),
            np.datetime64(
                datetime.datetime.strptime(date2, dateformat).strftime(dateformat)
            )
            + np.timedelta64(1, "D"),
        )
    )
    main_mark.on_hover(line_hover)
    main_mark.on_legend_hover(legend_hover)
    main_mark.on_legend_click(legend_click_line)

    figure.marks = [main_mark]

    if datasets.name == "World":
        figure.title = main_graph_ttl.format(codes_to_countries[country_name])

    elif datasets.name == "USStates":
        figure.title = main_graph_ttl.format(country_name + " State")

    return


In [None]:
def name_val_on_hover(data, norm, data_type):
    res = data
    if norm == "Per 1MPop":
        res += "/1MPop"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    return res


def line_hover(*args):
    # show table with Cases, deaths, rec and active corresponding to the
    # hovered marker
    selected = ["Cases", "Deaths", "Recovered", "Active Cases"]
    try:
        index = args[1]["data"]["sub_index"]
        Date = str(main_graph.marks[0].x[index])
        # get value of cases deaths rec and active at Date
        values = list(
            np.nan_to_num([main_graph.marks[0].y[i][index] for i in range(4)])
        )
        name_vals = [
            name_val_on_hover(data, norm_buttons.value, type_buttons.value)
            for data in selected
        ]
        if norm_buttons.value == "Values" and type_buttons.value != "Daily % change":
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,d}".format(int(values[i]))]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        else:
            res = [Date] + list(
                np.array(
                    [
                        [name_vals[i], "{0:,.2f}".format(values[i])]
                        for i in range(len(values))
                    ]
                ).flatten()
            )
        stats_line_tooltip_table_values.value = css_style + line_tooltip_table.format(
            *res
        )
    except:
        stats_line_tooltip_table_values.value = ""
    return


def legend_hover(*args):
    # Reduce opacity of non hovered legends
    op = [0.3] * 4
    index = args[1]["data"]["index"]
    op[index] = 1
    if (
        np.sum(main_graph.marks[0].opacities) < 4
        and main_graph.marks[0].opacities[index] == 1
    ) or (index not in main_graph.marks[0].curves_subset):
        main_graph.marks[0].opacities = [1] * 4
    else:
        main_graph.marks[0].opacities = op
    return


def legend_click_line(*args):
    # Remove curve if you click on legend
    # if it's the last visible curve add all the others
    index = args[1]["data"]["index"]
    if index in main_graph.marks[0].curves_subset:
        if len(main_graph.marks[0].curves_subset) == 1:
            main_graph.marks[0].curves_subset = [i for i in range(4)]
        else:
            main_graph.marks[0].curves_subset = [index]
    else:
        main_graph.marks[0].curves_subset = main_graph.marks[0].curves_subset + [index]

    main_graph.marks[0].scales["y"].max = float(
        np.nanmax(np.array(main_graph.marks[0].y)[main_graph.marks[0].curves_subset])
    )
    main_graph.marks[0].scales["y"].min = float(
        np.nanmin(np.array(main_graph.marks[0].y)[main_graph.marks[0].curves_subset])
    )
    if main_graph.marks[0].opacities[index] == 1:
        main_graph.marks[0].opacities = [1] * 4
    return


In [None]:
# Maps

def map_title(map_name, data, norm, data_type, date):
    res = map_name + " COVID-19 " + data
    if norm == "Per 1MPop":
        res += " " + norm
    if data_type == "Daily change":
        res += " daily change"
    if data_type == "Daily % change":
        res += " daily % change"
    return res + " Map, " + date

In [None]:
table_tmpl_no_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr>
    <td>{1:s}</td>
    <td>{2:}</td>
</tr>
<tr>
    <td>Cases</td>
    <td>{3:,d}</td>
</tr>
<tr>
    <td>Deaths</td>
    <td>{4:,d}</td>
</tr>
<tr>
    <td>Recovered</td>
    <td>{5:,d}</td>
</tr>
<tr>
    <td>Active Cases</td>
    <td>{6:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{7:,d}</td>
</tr>
</table>
"""

table_tmpl_duplicate = """
<table id="stats_table">
<thead>
<tr>
<th colspan="2">{0:s}</th>
</tr>
</thead>
<tbody>
<tr>
    <td>Cases</td>
    <td>{1:,d}</td>
</tr>
<tr>
    <td>Deaths</td>
    <td>{2:,d}</td>
</tr>
<tr>
    <td>Recovered</td>
    <td>{3:,d}</td>
</tr>
<tr>
    <td>Active Cases</td>
    <td>{4:,d}</td>
</tr>
<tr>
    <td>Population</td>
    <td>{5:,d}</td>
</tr>
</tbody>
</table>
"""
stats_table = HTML()


In [None]:
def current_val_on_hover(data, norm, scale, data_type):
    if data_type == "Total" and norm == "Values":
        return None
    else:
        res = data
        if norm == "Per 1MPop":
            res += "/1MPop"
        if data_type == "Daily change":
            res += " daily change"
        elif data_type == "Daily % change":
            res += " daily % change"
        return res

In [None]:
def map_color(dataset, data, norm, scale, data_type, date):
    IDs = []
    colors = []

    for country_name in dataset.get_columns(data):
        try:
            IDs.append(dataset.get_ID(country_name))
            value = dataset.get_value(country_name, data, norm, scale, data_type, date)
            if value == 0:
                colors.append(math.nan)
            else:
                colors.append(value)
        except:
            if len(IDs) > len(colors):
                IDs = IDs[:-1]
    col_test = pd.DataFrame.from_dict(
        dict(zip(IDs, colors)), orient="index", columns=["value"]
    )
    color = col_test.rank(method="min", pct=True, na_option="keep").to_dict()["value"]
    return color


def select_date(obj):
    # color the map with the values of the chosen date
    date = obj["new"]
    norm, scale, data_type = norm_buttons.value, scale_buttons.value, type_buttons.value
    data = data_buttons.value
    datasets = dict_datasets[tab.selected_index]
    color = map_color(datasets, data, norm, scale, data_type, date)
    dict_maps[tab.selected_index]["map_fig"].title = map_title(
        dict_maps[tab.selected_index]["name"], data, norm, data_type, date
    )
    dict_maps[tab.selected_index]["map"].color = color
    update_tables(data, norm, scale, data_type, date)
    return


def map_click(obj, value):
    # update main graph if you click on a country or state
    _id = value["data"]["id"]
    date = date_selector.value
    datasets = dict_datasets[tab.selected_index]
    name = datasets.get_name(_id)
    normalization, scale, data_type = (
        norm_buttons.value,
        scale_buttons.value,
        type_buttons.value,
    )
    try:
        update_main_graph(
            datasets,
            name,
            normalization,
            scale,
            data_type,
            STDT,
            ENDT,
            axis_t,
            axis_y,
            main_graph,
        )
    except:
        print("Impossible to update the main graph")
    return


def map_hovering(obj, value):
    # show stat table when you hover on a country or state
    _id = value["data"]["id"]
    date = date_selector.value
    datasets = dict_datasets[tab.selected_index]
    try:
        name = datasets.get_name(_id)
    except:
        print(str(_id) + "not in data")
        return
    if tab.selected_index == 0:
        name_table = codes_to_countries[name]
    elif tab.selected_index == 1:
        name_table = name
    try:
        norm, scale, data_type = (
            norm_buttons.value,
            scale_buttons.value,
            type_buttons.value,
        )
        data = data_buttons.value
        current_val = current_val_on_hover(data, norm, scale, data_type)
        selected = ["Cases", "Deaths", "Recovered", "Active Cases"]
        output_vals = [name_table]
        for c in selected:
            output_vals.append(
                int(datasets.get_value(name, c, "Values", "Linear", "Total", date))
            )
        output_vals.append(int(datasets.get_population(name)))
        if current_val is None:
            stats_table.value = css_style + table_tmpl_duplicate.format(
                *tuple(output_vals)
            )
        else:
            val = float(datasets.get_value(name, data, norm, scale, data_type, date))
            if data_type == "Daily change" and norm == "Values":
                val = int(val)
            else:
                val = round(val, 2)
            output_vals = [output_vals[0]] + [current_val, val] + output_vals[1:]
            stats_table.value = css_style + table_tmpl_no_duplicate.format(
                *tuple(output_vals)
            )
    except:
        population = int(datasets.get_population(name))
        stats_table.value = css_style + table_tmpl_duplicate.format(
            *tuple([name_table] + [0] * 4 + [population])
        )
    dict_maps[tab.selected_index]["map"].tooltip = stats_table
    return


In [None]:
def event(obj):
    # update map color and graph if the value of the radio buttons change
    date = date_selector.value
    norm, scale, data_type = norm_buttons.value, scale_buttons.value, type_buttons.value
    data = data_buttons.value
    datasets = dict_datasets[tab.selected_index]
    color = map_color(datasets, data, norm, scale, data_type, date)
    dict_maps[tab.selected_index]["map_fig"].title = map_title(
        dict_maps[tab.selected_index]["name"], data, norm, data_type, date
    )
    dict_maps[tab.selected_index]["map"].color = color
    if main_graph.title.split(" ")[-1] == "State":
        update_main_graph(
            dict_datasets[1],
            main_graph.title.split(" in ")[-1][:-6],
            norm,
            scale,
            data_type,
            STDT,
            ENDT,
            axis_t,
            axis_y,
            main_graph,
        )
    else:
        update_main_graph(
            dict_datasets[0],
            countries_to_codes[main_graph.title.split(" in ")[-1]],
            norm,
            scale,
            data_type,
            STDT,
            ENDT,
            axis_t,
            axis_y,
            main_graph,
        )
    update_tables(data, norm, scale, data_type, date)
    return

In [None]:
play = Play(
    value=datasets_World.get_len() - 1,
    min=0,
    max=datasets_World.get_len() - 1,
    step=1,
    description="Press play",
    disabled=False,
    layout=Layout(width="auto"),
)


def update_index_date(*args):
    date_selector.value = date_selector.options[args[0]["new"]]
    return

In [None]:
date_selector = SelectionSlider(
    options=datasets_World.get_index("Cases"),
    description="Date",
    disabled=False,
    value=ENDT,
    layout=Layout(width="auto", height="auto"),
    style={"description_width": "initial"},
)

In [None]:
date_axis_selector = SelectionRangeSlider(
    options=list(
        np.arange(
            np.datetime64(
                datetime.datetime.strptime(datasets_World.STDT, dateformat).strftime(
                    dateformat
                )
            ),
            np.datetime64(
                datetime.datetime.strptime(datasets_World.ENDT, dateformat).strftime(
                    dateformat
                )
            )
            + np.timedelta64(1, "D"),
        )
    ),
    index=(0, datasets_World.get_len() - 1),
    layout=Layout(width="auto", height="auto"),
)


def update_zoom(*args):
    min_x, max_x = args[0]["new"]
    main_graph.axes[0].scale.min = min_x
    main_graph.axes[0].scale.max = max_x

    main_graph.axes[1].scale.min = float(
        np.nanmin(
            main_graph.marks[0].y[
                :, date_axis_selector.index[0] : date_axis_selector.index[1]
            ]
        )
    )

    main_graph.axes[1].scale.max = float(
        np.nanmax(
            main_graph.marks[0].y[
                :, date_axis_selector.index[0] : date_axis_selector.index[1]
            ]
        )
    )
    return

In [None]:
date_axis_selector.observe(update_zoom, "value")
date_axis_selector.continuous_update = False

In [None]:
play.observe(update_index_date, "value")
date_selector.observe(select_date, "value")
data_buttons.observe(event, "value")
norm_buttons.observe(event, "value")
scale_buttons.observe(event, "value")
type_buttons.observe(event, "value")

In [None]:
# World Map

sc_geo = Mercator(scale_factor=220, center=(-20, 60))
sc_c1 = ColorScale(
    colors=["#ffa600", "#ff6e00", "#ff4417", "#d31522", "#8c000e", "#560410"],
    min=0,
    max=1,
)
caxis = ColorAxis(scale=sc_c1, tick_format=",")
color = map_color(
    datasets_World,
    data_buttons.value,
    norm_buttons.value,
    scale_buttons.value,
    type_buttons.value,
    datasets_World.ENDT,
)
map_tt = stats_table

wm = Map(
    map_data=topo_load("map_data/WorldMap.json"),
    scales={"projection": sc_geo, "color": sc_c1},
    colors={"default_color": "Grey"},
    color=color,
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)

map_fig = Figure(
    marks=[wm],
    axes=[caxis],
    animation_duration=1000,
    title=map_title("World", "Cases", "Values", "Total", datasets_World.ENDT),
    fig_margin={"top": 50, "bottom": 50, "left": 0, "right": 0},
    layout=Layout(width="100%"),
)


In [None]:
wm.on_hover(map_hovering)
wm.on_element_click(map_click)

In [None]:
# US States Map

sc_geo_US = Mercator(scale_factor=1100, center=(-103, 45))
color_US = map_color(
    datasets_US_States,
    data_buttons.value,
    norm_buttons.value,
    scale_buttons.value,
    type_buttons.value,
    datasets_US_States.ENDT,
)
map_tt = stats_table

wm_US = Map(
    map_data=topo_load("map_data/USStatesMap.json"),
    scales={"projection": sc_geo_US, "color": sc_c1},
    colors={"default_color": "Grey"},
    color=color_US,
    tooltip=map_tt,
    hovered_styles={"hovered_fill": "White"},
)
map_fig_US = Figure(
    marks=[wm_US],
    axes=[caxis],
    animation_duration=1000,
    title=map_title("US States", "Cases", "Values", "Total", datasets_US_States.ENDT),
    fig_margin={"top": 50, "bottom": 50, "left": 0, "right": 0},
    layout=Layout(width="auto"),
)

dict_maps = {
    0: {"map": wm, "map_fig": map_fig, "name": "World"},
    1: {"map": wm_US, "map_fig": map_fig_US, "name": "US States"},
}


In [None]:
wm_US.on_hover(map_hovering)
wm_US.on_element_click(map_click)

In [None]:
# Tables

css_style_2 = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <style type="text/css" media="screen">
        #sorted_table {
          font-family: "Trebuchet MS", Arial, Helvetica, sans-serif;
          font-size: 12px;
          border-collapse: collapse;
          width: 100%;
          text-align: left;
        }
        #sorted_table td, #sorted_table th {
          border: 1px solid #ddd;
          padding: 0px 0px;
        }
        #sorted_table tbody td {
          font-size: 12px;
          font-weight: bold;
        }
        #sorted_table tr:nth-child(even) {
          background-color: #f2f2f2;
        }
        #sorted_table th {
          padding-top: 12px;
          padding-bottom: 12px;
          text-align: center;
          background-color: #003366;
          color: white;
        }
    </style>
</head>
<body>
"""


def sorted_table_html(
    dataset, data, norm, scale, data_type, date, ascending, K, dict_codes, column_name
):
    df = dataset.get_ts_sort(
        data, norm, scale, data_type, date, ascending, K
    ).reset_index()
    current_val = name_val_on_hover(data, norm, data_type)

    df.rename(
        {df.columns.values[0]: column_name, date: current_val}, axis=1, inplace=True
    )
    if dict_codes:
        df[column_name] = [dict_codes[c] for c in df[column_name].values]

    if norm == "Values" and data_type != "Daily % change":
        df[current_val] = df[current_val].astype(int)
        formatters = {current_val: lambda x: "{:,d}".format(x)}
    else:
        formatters = {current_val: lambda x: "{:,.02f}".format(x)}

    val = df.to_html(
        index=False,
        header=True,
        formatters=formatters,
        notebook=True,
        table_id="sorted_table",
    )
    return css_style_2 + val + """</body>"""


def sorted_tables_html(data, norm, scale, data_type, date, ascending, K):
    value_table_1 = sorted_table_html(
        datasets_World,
        data,
        norm,
        scale,
        data_type,
        date,
        ascending,
        K,
        codes_to_countries,
        "Countries",
    )
    value_table_2 = sorted_table_html(
        datasets_US_States,
        data,
        norm,
        scale,
        data_type,
        date,
        ascending,
        K,
        None,
        "US States",
    )
    return value_table_1, value_table_2


def update_tables(data, norm, scale, data_type, date):
    ascending = False
    K = 40
    value_table_1, value_table_2 = sorted_tables_html(
        data, norm, scale, data_type, date, ascending, K
    )

    table_1.value = value_table_1
    table_2.value = value_table_2
    return


table_1 = HTML()
table_2 = HTML()
update_tables(
    data_buttons.value,
    norm_buttons.value,
    scale_buttons.value,
    type_buttons.value,
    ENDT,
)


In [None]:
# Infection Map tab

grid_1 = GridspecLayout(11, 10)
grid_1.layout.height = "99%"  # "750px"
grid_1.layout.width = "100%"
tab = Tab(
    _titles=dict(zip([0, 1], ["World", "US States"])), children=[map_fig, map_fig_US]
)
tab.layout.height = "99%"
table_grid = GridspecLayout(1, 4)
table_grid.layout.width = "100%"
table_grid.layout.height = "100%"
table_1.layout.width = "auto"
table_1.layout.height = "200px"
table_2.layout.width = "auto"
table_2.layout.height = "200px"
table_grid[0, :2] = table_1
table_grid[0, 2:] = table_2
center_right_panel = VBox([main_graph, date_axis_selector])
center_right_panel.layout.height = "100%"
tab_graph_table = Tab(
    _titles=dict(zip([0, 1], ["Graph", "Table"])),
    children=[center_right_panel, table_grid],
)
tab_graph_table.layout.height = "99%"
grid_1[0:2, 0:3] = VBox([data_buttons, type_buttons])
grid_1[0:2, 3:6] = VBox([norm_buttons, scale_buttons])
play.layout.height = "40%"
grid_1[0:2, 6:9] = date_selector
grid_1[0:2, 9] = play

map_fig.layout.height = "99%"
map_fig_US.layout.height = "99%"
main_graph.layout.height = "99%"
grid_1[2:, :6] = tab
grid_1[2:, 6:] = tab_graph_table
grid_1.align_items = "center"


In [None]:
tab.observe(event, "selected_index")
update_main_graph(
    dict_datasets[0],
    "USA",
    "Values",
    "Linear",
    "Total",
    STDT,
    ENDT,
    axis_t,
    axis_y,
    main_graph,
)

In [None]:
# Rebased Graph Tab

# Buttons and Multiple selectors

style = {"description_width": "100px", "font_weight": "bold", "button_width": "auto"}
style_2 = {"description_width": "100px", "font_weight": "bold", "button_width": "100px"}

rebased_graph_data_button = ToggleButtons(
    options=["Cases", "Deaths", "Recovered", "Active Cases"],
    value="Cases",
    description="Data",
    style=style,
    layout=Layout(width="auto", height="auto"),
)

rebased_graph_norm_button = ToggleButtons(
    options=["Values", "Per 1MPop"],
    value="Values",
    description="Norm",
    style=style_2,
    layout=Layout(width="auto", height="auto"),
)

rebased_graph_type_button = ToggleButtons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    style=style_2,
    layout=Layout(width="auto", height="auto"),
)

thr_val_slider = IntSlider(
    description="Threshold", value=1000, min=0, max=5000, step=100
)

plot_scale_button = ToggleButtons(
    options=["Linear", "Log"],
    value="Log",
    description="Scale",
    style=style_2,
    layout=Layout(width="auto", height="auto"),
)

rebased_graph_thr_data_button = ToggleButtons(
    options=["Cases", "Deaths", "Recovered"],
    value="Cases",
    description="Threshold Data",
    style=style_2,
    layout=Layout(width="auto", height="auto"),
)

rebased_graph_thr_norm_button = ToggleButtons(
    options=["Values", "Per 1MPop"],
    value="Values",
    description="Threshold Norm",
    style=style_2,
    layout=Layout(width="auto", height="auto"),
)


In [None]:
top_countries_codes = (
    datasets_World.dict_data["Cases"]
    .loc[ENDT][datasets_World.dict_data["Cases"].loc[ENDT] > 3000]
    .sort_values(ascending=False)
    .index.values
)
top_countries_names = [codes_to_countries[c] for c in top_countries_codes]
states = list(
    datasets_US_States.dict_data["Cases"]
    .loc[ENDT, np.array(list(states_to_codes.keys()))]
    .sort_values(ascending=False)
    .index.values
)

countries_selector = SelectMultiple(
    options=["None"] + top_countries_names,
    description="Countries",
    value=top_countries_names[:6],
    style=style,
    layout=Layout(width="auto", height="100%"),
)

states_selector = SelectMultiple(
    options=["None"] + states,
    description="US States",
    value=states[:2],
    style=style,
    layout=Layout(width="auto", height="100%"),
)


In [None]:
# Graph

scale_reb_t = LinearScale()
axis_reb_t = Axis(scale=scale_reb_t, grid_lines="none")
scale_reb_y = LogScale()
axis_reb_y = Axis(
    scale=scale_reb_y, orientation="vertical", grid_lines="solid", tick_format=","
)

rebased_graph = Figure(
    animation_duration=1000,
    title="Rebased Graph",
    legend_location="bottom-right",
    axes=[axis_reb_t, axis_reb_y],
    fig_margin={"top": 50, "bottom": 50, "left": 55, "right": 80},
)


In [None]:
def title_rebased_graph(
    data_to_plot, data_norm, data_type, threshold, plot_scale, thr_data, thr_norm
):
    res = ""
    if plot_scale == "Log":
        res += plot_scale
    res += " " + data_to_plot
    if data_norm != "Values":
        res += "/1MPop"
    if data_type == "Daily change":
        res += " daily change"
    elif data_type == "Daily % change":
        res += " daily % change"
    if threshold > 0:
        res += " since number of " + thr_data
        if thr_norm != "Values":
            res += "/1MPop"
        res += " = " + str(threshold)
    else:
        res += " since " + STDT
    return res


def label_index_pos(data):
    list_nan = list(np.isnan(data))
    if True in list_nan:
        index = len(data) - 1
        while index >= 1 and not (
            list_nan[index] == False and list_nan[index - 1] == False
        ):
            index -= 1
        return index
    else:
        return len(data) - 1


In [None]:
def plot_rebased_graph(
    data_to_plot,
    data_norm,
    data_type,
    threshold,
    countries,
    states,
    plot_scale,
    thr_data,
    thr_norm,
):
    if plot_scale == "Log":
        scale_reb_y = LogScale()
    else:
        scale_reb_y = LinearScale()
    Ys = []
    Xs = []
    label_names = []
    max_len = 0
    for country_name in countries:
        yaux = dict_datasets[0].get_ts_plot(
            countries_to_codes[country_name],
            data_to_plot,
            data_norm,
            plot_scale,
            data_type,
            dict_datasets[0].STDT,
            dict_datasets[0].ENDT,
        )
        if threshold > 0:
            y_thr = dict_datasets[0].get_ts(
                countries_to_codes[country_name], thr_data, thr_norm, "Linear", "Total"
            )
            # values starts from index of threshold
            yaux = yaux[bisect.bisect(y_thr, threshold) :]
        if yaux.shape[0] > 0:
            if plot_scale == "Log":
                yaux[yaux <= 0] = math.nan
            Ys.append(list(yaux.values))
            Xs.append(list(np.arange(0, len(Ys[-1]))))
            label_names.append(country_name)
            max_len = max(max_len, len(Ys[-1]))
    for st in states:
        if (
            data_to_plot != "Recovered" and thr_data != "Recovered"
        ):  # no recovered data for states
            yaux = dict_datasets[1].get_ts_plot(
                st,
                data_to_plot,
                data_norm,
                "Linear",
                data_type,
                dict_datasets[0].STDT,
                dict_datasets[0].ENDT,
            )
            if threshold > 0:
                y_thr = dict_datasets[1].get_ts(
                    st, thr_data, thr_norm, "Linear", "Total"
                )
                yaux = yaux[bisect.bisect(y_thr, threshold) :]
            if yaux.shape[0] > 0:
                if plot_scale == "Log":
                    yaux[yaux <= 0] = math.nan
                Ys.append(list(yaux.values))
                Xs.append(list(np.arange(0, len(Ys[-1]))))
                label_names.append(st)
                max_len = max(max_len, len(Ys[-1]))
    label_x = []
    label_y = []
    for i in range(len(Ys)):
        # get a 2d nparray to plot the rebased graph
        index_lab = label_index_pos(Ys[i])
        label_y.append(Ys[i][index_lab])
        label_x.append(Xs[i][index_lab])
        Ys[i] = Ys[i] + [math.nan] * (max_len - len(Ys[i]))
        Xs[i] = Xs[i] + [math.nan] * (max_len - len(Xs[i]))
    axis_reb_y.scale = scale_reb_y
    if threshold > 0:
        scale_reb_t = LinearScale()
    else:  # if thr == 0 Plot the values vs calendar dates
        scale_reb_t = DateScale(
            dateformat=dateformat,
            min=datetime.datetime.strptime(STDT, dateformat),
            max=datetime.datetime.strptime(ENDT, dateformat),
        )
        Xs = list(
            np.arange(
                np.datetime64(
                    datetime.datetime.strptime(STDT, dateformat).strftime(dateformat)
                ),
                np.datetime64(
                    datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
                )
                + np.timedelta64(1, "D"),
            )
        )
        label_x = [
            np.datetime64(
                datetime.datetime.strptime(ENDT, dateformat).strftime(dateformat)
            )
        ] * (len(countries) + len(states))
    axis_reb_t.scale = scale_reb_t
    rebased_graph.axes = [axis_reb_t, axis_reb_y]
    rebased_curves = Lines(scales={"x": scale_reb_t, "y": scale_reb_y})
    rebased_curves.x = np.array(Xs)
    rebased_curves.y = np.array(Ys)
    rebased_labels = Label(
        apply_clip=False, scales={"x": scale_reb_t, "y": scale_reb_y}, default_size=15
    )
    rebased_labels.text = label_names
    rebased_labels.x = np.array(label_x)
    rebased_labels.y = np.array(label_y)
    rebased_graph.marks = [rebased_curves, rebased_labels]
    rebased_graph.title = title_rebased_graph(
        data_to_plot, data_norm, data_type, threshold, plot_scale, thr_data, thr_norm
    )
    return


plot_rebased_graph(
    "Cases",
    "Values",
    "Total",
    1000,
    top_countries_names[:6],
    states[:2],
    "Log",
    "Cases",
    "Values",
)


In [None]:
value_buttons = [
    rebased_graph_data_button,
    rebased_graph_norm_button,
    rebased_graph_type_button,
    thr_val_slider,
    plot_scale_button,
    rebased_graph_thr_data_button,
    rebased_graph_thr_norm_button,
    countries_selector,
    states_selector,
]


def update_rebased_graph(*args):
    data_to_plot, data_norm, data_type, threshold, plot_scale, thr_data, thr_norm = [
        v.value for v in value_buttons[:-2]
    ]
    countries, states = list(countries_selector.value), list(states_selector.value)
    if "None" in countries:
        countries = []
    if "None" in states:
        states = []
    if len(countries) + len(states) == 0:
        countries = ["United States"]
    plot_rebased_graph(
        data_to_plot,
        data_norm,
        data_type,
        threshold,
        countries,
        states,
        plot_scale,
        thr_data,
        thr_norm,
    )
    return


thr_val_slider.continuous_update = False
for v in value_buttons:
    v.observe(update_rebased_graph, "value")


In [None]:
grid_2 = GridspecLayout(10, 10)
grid_2.layout.height = "100%"
grid_2.layout.width = "auto"
rebased_graph.layout.width = "100%"
rebased_graph.layout.height = "99%"
grid_2[:, 0:6] = rebased_graph
reb_top_right1 = VBox(
    [
        rebased_graph_data_button,
        rebased_graph_norm_button,
        rebased_graph_type_button,
        plot_scale_button,
    ]
)
reb_top_right2 = VBox([rebased_graph_thr_data_button, rebased_graph_thr_norm_button])
thr_val_slider.layout.width = "auto"
thr_val_slider.style.description_width = "100px"
thr_val_slider.layout.margin = "15px 0px 0px 0px"
countries_selector.layout.width = "auto"
countries_selector.layout.height = "90%"
states_selector.layout.height = "90%"
grid_2[0:2, 6:] = reb_top_right1
grid_2[2, 6:] = thr_val_slider
grid_2[3:5, 6:] = reb_top_right2
grid_2[5:, 6:8] = countries_selector
grid_2[5:, 8:] = states_selector
grid_2.align_items = "center"

In [None]:
# DNA Graph tab3

dna_countries_selector = SelectMultiple(
    options=['None']+top_countries_names,
    description="Countries",
    value=top_countries_names[:6],
    style=style,
    layout=Layout(width="auto", height="auto"),
)

dna_states_selector = SelectMultiple(
    options=['None']+states,
    description="US States",
    value=states[:2],
    style=style,
    layout=Layout(width="auto", height="auto"),
)


dna_data_button = ToggleButtons(
    options=["Cases", "Deaths", "Recovered", "Active Cases"],
    value="Cases",
    description="Data",
    style=style,
    layout=Layout(width="auto", height="auto"),
)

dna_norm_button = ToggleButtons(
    options=["Values", "Per 1MPop"],
    value="Values",
    description="Norm",
    style=style,
    layout=Layout(width="auto", height="auto"),
)

dna_type_button = ToggleButtons(
    options=["Total", "Daily change", "Daily % change"],
    value="Total",
    description="Type",
    style=style,
    layout=Layout(width="auto", height="auto"),
)


In [None]:
# GridHeatMap, DNA graph

data_countries = (
    dict_datasets[0]
    .get_ts(
        [countries_to_codes[c] for c in dna_countries_selector.value],
        dna_data_button.value,
        dna_norm_button.value,
        "Linear",
        dna_type_button.value,
    )
    .rename(columns=codes_to_countries)
)
data_states = dict_datasets[1].get_ts(
    states_selector.value,
    dna_data_button.value,
    dna_norm_button.value,
    "Linear",
    dna_type_button.value,
)
df_DNA = pd.merge(
    data_countries, data_states, left_index=True, right_index=True, how="outer"
)

column_dna = pd.to_datetime(df_DNA.index.values)
row_dna = list(df_DNA.columns.values)
dna_x_scale = DateScale()
dna_y_scale = OrdinalScale(padding_y=0)
dna_color_scale = ColorScale(
    colors=["white", "#ffa600", "#ff6e00", "#ff4417", "#d31522", "#8c000e", "#560410"]
)
dna_heat_map = GridHeatMap(
    color=df_DNA.values.T,
    column=column_dna,
    row=row_dna,
    column_align="start",
    scales={"column": dna_x_scale, "row": dna_y_scale, "color": dna_color_scale},
    stroke="white",
)
dna_x_ax = Axis(scale=dna_x_scale)
dna_y_ax = Axis(scale=dna_y_scale, orientation="vertical", side="right")
dna_color_ax = ColorAxis(scale=dna_color_scale, tick_format=",")
dna_axes = [dna_x_ax, dna_y_ax, dna_color_ax]

dna_layout = Layout(width="100%", height="99%")
dna_fig_margin = {"top": 60, "bottom": 60, "left": 0, "right": 80}
dna_figure = Figure(
    marks=[dna_heat_map],
    axes=dna_axes,
    fig_margin=dna_fig_margin,
    layout=dna_layout,
    min_aspect_ratio=0.0,
    title="Cases",
    padding_y=0,
    animation_duration=1000,
    null_color="#808080",
)


In [None]:
def update_dna(*args):
    if "None" in dna_countries_selector.value:
        countries = []
    else:
        countries = dna_countries_selector.value

    if "None" in dna_states_selector.value:
        states = []
    else:
        states = dna_states_selector.value

    if len(countries) + len(states) < 2:
        countries = ["United States", "China"]
        df_DNA = (
            dict_datasets[0]
            .get_ts(
                [countries_to_codes[c] for c in countries],
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
            )
            .rename(columns=codes_to_countries)
        )
    elif len(states) == 0:
        df_DNA = (
            dict_datasets[0]
            .get_ts(
                [countries_to_codes[c] for c in countries],
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
            )
            .rename(columns=codes_to_countries)
        )
    elif len(countries) == 0:
        df_DNA = dict_datasets[1].get_ts(
            states,
            dna_data_button.value,
            dna_norm_button.value,
            "Linear",
            dna_type_button.value,
        )
    else:
        data_countries = (
            dict_datasets[0]
            .get_ts(
                [countries_to_codes[c] for c in countries],
                dna_data_button.value,
                dna_norm_button.value,
                "Linear",
                dna_type_button.value,
            )
            .rename(columns=codes_to_countries)
        )
        data_states = dict_datasets[1].get_ts(
            states,
            dna_data_button.value,
            dna_norm_button.value,
            "Linear",
            dna_type_button.value,
        )
        df_DNA = pd.merge(
            data_countries, data_states, left_index=True, right_index=True, how="outer"
        )

    title = dna_data_button.value
    if dna_norm_button.value == "Per 1M Population":
        title += "/1M Pop"
    if dna_type_button.value != "Total":
        title += " " + dna_type_button.value

    dna_heat_map.color = df_DNA.values.T
    dna_heat_map.row = list(df_DNA.columns.values)
    dna_figure.title = title
    return


dna_countries_selector.observe(update_dna, "value")
dna_states_selector.observe(update_dna, "value")
dna_data_button.observe(update_dna, "value")
dna_norm_button.observe(update_dna, "value")
dna_type_button.observe(update_dna, "value")


In [None]:
grid_3 = GridspecLayout(10, 10)
grid_3.layout.height = "100%"
grid_3.layout.width = "100%"
dna_countries_selector.layout.height = "90%"
dna_states_selector.layout.height = "90%"
grid_3[:, :8] = dna_figure
grid_3[:2, 8:] = VBox([dna_data_button, dna_norm_button, dna_type_button])
grid_3[2:, 8:9] = dna_countries_selector
grid_3[2:, 9:] = dna_states_selector
grid_3.align_items = "center"


In [None]:
# Dashboard

tab_contents = ["Infection Maps", "Rebased Graph", "DNA Graph"]
children = [grid_1, grid_2, grid_3]
outer_tab = Tab(_titles=dict(zip([0, 1, 2], tab_contents)))
outer_tab.children = children
outer_tab.layout.height = "800px"


In [None]:
VBox(
    [
        outer_tab,
        HTML(value="Data Source : JHU https://github.com/CSSEGISandData/COVID-19"),
    ]
)
