In [None]:
import init_notebook
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
from mpl_toolkits.mplot3d import Axes3D

from scipy.integrate import solve_ivp

from sir_model import R0, model, mu, h

In [None]:
# parameters
random_state = 12345
t_0 = 0
t_end = 1000
NT = t_end-t_0
# if these error tolerances are set too high, the solution will be qualitatively (!) wrong
rtol=1e-8
atol=1e-8

# SIR model parameters
beta=11.5
A=20
d=0.1
nu=1
b=0.01 # try to set this to 0.01, 0.020, ..., 0.022, ..., 0.03
mu0 = 10   # minimum recovery rate
mu1 = 10.45  # maximum recovery rate

# information
print("Reproduction number R0=", R0(beta, d, nu, mu1))
print('Globally asymptotically stable if beta <=d+nu+mu0. This is', beta <= d+nu+mu0)

# simulation
rng = np.random.default_rng(random_state)

SIM0 = rng.uniform(low=(190, 0, 1), high=(199,0.1,8), size=(3,))
time = np.linspace(t_0,t_end,NT)
sol = solve_ivp(model, t_span=[time[0],time[-1]], y0=SIM0, t_eval=time, args=(mu0, mu1, beta, A, d, nu, b), method='LSODA', rtol=rtol, atol=atol)

fig,ax = plt.subplots(1,3,figsize=(15,5))
ax[0].plot(sol.t, sol.y[0]-0*sol.y[0][0], label='1E0*susceptible');
ax[0].plot(sol.t, 1e3*sol.y[1]-0*sol.y[1][0], label='1E3*infective');
ax[0].plot(sol.t, 1e1*sol.y[2]-0*sol.y[2][0], label='1E1*removed');
ax[0].set_xlim([0, 500])
ax[0].legend();
ax[0].set_xlabel("time")
ax[0].set_ylabel(r"$S,I,R$")

ax[1].plot(sol.t, mu(b, sol.y[1], mu0, mu1), label='recovery rate')
ax[1].plot(sol.t, 1e2*sol.y[1], label='1E2*infective');
ax[1].set_xlim([0, 500])
ax[1].legend();
ax[1].set_xlabel("time")
ax[1].set_ylabel(r"$\mu,I$")

I_h = np.linspace(-0.,0.05,100)
ax[2].plot(I_h, h(I_h, mu0, mu1, beta, A, d, nu, b));
ax[2].plot(I_h, 0*I_h, 'r:')
#ax[2].set_ylim([-0.1,0.05])
ax[2].set_title("Indicator function h(I)")
ax[2].set_xlabel("I")
ax[2].set_ylabel("h(I)")

fig.tight_layout()

In [None]:
def plot_orbit(initial_value, b):
    fig=plt.figure(figsize=(5,5))
    ax=fig.add_subplot(111,projection="3d")
    time = np.linspace(t_0,150000,NT)
    cmap = ["BuPu", "Purples", "bwr"][1]

    sol = solve_ivp(model, t_span=[time[0],time[-1]], y0=initial_value, t_eval=time, args=(mu0, mu1, beta, A, d, nu, b), method='DOP853', rtol=rtol, atol=atol)

    ax.plot(sol.y[0], sol.y[1], sol.y[2], 'r-')
    ax.scatter(sol.y[0], sol.y[1], sol.y[2], s=1, c=time, cmap='bwr')
    ax.plot(initial_value[0], initial_value[1], initial_value[2], color='#03fc13', marker='o')

    ax.set_xlabel("S")
    ax.set_ylabel("I")
    ax.set_zlabel("R")

    ax.set_title(f"SIR Orbit (b={b})") 
    fig.tight_layout()


SIM0 = [195.3, 0.052, 4.4] # what happens with this initial condition when b=0.022? -- it progresses VERY slowly. Needs t_end to be super large.
plot_orbit(SIM0, 0.022)

SIM1 = [195.7, 0.03, 3.92] # what happens with this initial condition when b=0.022?
SIM2 = [193, 0.08, 6.21] # what happens with this initial condition when b=0.022?



In [None]:
def float_range(start: float, end: float, step: float):
    current = start
    while current < end:
        yield current
        current += step

SIR_Solution = namedtuple('SIR_Solution', ('b', 's', 'i', 'r'))

def find_solutions(initial_value, plot_percentage=0.02, b_start=0.1, b_end=0.3,n_points=1000) -> SIR_Solution:
    time = np.linspace(t_0,10000000,NT)
    display(f'Solving from b={b_start} to b={b_end}. Sampling {n_points} b values.')
    solution = SIR_Solution(b=[], s=[], i=[], r=[])

    for b in float_range(b_start, b_end, (b_end-b_start)/n_points):
        sol = solve_ivp(model, t_span=[time[0],time[-1]], y0=initial_value, t_eval=time, args=(mu0, mu1, beta, A, d, nu, b), method='DOP853', rtol=rtol, atol=atol)
        plot_from = len(sol.y[0]) - int(len(sol.y[0]) * plot_percentage)
        b_array = [b] * (len(sol.y[0]) - plot_from)

        solution.b.extend(b_array)
        solution.s.extend(sol.y[0][plot_from:])
        solution.i.extend(sol.y[1][plot_from:])
        solution.r.extend(sol.y[2][plot_from:])

    return solution


sim0_sol = find_solutions(SIM0, plot_percentage=0.05, b_start=0.02, b_end=0.024, n_points=10)

In [None]:
sim1_sol = find_solutions(SIM1, plot_percentage=0.05, b_start=0.022002, b_end=0.022005, n_points=1000)

In [None]:
sim2_sol = find_solutions(SIM2, plot_percentage=0.05, b_start=0.022002, b_end=0.022005, n_points=1000)

In [None]:
def divergence_plots(solution: SIR_Solution):
    fig=plt.figure(figsize=(10,40))

    s_graph=fig.add_subplot(511)
    i_graph=fig.add_subplot(512)
    r_graph=fig.add_subplot(513)
    sir_graph = fig.add_subplot(514, projection='3d')
    sir_graph_side = fig.add_subplot(515, projection='3d')

    s_graph.scatter(solution.b, solution.s, marker=".",c=solution.b, cmap='hsv')
    s_graph.set_xlabel("b")
    s_graph.set_ylabel("Susceptible")
    s_graph.set_title("Susceptible Bifurcation Graph")

    i_graph.scatter(solution.b, solution.i, marker=".",c=solution.b, cmap='hsv')
    i_graph.set_xlabel("b")
    i_graph.set_ylabel("Infected")
    i_graph.set_title("Infected Bifurcation Graph") 

    r_graph.scatter(solution.b, solution.r, marker=".",c=solution.b, cmap='hsv')
    r_graph.set_xlabel("b")
    r_graph.set_ylabel("Removed")
    r_graph.set_title("Removed Bifurcation Graph") 

    sir_graph.scatter(solution.s, solution.i, solution.r, marker=".",c=solution.b, cmap='hsv')
    sir_graph.set_xlabel("S")
    sir_graph.set_ylabel("I")
    sir_graph.set_zlabel("R")
    sir_graph.set_title("SIR Bifurcation Graph") 

    sir_graph_side.scatter(solution.s, solution.i, solution.r, marker=".",c=solution.b, cmap='hsv')
    sir_graph_side.set_xlabel("S")
    sir_graph_side.set_ylabel("I")
    sir_graph_side.set_zlabel("R")
    sir_graph_side.set_title("SIR Bifurcation Graph Side") 
    sir_graph_side.view_init(90, -90)


    fig.tight_layout()

divergence_plots(sim0_sol)