# CartPole System - Lyapunov Function Visualization Demo

This notebook demonstrates Lyapunov function visualization capabilities for the CartPole system:
- Training a simple neural Lyapunov function
- 2D contour plots of V(x) and ΔV(x)
- 3D surface plots of Lyapunov function
- ROA analysis and metrics
- Trajectory overlays showing V(x) evolution

The CartPole is a classic underactuated system modeling an inverted pendulum on a moving cart.
- State: [cart_position, pole_angle, cart_velocity, pole_angular_velocity]
- Control: [horizontal_force]
- Parameters: [cart_mass, pole_mass, pole_length, gravitational_acceleration, friction_coefficient]

In [2]:
# imports

import sys

# Add repository root to Python path
sys.path.append("../")

import torch
import torch.nn as nn
import numpy as np

from neural_lyapunov_training.symbolic_systems import CartPole
from neural_lyapunov_training.symbolic_dynamics import (
    GenericDiscreteTimeSystem,
    IntegrationMethod,
    LinearController,
)
from neural_lyapunov_training.lyapunov_roa_visualization import (
    plot_lyapunov_2d,
    plot_lyapunov_3d_surface,
)
from neural_lyapunov_training.roa_metrics import (
    compute_lyapunov_difference_metrics_qmc_sobol,
    print_lyapunov_difference_metrics,
)

In [3]:
# simple Lyapunov neural network and associated function for training
class SimpleLyapunovNetwork(nn.Module):
    """Simple quadratic-like Lyapunov function for demonstration"""

    def __init__(self, state_dim, hidden_dim=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )

        # Initialize to approximate quadratic form
        with torch.no_grad():
            for layer in self.net:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_normal_(layer.weight, gain=0.1)
                    nn.init.zeros_(layer.bias)

    def forward(self, x):
        """Evaluate V(x), ensuring V(0) ≈ 0 and V(x) > 0 for x ≠ 0"""
        V_raw = self.net(x)
        x_norm_sq = (x**2).sum(dim=-1, keepdim=True)
        V = x_norm_sq + 0.1 * torch.relu(V_raw)
        return V


def train_simple_lyapunov(system, controller, num_samples=5000, num_epochs=100):
    """Train a simple Lyapunov function by sampling and enforcing conditions"""
    print("\nTraining simple Lyapunov function...")

    lyapunov_nn = SimpleLyapunovNetwork(system.nx, hidden_dim=32)
    optimizer = torch.optim.Adam(lyapunov_nn.parameters(), lr=1e-3)

    for epoch in range(num_epochs):
        x_samples = torch.randn(num_samples, system.nx) * 0.5
        V_current = lyapunov_nn(x_samples)
        u_samples = controller(x_samples)
        x_next = system(x_samples, u_samples)
        V_next = lyapunov_nn(x_next)

        V_origin = lyapunov_nn(torch.zeros(1, system.nx))
        loss_origin = V_origin**2
        delta_V = V_next - V_current
        loss_decrease = torch.relu(delta_V + 0.01).mean()
        loss_magnitude = V_current.mean() * 0.01
        loss = loss_origin + loss_decrease + loss_magnitude

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0 or epoch == 0:
            with torch.no_grad():
                violations = (delta_V > 0).sum().item()
                print(
                    f"  Epoch {epoch+1:3d}: Loss={loss.item():.4f}, "
                    f"V(0)={V_origin.item():.6f}, "
                    f"Violations={violations}/{num_samples} "
                    f"({100*violations/num_samples:.1f}%)"
                )

    print("  ✓ Training complete")
    return lyapunov_nn

In [4]:
# system definition
cartpole = CartPole(m_cart=1.0, m_pole=0.1, length=0.5, gravity=9.81, friction=0.1)
cartpole.print_equations(simplify=True)

CartPole
State Variables: [x, theta, x_dot, theta_dot]
Control Variables: [F]
System Order: 2
Dimensions: nx=4, nu=1, ny=2

Dynamics: dx/dt = f(x, u)
  dx/dt = (-F - 0.05*theta_dot**2*sin(theta) + 0.1*x_dot + 0.4905*sin(2*theta))/(0.1*cos(theta)**2 - 1.1)
  dtheta/dt = (-2.0*F*cos(theta) - 0.05*theta_dot**2*sin(2*theta) + 0.2*x_dot*cos(theta) + 21.582*sin(theta))/(0.1*cos(theta)**2 - 1.1)

