In [None]:
import os
os.chdir('..')

In [None]:
import diag_vae.constants as const
import pandas as pd
from diag_vae.vanilla_tcn_ae import VanillaTcnAE
from diag_vae.diag_tcn_ae import DiagTcnAE
from diag_vae.diag_tcn_ae_predictor import DiagTcnAePredictor
from diag_vae.swat_data_module import SwatDataset
import matplotlib.pyplot as plt
import torch
import numpy as np
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import tqdm

In [None]:
loggs_dir = './logs/VanillaTcnAE/version_6/checkpoints/'
checkpoint = os.path.join(loggs_dir, f'{os.listdir(loggs_dir)[0]}')
ae_model = VanillaTcnAE.load_from_checkpoint(checkpoint)

In [None]:
# loggs_dir = './logs/DiagTcnAePredictor/version_0/checkpoints/'
# checkpoint = os.path.join(loggs_dir, f'{os.listdir(loggs_dir)[0]}')
# diag_predictor_model = DiagTcnAePredictor.load_from_checkpoint(checkpoint)

In [None]:
loggs_dir = './logs/DiagTcnAE/version_8/checkpoints/'
checkpoint = os.path.join(loggs_dir, f'{os.listdir(loggs_dir)[0]}')
diag_ae_model = DiagTcnAE.load_from_checkpoint(checkpoint)

In [None]:
# SAMPLE_IDX = 10500
SAMPLE_IDX = 500

ds = SwatDataset(
    val_data_path=const.SWAT_VAL_PATH,
    train_data_path=const.SWAT_TRAIN_PATH,
    seq_len_x=500,
    seq_len_y=100,
    cols=const.SWAT_SENSOR_COLS,
    symbols_dct=const.SWAT_SYMBOLS_MAP,
    split='test',
    val_ts_start="2015-12-29 19:00:00",
    val_ts_end="2015-12-29 22:00:00",
)
x, x_comp_list, y_comp_list = ds.__getitem__(SAMPLE_IDX)

In [None]:
z_ae = ae_model.encode(x)
x_ae_hat = ae_model.decode(z_ae)

In [None]:
z_diag_ae = diag_ae_model.encode(x)
x_diag_ae_hat_ls = diag_ae_model.decode(z_diag_ae)

In [None]:
z_diag_ae.shape, z_ae.shape

In [None]:
fig, axs = plt.subplots(51, 2, figsize=(15, 50))
for i in range(51):
    axs[i, 0].plot(x.numpy()[i,:])
    axs[i, 0].plot(x_ae_hat.detach().numpy()[0,i,:])
    axs[i, 0].set_ylim((-15, 15))
    axs[i, 1].plot(x.numpy()[i,:])
    axs[i, 1].plot(torch.concat([x_.reshape(-1, 500) for x_ in x_diag_ae_hat_ls], dim=0).detach().numpy()[i,:])
    axs[i, 1].set_ylim((-15, 15))
fig.show()

### get error over longer period of time

In [None]:
val_df = pd.read_parquet(const.SWAT_VAL_PATH)
# val_df.head()

In [None]:
ts_idx_series = pd.Series(val_df.index)
ts_idx_series.head()

In [None]:
pd.set_option('display.max_colwidth', None)
df_label = pd.read_csv(const.SWAT_LABEL_PATH)
df_label.head(25)

# batch prediction

In [None]:
from diag_vae.swat_data_module import SwatDataModule
dm = SwatDataModule(
     data_path_val=const.SWAT_VAL_PATH,
     data_path_train=const.SWAT_TRAIN_PATH,
    seq_len_x=500,
    seq_len_y=100,
    cols=const.SWAT_SENSOR_COLS,
    symbols_dct=const.SWAT_SYMBOLS_MAP,
    val_ts_start="2015-12-29 19:00:00",
    val_ts_end="2015-12-29 22:00:00",
    batch_size=100,
    dl_workers=8,
)
dl_val = dm.val_dataloader()
dl_train = dm.train_dataloader()
dl_test = dm.test_dataloader()

In [None]:
with torch.no_grad():
    res = ae_model.predict_step(dl_test.dataset[0], batch_idx=None)
res

In [None]:
skip_len = 100

In [None]:
list(range(0,100))[0::10]

