In [5]:
import numpy as np
from scipy.integrate import odeint
from itertools import combinations_with_replacement
from math import factorial
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
import importlib
import SINDy
importlib.reload(SINDy)


# Lorenz system definitions
def lorenz(state, t, sigma=10, rho=28, beta=8/3):
    x, y, z = state
    x_dot = sigma * (y - x)
    y_dot = x * (rho - z) - y
    z_dot = x * y - beta * z
    return x_dot, y_dot, z_dot

def lorenz_derivative(states, lorenz, dt=1):
    x_dot, y_dot, z_dot = lorenz(states[0, :], dt)
    dx_dt = np.array((x_dot, y_dot, z_dot))
    for i in range(1, states.shape[0]):
        x_dot, y_dot, z_dot = lorenz(states[i, :], dt)
        dx_dt = np.vstack((dx_dt, np.array((x_dot, y_dot, z_dot))))
    return dx_dt

In [6]:
# Time step and initial values
dt = 0.002
t = np.arange(0,10,dt)
x_0 = np.array([-8,8,27])
states = odeint(lorenz, x_0, t)
dx_dt = lorenz_derivative(states, lorenz, dt)

In [7]:
# Plot the matrix of states and dx_dt
df_states = pd.DataFrame(states, columns=['x','y','z'])
df_dxdt = pd.DataFrame(dx_dt, columns=['x','y','z'])

             
plot1 = px.scatter_3d(df_states, x='x', y='y', z='z', color=df_states['z'])
plot1.update_traces(marker_size = 1.5)

In [8]:
plot2 = px.scatter_3d(df_dxdt, x='x', y='y', z='z', color=df_dxdt['z'])
plot2.update_traces(marker_size = 1.5)

In [9]:
# Run SINDy and get the output equations
sindy = SINDy.SINDy() 
sindy.fit(states, dx_dt)
sindy.equations()

(5000, 3)
(5000, 20)
(20, 3)
Functions:
 ['1' 'x' 'y' 'z' 'xx' 'xy' 'xz' 'yy' 'yz' 'zz' 'xxx' 'xxy' 'xxz' 'xyy'
 'xyz' 'xzz' 'yyy' 'yyz' 'yzz' 'zzz']

Fit coefficients (Xi):
 [[  0.           0.           0.        ]
 [-10.          28.           0.        ]
 [ 10.          -1.           0.        ]
 [  0.           0.          -2.66666667]
 [  0.           0.           0.        ]
 [  0.           0.           1.        ]
 [  0.          -1.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.           0.        ]
 [  0.           0.        

["x' = -9.999x + 9.9999y",
 "y' = 28.000x + -1.000y + -0.999xz",
 "z' = -2.666z + 1.0xy"]