In [303]:
from datetime import date, datetime, timedelta
import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
import matplotlib.pyplot as plt 
import matplotlib.dates as dates
import ipywidgets as widgets
from scipy.optimize import curve_fit


register_matplotlib_converters()
%matplotlib inline

In [304]:
df = None
last = datetime.now() - timedelta(hours=1)


def get_data():
    """
    Get the latest data. Just return the cached copy
    if less than one hour has elapsed.
    """
    global last, df
    now = datetime.now()
    if now - last < timedelta(hours=1):
        return df

    last = now    
    start = date(2020, 1, 22)
    end = date.today()

    df = pd.DataFrame()
    while start <= end:
        try:
            df_day = pd.read_csv(f'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_daily_reports/{start.month:02d}-{start.day:02d}-{start.year}.csv')
            df_day['Date'] = start
            df = df.append(df_day)
        except Exception:
            pass
        start += timedelta(days=1)
    
    del df['Last Update']
    df = df.replace({'Country/Region': {'Mainland China': 'China'}})
    df.fillna({'Province/State': ''}, inplace=True)
    for col in ['Confirmed', 'Deaths', 'Recovered']:
        df.fillna({col: 0}, inplace=True)
    return df

In [305]:
df = get_data()
df.tail(5)

Unnamed: 0,Confirmed,Country/Region,Date,Deaths,Latitude,Longitude,Province/State,Recovered
304,0.0,Jersey,2020-03-22,0.0,49.19,-2.11,,0.0
305,0.0,Puerto Rico,2020-03-22,1.0,18.2,-66.5,,0.0
306,0.0,Republic of the Congo,2020-03-22,0.0,-1.44,15.556,,0.0
307,0.0,The Bahamas,2020-03-22,0.0,24.25,-76.0,,0.0
308,0.0,The Gambia,2020-03-22,0.0,13.4667,-16.6,,0.0


In [306]:
def aggregate_for_locations(df, locations, sub_locs=None, do_diffs=True, fields=None):
    if fields is None:
        fields = ['Confirmed', 'Deaths', 'Recovered']
    cols = ['Date']
    cols.extend(fields)
    
    if isinstance(locations, str):
        locations = [locations]
        
    if sub_locs:
        in_loc = df[(df['Country/Region'].isin(locations)) & (df['Province/State'].isin(sub_locs))]
    else:
        in_loc = df[df['Country/Region'].isin(locations)]
        
    result = in_loc[cols].groupby('Date').sum()
    
    if do_diffs:
        for f in fields:
            result[f] = result[f].diff()
            result.fillna({f: 0}, inplace=True)
    return result


In [307]:
def plot_time_series(df, title):
    fig, ax = plt.subplots()
    ax.plot_date(df.index, df, 'v-')
    ax.xaxis.set_minor_locator(dates.WeekdayLocator(byweekday=(1), interval=1))
    ax.xaxis.set_minor_formatter(dates.DateFormatter('%d\n%a'))
    ax.xaxis.grid(True, which="minor")
    ax.yaxis.grid()
    ax.xaxis.set_major_locator(dates.MonthLocator())
    ax.xaxis.set_major_formatter(dates.DateFormatter('\n\n\n%b\n%Y'))
    ax.legend(df.columns, loc='upper left', shadow=True)
    ax.set_title(title)
    plt.tight_layout()
    return plt

In [308]:
locations = sorted(list(df['Country/Region'].unique()))

In [309]:
states = {
    'AL': ['Alabama'],
    'AK': ['Alaska'],
    'AZ': ['Arizona'],
    'AR': ['Arkansas'],
    'CA': ['California'],
    'CO': ['Colorado'],
    'CT': ['Connecticut'],
    'D.C.': ['District of Columbia'],
    'DE': ['Delaware'],
    'FL': ['Florida'],
    'GA': ['Georgia'],
    'HI': ['Hawaii'],
    'ID': ['Idaho'],
    'IL': ['Illinois'],
    'IN': ['Indiana'],
    'IA': ['Iowa'],
    'KS': ['Kansas'],
    'KY': ['Kentucky'],
    'LA': ['Louisiana'],
    'ME': ['Maine'],
    'MD': ['Maryland'],
    'MA': ['Massachusetts'],
    'MI': ['Michigan'],
    'MN': ['Minnesota'],
    'MS': ['Mississippi'],
    'MO': ['Missouri'],
    'MT': ['Montana'],
    'NE': ['Nebraska'],
    'NV': ['Nevada'],
    'NH': ['New Hampshire'],
    'NJ': ['New Jersey'],
    'NM': ['New Mexico'],
    'NY': ['New York'],
    'NC': ['North Carolina'],
    'ND': ['North Dakota'],
    'OH': ['Ohio'],
    'OK': ['Oklahoma'],
    'OR': ['Oregon'],
    'PA': ['Pennsylvania'],
    'PR': ['Puerto Rico'],
    'RI': ['Rhode Island'],
    'SC': ['South Carolina'],
    'SD': ['South Dakota'],
    'TN': ['Tennessee'],
    'TX': ['Texas'],
    'UT': ['Utah'],
    'VT': ['Vermont'],
    'VA': ['Virginia'],
    'WA': ['Washington'],
    'WV': ['West Virginia'],
    'WI': ['Wisconsin'],
    'WY': ['Wyoming'],
    'U.S.': [] # other territories
}

