In [None]:
"""epidemiology_sir.ipynb"""
# Cell 1

from __future__ import annotations

import typing

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from scipy.integrate import solve_ivp  # type: ignore

if typing.TYPE_CHECKING:
    from typing import Any

    from matplotlib.axes import Axes
    from numpy.typing import NDArray

%matplotlib widget


def model(
    time: float, state_vector: tuple[float, float, float], beta: float, delta: float
) -> tuple[float, float, float]:
    s: float
    i: float
    r: float
    s, i, r = state_vector
    d_s: float = -beta * s * i
    d_i: float = beta * s * i - delta * i
    d_r: float = delta * i
    return d_s, d_i, d_r


def plot(ax: Axes) -> None:
    # Kermack-McKendrick Parameters
    beta: float = 0.003  # Infection rate
    delta: float = 1.0  # Recovery rate

    # Set initial conditions
    s_initial: float = 1000  # Susceptible people
    i_initial: float = 1.0  # Infected people
    r_initial: float = 0.0  # Recovered people

    # Set model duration (months)
    time_initial: float = 0.0
    time_final: float = 10.0

    # Estimate model behavior
    sol: Any = solve_ivp(
        model,
        (time_initial, time_final),
        [s_initial, i_initial, r_initial],
        max_step=0.01,
        args=[beta, delta],
    )
    time_steps: NDArray[np.float_] = sol.t
    s: NDArray[np.float_]
    i: NDArray[np.float_]
    r: NDArray[np.float_]
    s, i, r = np.array(sol.y, dtype=np.float_)

    ax.plot(time_steps, s, label="Susceptible", linewidth=2)
    ax.plot(time_steps, i, label="Infected", linewidth=2)
    ax.plot(time_steps, r, label="Recovered", linewidth=2)

    ax.set_title("Epidemiology (Kermack-McKendrick)")
    ax.set_xlabel("Time (months)")
    ax.set_ylabel("Population")

    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.yaxis.set_minor_locator(AutoMinorLocator())
    ax.legend()


def main() -> None:
    plt.close("all")
    plt.figure(" ")
    plot(plt.axes())
    plt.show()


main()