In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime as dt
from datetime import datetime
import numpy as np
import warnings
from scipy.optimize import curve_fit

from ts_viz import TimeSeriesViz

In [None]:
# func = lambda x, a, b: a * np.exp(b * x) 
# func_name = 'exp'

func = lambda t, K, x, t0: K * np.power(t, x) * np.exp(-t / t0) 
func_name = 'power'

def show_fit(series, func=False, func_name=False, title=None, pred=7, sma=False, figsize=(10, 6)):
    x = np.arange(len(series))
    index = pd.date_range(series.index.min(), periods=(len(series) + pred), freq='D')
    if func:    
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            popt, pcov = curve_fit(func, x, series.values, maxfev=2000, p0=(1, 1, 5))
        y = func(np.arange(len(index)), *popt)
        fit = pd.Series(y, index=index)
        err = np.sqrt(np.sum((fit[:-pred] - series) ** 2) / len(series))
        print(f'Fit error: {err:.02f}')
        print(f'Predicted volume error: {1 - fit[:-pred].sum() / series.sum():.02%}')

    fig, ax = plt.subplots(figsize=figsize)
    if func:
        ax.plot(index, y, label=f'{func_name} fit', color='tab:red', lw=2)
        ax.axvline(index[np.argmax(fit)], color='tab:red')

    ax.bar(series.index, series, label=series.name)
    
    if sma:
        sma_series = series.rolling(sma, center=True).mean()
        ax.plot(sma_series.index, sma_series, label=f'Moving avg ({sma} days)', color='tab:red', lw=2)
    
#     locator = mdates.DayLocator(interval=7)
#     locator = mdates.WeekdayLocator(byweekday=mdates.MO)
    locator = mdates.MonthLocator(bymonthday=[1, 5, 10, 15, 20, 25])
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_minor_locator(mdates.DayLocator(interval=1))
    ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator))
    ax.set_ylim(0, None)
    if title:
        ax.set_title(title)
    plt.legend(loc='upper left')
    if func:
        return (fig, ax), popt, pcov, fit
    else:
        return (fig, ax), None, None, None

In [None]:
today = (datetime.now() - dt.timedelta(days=0)).strftime('%Y-%m-%d')
world_url = f'https://www.ecdc.europa.eu/sites/default/files/documents/COVID-19-geographic-disbtribution-worldwide-{today}.xlsx'
print(world_url)
date_parser = lambda x: pd.datetime.strptime(x, '%Y-%m-%d')
df_raw = pd.read_excel(world_url, parse_dates=['dateRep'], index_col=0)
print(df_raw.index.max())
df_raw = df_raw[['cases', 'deaths', 'countriesAndTerritories']]
df_raw.columns = ['cases', 'deaths', 'country']
df_raw = df_raw.sort_index()
df_raw[df_raw['country'] == 'Hungary'].tail()

In [None]:
df = df_raw.groupby('country').sum()

In [None]:
df.to_csv(f'data/world-{today}.csv')

In [None]:
def show_new(series, title, figsize=(16, 10)):
    fig, ax = plt.subplots(figsize=figsize)
    ax.bar(series.index, series, align='center')
    ax.set_title(title)
    ax.yaxis.grid(True, which='major')
    locator = mdates.DayLocator(interval=1)
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator))
    ax.set_xlim((series.index.min() + pd.Timedelta(days=.5), series.index.max() + pd.Timedelta(days=.5)))

In [None]:
class CountryViz:

    def __init__(self, country, start, df_raw, fit_func=False, fit_func_name=False):
        self.country = country
        self.start = start
        self.fit_func = fit_func
        self.fit_func_name = fit_func_name
        self.df =  df_raw[(df_raw['country'] == country) & (df_raw.index > start)].sort_index()
        self.cases = self.df['cases'].copy()
        self.deaths = self.df['deaths'].copy()

    def _show_diag(self, series, name):
        (fig, ax), popt, pcov, y = show_fit(series, self.fit_func, self.fit_func_name, 
                                      title=f'{name} in {self.country}', sma=7, figsize=(16, 10))
        if self.fit_func:
            K, x, t0 = popt
            print(f'{self.country}: K = {K:.08f}, x = {x:.02f}, t0 = {t0:.02f}')
            print(f'first day: {self.cases.index[0]:%Y-%m-%d}')
            print(f'peak as of fit: {y.index[np.argmax(y)]:%Y-%m-%d}')
        print(f'\n{name} data:')
        print(series.tail())
        return fig, ax
        
        
    def show_cases(self):
        return self._show_diag(self.cases, 'New daily cases')
        
    def show_deaths(self):
        return self._show_diag(self.deaths, 'Daily deaths')

In [None]:
france = CountryViz('France', '2020-02-28', df_raw)
fig, ax = france.show_cases()
plt.show()

In [None]:
france.show_deaths()

