In [None]:
"""predator_prey.ipynb"""
# Cell 1

from __future__ import annotations

import typing

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from scipy.integrate import solve_ivp  # type: ignore

if typing.TYPE_CHECKING:
    from typing import Any

    from matplotlib.axes import Axes
    from numpy.typing import NDArray

%matplotlib widget

# fmt: off
def model(time: float, state_vector: tuple[float, float],
    alpha: float, beta: float, delta: float, gamma: float,
) -> tuple[float, float]:
    prey: float
    pred: float
    prey, pred = state_vector
    d_prey: float = alpha * prey - beta * prey * pred
    d_pred: float = delta * prey * pred - gamma * pred
    return d_prey, d_pred
# fmt: on


def plot(ax: Axes) -> None:
    # Lotka-Volterra parameters
    alpha: float = 2.0  # Prey birth rate
    beta: float = 1.1  # Prey death rate
    delta: float = 1.0  # Pred birth rate
    gamma: float = 0.9  # Pred death rate

    # Set initial conditions (% of population)
    prey_initial: float = 1.0
    predator_initial: float = 0.5

    # Set model duration (dimensionless)
    time_initial = 0
    time_final = 20

    # Estimate model behavior
    sol: Any = solve_ivp(
        model,
        (time_initial, time_final),
        [prey_initial, predator_initial],
        max_step=0.01,
        args=[alpha, beta, delta, gamma],
    )
    time_steps: NDArray[np.float_] = np.array(sol.t, dtype=np.float_)
    prey: NDArray[np.float_]
    pred: NDArray[np.float_]
    prey, pred = np.array(sol.y * 100, np.float_)

    ax.plot(time_steps, pred, label="predator", color="red", linewidth=2)
    ax.plot(time_steps, prey, label="prey", color="blue", linewidth=2)

    ax.set_title("Predator-Prey Model (Lotka-Volterra)")
    ax.set_xlabel("Time")
    ax.set_ylabel("Population Size (%)")

    ax.legend(loc="upper right")

    ax.xaxis.set_major_locator(MultipleLocator(5))
    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(50))
    ax.yaxis.set_minor_locator(MultipleLocator(10))


def main() -> None:
    plt.close("all")
    plt.figure(" ")
    plot(plt.axes())
    plt.show()


main()