In [1]:
import pandas as pd
import numpy as np

from bokeh.plotting import figure
from bokeh.io import show, output_notebook
from bokeh.models import ColumnDataSource, Panel, Tabs, LinearColorMapper, ColorBar, DatetimeTickFormatter
from bokeh.palettes import all_palettes

from covid_dataset import load_covid_df

In [2]:
output_notebook()

In [3]:
df_covid = load_covid_df()


Loading COVID-19 Dataset from  https://covid.ourworldindata.org/data/owid-covid-data.csv ... Done


In [4]:
def fill_gaps(df, col_name):
    for location in df["location"].unique():
        df.loc[df["location"] == location, col_name] = df.loc[
            df["location"] == location, col_name
        ].interpolate(method="linear").fillna(0)
    return df

def generate_line_chart(values_dict, date_dict=None):
    p = figure(
        title="Increase of COVID Deaths Over Time: Line Chart",
        height=400,
        width=800,
    )
    p.xaxis.axis_label = "Countries"
    p.xaxis.major_label_orientation = np.pi / 2
    p.yaxis.axis_label = "Deaths per Million"
    line_colors = list(all_palettes["Turbo"][len(values_dict)])
    for i, (country, values) in enumerate(values_dict.items()):
        p.line(
            x=range(len(values)) if date_dict == None else date_dict[country],
            y=values,
            color=line_colors[i],
            legend_label=country
        )
    p.legend.location = "top_left"
    p.xaxis.formatter = DatetimeTickFormatter(
        hours=["%d %B %Y"],
        days=["%d %B %Y"],
        months=["%d %B %Y"],
        years=["%d %B %Y"],
    )
    return p


def generate_bar_chart(values_dict, date_dict=None):
    p = figure(
        title="Increase of COVID Deaths Over Time: Bar Chart",
        height=400,
        width=800,
    )
    p.xaxis.axis_label = "Countries"
    p.yaxis.axis_label = "Deaths per Million"
    # print(date_dict)
    cmap_value_range = {
        "low": 0
        if date_dict == None
        else min([min(dates) for dates in date_dict.values()]),
        "high": max([len(values) for values in values_dict.values()])
        if date_dict == None
        else max([max(dates) for dates in date_dict.values()]),
    }
    cmap = LinearColorMapper(
        palette="Turbo256",
        low=cmap_value_range["low"],
        high=cmap_value_range["high"],
    )
    for i, (country, values) in enumerate(values_dict.items()):
        cds = ColumnDataSource(
            data=dict(
                left=[i - 0.4] * (len(values)),
                right=[i + 0.4] * (len(values)),
                top=list(values),
                bottom=[0] + list(values[:-1]),
                index=range(len(values)) if date_dict == None else date_dict[country],
            )
        )
        glyph = p.quad(
            source=cds,
            left="left",
            right="right",
            top="top",
            bottom="bottom",
            fill_color={"field": "index", "transform": cmap},
            line_color=None,
            line_width = 0
        )

    xaxis_labels = {}
    for i, country in enumerate(values_dict.keys()):
        xaxis_labels[i] = country
    p.xaxis.ticker = list(xaxis_labels.keys())
    p.xaxis.major_label_overrides = xaxis_labels
    p.xaxis.major_label_orientation = np.pi / 4

    cb = ColorBar(color_mapper=cmap, label_standoff=12)
    p.add_layout(cb, "right")
    return p

    xaxis_labels = {}
    for i, country in enumerate(values_dict.keys()):
        xaxis_labels[i] = country
    p.xaxis.ticker = list(xaxis_labels.keys())
    p.xaxis.major_label_overrides = xaxis_labels
    p.xaxis.major_label_orientation = np.pi/4
    return p

countries = [
    'Ghana',
    "Bangladesh",
    'India',
    "Vietnam",
    "Germany",
    "France",
    "Italy",
    "Brazil",
    "United States",
]
df = df_covid[df_covid["location"].isin(countries)].copy()
df["date"] = pd.to_datetime(df["date"])
df = fill_gaps(df, "total_deaths_per_million")
deaths_per_million_dict = {}
date_dict = {}
date_dict2 = {}
for country in countries:
    deaths_per_million_dict[country] = df[df["location"] == country][
        "total_deaths_per_million"
    ].tolist()
    date_dict[country] = (
        df[df["location"] == country]["date"].astype(np.int64).to_list()
    )
    date_dict2[country] = df[df["location"] == country]["date"].to_list()

p_line_chart=generate_line_chart(deaths_per_million_dict, date_dict2)
p_bar_chart = generate_bar_chart(deaths_per_million_dict, date_dict)
show(p_line_chart)
show(p_bar_chart)