In [None]:
from datetime import date

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import statsmodels.api as sm

In [None]:
path = 'COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
df = pd.read_csv(path)
df

In [None]:
countries = {}
for row in df.to_numpy():    
    country = row[1]
    province = None if isinstance(row[0], float) else row[0]
        
    x = []
    y = []
    for i, cell in enumerate(list(row[4:])):
        if cell > 20:
            x.append(i)
            y.append(cell)
            
    if x:   
        if row[1] not in countries:
            countries[country] = {}
            
        countries[country][province] = {
            'x': np.array(x),
            'y': np.log10(y)
        }

In [None]:
data = []
for country, color in zip(
    ['Germany', 'Italy', 'Austria', 'Iran', 'Brazil'],
    list(mcolors.BASE_COLORS.keys())[:5]
):
    countries[country][None]['name'] = country
    countries[country][None]['color'] = color
    data.append(countries[country][None])

In [None]:
a = -5
for country in data:
    model = sm.OLS(country['y'][a:], sm.add_constant(country['x'][a:]))
    results = model.fit()
    country['n'], country['m'] = results.params

In [None]:
b = 14

fig, ax = plt.subplots()

for country in data:
    x_fit = np.arange(country['x'][a], country['x'][-1] + b)
    y_fit = country['m'] * x_fit + country['n']
    
    ax.plot(country['x'], country['y'], country['color'] + 's', label=country['name'])
    ax.plot(x_fit, y_fit, country['color'] + '-')

ax.grid(True)
ax.set_xlabel('days since 2020-01-22')
ax.set_ylabel('log10 of cases')
ax.legend(loc='upper left')

fig = plt.gcf()
fig.set_size_inches(10, 10)
fig.savefig('{}.png'.format(date.today()))