# Experimenting with SINDy-PI

Here we begin trying to replicate the experiments in [a SINDy-PI
notebook][nb], but with the cartpole problem.

[nb]: https://github.com/dynamicslab/pysindy/blob/master/examples/9_sindypi_with_sympy.ipynb

In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
#| include: false
import os
import re
from concurrent import futures as f
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pysindy as ps
import sympy as sp
from scipy.integrate import solve_ivp

from cartpole import Cartpole, Problem, ThinnedModel, cache_model, load_model


In [None]:
#| include: false

# Set up logging/debugging: >= WARNING except for our code (INFO).
import logging

logging.basicConfig(level="WARNING", format="%(message)s")
log = logging.getLogger("cartpole")
log.setLevel("INFO")


In [None]:
#| include: false

# Set up plotting styles
try:
    from eriplots import eri_style  # type: ignore

    plt.style.use(eri_style)
except ImportError:
    log.info("Note: eriplots styles aren't available")


## Set up the data sets

Following the SINDy-PI paper, we have three different sets of
samples with different forcing functions: train, val, and test.


In [None]:
# Training data
train_pars = Cartpole(f=lambda t: -0.2 + 0.5 * np.sin(6 * t))
train = Problem(train_pars, [0.3, 0, 1, 0], (0, 16), step=0.01)

train.plot()
plt.suptitle("Training inputs")
plt.tight_layout()

In [None]:
# Validation data
val_pars = Cartpole(f=lambda t: -1 + 1 * np.sin(t) + 3 * np.sin(2 * t))
val = Problem(val_pars, [0.1, 0, 0.1, 0], (0, 2))

val.plot()
plt.suptitle("Validation inputs")
plt.tight_layout()

In [None]:
test_pars = Cartpole(f=lambda t: -0.5 + 0.2 * np.sin(t) + 0.3 * np.sin(2 * t))
test = Problem(test_pars, [np.pi, 0, 0, 0], (0, 2))

test.plot()
plt.suptitle("Test inputs")
plt.tight_layout()

## Solve the models

For now, let's only solve the $\phi$ equations. Note that we have
to solve this problem as a coupled set of 1st-order ODEs (it
seems) to get functions of $\dot\phi$. I'd love to not have to do
that, and I'd love to limit the trig functions so that we only
have functions of $\phi$ instead of $\dot\phi$, but that will
take more work.

Instead, what we are doing is providing the solver with $\phi$
(named `p`) and $\dot\phi$ (named `q`) and ask for solutions up
to the first order in derivatives (`derivative_order = 1`, $\dot
p \equiv \dot\phi$ and $\dot q \equiv \ddot \phi$). I'd rather
only put in $\phi$ and get out $\ddot\phi$ (`derivative_order =
2`), but SINDy doesn't seem to mix in single-dot terms if we do
that. Maybe I'm missing something.


In [None]:
# Solve the problem for \phi and store the solution
if True:
    library = ps.PDELibrary(
        library_functions=[
            lambda x: x,
            lambda x: x**2,
            lambda x: np.sin(x),
            lambda x: np.cos(x),
            lambda x: np.sin(x) ** 2,
            lambda x: np.sin(x) * np.cos(x),
        ],
        function_names=[
            lambda x: x,
            lambda x: x + x,
            lambda x: f"sin({x})",
            lambda x: f"cos({x})",
            lambda x: f"sin({x})sin({x})",
            lambda x: f"sin({x})cos({x})",
        ],
        derivative_order=1,
        interaction_only=True,
        implicit_terms=True,
        temporal_grid=train.times(),
    )

    # Solving these as coupled 1st order ODEs seems like
    # the only way to get quadratic functions of phi-dot.

    # fit \ddot\phi ~ f(\phi, \dot\phi)
    model_p = ps.SINDy(ps.SINDyPI(), deepcopy(library), feature_names=["p", "q"])
    model_p.fit(train.simy()[[0, 2]].T, train.times())
    
    # Cache to save running time
    destdir = Path("../../learnedModels/SINDyPI")
    if not os.path.exists(destdir):
        os.makedirs(destdir)
    cache_model(model_p, destdir / "model_phi.pkl")

model_p = load_model(destdir / "model_phi.pkl")


Here's where we do the dance described in the [SINDy-PI
notebook][nb]:

- Convert a solved equation into a set of symbols
- Solve the symbolic equations for $\ddot\phi$
- Cast back into a form we can use downstream

[nb]: https://github.com/dynamicslab/pysindy/blob/master/examples/9_sindypi_with_sympy.ipynb