In [None]:
hun = CountryViz('Hungary', '2020-02-28', df_raw)
hun.show_cases()

In [None]:
hun.show_deaths()

In [None]:
spain = CountryViz('Spain', '2020-02-28', df_raw)
fig, ax = spain.show_cases()
ax.set_ylim(0, 20000)
plt.show()

In [None]:
spain.show_deaths()

In [None]:
uk = CountryViz('United_Kingdom', '2020-02-28', df_raw)
uk.show_cases()

In [None]:
uk.show_deaths()

In [None]:
usa = CountryViz('United_States_of_America', '2020-02-28', df_raw)
usa.show_cases()

In [None]:
usa.show_deaths()

In [None]:
it = CountryViz('Italy', '2020-02-28', df_raw)
it.show_cases()

In [None]:
it.show_deaths()

In [None]:
sw = CountryViz('Sweden', '2020-02-28', df_raw)
sw.show_cases()

In [None]:
sw.show_deaths()

In [None]:
ger = CountryViz('Germany', '2020-02-28', df_raw)
ger.show_cases()

In [None]:
ger.show_deaths()

In [None]:
ice = CountryViz('Iceland', '2020-02-28', df_raw)
ice.show_cases()

In [None]:
sing = CountryViz('Singapore', '2020-02-28', df_raw)
sing.show_cases()

In [None]:
sk = CountryViz('South_Korea', '2020-02-18', df_raw)
sk.show_cases()

In [None]:
sk.show_deaths()

In [None]:
CountryViz('Japan', '2020-02-15', df_raw).show_cases()

In [None]:
fig, ax = CountryViz('Japan', '2020-02-15', df_raw).show_deaths()
ax.set_ylim(0, 40)

In [None]:
viz = CountryViz('Brazil', '2020-02-28', df_raw)
viz.show_cases()

In [None]:
viz.show_deaths()

In [None]:
viz = CountryViz('Russia', '2020-02-28', df_raw)
viz.show_cases()

In [None]:
viz.show_deaths()

In [None]:
viz = CountryViz('Romania', '2020-02-28', df_raw)
viz.show_cases()

In [None]:
viz = CountryViz('Ukraine', '2020-02-28', df_raw)
viz.show_cases()

In [None]:
iran = CountryViz('Iran', '2020-02-28', df_raw)
iran.show_cases()

In [None]:
iran.show_deaths()

In [None]:
cz = CountryViz('Czechia', '2020-02-28', df_raw)
cz.show_cases()

In [None]:
cz.show_deaths()

In [None]:
CountryViz('South_Africa', '2020-02-15', df_raw).show_cases()

In [None]:
fig, ax = CountryViz('Chile', '2020-02-15', df_raw).show_cases()
ax.set_ylim(None, 8000)

In [None]:
CountryViz('Australia', '2020-02-15', df_raw).show_cases()

In [None]:
CountryViz('New_Zealand', '2020-02-15', df_raw).show_cases()

In [None]:
CountryViz('Canada', '2020-02-15', df_raw).show_cases()

In [None]:
india = CountryViz('India', '2020-02-15', df_raw)
india.show_cases()

In [None]:
india.show_deaths()

In [None]:
fig, ax = CountryViz('China', '2020-01-01', df_raw).show_cases()
ax.set_ylim(0, 5000)

In [None]:
sorted(df_raw['country'].unique())

In [None]:
from matplotlib.patches import Ellipse

fig, ax = plt.subplots(figsize=(16, 10))
country = 'China'
start = '2020-01-15'
end = '2020-03-01'
country_df = df_raw[(df_raw['country'] == country) & (df_raw.index > start) & (df_raw.index < end)].sort_index()
ax.set_title('Total cases of COVID-19 in China')
cases_exp = country_df['cases'].expanding().sum()
ax.plot(cases_exp)

ann_start = '2020-02-12'
ann_stop = '2020-02-13'

ann_center_x = (mdates.datestr2num(ann_start) + mdates.datestr2num(ann_stop)) / 2
ann_center_y = (cases_exp[ann_start] + cases_exp[ann_stop]) / 2
ann_zoom = 1.5
ann_height = (cases_exp[ann_stop] - cases_exp[ann_start]) * ann_zoom
timespan = mdates.datestr2num(ann_stop) - mdates.datestr2num(ann_start)
ann_width = timespan * 2 * ann_zoom
ellipse = Ellipse(xy=(ann_center_x, ann_center_y), width=ann_width, height=ann_height, edgecolor='r', fc='None', lw=2)
ax.add_patch(ellipse)

ax.annotate('China changed\ncase definition',
            xy=(ann_center_x, ann_center_y), xycoords='data',
            xytext=(50, -100), textcoords='offset points',
            size=15,
            arrowprops=dict(arrowstyle="simple",
                            fc="0.6", ec="none",
                            patchB=ellipse,
                            connectionstyle="arc3,rad=0.3"))

plt.show()