In [None]:
import os, sys
import numpy as np
import sympy as sm

import scipy as sp
from scipy.integrate import solve_ivp

import matplotlib as mp
import matplotlib.pyplot as plt

%matplotlib inline
# enable pretty printing of equations
# sm.init_printing()

In [None]:
# update fonts
FONTSIZE = 20
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : FONTSIZE}
mp.rc('font', **font)
mp.rc('xtick', labelsize='x-small')
mp.rc('ytick', labelsize='x-small')
#mp.rc('text', usetex=True)

Here, we solve the SIR model:
$$
\begin{aligned}
    \frac{dS}{dt}&= -\beta S I \\
    \frac{dI}{dt}&=\beta S I - \alpha  I \\
    \frac{dR}{dt}&= \alpha I 
\end{aligned}
$$

In [None]:
# Define a class ModelSetup that contains all the parameters, intial conditions, time spans, data, etc, to run the model
class ModelSetup:
    
    # Define parameters
    
    beta = .01 #0.4
    alpha = .1 #0.25
     
    # Define initial conditions

    N0 = 1000
    I0 = 10
    R0 = 0
    S0 = N0-I0-R0
    init_cond = np.array([S0, I0, R0])
    
    tspan = (0,300)#(times[0],times[-1])
    #tdata = times
    #ydata = data
    t_eval = np.arange(start=tspan[0], stop=tspan[-1], step=0.1)
    

In [None]:
MS = ModelSetup

In the following cell, we (1) define a function that computes the derivatives dS/dt, dI/dt, dR/dt that define the SIR model, and (2) define a function SolveSIRModel that solves the model given to input parameters alpha and beta, using the function 'solve_ivp'.

In [None]:
# Define model equations in function dydt.  It returns the values of the RHS of the ODE at (t,y).
def dydt(t, y):
    
    alpha = MS.alpha
    beta  = MS.beta


    S = y[0]
    I = y[1]
    R = y[2]
    
    dS =  -beta*S*I
    dI = beta*S*I - alpha*I
    dR = alpha*I
    
    return np.array([dS,dI,dR])

# Solve the SIR model using solve_ivp
def SolveSIRModel(alpha, beta):
    MS.alpha = alpha
    MS.beta  = beta

    
    return solve_ivp(dydt, 
                     MS.tspan, 
                     MS.init_cond, 
                     t_eval=MS.t_eval)

Now let's plot the solution curves and the data by defining a function plotSolution that takes the output of SolveSIRModel as input.

In [None]:
def plotSolution(soln):
    # set size of figure
    plt.figure(figsize=(8, 5))

    #plt.errorbar(z, x, yerr=dx, fmt=".k")

    t = soln.t
    S = soln.y[0]*MS.kappa
    I = soln.y[1]*MS.kappa

    #plt.plot(t, S, c='royalblue', linewidth=2, label='$S(t)$')
    plt.plot(t, I, c='darkorange',  linewidth=2, label='$\kappa I(t)$')
    plt.plot(MS.tdata,MS.ydata,'ko',label='Data')
    plt.legend(fontsize=16)
    
    xmin = t[0]
    xmax = t[-1]
    plt.xlim(xmin, xmax)
    
    ymin = 1.e-4
    ymax = int(1.1*max(max(I),max(MS.ydata)))
    plt.ylim(ymin, ymax)
    #plt.yscale('log')
    plt.grid(True, which="both", linestyle='-')
    plt.xlabel("Time, $t$", fontsize=16)
    plt.ylabel("Number infected",fontsize=16)
    
    plt.savefig('SIRplot.pdf')

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=b323b509-75b9-423f-bb04-ac84a2ad5053' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>