In [None]:
%%time
ae_model_mse_ls_ls = []
diag_ae_model_mse_ls_ls = []
diag_pred_model_mse_ls_ls = []
counter = 0
for batch in tqdm.tqdm(iter(dl_test)):
    # if counter > 1000:
    #     break
    with torch.no_grad():
        ae_model_mse_ls_ls.append(ae_model.predict_step(batch, batch_idx=None))
        # diag_pred_model_mse_ls_ls.append(diag_predictor_model.get_sample_component_recon_error(x, y_comp_ls))
        diag_ae_model_mse_ls_ls.append(diag_ae_model.predict_step(batch, batch_idx=None))
        counter+=1

In [None]:
ae_mse_df = pd.DataFrame({f'ae_mse_comp{i+1}':np.concatenate([l[i] for l in ae_model_mse_ls_ls]) for i in range(6)})
ae_mse_df.index = dl_test.dataset.df.index[0:len(ae_mse_df)]

In [None]:
diag_ae_mse_df = pd.DataFrame({f'diag_ae_mse_comp{i+1}':np.concatenate([l[i] for l in diag_ae_model_mse_ls_ls]) for i in range(6)})
diag_ae_mse_df.index = dl_test.dataset.df.index[0:len(diag_ae_mse_df)]

In [None]:
result_df =  pd.concat([ae_mse_df, diag_ae_mse_df], axis=1, join='outer')
result_df.index=result_df.index+pd.Timedelta(minutes=10)

In [None]:
result_df.to_parquet('data/results_df_v6_v8.parquet')

In [None]:
rolling_mean_result_df = result_df.rolling(window=1).median()

In [None]:
df_plot_results = rolling_mean_result_df.sample(frac=.1).sort_index()
df_plot_label = df_label[
    :
    # df_label['Actual Change']=='Yes'
].reset_index(drop=True)
fig = make_subplots(rows=7, cols=1, shared_xaxes=True)
THRESHOLD=2.5

for comp in range(1, 7):
    for column, color in zip([f'ae_mse_comp{comp}',
                              f'diag_ae_mse_comp{comp}',
                              # f'diag_pred_mse_comp{comp}'
                             ], 
                             ['blue', 'red', 'green']):
        fig.add_trace(
            go.Scatter(
                x=df_plot_results.index,
                y=df_plot_results[column],
                name=column,
                # mode='markers',
                marker=dict(color=color, opacity=.5),
            ),
            row=comp, col=1,
        ) 
    fig.add_hrect(y0=0, y1=THRESHOLD, 
                  fillcolor="green", opacity=0.25, line_width=0,
                  row=comp, col=1)

for i in range(len(df_plot_label)):
    start = df_plot_label.loc[i, 'Start Time']
    end = df_plot_label.loc[i, 'End Time']
    attack_no = df_plot_label.loc[i, 'Attack #']
    attack_point = df_plot_label.loc[i, 'Attack Point']

    fig.add_vrect(x0=start, x1=end, 
                  annotation_text=f'At. Point {attack_point}',
                  fillcolor="red", opacity=0.25, line_width=0,
                  row=7, col=1)

    fig.add_trace(
                go.Scatter(x=[start, end],
                           y=[1, 1], name=f'Attack #{attack_no}'),
                row=7, col=1,
            )
    
    




fig.update_layout(height=800, width=1200, title_text='Model versions: AE 6, Custom 8')
fig.show()

In [None]:
fig.write_html('./notebooks/result_v6_v8.html')

In [None]:
from diag_vae.swat_data_module import SwatDataModule

In [None]:
dm = SwatDataModule(
    data_path_train=const.SWAT_TRAIN_PATH,
    data_path_val=const.SWAT_VAL_PATH,
    seq_len_x=500,
    seq_len_y=100,
    cols=const.SWAT_SENSOR_COLS,
    symbols_dct=const.SWAT_SYMBOLS_MAP,
    batch_size=1000,
    dl_workers=8,
)
dl_val = dm.val_dataloader()

In [None]:
x, x_comp_ls, y_comp_ls = next(iter(dl_val))

In [None]:
y_comp_ls[0].shape

In [None]:
x_comp_hat_ls = diag_ae_model.get_sample_component_recon_error(x, x_comp_ls)

