In [None]:
import os
os.chdir('../../..')


In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from examples.lorenz.data_gen import LorenzDataGenerator
import pandas as pd
import examples.lorenz.constants as const
import pysindy as ps
import numpy as np
import plotly.express as px
from ipywidgets import interact
import ipywidgets as widgets
from plotly.subplots import make_subplots
import plotly.graph_objects as go
%load_ext autoreload
%autoreload 2

In [None]:
df_z = pd.read_parquet(const.Z_SPACE_DATA_PATH)

In [None]:
df_z.head()

In [None]:
df_z.tail()

In [None]:
def _create_ts_plot(idx):
    df_plot = df_z[df_z[const.UID_INITIAL_STATE_COL_NAME] == idx]
    # signal 1
    return px.scatter_3d(x=df_plot.x, y=df_plot.y, z=df_plot.z, color=df_plot.time, opacity=.1, template='plotly_dark')
interact(_create_ts_plot, idx=list(range(const.NUMBER_INITIAL_STATES)))

In [None]:
def _create_ts_plot(idx):
    df_plot = df_z[df_z[const.UID_INITIAL_STATE_COL_NAME] == idx]
    # signal 1
    return px.scatter_3d(x=df_plot.x_dot, y=df_plot.y_dot, z=df_plot.z_dot, color=df_plot.time, opacity=.1, template='plotly_dark')
interact(_create_ts_plot, idx=list(range(const.NUMBER_INITIAL_STATES)))

# Lets try the pysindy exercise in the z-space!

In [None]:
library_functions = [
    lambda x : np.exp(x),
    lambda x : 1./x,
    lambda x : x,
    lambda x : np.sin(x),
    lambda x : np.cos(x),
    lambda x,y : np.sin(x+y),
    lambda x,y : np.cos(x+y),
    lambda x,y : np.sign(x-y)*np.sqrt(np.abs(x-y)),
    lambda x: x**2,
#     lambda x: np.sqrt(x),
    lambda x,y: x*y,
    lambda x,y: x**2*y,
    lambda x,y: x*y**2
]
library_function_names = [
    lambda x : 'exp(' + x + ')',
    lambda x : '1/' + x,
    lambda x : x,
    lambda x : f'sin({x})',
    lambda x : f'cos({x})',
    lambda x,y : 'sin(' + x + '+' + y + ')',
    lambda x,y : f'cos({x}+{y})',
    lambda x,y : 'sign('+x+'-'+y+')*sqrt('+x+' - '+y+')',
    lambda x: x+'^2',
#     lambda x: f'sqrt({x})',
    lambda x,y: f'{x}*{y}',
    lambda x,y: f'{x}^2*{y}',
    lambda x,y: f'{x}*{y}^2'
]
feature_library = ps.CustomLibrary(
    library_functions=library_functions, function_names=library_function_names
)

In [None]:
optimizer = ps.STLSQ(threshold=0.5)
model = ps.SINDy(
    feature_library=feature_library,
    optimizer=optimizer,
    feature_names=["x", "y", "z"],
)
model.fit(df_z[const.Z_COL_NAMES].values, x_dot=df_z[const.Z_DOT_COL_NAMES].values, t=0.01)
model.print()

In [None]:
df_z.describe()

In [None]:
x = df_z[const.Z_COL_NAMES].values
x

In [None]:
x_dot = df_z[const.Z_DOT_COL_NAMES].values

In [None]:
x_2_norm = np.array([v[0]**2+v[1]**2+v[2]**2 for v in x])
x_2_norm_sum = x_2_norm.sum()
x_2_norm_sum

In [None]:
x_dot_2_norm = np.array([v[0]**2+v[1]**2+v[2]**2 for v in x_dot])
x_dot_2_norm_sum = x_dot_2_norm.sum()
x_dot_2_norm_sum

In [None]:
x_2_norm_sum/x_dot_2_norm_sum

In [None]:
5e-2

## NICE!

In [None]:
feature_library.get_feature_names()

In [None]:
from examples.lorenz.dataset import LorenzBaseDataSet
from torch.utils.data import DataLoader

In [None]:
dataset = LorenzBaseDataSet()

In [None]:
batch_size = 100000
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [None]:
batches = iter(dataloader)

In [None]:
x, xdot, idx = batches.next()

In [None]:
type(x)

In [None]:
import torch
torch.norm(x)

In [None]:
torch.norm(xdot)

In [None]:
torch.norm(x) / torch.norm(xdot)

In [None]:
a = torch.tensor([[1., 2.],[3.,4.]])

In [None]:
a**2

In [None]:
x2norm = torch.linalg.matrix_norm(x, ord=2)

In [None]:
xdot2norm = torch.linalg.matrix_norm(xdot, ord=2)

In [None]:
x2norm**2/ xdot2norm**2

# Now in the x space!

In [None]:
df_x = pd.read_parquet(const.X_SPACE_DATA_PATH)

