In [None]:
# default stuff (display width, dir change, jupyter extentions)
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import anodeclstmgru.constants as const
import os
from anodeclstmgru.models.lit_module import AutoEncoderLitModule
from anodeclstmgru.data.data_module import SWaTSDataModule
import yaml
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import interact

In [None]:
MODEL_VERSION = 1 
hparams_path = f'./lightning_logs/version_{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/')[0]
ckpt_file_path = f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/{ckpt_file_name}'

In [None]:
dm_train = SWaTSDataModule(**hparam_dct)
dm_train.setup()
train_data_loade = dm_train.train_dataloader()

In [None]:
batch_in = iter(train_data_loade).next()
df_in = pd.DataFrame(batch_in[0,:,:].numpy(), columns=const.SENSOR_COLS)

In [None]:
model = AutoEncoderLitModule.load_from_checkpoint(ckpt_file_path)

In [None]:
batch_out = model(batch_in)
df_out = pd.DataFrame(batch_out[0,:,:].detach().numpy(), columns=const.SENSOR_COLS)

In [None]:
df_in.head()

In [None]:
df_out.head()

In [None]:
def plot_ts_and_reconstruction(signal):
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    for df, key in zip([df_in, df_out], ['orig.', 'reconstr.']):
        fig.add_trace(go.Scatter(x=df.index, y=df[signal], name=f'{signal}_{key}'), row=1, col=1)
    fig.show()

interact(plot_ts_and_reconstruction, signal=const.SENSOR_COLS)

In [None]:
signal = 'LIT101'
fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(x=dm_train.swats_train.dataset.df.index,
                         y=dm_train.swats_train.dataset.df[signal]), row=1, col=1)
fig.show()

In [None]:
hparam_att_dct = hparam_dct.copy()
hparam_att_dct.update(dict(normal=False))

In [None]:
dm_test = SWaTSDataModule(**hparam_att_dct)
dm_test.setup()

In [None]:
dm_test.swats_train.dataset.df.head()

In [None]:
end_idxs = list(range(dm_test.window_size,
                 len(dm_test.swats_train.dataset.df),
                 dm_test.window_size))
timestamps = dm_test.swats_train.dataset.df.loc[end_idxs, 'Timestamp']
arrays = []