In [None]:
x_comp_hat_ls[1].shape

In [None]:
ae_model_mse_ls_ls = []
diag_ae_model_mse_ls_ls = []
diag_pred_model_mse_ls_ls = []
counter = 0
for batch in tqdm.tqdm(iter(dl_val)):
    if counter > 10:
        break
    x, x_comp_ls, y_comp_ls = batch
    ae_model_mse_ls_ls.append(ae_model.get_sample_component_recon_error(x))
    # diag_pred_model_mse_ls_ls.append(diag_predictor_model.get_sample_component_recon_error(x, y_comp_ls))
    diag_ae_model_mse_ls_ls.append(diag_ae_model.get_sample_component_recon_error(x, x_comp_ls))
    counter+=1



In [None]:
ae_mse_df = pd.DataFrame({f'ae_mse_comp{i+1}':np.concatenate([l[i] for l in ae_model_mse_ls_ls]) for i in range(6)})
ae_mse_df.index = dl_val.dataset.df.index[0:len(ae_mse_df)]

In [None]:
diag_pred_mse_df = pd.DataFrame({f'diag_pred_mse_comp{i+1}':np.concatenate([l[i] for l in diag_pred_model_mse_ls_ls]) for i in range(6)})
diag_pred_mse_df.index = dl_val.dataset.df.index[0:len(diag_pred_mse_df)]

In [None]:
diag_ae_mse_df = pd.DataFrame({f'diag_ae_mse_comp{i+1}':np.concatenate([l[i] for l in diag_ae_model_mse_ls_ls]) for i in range(6)})
diag_ae_mse_df.index = dl_val.dataset.df.index[0:len(diag_ae_mse_df)]

In [None]:
result_df =  pd.concat([ae_mse_df, diag_ae_mse_df, diag_pred_mse_df], axis=1, join='outer')
result_df.index=result_df.index+pd.Timedelta(minutes=10)

In [None]:
rolling_mean_result_df = result_df.rolling(window=120).median()

In [None]:
rolling_mean_result_df.tail()

In [None]:
df_plot_results = rolling_mean_result_df.sample(frac=.01).sort_index()
df_plot_results.head()

In [None]:
result_df.columns

In [None]:
df_plot_results = rolling_mean_result_df.sample(frac=.1).sort_index()
df_plot_label = df_label[
    :
    # df_label['Actual Change']=='Yes'
].reset_index(drop=True)
fig = make_subplots(rows=7, cols=1, shared_xaxes=True)
THRESHOLD=2.5

for comp in range(1, 7):
    for column, color in zip([f'ae_mse_comp{comp}',
                              f'diag_ae_mse_comp{comp}',
                              f'diag_pred_mse_comp{comp}'], 
                             ['blue', 'red', 'green']):
        fig.add_trace(
            go.Scatter(
                x=df_plot_results.index,
                y=df_plot_results[column],
                name=column,
                # mode='markers',
                marker=dict(color=color, opacity=.5),
            ),
            row=comp, col=1,
        ) 
    fig.add_hrect(y0=0, y1=THRESHOLD, 
                  fillcolor="green", opacity=0.25, line_width=0,
                  row=comp, col=1)

for i in range(len(df_plot_label)):
    start = df_plot_label.loc[i, 'Start Time']
    end = df_plot_label.loc[i, 'End Time']
    attack_no = df_plot_label.loc[i, 'Attack #']
    attack_point = df_plot_label.loc[i, 'Attack Point']

    fig.add_vrect(x0=start, x1=end, 
                  annotation_text=f'At. Point {attack_point}',
                  fillcolor="red", opacity=0.25, line_width=0,
                  row=7, col=1)

    fig.add_trace(
                go.Scatter(x=[start, end],
                           y=[1, 1], name=f'Attack #{attack_no}'),
                row=7, col=1,
            )
    
    




fig.update_layout(height=800, width=1200, title_text='')
fig.show()

In [None]:
df_plot_label.head(30)

In [None]:
%%time
# get sample input for ts
ds_val = SwatDataset( 
    data_path=const.SWAT_VAL_PATH,
    seq_len_x=500,
    seq_len_y=100,
    cols=const.SWAT_SENSOR_COLS,
    symbols_dct=const.SWAT_SYMBOLS_MAP,
    scale=True
)


