In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt  
import datetime as dt
import requests

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import chart_studio.plotly as py

from scipy.integrate import odeint

import os
import warnings
warnings.filterwarnings("ignore")

In [None]:
def sir_rates(y,  t, population_size, daily_reproductive_number, daily_recovery_number):
    
    S, I, R = y
    
    dSdt = (-daily_reproductive_number * I) * (S / population_size)
    
    dIdt = (I * daily_reproductive_number) * (S / population_size) - (I*daily_recovery_number) 
    
    dRdt = (I*daily_recovery_number) 
       
    return dSdt, dIdt, dRdt

def sir_model(simulation_length,
                population_size,
                infected_start,
                reproductive_number,
                recovery_time):
    
    daily_reproductive_number = reproductive_number/recovery_time
    daily_recovery_number = 1/recovery_time 
    
    t = [i for i in range(1,simulation_length+1)]
    
    I0 = infected_start
    S0 = population_size - I0
    R0 = 0
    
    y0 = S0, I0, R0
    
    ret = odeint(sir_rates, y0, t, args=(population_size, 
                                               daily_reproductive_number, 
                                               daily_recovery_number))
    S, I, R = ret.T
    sir_model = pd.DataFrame()
    sir_model['Day'] = [i for i in range(1, simulation_length + 1)] 
    sir_model['S'] = S.astype(int)
    sir_model['I'] = I.astype(int)
    sir_model['R'] = R.astype(int)
    
    fig = go.Figure()
    
    trace0 = go.Scatter(x=sir_model['Day'],
                        y=sir_model['S'],
                        mode='lines',
                        name='Susceptible',
                        marker_color = 'blue')


    trace1 = go.Scatter(x=sir_model['Day'],
                             y=sir_model['I'],
                             mode='lines',
                             name='Infected',
                             marker_color = 'red')

    trace2 = go.Scatter(x=sir_model['Day'],
                             y=sir_model['R'],
                             mode='lines',
                             name='Recovered',
                             marker_color = 'green')


    layout = go.Layout(width=900,
                       height=450,
                       showlegend=False)


    layout.update(xaxis =dict(range=[0, simulation_length], autorange=False),
                  yaxis =dict(range=[0, population_size], autorange=False));

    fig = go.Figure(data=[trace0,trace1,trace2],layout=layout)

    fig.update_yaxes(title_text="<b>Number of Individuals</b>")
    fig.update_xaxes(title_text="<b>Days</b>")

    fig.update_layout(
        title_text="<b>Reproductive Number: {}</b>".format(reproductive_number)
    )

    fig.show()
    
    print("Max Infected at One Time: {}".format(sir_model['I'].max()))
    print("Total Infected: {}".format(population_size - sir_model['S'].min()))  

## Basic SIR Model

In [None]:
model = sir_model(simulation_length = 50,
                  population_size = 10000,
                  infected_start = 100,
                  reproductive_number = 2.5,
                  recovery_time = 10)

In [7]:
#Shelter in Place on day X

   
def get_shelter_rate(t,
                     daily_reproductive_number,
                     shelter_start,
                     shelter_end,
                     shelter_daily_reproductive_number):
    
    if ((t > shelter_start) and (t < shelter_end)):
        
        return shelter_daily_reproductive_number
    
    else:
        
        return daily_reproductive_number


def shelter_sir_rates(y, 
                      t, 
                      population_size, 
                      daily_reproductive_number, 
                      daily_recovery_number,
                      shelter_start,
                      shelter_end,
                      shelter_daily_reproductive_number):
    
    S, I, R = y
    
    dSdt = (-get_shelter_rate(t,daily_reproductive_number,shelter_start,shelter_end,shelter_daily_reproductive_number) * I) * (S / population_size) 
    
    dIdt = (I * get_shelter_rate(t,daily_reproductive_number,shelter_start,shelter_end,shelter_daily_reproductive_number)) * (S / population_size) - (I*daily_recovery_number)  
    
    dRdt = (I*daily_recovery_number) 
       
    return dSdt, dIdt, dRdt


def sir_shelter(number_of_days,
                      population_size,
                      infected_start,
                      reproductive_number,
                      recovery_time,
                      shelter_start,
                      shelter_end,
                      shelter_reproductive_number):
    
    
    daily_recovery_number = 1./recovery_time
    daily_reproductive_number = reproductive_number/recovery_time
    shelter_daily_reproductive_number = shelter_reproductive_number/recovery_time
    
    t = [i for i in range(1,number_of_days+1)]
    
    I0 = infected_start
    S0 = population_size - I0
    R0 = 0
    
    y0 = S0, I0, R0
    
    ret = odeint(shelter_sir_rates, y0, t, args=(population_size,daily_reproductive_number,daily_recovery_number,shelter_start,shelter_end,shelter_daily_reproductive_number))
    S, I, R = ret.T
    sir_model = pd.DataFrame()
    sir_model['Day'] = [i for i in range(1, number_of_days + 1)] 
    sir_model['S'] = S.astype(int)
    sir_model['I'] = I.astype(int)
    sir_model['R'] = R.astype(int)
    
    trace0 = go.Scatter(x=sir_model['Day'][:1],
                        y=sir_model['S'][:1],
                        mode='lines',
                        name='Susceptible',
                        marker_color = 'blue')


    trace1 = go.Scatter(x=sir_model['Day'][:1],
                             y=sir_model['I'][:1],
                             mode='lines',
                             name='Infected',
                             marker_color = 'red')

    trace2 = go.Scatter(x=sir_model['Day'][:1],
                             y=sir_model['R'][:1],
                             mode='lines',
                             name='Recovered',
                             marker_color = 'green')

    frames = [dict(data= [dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['S'][:k+1]),
                         dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['I'][:k+1]),
                         dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['R'][:k+1]),],
                        traces= [0,1,2], 
                       )for k in list(sir_model['Day'])] 

    layout = go.Layout(width=900,
                       height=450,
                       showlegend=False,
                       hovermode='closest',
                       updatemenus=[dict(type='buttons', showactive=False,
                                    y=1.05,
                                    x=1.15,
                                    xanchor='right',
                                    yanchor='bottom',
                                    pad=dict(t=0, r=10),
                                    buttons=[dict(label='Play',
                                                  method='animate',
                                                  args=[None, 
                                                        dict(frame=dict(duration=3, 
                                                                        redraw=False),
                                                             transition=dict(duration=0),
                                                             fromcurrent=True,
                                                             mode='immediate')])])])


    layout.update(xaxis =dict(range=[sir_model['Day'].min(), sir_model['Day'].max()], autorange=False),
                  yaxis =dict(range=[0, population_size], autorange=False));

    fig = go.Figure(data=[trace0,trace1,trace2],frames=frames, layout=layout)

    fig.update_yaxes(title_text="<b>Number of Individuals</b>")
    fig.update_xaxes(title_text="<b>Days</b>")
    
    fig.show()