Output: y = h(x)
  y[0] = x
  y[1] = theta


In [5]:
# defining discretization
dt = 0.02
discrete_system = GenericDiscreteTimeSystem(
    cartpole, dt, integration_method=IntegrationMethod.RK4
)
discrete_system.print_info()

Discrete-Time System: CartPole

Discretization:
  Time step (dt):        0.02
  Integration method:    RK4
  Position integration:  RK4

Dimensions:
  State dimension (nx):    4
  Control dimension (nu):  1
  Output dimension (ny):   2
  System order:            2
  Generalized coords (nq): 2

Equilibrium:
  x_eq = [0. 0. 0. 0.]
  u_eq = [0.]

----------------------------------------------------------------------
Continuous-Time Dynamics (before discretization):
----------------------------------------------------------------------
CartPole
State Variables: [x, theta, x_dot, theta_dot]
Control Variables: [F]
System Order: 2
Dimensions: nx=4, nu=1, ny=2

Dynamics: dx/dt = f(x, u)
  dx/dt = (-F - 0.05*theta_dot**2*sin(theta) + 0.1*x_dot + 0.4905*sin(2*theta))/(0.1*cos(theta)**2 - 1.1)
  dtheta/dt = (-2.0*F*cos(theta) - 0.05*theta_dot**2*sin(2*theta) + 0.2*x_dot*cos(theta) + 21.582*sin(theta))/(0.1*cos(theta)**2 - 1.1)

Output: y = h(x)
  y[0] = x
  y[1] = theta

-------------------------

In [6]:
# showing the difference between the print_info and summary methods
print(discrete_system.summary())

CartPole (nx=4, nu=1, ny=2, order=2, dt=0.0200, RK4)
  Continuous stable: False, Discrete stable: False


In [7]:
# demonstration of how to check if the pre-specified system equilibrium
# is stable
is_eq, max_deriv = cartpole.check_equilibrium(
        cartpole.x_equilibrium, cartpole.u_equilibrium
        )

print(f"Equilibrium valid: {is_eq}")
print(f"Open-loop stable: {cartpole.is_stable_equilibrium()}")

Equilibrium valid: True
Open-loop stable: False


In [8]:
# demonstration of discrete time LQR control gain
# calculation
Q = np.diag([10.0, 100.0, 1.0, 1.0])
R = np.array([[0.1]])
K, S = discrete_system.dlqr_control(Q, R)
print(f"LQR gain K: {K}")

LQR gain K: [[ -8.65491224 -29.17743136  -8.82003307  -2.54078433]]


In [9]:
# constructing a linear controller with LQR control gain
controller = LinearController(K, cartpole.x_equilibrium, cartpole.u_equilibrium)

# converting to floats as I believe there's some type mismatching
# with floats and doubles downstream of this
controller.K = controller.K.float()
controller.x_eq = controller.x_eq.float()
controller.u_eq = controller.u_eq.float()

In [10]:
# training the Lyapunov neural network
lyapunov_nn = train_simple_lyapunov(
    discrete_system, controller, num_samples=10000, num_epochs=1000
)

V_eq = lyapunov_nn(torch.zeros(1, discrete_system.nx))
print(f"\nV(equilibrium) = {V_eq.item():.6f} (should be ≈ 0)")


Training simple Lyapunov function...
  Epoch   1: Loss=0.4683, V(0)=0.000000, Violations=6596/10000 (66.0%)
  Epoch  20: Loss=0.4577, V(0)=0.000000, Violations=6614/10000 (66.1%)
  Epoch  40: Loss=0.4651, V(0)=0.000000, Violations=6602/10000 (66.0%)
  Epoch  60: Loss=0.4495, V(0)=0.000000, Violations=6580/10000 (65.8%)
  Epoch  80: Loss=0.4556, V(0)=0.000000, Violations=6568/10000 (65.7%)
  Epoch 100: Loss=0.4524, V(0)=0.000000, Violations=6602/10000 (66.0%)
  Epoch 120: Loss=0.4468, V(0)=0.000000, Violations=6650/10000 (66.5%)
  Epoch 140: Loss=0.4607, V(0)=0.000000, Violations=6607/10000 (66.1%)
  Epoch 160: Loss=0.4620, V(0)=0.000000, Violations=6615/10000 (66.2%)
  Epoch 180: Loss=0.4579, V(0)=0.000000, Violations=6609/10000 (66.1%)
  Epoch 200: Loss=0.4470, V(0)=0.000000, Violations=6623/10000 (66.2%)
  Epoch 220: Loss=0.4644, V(0)=0.000000, Violations=6694/10000 (66.9%)
  Epoch 240: Loss=0.4641, V(0)=0.000000, Violations=6587/10000 (65.9%)
  Epoch 260: Loss=0.4623, V(0)=0.000000