mse_ls_ae_all = []
mse_ls_ae_signals = []
mse_y_comp_hat_ls = []

diag_ae_pred_comp_1_mse = []
diag_ae_pred_comp_2_mse = []
diag_ae_pred_comp_3_mse = []
diag_ae_pred_comp_4_mse = []
diag_ae_pred_comp_5_mse = []
diag_ae_pred_comp_6_mse = []



for idx in tqdm.notebook.tqdm(list(ts_idx_series.index)[0::10][:-100]):
    x, _, y_comp_ls = ds_val.__getitem__(idx)
    z_ae = ae_model.encode(x)
    x_ae_hat = ae_model.decode(z_ae)
    # z_diag = diag_model.encode(x)
    # x_comp_hat_ls = diag_model.decode(z_diag)
    # x_comp_hat = torch.cat(x_comp_hat_ls, 1)
    z_ae_predictor = diag_predictor_model.encode(x)
    x_ae_predictor_hat, y_comp_hat_ls = diag_predictor_model.decode(z_ae_predictor)
    
    mse_ls_ae_all.append(((x.numpy() - x_ae_hat.detach().numpy())**2).mean())
    mse_ls_ae_signals.append(((x.numpy() - x_ae_hat.detach().numpy())**2).mean(axis=2))
    # _, _, y_loss_ls, _ = diag_predictor_model.shared_eval(x.reshape(1, *x.shape), [y.reshape(1, *y.shape) for y in y_comp_list])
    # for y_loss, diag_ae_pred_comp_mse in zip(y_loss_ls, diag_ae_pred_comp_mse_ls):
    #     diag_ae_pred_comp_mse.append(float(y_loss.detach().numpy()))
    diag_ae_pred_comp_1_mse.append(((y_comp_ls[0].numpy() - y_comp_hat_ls[0].detach().numpy())**2).mean())
    diag_ae_pred_comp_2_mse.append(((y_comp_ls[1].numpy() - y_comp_hat_ls[1].detach().numpy())**2).mean())
    diag_ae_pred_comp_3_mse.append(((y_comp_ls[2].numpy() - y_comp_hat_ls[2].detach().numpy())**2).mean())
    diag_ae_pred_comp_4_mse.append(((y_comp_ls[3].numpy() - y_comp_hat_ls[3].detach().numpy())**2).mean())
    diag_ae_pred_comp_5_mse.append(((y_comp_ls[4].numpy() - y_comp_hat_ls[4].detach().numpy())**2).mean())
    diag_ae_pred_comp_6_mse.append(((y_comp_ls[5].numpy() - y_comp_hat_ls[5].detach().numpy())**2).mean())
        
        

# Compute component-mse base on AE results 
ae_comp_1_mse = np.array([mse[:, 0:5].mean(axis=1) for mse in mse_ls_ae_signals])
ae_comp_2_mse = np.array([mse[:, 5:5+11].mean(axis=1) for mse in mse_ls_ae_signals])
ae_comp_3_mse = np.array([mse[:, 5+11:5+11+9].mean(axis=1) for mse in mse_ls_ae_signals])
ae_comp_4_mse = np.array([mse[:, 5+11+9:5+11+9+9].mean(axis=1) for mse in mse_ls_ae_signals])
ae_comp_5_mse = np.array([mse[:, 5+11+9+9:5+11+9+9+13].mean(axis=1) for mse in mse_ls_ae_signals])
ae_comp_6_mse = np.array([mse[:, 5+11+9+9+13:5+11+9+9+13+4].mean(axis=1) for mse in mse_ls_ae_signals])

# get 

In [None]:
df_results = pd.DataFrame(dict(
    ae_comp_1_mse=ae_comp_1_mse.reshape(-1),
    ae_comp_2_mse=ae_comp_2_mse.reshape(-1),
    ae_comp_3_mse=ae_comp_3_mse.reshape(-1),
    ae_comp_4_mse=ae_comp_4_mse.reshape(-1),
    ae_comp_5_mse=ae_comp_5_mse.reshape(-1),
    ae_comp_6_mse=ae_comp_6_mse.reshape(-1),
    diag_ae_pred_comp_1_mse=diag_ae_pred_comp_1_mse,
    diag_ae_pred_comp_2_mse=diag_ae_pred_comp_2_mse,
    diag_ae_pred_comp_3_mse=diag_ae_pred_comp_3_mse,
    diag_ae_pred_comp_4_mse=diag_ae_pred_comp_4_mse,
    diag_ae_pred_comp_5_mse=diag_ae_pred_comp_5_mse,
    diag_ae_pred_comp_6_mse=diag_ae_pred_comp_6_mse,
),    index=ts_idx_series[0::10][:-100])
## adding 10 min due to model inference stuff
df_results.index=df_results.index+pd.Timedelta(minutes=10)