## Shelter in Place Model

In [8]:
model = sir_shelter(number_of_days = 100,
                      population_size = 10000,
                      infected_start = 100,
                      reproductive_number = 4.0,
                      recovery_time = 10,
                      shelter_start = 10,
                      shelter_end = 100,
                      shelter_reproductive_number = 0.0)

In [11]:
def get_vaccine(t,vaccine_start,vaccine_rate):
    
    if t > vaccine_start:
        return vaccine_rate
    else:
        return 0


def vaccine_rates(y,  t, population_size, daily_reproductive_number, daily_recovery_number,vaccine_start,vaccine_rate):
    
    S, I, R = y
    
    dSdt = ((-daily_reproductive_number * I) * (S / population_size)) - (S*get_vaccine(t,vaccine_start,vaccine_rate)) 
    
    dIdt = (I * daily_reproductive_number) * (S / population_size) - (I*daily_recovery_number) 
    
    dRdt = (I*daily_recovery_number) + (S*get_vaccine(t,vaccine_start,vaccine_rate)) 
       
    return dSdt, dIdt, dRdt

def vaccine_model(simulation_length,
                    population_size,
                    infected_start,
                    reproductive_number,
                    recovery_time,
                    vaccine_start,
                    vaccine_rate):
    
    daily_reproductive_number = reproductive_number/recovery_time
    daily_recovery_number = 1/recovery_time 
    
    t = [i for i in range(1,simulation_length+1)]
    
    I0 = infected_start
    S0 = population_size - I0
    R0 = 0
    
    y0 = S0, I0, R0
    
    ret = odeint(vaccine_rates, y0, t, args=(population_size, 
                                             daily_reproductive_number, 
                                             daily_recovery_number,
                                             vaccine_start,
                                             vaccine_rate))
    S, I, R = ret.T
    sir_model = pd.DataFrame()
    sir_model['Day'] = [i for i in range(1, simulation_length + 1)] 
    sir_model['S'] = S.astype(int)
    sir_model['I'] = I.astype(int)
    sir_model['R'] = R.astype(int)
    
    trace0 = go.Scatter(x=sir_model['Day'][:1],
                        y=sir_model['S'][:1],
                        mode='lines',
                        name='Susceptible',
                        marker_color = 'blue')


    trace1 = go.Scatter(x=sir_model['Day'][:1],
                             y=sir_model['I'][:1],
                             mode='lines',
                             name='Infected',
                             marker_color = 'red')

    trace2 = go.Scatter(x=sir_model['Day'][:1],
                             y=sir_model['R'][:1],
                             mode='lines',
                             name='Recovered',
                             marker_color = 'green')

    frames = [dict(data= [dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['S'][:k+1]),
                         dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['I'][:k+1]),
                         dict(type='scatter',
                               x=sir_model['Day'][:k+1],
                               y=sir_model['R'][:k+1]),],
                        traces= [0,1,2], 
                       )for k in list(sir_model['Day'])] 

    layout = go.Layout(width=900,
                       height=450,
                       showlegend=False,
                       hovermode='closest',
                       updatemenus=[dict(type='buttons', showactive=False,
                                    y=1.05,
                                    x=1.15,
                                    xanchor='right',
                                    yanchor='bottom',
                                    pad=dict(t=0, r=10),
                                    buttons=[dict(label='Play',
                                                  method='animate',
                                                  args=[None, 
                                                        dict(frame=dict(duration=3, 
                                                                        redraw=False),
                                                             transition=dict(duration=0),
                                                             fromcurrent=True,
                                                             mode='immediate')])])])


    layout.update(xaxis =dict(range=[sir_model['Day'].min(), sir_model['Day'].max()], autorange=False),
                  yaxis =dict(range=[0, population_size], autorange=False));

    fig = go.Figure(data=[trace0,trace1,trace2],frames=frames, layout=layout)

    fig.update_yaxes(title_text="<b>Number of Individuals</b>")
    fig.update_xaxes(title_text="<b>Days</b>")
    
    fig.show()


In [12]:
vaccine_model(simulation_length = 100,
                population_size = 1000000,
                infected_start = 1000,
                reproductive_number = 4.0,
                recovery_time = 10,
                vaccine_start = 0,
                vaccine_rate = .0)