def plot_pics(sample_index):
    df_plot_x = df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == sample_index].copy().reset_index(drop=True)
    sub_indexes = [i for i in range(5)]
    img_sequence = [df_plot_x[const.X_COL_NAMES].values[i,:].reshape(
        (const.PICTURE_SIZE, const.PICTURE_SIZE)) for i in sub_indexes]
    fig = px.imshow(np.array(img_sequence),
                    facet_col=0,
#                     binary_string=True,
                    labels={'facet_col':'sigma'}
                   )
    for i, index in enumerate(sub_indexes):
        fig.layout.annotations[i]['text'] = f'Timestamp {index}'
    fig.show()

interact(plot_pics, sample_index=range(const.NUMBER_INITIAL_STATES))

In [None]:
[i for i in range(10,15)]

In [None]:
def plot_pics(sample_index):
    df_plot_x = df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == sample_index].copy().reset_index(drop=True)
    sub_indexes = [i for i in range(20,30)]
    img_sequence = [df_plot_x[const.XDOT_COL_NAMES].values[i,:].reshape(
        (const.PICTURE_SIZE, const.PICTURE_SIZE)) for i in sub_indexes]
    fig = px.imshow(np.array(img_sequence),
                    facet_col=0,
#                     binary_string=True,
                    labels={'facet_col':'sigma'}
                   )
    for i, index in enumerate(sub_indexes):
        fig.layout.annotations[i]['text'] = f'Timestamp {index}'
    fig.show()

interact(plot_pics, sample_index=range(const.NUMBER_INITIAL_STATES))

In [None]:
df_x = pd.read_parquet(const.X_SPACE_DATA_PATH)

def plot_pics(sample_index):
    df_plot_x = df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == sample_index].copy().reset_index(drop=True)
    sub_indexes = [0, 25, 50, 75, 100]
    img_sequence = [df_plot_x[const.X_COL_NAMES].values[i,:].reshape(
        (const.PICTURE_SIZE, const.PICTURE_SIZE)) for i in sub_indexes]
    fig = px.imshow(np.array(img_sequence),
                    facet_col=0,
#                     binary_string=True,
                    labels={'facet_col':'sigma'}
                   )
    for i, index in enumerate(sub_indexes):
        fig.layout.annotations[i]['text'] = f'Timestamp {index}'
    fig.show()

interact(plot_pics, sample_index=range(const.NUMBER_INITIAL_STATES))

## What about the derivatives in x?

In [None]:
def plot_pics(sample_index):
    df_plot_x = df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == sample_index].copy().reset_index(drop=True)
    sub_indexes = [0, 25, 50, 75, 100]
    img_sequence = [df_plot_x[const.XDOT_COL_NAMES].values[i,:].reshape(
        (const.PICTURE_SIZE, const.PICTURE_SIZE)) for i in sub_indexes]
    fig = px.imshow(np.array(img_sequence),
                    facet_col=0,
#                     binary_string=True,
                    labels={'facet_col':'sigma'}
                   )
    for i, index in enumerate(sub_indexes):
        fig.layout.annotations[i]['text'] = f'Timestamp {index}'
    fig.show()

interact(plot_pics, sample_index=range(const.NUMBER_INITIAL_STATES))

In [None]:
x_dot = df_x[const.XDOT_COL_NAMES].values
x = df_x[const.X_COL_NAMES].values

In [None]:
i = 10

In [None]:
px.imshow(x[i,:].reshape(100, 100))

In [None]:
px.imshow(x_dot[i,:].reshape(100, 100))

In [None]:
df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == 0]

In [None]:
df_x.info()

In [None]:
pd.Series([df_x[col].max() for col in const.X_COL_NAMES]).value_counts()

In [None]:
df_test = df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == 0]

In [None]:
df_test[const.X_COL_NAMES].std()[df_test[const.X_COL_NAMES].std()>0.069]

In [None]:
x = np.array([[1, 2, 3],[4,5,6],[7,8,9]])
x

In [None]:
x.ravel()

In [None]:
33*100+1

In [None]:
set([df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == 0][f'x_{i}'].max() for i in range(1000)])

In [None]:
def plot_stuff(i=3504):
    fig = px.scatter(df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == 0][f'x_{i}'])
#     fig = px.scatter(df_x[f'x_{i}'])

    fig.show()
interact(plot_stuff, i = [i for i in range(10000)])
    

In [None]:
def plot_stuff(i=3504):
    fig = px.scatter(df_x[df_x[const.UID_INITIAL_STATE_COL_NAME] == 0][f'xdot_{i}'])
#     fig = px.scatter(df_x[f'x_{i}'])

    fig.show()
interact(plot_stuff, i = [i for i in range(10000)])

In [None]:
from three_tank_data.data_module import ThreeTankDataModule
dm = ThreeTankDataModule(.1, 10, 10)