In [None]:
df_results

In [None]:
df_results_rolling = df_results.rolling(window=100).mean()

In [None]:
df_plot_label = df_label[
    # :
    df_label['Actual Change']=='Yes'
].reset_index(drop=True)
fig = make_subplots(rows=7, cols=1, shared_xaxes=True)
ae_comp_mse_ls = [ae_comp_1_mse, ae_comp_2_mse, ae_comp_3_mse, ae_comp_4_mse, ae_comp_5_mse, ae_comp_6_mse]
diag_ae_pred_comp_mse_ls = [diag_ae_pred_comp_1_mse, diag_ae_pred_comp_2_mse, diag_ae_pred_comp_3_mse,
                            diag_ae_pred_comp_4_mse, diag_ae_pred_comp_5_mse, diag_ae_pred_comp_6_mse]



for row, ae_mse in enumerate(ae_comp_mse_ls):
    fig.add_trace(
            go.Scatter(x=ts_idx_series[0::10][0:-100],
                       y=ae_mse.reshape(-1), name=f'AE MSE comp {row+1}',
                           mode='markers',
                                              marker=dict(color='red'),

                      ),
            row=row+1, col=1,
    )  
for row, diag_ae_mse in enumerate(diag_ae_pred_comp_mse_ls):
    fig.add_trace(
            go.Scatter(x=ts_idx_series[0::10][0:-100],
                       y=diag_ae_mse, name=f'AE MSE comp {row+1}',
                       marker=dict(color='blue'),
                      ),
            row=row+1, col=1,
    )  


for i in range(len(df_plot_label)):
    start = df_plot_label.loc[i, 'Start Time']
    end = df_plot_label.loc[i, 'End Time']
    attack_no = df_plot_label.loc[i, 'Attack #']
    attack_point = df_plot_label.loc[i, 'Attack Point']

    fig.add_vrect(x0=start, x1=end, 
                  annotation_text=f'At. Point {attack_point}',
                  # annotation_position="top left",
                  # annotation=dict(font_size=8),
                  fillcolor="red", opacity=0.25, line_width=0,
                  row=7, col=1)
    fig.add_trace(
                go.Scatter(x=[start, end],
                           y=[1, 1], name=f'Attack #{attack_no}'),
                row=7, col=1,
            )
    
    




fig.update_layout(height=800, width=1800, title_text='')
fig.show()

In [None]:
df_plot_label = df_label.reset_index(drop=True)
fig = make_subplots(rows=7, cols=2, shared_xaxes=True)
ae_comp_mse_ls = [ae_comp_1_mse, ae_comp_2_mse, ae_comp_3_mse, ae_comp_4_mse, ae_comp_5_mse, ae_comp_6_mse]
diag_ae_pred_comp_mse_ls = [diag_ae_pred_comp_1_mse, diag_ae_pred_comp_2_mse, diag_ae_pred_comp_3_mse,
                            diag_ae_pred_comp_4_mse, diag_ae_pred_comp_5_mse, diag_ae_pred_comp_6_mse]



# fig.add_trace(
#             go.Scatter(x=mse_ae_all_series.index,
#                        y=mse_ae_all_series, name=f'AE MSE all sigs'),
#             row=2, col=1,
# )
for row, ae_mse in enumerate(ae_comp_mse_ls):
    fig.add_trace(
            go.Scatter(x=ts_idx_series[0::10][0:-100],
                       y=ae_mse.reshape(-1), name=f'AE MSE comp {row+1}'),
            row=row+1, col=1,
    )  