# Extend the above with all the counties/cities that were broken out separately
for s in df[df['Country/Region'] == 'US']['Province/State'].unique():
    i = s.find(',')
    if i < 0:
        continue
    st = s[i+2:i+5].strip()
    if st not in states:
        st += '.'
        if st not in states:
            print(f'Failed with {s}/{st}')
            continue
    states[st].append(s)


In [310]:
picker = widgets.SelectMultiple(
                        options=locations,
                        value=['US'],
                        description='Locations'
                )

In [311]:
widgets.interact(lambda Location: plot_time_series(aggregate_for_locations(df, Location, do_diffs=False), 
                                                   title=f'{" ".join(Location)}: Cumulative'),
                 Location=picker)

interactive(children=(SelectMultiple(description='Locations', index=(200,), options=(' Azerbaijan', 'Afghanist…

<function __main__.<lambda>(Location)>

In [312]:
widgets.interact(lambda Location: plot_time_series(aggregate_for_locations(df, Location), 
                                                   title=f'{" ".join(Location)}: New'),
                 Location=picker)

interactive(children=(SelectMultiple(description='Locations', index=(200,), options=(' Azerbaijan', 'Afghanist…

<function __main__.<lambda>(Location)>

In [313]:
picker2 = widgets.Select(
                        options=states.keys(),
                        value='WA',
                        description='States'
                )
widgets.interact(lambda Location: plot_time_series(aggregate_for_locations(df, 'US', sub_locs=states[Location], do_diffs=False), 
                                                   title=f'{Location}: Cumulative'),
                 Location=picker2)

interactive(children=(Select(description='States', index=48, options=('AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT'…

<function __main__.<lambda>(Location)>

In [332]:
def exponential(x, a):
    " Exponential"
    return a ** x


def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)


def sigmoid_init(x, y):
    return [max(y), np.median(x), 1, min(y)] # initial guess


# It takes on average 5 days to show symptoms. Typically 14 days from that to death.
# So it seems reasonable to say that people infected on day x will die around day x + 19
# So if we extrapolate the death rate by 15 days, we can approximate how many are infected 
# now after adjusting for fatality rate.

def predict_from_death_rate(region=None, fatality_rate=3, death_time=19, fn=exponential, init=None):
    if region is None:
        region = 'US'
        
    deaths = aggregate_for_locations(df, region, do_diffs=False)['Deaths']
    deaths = list(deaths[deaths > 0])
    
    x = range(1, len(deaths)+1)
    y = deaths
    
    if not y:
        print(f"Insuffienct data for {region}")
        return
    
    if init:
        p0 = init(x, y)
        a, pcov = curve_fit(fn, x, y, p0, method='dogbox')
    else:
        a, pcov = curve_fit(fn, x, y)

    #a, pcov = curve_fit(fn, x[:-1], y[:-1])  # useful for seeing if latest data point is still aligned with curve


    plt.figure()
    plt.plot(x, y, 'ko', label=f"{region} Actual Deaths")
    plt.plot(x, fn(x, *a), 'r-', label="Fitted Curve")
    plt.legend()
    plt.show()
    
    forecast_deaths = fn(x[-1] + death_time, *a)
    predicted_infected = forecast_deaths * 100 / fatality_rate
    print(f'Predicted infected {int(predicted_infected)} leading to {int(forecast_deaths)} deaths in {death_time} days')


In [333]:
widgets.interact(lambda Location: predict_from_death_rate(Location),
                 Location=picker)

interactive(children=(SelectMultiple(description='Locations', index=(71,), options=(' Azerbaijan', 'Afghanista…

<function __main__.<lambda>(Location)>

In [334]:
# Sigmoid version should be better for places that are flattening the curve or reaching saturation

widgets.interact(lambda Location: predict_from_death_rate(Location, fn=sigmoid, init=sigmoid_init),
                 Location=picker)

interactive(children=(SelectMultiple(description='Locations', index=(71,), options=(' Azerbaijan', 'Afghanista…

<function __main__.<lambda>(Location)>