In [None]:
# default stuff (display width, dir change, jupyter extentions)
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import anodeclstmgru.constants as const
import os
from anodeclstmgru.models.lit_module import AutoEncoderLitModule
from anodeclstmgru.data.data_module import SWaTSDataModule
from anodeclstmgru.data.dataset import SWaTSDataset
import yaml
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import interact
import torch
import numpy as np
from sklearn.metrics import mean_squared_error
from datetime import datetime
from tqdm.notebook import tqdm

# Load model and configs

In [None]:
MODEL_VERSION = 14
TEST_SET_STEP_SIZE = 10
hparams_path = f'./lightning_logs/version_{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
hparam_dct.update(dict(test_set_step_size=TEST_SET_STEP_SIZE))
ckpt_file_name = os.listdir(f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/')[0]
ckpt_file_path = f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = AutoEncoderLitModule.load_from_checkpoint(ckpt_file_path)

# Load training samples and predict output

In [None]:
hparam_dct.pop('test_set_step_size')

In [None]:
dm = SWaTSDataModule(**hparam_dct)
dm.setup()

In [None]:
dm.test_set_step_size

In [None]:
train_data_loader = dm.train_dataloader()
batch_in = iter(train_data_loader).next()
batch_out = model(batch_in)

In [None]:
batch_out = model(batch_in)
df_out = pd.DataFrame(batch_out[0,:,:].detach().numpy(), columns=const.SENSOR_COLS)

In [None]:
def get_dfs(idx):
    df_in = pd.DataFrame(batch_in[idx,:,:].numpy(), columns=const.SENSOR_COLS)
    df_out = pd.DataFrame(batch_out[idx,:,:].detach().numpy(), columns=const.SENSOR_COLS)
    return df_in, df_out


def plot_ts_and_reconstruction(signal='AIT203', training_sample_idx=0):
    df_in, df_out = get_dfs(training_sample_idx)
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    for df, key in zip([df_in, df_out], ['orig.', 'reconstr.']):
        fig.add_trace(go.Scatter(x=df.index, y=df[signal], name=f'{signal}_{key}'), row=1, col=1)
    title = f'{signal} original and reproduction over time (Sample {training_sample_idx} of first training batch).'
    fig.update_layout(height=600, width=800, title_text=title)
    fig.show()


interact(plot_ts_and_reconstruction, signal=const.SENSOR_COLS, training_sample_idx=list(range(32)))

# Alright, now what about the reconstruction error?

In [None]:
def get_mse(idx, input_array):
    out = model(torch.tensor(input_array[idx,:,:].reshape(-1, input_array.shape[1], input_array.shape[2]))).detach().numpy()[0,:,:]
    return mean_squared_error(input_array[idx,:,:], out)

In [None]:
# only run first time
mse_list = [get_mse(i, dm.swats_test.samples)
            for i in tqdm(list(range(dm.swats_test.samples.shape[0])))]
timestamps = dm.swats_test.timestamps

In [None]:
# store predictions
# only run first time
df_errors = pd.DataFrame(dict(mse=mse_list, timestamp=timestamps))
store = pd.HDFStore(const.HDF_STORE_PATH_PREPROC)
store[f'df_errors_{MODEL_VERSION}'] = df_errors
store.close()

In [None]:
#
store = pd.HDFStore(const.HDF_STORE_PATH_PREPROC)
df_errors = store[f'df_errors_{MODEL_VERSION}']
store.close()

In [None]:
def create_results_plot(threshold=1):
    # load data frames from h5 store
    store = pd.HDFStore(const.HDF_STORE_PATH_INTERIM)
    df_labels = store['df_labels']
    store.close()
    # filter labels df to the attack that have a end date attached
    # transofrm end time to full timestmap
    df_labels_time = df_labels[df_labels['End Time'].notna()].copy()
    df_labels_time.loc[:, 'End Time'] = [datetime.combine(datetime.date(a), b) for a,b in zip(
        df_labels_time['Start Time'], df_labels_time['End Time'])]
    df_labels_time = df_labels_time.reset_index(drop=True)
    # ok, lets remove everything smaller than min_date and larger than max date...
    df_labels_time = df_labels_time[(df_labels_time['Start Time'] > const.MIN_DATE) &
                             (df_labels_time['Start Time'] < const.MAX_DATE)]
    df_plot_label = df_labels_time.copy()
    
    df_errors['anomaly_predicted'] = [x > threshold for x in df_errors.mse]
    fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
    fig.add_trace(go.Scatter(x=df_errors.timestamp, y=df_errors.mse, name='mse'), row=1, col=1)
    fig.add_trace(go.Scatter(x=df_errors.timestamp, y=df_errors.anomaly_predicted,
                             name='Anomaly predicted'), row=2, col=1)
    for i in range(len(df_plot_label)):
        df_plot_label = df_plot_label.reset_index(drop=True)
        start = df_plot_label.loc[i, 'Start Time']
        end = df_plot_label.loc[i, 'End Time']
        attack = df_plot_label.loc[i, 'Attack #']
        fig.add_trace(
            go.Scatter(x=[start, end],
                       y=[1, 1], name=f'Attack #{attack}'),
            row=3, col=1,
        )
    title = f'MSE, anomaly predictions and real anomalies over time. Sample window size: {300*10/60} min'
    fig.update_layout(title_text=title)
    fig.show()

In [None]:
interact(create_results_plot, threshold=np.linspace(0,2,89))

In [None]:
interact(create_results_plot, threshold=np.linspace(0,2,89))