In [2]:
import numpy as np
import pandas as pd

from pathlib import Path
data_dir = Path('.')

import os
os.listdir(data_dir)

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_dark"

# enable the png render for viewing graphs in the github repo
pio.renderers.default = "notebook_connected"
#pio.renderers.default = "png"

from plotly.subplots import make_subplots

#from scipy.optimize import curve_fit


In [None]:
# Susceptible equation
def dS_dt(S, I, R_t, T_inf):
    return -(R_t / T_inf) * I * S

# Exposed equation
def dE_dt(S, E, I, R_t, T_inf, T_inc):
    return (R_t / T_inf) * I * S - (T_inc**-1) * E

# Infected equation
def dI_dt(I, E, T_inc, T_inf):
    return (T_inc**-1) * E - (T_inf**-1) * I

# Recovered/Remove/deceased equation
def dR_dt(I, T_inf):
    return (T_inf**-1) * I

def SEIR_model(t, y, R_t, T_inf, T_inc):
    
    if callable(R_t):
        reproduction = R_t(t)
    else:
        reproduction = R_t
        
    S, E, I, R = y
    
    S_out = dS_dt(S, I, reproduction, T_inf)
    E_out = dE_dt(S, E, I, reproduction, T_inf, T_inc)
    I_out = dI_dt(I, E, T_inc, T_inf)
    R_out = dR_dt(I, T_inf)
    
    return [S_out, E_out, I_out, R_out]

In [3]:
world_data = pd.read_csv('covid_19_clean_complete.csv', parse_dates=['Date'])
world_data.head()

Unnamed: 0,Province/State,Country/Region,Lat,Long,Date,Confirmed,Deaths,Recovered
0,,Afghanistan,33.0,65.0,2020-01-22,0,0,0.0
1,,Albania,41.1533,20.1683,2020-01-22,0,0,0.0
2,,Algeria,28.0339,1.6596,2020-01-22,0,0,0.0
3,,Andorra,42.5063,1.5218,2020-01-22,0,0,0.0
4,,Angola,-11.2027,17.8739,2020-01-22,0,0,0.0


In [4]:
world_data.rename(columns={
                     'Province/State':'State',
                     'Country/Region':'Country',
                    }, inplace=True)

world_data['Active'] = world_data['Confirmed'] - world_data['Deaths'] - world_data['Recovered']

world_data['Country'] = world_data['Country'].replace('Mainland China', 'China')

world_data[['State']] = world_data[['State']].fillna('')

world_data[['Confirmed', 'Deaths', 'Recovered', 'Active']] = world_data[['Confirmed', 'Deaths', 'Recovered', 'Active']].fillna(0)

In [5]:
world_data.head()

Unnamed: 0,State,Country,Lat,Long,Date,Confirmed,Deaths,Recovered,Active
0,,Afghanistan,33.0,65.0,2020-01-22,0,0,0.0,0.0
1,,Albania,41.1533,20.1683,2020-01-22,0,0,0.0,0.0
2,,Algeria,28.0339,1.6596,2020-01-22,0,0,0.0,0.0
3,,Andorra,42.5063,1.5218,2020-01-22,0,0,0.0,0.0
4,,Angola,-11.2027,17.8739,2020-01-22,0,0,0.0,0.0


In [6]:
def plot_data(country):
    
    grouped_data = world_data[world_data['Country'] == country].reset_index()
    grouped_data = grouped_data.groupby('Date')['Date', 'Confirmed', 'Deaths'].sum().reset_index()
    
    fig = px.line(grouped_data, x="Date", y="Confirmed", 
              title=f"Confirmed Cases in {country} Over Time", 
              color_discrete_sequence=['#F61067'],
              height=500
             )
    fig.show()
    
    fig = px.line(grouped_data, x="Date", y="Deaths", 
              title=f"Deaths in {country} Over Time", 
              color_discrete_sequence=['#F61067'],
              height=500
             )
    fig.show()

In [7]:
plot_data('US')