In [None]:
import os
os.chdir('../../..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from examples.two_tank_system.data_gen import TwoTankDataGenerator
import pandas as pd
import examples.two_tank_system.constants as const
import pysindy as ps
import numpy as np
import plotly.express as px
from ipywidgets import interact
import ipywidgets as widgets
from plotly.subplots import make_subplots
import plotly.graph_objects as go
%load_ext autoreload
%autoreload 2

In [None]:
df = pd.read_parquet(const.X_SPACE_DATA_PATH)

In [None]:
df.head()

In [None]:
df.describe()

In [None]:
def _create_ts_plot(idx):
    df_plot = df[df[const.UID_INITIAL_STATE_COL_NAME] == idx]
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    # signal 1
    for col, name in zip(const.Z_COL_NAMES, ['h1(t)', 'h2(t)', 'h3(t)']):
        fig.add_trace(go.Scatter(x=df_plot.time, y=df_plot[col], name=name,
                      mode="lines", opacity=1),
            row=1, col=1)

    for col, name in zip(const.Z_DOT_COL_NAMES, ['h1_dot(t)', 'h2_dot(t)', 'h3_dot(t)']):
        fig.add_trace(go.Scatter(x=df_plot.time, y=df_plot[col], name=name,
                      mode="lines", opacity=1),
            row=2, col=1)

    fig.update_xaxes(title_text='time')
    fig.update_yaxes(title_text='x', row=1)
    fig.update_yaxes(title_text='x_dot', row=2)
    fig.update_layout(title_text="Latent neuron activations vs. hidden states", showlegend=True)
    fig.show()

interact(_create_ts_plot, idx=list(range(const.NUMBER_INITIAL_STATES)))

# Lets try the pysindy exercise in the z-space!

In [None]:
library_functions = [
#     lambda x : np.exp(x),
    lambda x : 1./x,
    lambda x : x,
    lambda x : np.sin(x),
    lambda x : np.cos(x),
    lambda x,y : np.sin(x+y),
    lambda x,y : np.cos(x+y),
    lambda x,y : np.sign(x-y)*np.sqrt(np.abs(x-y)),
    lambda x: x**2,
#     lambda x: np.sqrt(x),
    lambda x,y: x*y,
    lambda x,y: x**2*y,
    lambda x,y: x*y**2
]
library_function_names = [
#     lambda x : 'exp(' + x + ')',
    lambda x : '1/' + x,
    lambda x : x,
    lambda x : f'sin({x})',
    lambda x : f'cos({x})',
    lambda x,y : 'sin(' + x + '+' + y + ')',
    lambda x,y : f'cos({x}+{y})',
    lambda x,y : 'sign('+x+'-'+y+')*sqrt('+x+' - '+y+')',
    lambda x: x+'^2',
#     lambda x: f'sqrt({x})',
    lambda x,y: f'{x}*{y}',
    lambda x,y: f'{x}^2*{y}',
    lambda x,y: f'{x}*{y}^2'
]
feature_library = ps.CustomLibrary(
    library_functions=library_functions, function_names=library_function_names
)

In [None]:
const.Z_COL_NAMES

In [None]:
const.Z_DOT_COL_NAMES

In [None]:
optimizer = ps.STLSQ(threshold=.01)
model = ps.SINDy(
    feature_library=feature_library,
    optimizer=optimizer,
    feature_names=["h1", "h2"],
)
model.fit(df[const.Z_COL_NAMES].values, x_dot=df[const.Z_DOT_COL_NAMES].values, t=0.01)
model.print()

## NICE!

In [None]:
feature_library.get_feature_names()

In [None]:
len(const.X_COL_NAMES)

In [None]:
df_x = pd.read_parquet(const.X_SPACE_DATA_PATH)

In [None]:
df_x.info()

In [None]:
def _create_ts_plot(idx):
    df_plot = df[df[const.UID_INITIAL_STATE_COL_NAME] == idx]
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    # signal 1
    for col, name in zip(const.X_COL_NAMES, const.X_COL_NAMES):
        fig.add_trace(go.Scatter(x=df_plot.time, y=df_plot[col], name=name,
                      mode="lines", opacity=1), row=1, col=1)
    for col, name in zip(const.XDOT_COL_NAMES, const.XDOT_COL_NAMES):
        fig.add_trace(go.Scatter(x=df_plot.time, y=df_plot[col], name=name,
                      mode="lines", opacity=1), row=2, col=1)
    fig.show()
interact(_create_ts_plot, idx=list(range(const.NUMBER_INITIAL_STATES)))