In [None]:
#| include: false
def setup_equations(symbols: list, coefs: np.ndarray, round: int = 10) -> list:
    """Convert a model output to a set of symbolic equations to be solved."""
    log.info(f"Setting up {len(symbols)} equation{'s' * (len(symbols)!=1)}")

    with f.ProcessPoolExecutor() as x:
        jobs = {}
        for i, (sym, coef) in enumerate(zip(symbols, coefs, strict=True)):
            j = x.submit(sp.solve, sp.Eq(sym, symbols @ np.around(coef, round)), sym)
            jobs[j] = i

        equations = {}
        for j in f.as_completed(jobs):
            eqn = ["None"] if not j.result() else j.result()
            assert len(eqn) == 1, "Expected a single symbolic equation"
            equations[jobs[j]] = eqn[0]

    return [equations[k] for k in sorted(equations)]


def scrub_features(eqns: list[str], maps: dict[str, str]) -> list[str]:
    """Convert SINDy feature strings to symbols."""

    # Consolidate dot names, e.g., 'q' = 'p_t' in a set of ODEs
    neweqs = [str(eq).replace(k, v) for eq in eqns for k, v in maps.items()]

    # The above sometimes makes 'p_t_t' instead of 'p_tt'. Fix it.
    neweqs = [eq.replace("_t_t", "_tt") for eq in neweqs]

    # Convert, e.g., 'pq' to 'p*q' or 'cos(x)x' to 'cos(x)*x'
    RX = re.compile(r"(?<=[spqrt)])([spqrc])")
    neweqs = [RX.sub(r"*\1", eq) for eq in neweqs]

    return neweqs


def solve_for_target(y: str, ftrs: list[str], eqns: list, maps: dict[str, str]) -> list:
    """Solve a set of symbolic equations, individually, for the target."""
    neweqs = scrub_features(eqns, maps)
    newftrs = scrub_features(ftrs, maps)
    ysym = sp.symbols(y)

    equations = []

    # Do this in parallel
    with f.ProcessPoolExecutor() as exe:
        jobs = []
        for sym, eqn in zip(newftrs, neweqs, strict=True):
            if eqn == "None":
                continue
            s_eqn = sp.sympify(eqn)
            s_sym = sp.sympify(sym)
            jobs += [exe.submit(sp.solve, sp.Add(s_eqn, -s_sym), ysym)]
        for j in f.as_completed(jobs):
            soln = j.result()
            if soln is not None:
                assert len(soln) == 1, "Expected a single symbolc equation"
                equations += [soln[0]]

    return equations


def solve_equations(model: ThinnedModel, target: str, maps: dict[str, str]):
    """Solve equations produced by a SINDy-PI model."""
    features = model.features
    coefs = model.coefficients
    
    log.info("Setting up model with features:")
    for f in features:
        log.info(f" - {f}")

    equations = setup_equations([sp.symbols(f) for f in features], coefs)
    equations = solve_for_target(target, features, equations, maps)

    return equations


In [None]:
# This takes about 1 minute on my Intel MBP
peqns = solve_equations(model_p, "p_tt", {"q": "p_t"})


## Generate solutions from each solved equation

Each equation is a solution for the 2D 1st-order ODE problem we
gave the solver, so we get back one solution for $\dot\phi$ and
another for $\ddot\phi$. We use those to compute ODE solutions
and plot how well they reproduce the expected behavior.

In [None]:
# Generate solutions for each equation
pivps = []
for eqn in peqns:
    fn = sp.lambdify(sp.symbols("p p_t"), eqn, "numpy")

    def wrapper(_, arr):
        p, p_t = arr
        p_tt = fn(p, p_t)
        return np.stack([p_t, p_tt])

    ivp = solve_ivp(
        wrapper,
        val.time,
        val.init[::2],
        method="Radau",
        t_eval=val.times(),
        rtol=1e-6,
        atol=1e-6,
        dense_output=True,
    )

    if ivp.y.shape[-1] != val.times().shape[-1]:
        log.warning("Incorrect shape, skipping solution")
        continue

    rmse = np.sqrt(np.mean((val.simy()[0] - ivp.y[0]) ** 2))
    log.info(f"{str(eqn)[:60]}: {rmse:.5f}")
    pivps += [(ivp, rmse)]


In [None]:
#| echo: false
fig, ax = plt.subplots()
for (ivp, _) in pivps:
    ax.plot(val.times(), ivp.sol(val.times())[0], c="0.4", alpha=0.5, lw=1)
ax.plot(val.times(), val.simy()[0], zorder=0)
ax.set_ylim(0, 10)
pass