In [11]:
# Analyze in (θ, θ̇ ) subspace - state indices (1, 3)
state_indices_analysis = (1, 3)
state_limits_analysis = (
    (-np.pi / 3, np.pi / 3),  # θ bounds
    (-2.0, 2.0),  # θ̇  bounds
)

# Estimate ρ with multiplier = 0.9
boundary_samples = []
for theta in [state_limits_analysis[0][0], state_limits_analysis[0][1]]:
    for theta_dot in np.linspace(
        state_limits_analysis[1][0], state_limits_analysis[1][1], 20
    ):
        state = torch.zeros(4, dtype=torch.float32)
        state[1] = theta
        state[3] = theta_dot
        boundary_samples.append(state)

with torch.no_grad():
    V_boundary = lyapunov_nn(torch.stack(boundary_samples))
    rho = V_boundary.min().item() * 0.9

print(f"Estimated ρ = {rho:.4f}")

metrics = compute_lyapunov_difference_metrics_qmc_sobol(
    lyapunov_nn,
    controller,
    discrete_system,
    state_limits_analysis,
    rho,
    num_samples=50000,
    state_indices=state_indices_analysis,
    compute_discrepancy_metric=True,
)

print_lyapunov_difference_metrics(metrics, title="CartPole ROA Analysis")

Estimated ρ = 0.9969





                             CartPole ROA Analysis                              

                                 Configuration                                  
--------------------------------------------------------------------------------
  Method: qmc_sobol
  Total samples: 65,536
  ROA threshold (ρ): 0.996933
  Stability threshold: ΔV ≤ 0.000000
  Domain volume: 8.377580
  Discrepancy: 3.160758e-10 (lower = better uniformity)

                             Region Classifications                             
--------------------------------------------------------------------------------
  ROA (V(x) ≤ ρ):
    Volume: 3.131877
    Coverage: 37.38%
    Samples: 24,500

  Decreasing region (ΔV ≤ 0.0):
    Volume: 3.582356
    Coverage: 42.76%
    Samples: 28,024

  Verified ROA (V ≤ ρ AND ΔV ≤ 0.0):
    Volume: 0.921028
    Coverage: 10.99%
    Samples: 7,205

                      Lyapunov Difference Statistics (ΔV)                       
-------------------------------------------

In [12]:
# simulating trajectories with linear control
initial_conditions = [
    torch.tensor([0.0, np.deg2rad(30), 0.0, 0.0], dtype=torch.float32),
    torch.tensor([0.0, np.deg2rad(-30), 0.0, 0.0], dtype=torch.float32),
    torch.tensor([0.0, np.deg2rad(20), 0.0, 1.0], dtype=torch.float32),
    torch.tensor([0.0, np.deg2rad(-20), 0.0, -1.0], dtype=torch.float32),
]

trajectory_names = ["30° right", "30° left", "20° right, θ̇ = +1", "20° left, θ̇ = -1"]

trajectories = [
    discrete_system.simulate(ic, controller=controller, horizon=200, return_controls=True)[0]
    for ic in initial_conditions
]
controls = [
    discrete_system.simulate(ic, controller=controller, horizon=200, return_controls=True)[1]
    for ic in initial_conditions
]

In [13]:
# 2D Lyapunov Function Visualization, position and angle
state_limits_x_theta = ((-1.0, 1.0), (-np.pi / 3, np.pi / 3))
state_indices_x_theta = (0, 1)

