In [48]:
from datetime import date, datetime, timedelta
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import matplotlib.dates as dates
import ipywidgets as widgets

%matplotlib inline

In [115]:
df = None
last = datetime.now() - timedelta(hours=1)


def get_data():
    """
    Get the latest data. Just return the cached copy
    if less than one hour has elapsed.
    """
    global last, df
    now = datetime.now()
    if now - last < timedelta(hours=1):
        return df

    last = now    
    start = date(2020, 1, 22)
    end = date.today()

    df = pd.DataFrame()
    while start <= end:
        try:
            df_day = pd.read_csv(f'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_daily_reports/{start.month:02d}-{start.day:02d}-{start.year}.csv')
            df_day['Date'] = start
            df = df.append(df_day)
        except Exception:
            pass
        start += timedelta(days=1)
    
    del df['Last Update']
    df = df.replace({'Country/Region': {'Mainland China': 'China'}})
    df.fillna({'Province/State': ''}, inplace=True)
    for col in ['Confirmed', 'Deaths', 'Recovered']:
        df.fillna({col: 0}, inplace=True)
    return df

In [116]:
df = get_data()
df.tail(5)

Unnamed: 0,Province/State,Country/Region,Confirmed,Deaths,Recovered,Date,Latitude,Longitude
211,Mississippi,US,0.0,0.0,0.0,2020-03-11,32.7416,-89.6787
212,North Dakota,US,0.0,0.0,0.0,2020-03-11,47.5289,-99.784
213,West Virginia,US,0.0,0.0,0.0,2020-03-11,38.4912,-80.9545
214,Wyoming,US,0.0,0.0,0.0,2020-03-11,42.756,-107.3025
215,,occupied Palestinian territory,0.0,0.0,0.0,2020-03-11,31.9522,35.2332


In [117]:
def aggregate_for_locations(df, locations, do_diffs = True, fields = None):
    if fields is None:
        fields = ['Confirmed', 'Deaths', 'Recovered']
    if isinstance(locations, str):
        locations = [locations]
    in_loc = df[df['Country/Region'].isin(locations)]
    cols = ['Date']
    cols.extend(fields)
    result = in_loc[cols].groupby('Date').sum()
    if do_diffs:
        for f in fields:
            result[f] = result[f].diff()
            result.fillna({f: 0}, inplace=True)
    return result


In [118]:
def plot_time_series(df, title):
    fig, ax = plt.subplots()
    ax.plot_date(df.index, df, 'v-')
    ax.xaxis.set_minor_locator(dates.WeekdayLocator(byweekday=(1), interval=1))
    ax.xaxis.set_minor_formatter(dates.DateFormatter('%d\n%a'))
    ax.xaxis.grid(True, which="minor")
    ax.yaxis.grid()
    ax.xaxis.set_major_locator(dates.MonthLocator())
    ax.xaxis.set_major_formatter(dates.DateFormatter('\n\n\n%b\n%Y'))
    ax.legend(df.columns, loc='upper left', shadow=True)
    ax.set_title(title)
    plt.tight_layout()
    return plt

In [119]:
locations = sorted(list(df['Country/Region'].unique()))

In [120]:
widgets.interact(lambda Location: plot_time_series(aggregate_for_locations(df, Location, do_diffs=False), title='Cumulative'),
                 Location=widgets.SelectMultiple(
                        options=locations,
                        value=['US'],
                        description='Locations'
                )
)

interactive(children=(SelectMultiple(description='Locations', index=(134,), options=(' Azerbaijan', 'Afghanist…

<function __main__.<lambda>(Location)>

In [121]:
widgets.interact(lambda Location: plot_time_series(aggregate_for_locations(df, Location), title='New'),
                 Location=widgets.SelectMultiple(
                        options=locations,
                        value=['US'],
                        description='Locations'
                )
)

interactive(children=(SelectMultiple(description='Locations', index=(134,), options=(' Azerbaijan', 'Afghanist…

<function __main__.<lambda>(Location)>