for i in range(len(df_plot_label)):
    start = df_plot_label.loc[i, 'Start Time']
    end = df_plot_label.loc[i, 'End Time']
    attack_no = df_plot_label.loc[i, 'Attack #']
    attack_point = df_plot_label.loc[i, 'Attack Point']

    fig.add_vrect(x0=start, x1=end, 
                  annotation_text=f'At. Point {attack_point}',
                  # annotation_position="top left",
                  # annotation=dict(font_size=8),
                  fillcolor="red", opacity=0.25, line_width=0,
                  row=7, col=1)
    fig.add_trace(
                go.Scatter(x=[start, end],
                           y=[1, 1], name=f'Attack #{attack_no}'),
                row=7, col=1,
            )
    
    
for row, diag_ae_mse in enumerate(diag_ae_pred_comp_mse_ls):
    fig.add_trace(
            go.Scatter(x=ts_idx_series[0::10][0:-100],
                       y=diag_ae_mse, name=f'AE MSE comp {row+1}'),
            row=row+1, col=2,
    )  

for i in range(len(df_plot_label)):
    start = df_plot_label.loc[i, 'Start Time']
    end = df_plot_label.loc[i, 'End Time']
    attack_no = df_plot_label.loc[i, 'Attack #']
    attack_point = df_plot_label.loc[i, 'Attack Point']

    fig.add_vrect(x0=start, x1=end, 
                  annotation_text=f'At. Point {attack_point}',
                  # annotation_position="top left",
                  # annotation=dict(font_size=8),
                  fillcolor="red", opacity=0.25, line_width=0,
                  row=7, col=2
                 )
    fig.add_trace(
                go.Scatter(x=[start, end],
                           y=[1, 1], name=f'Attack #{attack_no}'),
                row=7, col=2,
            )


fig.update_layout(height=800, width=1800, title_text='')
fig.show()

In [None]:
# get sample input for ts
ds = SwatDataset( 
    data_path=const.SWAT_VAL_PATH,
    seq_len=1000,
    cols=const.SWAT_SENSOR_COLS,
    symbols_dct=const.SWAT_SYMBOLS_MAP,
)


In [None]:
x, x_comp_list = ds.__getitem__(sample_idx)
z = ae_model.encode(x)
x_hat = ae_model.decode(z)
mse = ((x.numpy() - x_hat.detach().numpy())**2).mean()

In [None]:
mse

In [None]:
import numpy as np

In [None]:
len(list(ts_idx_series.index)[0::100])

In [None]:
import tqdm

In [None]:
x.reshape(-1, *x.shape).shape

In [None]:
import torch

In [None]:
ae_model.encode(torch.stack([ds.__getitem__(idx)[0] for idx in range(0, 3000)])).shape

In [None]:
%%time
mse_ls = []
for idx in tqdm.notebook.tqdm(list(ts_idx_series.index)[0::100][:-10]):
    x, x_comp_list = ds.__getitem__(idx)
    z = ae_model.encode(x)
    x_hat = ae_model.decode(z)
    
    mse_ls.append(((x.numpy() - x_hat.detach().numpy())**2).mean())

In [None]:
mse_series = pd.Series(np.array(mse_ls))
mse_series.index = list(ts_idx_series)[0::100][:-10]

In [None]:
mse_series

In [None]:
val_df['Normal/Attack'].unique()

In [None]:
val_df['label'] = [0 if val == 'Normal' else 1 for val in val_df['Normal/Attack']]

In [None]:


fig, axs = plt.subplots(3, 1, sharex='all')
axs[0].plot(mse_series)
axs[1].plot(mse_series>3)
# axs[1].plot(mse_series>2)
# axs[1].plot(mse_series>1)
axs[2].plot(val_df.label)

fig.show()

In [None]:
x.numpy().shape

In [None]:
val_df[ts_idx_series[sample_idx]:ts_idx_series[sample_idx+1000-1]][const.SWAT_SENSOR_COLS].values.T.shape

In [None]:
plt.plot(
ds.scaler.transform(
    val_df[ts_idx_series[sample_idx]:ts_idx_series[sample_idx+1000-1]][const.SWAT_SENSOR_COLS].values
    ).T[1,:]
)
plt.plot(x.numpy()[1,:])

In [None]:
mod

In [None]:
ts_ls = list(val_df.index)

In [None]:
len(ts_ls[0::100])

In [None]:
val_df.head()