In [2]:
import numpy as np
from scipy.integrate import solve_ivp
import plotly.graph_objects as go
import pandas as pd
import plotly.io as pio

# Import local python style
import os, sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import relative_pathing
import plotly_styles.walkintheforest_styles

In [3]:
def sir_model(t, y, r0, gamma, n):
    """ Solves system using the solve_ivp definition and function
    
    Arguments
        y: List of variable solutions
        t: List of time steps to evaluate at
        beta: Infection rate constant
        gamma: Recovery rate constant
        n: Total number in the population
    
    Returns:
        ds: Change in S 
        di: Change in I
        dr: Change in R
        
    """
    s, i, r = y
    beta = r0 * gamma
    ds = - beta * s * i / n
    di = beta * s * i / n - gamma * i
    dr = gamma * i

    return ds, di, dr

In [4]:
# Required constants and T
num_tsteps = 365
t = np.linspace(0,365,num_tsteps)
pop_size = 1000000
y0 = [pop_size - 1, 1, 0] # Start with a single infected individual
gamma = 0.1 # Use this for a 10 day recovery period (Modeling SARS-CoV-2)

# Explicit definitions to help generate slider steps
r0_start = 1
r0_stop = 5
r0_step = 0.1
num_r0 = int((r0_stop - r0_start)/r0_step + 1)
r0_list = np.linspace(r0_start, r0_stop, num_r0) #arange(0.5, 5, 0.1)

pd_column_names = ['t', 'R0', 'S', 'I', 'R']

## Data structure
data = pd.DataFrame(columns=pd_column_names)

In [5]:
# Pre-generate all of the data
for r0 in r0_list:
    sir_solver = solve_ivp(sir_model, [0, 365], y0, args=(r0, gamma, pop_size), dense_output=True)
    sir_solution = sir_solver.sol(t)
    temp_df = pd.DataFrame(list(zip(t, [r0] * num_tsteps, sir_solution[0,:], sir_solution[1,:], sir_solution[2,:])), columns=pd_column_names)
    data = pd.concat([data, temp_df])

In [6]:
# Generate the list of all of the traces
s_trace_list = [go.Scatter(x=t, y = data[data['R0'] == r0]['S'], visible=False, name = 'Susceptible') for r0 in data['R0'].unique()]
i_trace_list = [go.Scatter(x=t, y = data[data['R0'] == r0]['I'], visible=False, name = 'Infected') for r0 in data['R0'].unique()]
r_trace_list = [go.Scatter(x=t, y = data[data['R0'] == r0]['R'], visible=False, name = 'Recovered') for r0 in data['R0'].unique()]

# Choose the starting visible trace
starting_r0 = np.where(r0_list == 5)[0][0]
s_trace_list[starting_r0]['visible'] = True
i_trace_list[starting_r0]['visible'] = True
r_trace_list[starting_r0]['visible'] = True

In [20]:
# Add all of the data to a figure
fig = go.Figure(s_trace_list + i_trace_list + r_trace_list)

# Generating all of the steps
steps = []
for i in range(num_r0):
    # Define the steps
    step = dict(
        method = 'update',
        label =  str(round(r0_list[i], 2)),
        args = [{'visible': [False] * len(fig.data)},
                {"title" : "SIR Model: R<sub>0</sub> = " + str(round(r0_list[i], 2))}]
    )
    
    # Update the visible traces for each step
    step['args'][0]['visible'][i] = True
    step['args'][0]['visible'][i+num_r0] = True
    step['args'][0]['visible'][i+2*num_r0] = True
    
    steps.append(step)

# Generate slider
sliders = [dict(steps = steps,
                active = starting_r0,
                currentvalue={'visible' : False},
                pad = {"t" : 70})]

# Final Figure generation
fig.update_layout(sliders=sliders,
                  title = "SIR Model: R<sub>0</sub> = " + str(round(r0_list[i], 2)),
                  template='walkintheforest-dark', autosize=True,
                  legend = {"orientation": "h", "xanchor": "center", "yanchor": "top", "x": 0.5, "y": 1.1})
fig.update_yaxes(title='Number of Individuals',
                 showgrid=True)
fig.update_xaxes(title='Days',
                 showgrid=True)

# Figure preview
fig.show()

In [None]:
# Write figure to HTML
pio.write_html(fig, "figures/basic-sir-r0.html", auto_open=False)