plot_lyapunov_2d(
    lyapunov_nn,
    controller,
    discrete_system,
    state_limits_x_theta,
    state_indices=state_indices_x_theta,
    state_names=("Cart Position x [m]", "Pole Angle θ [rad]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (x-θ Plane)",
    show=False,
    colorscale="Plasma",
    trajectory_colorscale="D3",
)

Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)


In [26]:
# 2D Lyapunov Function Visualization, position and linear velocity
state_limits_x_xdot = ((-1.0, 1.0), (-1.0, 1.0))
state_indices_x_xdot = (0, 2)

plot_lyapunov_2d(
    lyapunov_nn,
    controller,
    discrete_system,
    state_limits_x_xdot,
    state_indices=state_indices_x_xdot,
    state_names=("Cart Position x [m]", "Cart Velocity ẋ [m/s]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (x-ẋ Plane)",
    show=False,
    colorscale="Viridis",
    trajectory_colorscale="D3",
)

Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)


In [27]:
# 2D Lyapunov Function Visualization, angle and angular velocity
state_limits_theta_thetadot = ((-np.pi / 3, np.pi / 3), (-np.pi / 3, np.pi / 3))
state_indices_theta_thetadot = (1, 3)

plot_lyapunov_2d(
    lyapunov_nn,
    controller,
    discrete_system,
    state_limits_theta_thetadot,
    state_indices=state_indices_theta_thetadot,
    state_names=("Pole Angle θ [rad]", "Angular Velocity θ̇  [rad/s]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (θ-θ̇  Plane)",
    show=False,
    colorscale="Plasma",
    trajectory_colorscale="D3",
)

Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)


In [14]:
# 3D Lyapunov Function Visualization, position and angle, only function
plot_lyapunov_3d_surface(
    lyapunov_nn = lyapunov_nn,
    controller_nn = controller,
    dynamics_system = discrete_system,
    state_limits = state_limits_x_theta,
    state_indices=state_indices_x_theta,
    state_names=("Cart Position x [m]", "Pole Angle θ [rad]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (x-θ Plane)",
    show=False,
    colorscale="Plasma",
    trajectory_colorscale="D3",
)

In [15]:
# 3D Lyapunov Function Visualization, position and angle, including derivative/difference
plot_lyapunov_3d_surface(
    lyapunov_nn = lyapunov_nn,
    controller_nn = controller,
    dynamics_system = discrete_system,
    state_limits = state_limits_x_theta,
    state_indices=state_indices_x_theta,
    state_names=("Cart Position x [m]", "Pole Angle θ [rad]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (x-θ Plane)",
    show=False,
    colorscale="Plasma",
    trajectory_colorscale="D3",
    show_derivative = True,
    surface_opacity = 0.75
)

Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)


In [28]:
# 3D Lyapunov Function Visualization, angle and angular velocity, including derivative/difference
plot_lyapunov_3d_surface(
    lyapunov_nn = lyapunov_nn,
    controller_nn = controller,
    dynamics_system = discrete_system,
    state_limits = state_limits_theta_thetadot,
    state_indices=state_indices_theta_thetadot,
    state_names=("Pole Angle θ [rad]", "Angular Velocity θ̇  [rad/s]"),
    rho=rho,
    grid_resolution=120,
    trajectories=trajectories,
    title="CartPole: Lyapunov Function (θ-θ̇  Plane)",
    show=False,
    colorscale="Plasma",
    trajectory_colorscale="D3",
    show_derivative = True,
    surface_opacity = 0.75
)

Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)
Derivative computation function was unable to infer controller output dimension.
Current assumption is scalar control (u_dim = 1)


