In [None]:
from datetime import date

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import statsmodels.api as sm

In [None]:
path = 'COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
df_in = pd.read_csv(path)
df_in

In [None]:
# rebin values for each country
countries = set(df_in['Country/Region'])

xlength = len(countries)
ylength = df_in.shape[1] - 4

data = np.ones((ylength, xlength), dtype=float)
df = pd.DataFrame(data, index=np.arange(ylength), columns=countries)

for row in df_in.to_numpy():
    country = row[1]
    
    for i, n in enumerate(row[4:]):
        df[country][i] += n

In [None]:
a = -5
b = 5

xmin = 30
xmax = df.shape[0] + b - 1
ymin = 2
ymax = 5

fig, ax = plt.subplots()

for country, color in zip(
    [
        'Germany',
        'Italy',
        'Austria',
        'Iran',
        'Singapore',
        'Switzerland', 
        'Spain',
        'US',
        'China'
    ],
    list(mcolors.TABLEAU_COLORS.keys())
):    
    x = df[country].index
    y = np.log10(df[country])

    model = sm.OLS(y[a:], sm.add_constant(x[a:]))
    results = model.fit()
    params = results.params
    
    x_fit = np.arange(x[a], x[-1] + b)
    y_fit = params[1] * x_fit + params[0]

    ax.plot(x, y, label=country, linestyle='--', marker='s', color=color)
    ax.plot(x_fit, y_fit, linestyle='-', color=color)
    
ax.grid(True)
ax.set_xlim([xmin,xmax])
ax.set_ylim([ymin,ymax])
ax.set_xlabel('days since 2020-01-22')
ax.set_ylabel('log10 of cases')
ax.legend(loc='upper left')

fig = plt.gcf()
fig.set_size_inches(10, 10)
fig.savefig('{}.png'.format(date.today()))