In [None]:
import os
os.chdir('../../..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
ll examples/

In [None]:
from examples.lorenz.dataset import LorenzBaseDataSet
from torch.utils.data import DataLoader

In [None]:
dataset = LorenzBaseDataSet()

In [None]:
batch_size = 1000
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [None]:
from sindy_autoencoder_cps.lightning_module import SINDyAutoencoder
import yaml
# MODEL_VERSION = 'freq_and_phase'
MODEL_VERSION = 'version_7'

hparams_path = f'./lightning_logs/{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/{MODEL_VERSION}/checkpoints/')[-1]
ckpt_file_path = f'./lightning_logs/{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = SINDyAutoencoder.load_from_checkpoint(ckpt_file_path)
model

In [None]:
model.SINDyLibrary.device

In [None]:
batches = iter(dataloader)

In [None]:
x, xdot, idxs = batches.next()
print(f'shape x: {x.shape},   shape xdot: {xdot.shape}')

In [None]:
x, xdot

In [None]:
import torch
torch.norm(x) / torch.norm(xdot)

In [None]:
torch.norm(xdot)

In [None]:
x_hat, xdot_hat, z, zdot, zdot_hat = model.cuda()(x.cuda(), xdot.cuda())

In [None]:
#select random sample and compare original and recostruction
sample = 0
# sample = 477
x_sample = x[sample, :].cpu().detach().numpy()
xhat_sample = x_hat[sample, :].cpu().detach().numpy()
z_sample = z[sample, :].cpu().detach().numpy()
xdot_sample = xdot[sample, :].cpu().detach().numpy()
xdot_hat_sample = xdot_hat[sample, :].cpu().detach().numpy()
zdot_sample = zdot[sample, :].cpu().detach().numpy()
zdot_hat_sample = zdot_hat[sample, :].cpu().detach().numpy()
idx_sample = idxs[sample]

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

fig = make_subplots(rows=3, cols=3, shared_xaxes=False)

for i in range(3):
    fig.add_trace(go.Scatter(x=x[:, i].cpu().detach().numpy(),
                       y=x_hat[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'x_hat{i} over x{i}',
#                              marker_color='blue',
                             opacity=.7), row=1, col=1+i)
    fig.add_trace(go.Scatter(x=xdot[:, i].cpu().detach().numpy(),
                       y=xdot_hat[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'z_hat{i} over z{i}',
#                              marker_color='blue',
                             opacity=.7), row=2, col=1+i)
    fig.add_trace(go.Scatter(x=zdot[:, i].cpu().detach().numpy(),
                       y=zdot_hat[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'zdot_hat{i} over zdot{i}',
#                              marker_color='blue',
                             opacity=.7), row=3, col=1+i)
fig.show()

In [None]:
import pandas as pd
XI = model.XI.cpu().detach().numpy()
XI.max()
df_XI = pd.DataFrame(XI, columns=['z0_dot_hat', 'z1_dot_hat', 'z2_dot_hat'])
df_XI.index=model.SINDyLibrary.feature_names

In [None]:
import plotly.express as px
print(model.SINDyLibrary.feature_names)
px.imshow(df_XI.abs())

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

fig = make_subplots(rows=2, cols=3, shared_xaxes=True)

for i in range(3):
    fig.add_trace(go.Scatter(x=np.array(range(z.shape[0])),
                       y=z[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'z{i}',
                             marker_color='blue', opacity=.5), row=1, col=1+i,
                    )
for i in range(3):
    fig.add_trace(go.Scatter(x=np.array(range(z.shape[0])),
                       y=zdot_hat[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'zdot_hat{i}',
                            marker_color='blue', opacity=.4), row=2, col=1+i)
for i in range(3):
    fig.add_trace(go.Scatter(x=np.array(range(z.shape[0])),
                       y=zdot[:, i].cpu().detach().numpy(),
                      mode='markers', name=f'zdot{i}',
                             marker_color='green', opacity=.5), row=2, col=1+i)



fig.update_layout(height=600, width=1200, title_text='z over time and zdot')
fig.show()

In [None]:
import examples.lorenz.constants as const
z_real = dataset.df.iloc[idxs][const.Z_COL_NAMES]
fig = make_subplots(rows=3, cols=3, shared_xaxes=True)

for i in range(3):
    for j in range(3):
        fig.add_trace(go.Scatter(y=z.detach().cpu().numpy()[:, i],
                           x=z_real.values[:, j],
                          mode='markers', name=f'z{i} over z_real{j}'), row=1+j, col=1+i)


fig.update_layout(height=600, width=1200, title_text="Real hidden state vs latent activations")
fig.show()

In [None]:
zdot_real = dataset.df.iloc[idxs][const.Z_DOT_COL_NAMES]
fig = make_subplots(rows=3, cols=3, shared_xaxes=True)

for i in range(3):
    for j in range(3):
        fig.add_trace(go.Scatter(y=zdot.detach().cpu().numpy()[:, i],
                           x=zdot_real.values[:, j],
                          mode='markers', name=f'z{i} over z_real{j}'), row=1+j, col=1+i)


fig.update_layout(height=600, width=1200, title_text="Real hidden states derivatives vs predicted")
fig.show()

In [None]:
[(val, i) for val,i in enumerate(['a', 'b', 'c'])]

In [None]:
model.XI_coefficient_mask