In [16]:
# trajectory analysis
for i, (traj, name) in enumerate(zip(trajectories, trajectory_names)):
        with torch.no_grad():
            V_traj = lyapunov_nn(traj).squeeze()
        delta_V_traj = V_traj[1:] - V_traj[:-1]

        print(f"\n  Trajectory {i+1}: {name}")
        print(f"    Initial V(x₀)           = {V_traj[0].item():.4f}")
        print(f"    Final V(x_f)            = {V_traj[-1].item():.4f}")
        print(f"    Decrease                = {V_traj[0].item() - V_traj[-1].item():.4f}")
        print(f"    Min ΔV                  = {delta_V_traj.min().item():.6f}")
        print(f"    Max ΔV                  = {delta_V_traj.max().item():.6f}")
        print(f"    Mean ΔV                 = {delta_V_traj.mean().item():.6f}")
        print(f"    Stays in ROA (below ρ)  = {(V_traj <= rho).all()}")
        print(f"    End at equilibrium      = {torch.allclose(traj[-1], discrete_system.x_equilibrium)}")


  Trajectory 1: 30° right
    Initial V(x₀)           = 0.2742
    Final V(x_f)            = 0.0000
    Decrease                = 0.2741
    Min ΔV                  = -0.663968
    Max ΔV                  = 1.393335
    Mean ΔV                 = -0.001371
    Stays in ROA (below ρ)  = False
    End at equilibrium      = False

  Trajectory 2: 30° left
    Initial V(x₀)           = 0.2742
    Final V(x_f)            = 0.0000
    Decrease                = 0.2741
    Min ΔV                  = -0.663968
    Max ΔV                  = 1.393335
    Mean ΔV                 = -0.001371
    Stays in ROA (below ρ)  = False
    End at equilibrium      = False

  Trajectory 3: 20° right, θ̇ = +1
    Initial V(x₀)           = 1.1218
    Final V(x_f)            = 0.0000
    Decrease                = 1.1218
    Min ΔV                  = -0.777924
    Max ΔV                  = 0.472926
    Mean ΔV                 = -0.005609
    Stays in ROA (below ρ)  = False
    End at equilibrium      = False

  Tr

In [17]:
# 2D trajectory visualization
all_trajs = torch.stack(trajectories)
all_controls = torch.stack(controls)
discrete_system.plot_trajectory(
    all_trajs,
    control_sequence=all_controls,
    state_names=[
        "Cart Position x [m]",
        "Pole Angle θ [rad]",
        "Cart Velocity ẋ [m/s]",
        "Angular Velocity θ̇  [rad/s]",
    ],
    control_names = ["Horizontal Force [N]"],
    trajectory_names=trajectory_names,
    title="CartPole: State Evolution",
    colorway="Vivid",
    show=False,
)

Plotting 4 trajectories...


In [18]:
# 3D trajectory plot
discrete_system.plot_trajectory_3d(
        all_trajs,
        state_indices=(0, 1, 3),
        state_names=(
            "Cart Position x [m]",
            "Pole Angle θ [rad]",
            "Angular Velocity θ̇ [rad/s]",
        ),
        trajectory_names=trajectory_names,
        title="CartPole: 3D Trajectories",
        colorway="Vivid",
        show=False,
    )

In [19]:
# single 3D trajectories are colored by time
discrete_system.plot_trajectory_3d(
        all_trajs[1].squeeze(),
        state_indices=(0, 1, 3),
        state_names=(
            "Cart Position x [m]",
            "Pole Angle θ [rad]",
            "Angular Velocity θ̇ [rad/s]",
        ),
        trajectory_names=trajectory_names,
        title="CartPole: 3D Trajectory",
        colorway="Vivid",
        show=False,
    )

In [20]:
# 2D angular phase portrait
discrete_system.plot_phase_portrait_2d(
    all_trajs,
    state_indices=(1, 3),
    state_names=("Pole Angle θ [rad]", "Angular Velocity θ̇  [rad/s]"),
    trajectory_names=trajectory_names,
    title="CartPole: Phase Portrait (θ-θ̇)",
    colorway="Vivid",
    show=False,
)

In [21]:
# 2D position phase portrait
discrete_system.plot_phase_portrait_2d(
    all_trajs,
    state_indices=(0, 2),
    state_names=("Cart Position x [m]", "Cart Velocity ẋ [m/s]"),
    trajectory_names=trajectory_names,
    title="CartPole: Phase Portrait (x-ẋ)",
    colorway="Vivid",
    show=False,
)

In [22]:
# 3D phase portrait, multiple trajectories
# (effectively the same as the 3D trajectory plotting)
discrete_system.plot_phase_portrait_3d(
        all_trajs,
        state_indices=(0, 1, 3),
        state_names=(
            "Cart Position x [m]",
            "Pole Angle θ [rad]",
            "Angular Velocity θ̇ [rad/s]",
        ),
        trajectory_names=trajectory_names,
        title="CartPole: 3D Phase Portrait",
        colorway="Vivid",
        show=False,
        show_time_markers=True,
        marker_interval=15,
    )