## Packages

In [1]:
# !pip install numpy tensorflow scikit-learn plotly pandas

# For the UCR dataset we clone the git repo (if in Colab/Kaggle env)
!git clone https://github.com/iMohammad97/anomaly_detection

!pip install kaleido

Cloning into 'anomaly_detection'...
remote: Enumerating objects: 3828, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 3828 (delta 16), reused 8 (delta 4), pack-reused 3798 (from 1)[K
Receiving objects: 100% (3828/3828), 202.52 MiB | 33.48 MiB/s, done.
Resolving deltas: 100% (1623/1623), done.
Updating files: 100% (2721/2721), done.
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.utils import custom_object_scope
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score, precision_score
import glob, os, sys
import kaleido
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm, trange
import shutil
import math

## Metrics

In [None]:
def pointwise_precision(y_true, y_pred):
    """
    Timepoint-wise precision: fraction of detected anomalies that are correct.
    """
    return precision_score(y_true, y_pred, zero_division=0)


def make_event(y_true, y_pred):
    """
    Converts binary sequences (y_true and y_pred) into a list of (start, end) event tuples.
    """
    y_true_starts = np.argwhere(np.diff(y_true.flatten(), prepend=0) == 1).flatten()
    y_true_ends = np.argwhere(np.diff(y_true.flatten(), append=0) == -1).flatten()
    y_true_events = list(zip(y_true_starts, y_true_ends))

    y_pred_starts = np.argwhere(np.diff(y_pred, prepend=0) == 1).flatten()
    y_pred_ends = np.argwhere(np.diff(y_pred, append=0) == -1).flatten()
    y_pred_events = list(zip(y_pred_starts, y_pred_ends))

    return y_true_events, y_pred_events


def event_wise_recall(y_true_events, y_pred_events):
    """
    Event-based recall. We consider an event 'detected' if the predicted event
    overlaps with the true event in any way.
    """
    detected_events = 0
    for true_event in y_true_events:
        true_start, true_end = true_event
        for pred_event in y_pred_events:
            pred_start, pred_end = pred_event
            if pred_end >= true_start and pred_start <= true_end:
                detected_events += 1
                break
    return detected_events / len(y_true_events) if y_true_events else 0


def composite_f_score(y_true, y_pred):
    """
    Combines timepoint precision and event-wise recall into a single F-score.
    """
    prt = pointwise_precision(y_true, y_pred)
    y_true_events, y_pred_events = make_event(y_true, y_pred)
    rece = event_wise_recall(y_true_events, y_pred_events)
    if prt + rece == 0:
        return 0
    return 2 * (prt * rece) / (prt + rece)


def custom_auc_with_perfect_point(y_true, anomaly_scores, threshold_steps=100, plot=False):
    """
    Generates thresholds, computes precision (timepoint) and recall (event-wise)
    pairs, checks for a perfect point, and computes the AUC (area under the curve)
    on the PR plane.
    """
    percentiles = np.linspace(np.min(anomaly_scores), np.max(anomaly_scores) + 1e-7, threshold_steps)
    precision_list = []
    recall_list = []
    perfect_point_found = False

    for threshold in percentiles:
        y_pred = (anomaly_scores >= threshold).astype(int)
        prt = pointwise_precision(y_true, y_pred)

        y_true_events, y_pred_events = make_event(y_true, y_pred)
        rece = event_wise_recall(y_true_events, y_pred_events)

        precision_list.append(prt)
        recall_list.append(rece)

        if prt == 1 and rece == 1:
            perfect_point_found = True
            break

    # Compute AUC (precision vs recall)
    custom_area = auc(recall_list, precision_list)

    if plot:
        plt.figure(figsize=(8, 6))
        plt.plot(recall_list, precision_list, marker='o', label=f"AUC = {custom_area:.4f}")
        plt.title("Precision-Recall Curve")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.grid(True)
        plt.legend(loc="best")
        plt.show()

    return custom_area, perfect_point_found


def compute_auc_pr(y_true, anomaly_scores):
    """
    Compute AUC-PR for time-series anomaly detection.
    """
    try:
        precision, recall, _ = precision_recall_curve(y_true, anomaly_scores)
        auc_pr = auc(recall, precision)
    except ValueError:
        print("AUC-PR computation failed: Ensure both classes (0 and 1) are present in y_true.")
        auc_pr = np.nan
    return auc_pr


def compute_auc_roc(y_true, anomaly_scores):
    """
    Compute AUC-ROC for time-series anomaly detection.
    """
    try:
        auc_roc = roc_auc_score(y_true, anomaly_scores)
    except ValueError:
        print("AUC-ROC computation failed: Ensure both classes (0 and 1) are present in y_true.")
        auc_roc = np.nan
    return auc_roc

## Utilities

In [None]:
def create_windows(data, window_size: int, step_size: int = 1):
    """
    Given a 2D array `data` of shape (N, features), create overlapping windows
    of shape (window_size, features). Returns array of shape (M, window_size, features).
    If data is shorter than window_size, returns None.
    """
    N = data.shape[0]
    if N < window_size:
        return None
    windows = []
    for i in range(0, N - window_size + 1, step_size):
        window = data[i:i+window_size]
        windows.append(window)
    return np.stack(windows, axis=0)

def set_seed(seed):
    np.random.seed(seed)
    tf.random.set_seed(seed)


class UCR(Dataset):
    def __init__(self, path: str, window_size: int, step_size: int = 1, train: bool = True, data_id: int = 0):
        self.path = path
        self.window_size, self.step_size = window_size, step_size
        if train:
            self.data = self.create_windows('train', data_id)
            self.labels = torch.zeros_like(self.data)
        else:
            self.data = self.create_windows('test', data_id)
            self.labels = self.create_windows('labels', data_id)

    def create_windows(self, tag: str, data_id: int):
        files = [f for f in os.listdir(self.path) if f.endswith(f'{tag}.npy')]
        windows = []
        for file_path in files:
            if data_id != 0 and not file_path.startswith(f'{str(data_id)}_'):
                continue
            array = np.load(os.path.join(self.path, file_path))
            for i in range(0, len(array) - self.window_size + 1, self.step_size):
                windows.append(array[i:i + self.window_size])
        windows = np.array(windows) # because it's faster
        return torch.tensor(windows, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        return self.data[idx]

def get_dataloaders(path: str, window_size: int, batch_size: int, step_size: int = 1, data_id: int = 0, shuffle: bool = False, seed: int = 0):
    torch.manual_seed(seed)
    # Create datasets
    train_dataset = UCR(path, window_size, step_size=step_size, train=True, data_id=data_id)
    test_dataset = UCR(path, window_size, step_size=1, train=False, data_id=data_id) # test step size should always be 1
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(1)  # Shape: (max_len, 1, d_model)

    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pe[:seq_len, :].to(x.device)

## LSTM AutoEncoder

In [None]:
class LSTMAutoencoder:
    def __init__(self,
                 train_data,
                 test_data,
                 labels,
                 timesteps: int = 128,
                 features: int = 1,
                 latent_dim: int = 32,
                 lstm_units: int = 64,
                 step_size: int = 1,
                 threshold_sigma=2.0,
                 seed: int = 0):

        self.train_data = train_data
        self.test_data = test_data
        self.labels = labels

        self.timesteps = timesteps
        self.features = features
        self.latent_dim = latent_dim
        self.lstm_units = lstm_units
        self.step_size = step_size
        self.threshold_sigma = threshold_sigma

        # Prepare windowed data
        self.train_data_window = create_windows(self.train_data, timesteps, step_size)
        self.test_data_window = create_windows(self.test_data, timesteps, 1)

        # Model placeholders
        self.model = None
        self.threshold = 0

        # Arrays to hold predictions
        if self.test_data_window is not None:
            self.predictions_windows = np.zeros(len(self.test_data_window))
        self.anomaly_preds = np.zeros(len(self.test_data))
        self.anomaly_errors = np.zeros(len(self.test_data))
        self.predictions = np.zeros(len(self.test_data))

        self.losses = {'train': [], 'valid': []}
        self.name = 'LSTMAutoencoder'  # A name attribute for the class

        set_seed(seed)
        self._build_model()

    def _build_model(self):
        inputs = tf.keras.Input(shape=(self.timesteps, self.features), name='input_layer')

        x = layers.LSTM(self.lstm_units, return_sequences=True, name='lstm_1')(inputs)
        x = layers.LSTM(self.latent_dim, return_sequences=False, name='latent')(x)

        # Decoder
        x = layers.RepeatVector(self.timesteps, name='repeat_vector')(x)
        x = layers.LSTM(self.latent_dim, return_sequences=True, name='lstm_3')(x)
        x = layers.LSTM(self.lstm_units, return_sequences=True, name='lstm_4')(x)
        outputs = layers.TimeDistributed(layers.Dense(self.features, name='dense_output'))(x)

        self.model = models.Model(inputs, outputs, name='model')

    def compute_threshold(self):
        rec = self.model.predict(self.train_data_window, verbose=0)
        mse = np.mean(np.square(self.train_data_window - rec), axis=(1, 2))
        self.threshold = np.mean(mse) + self.threshold_sigma * np.std(mse)

    def train(self,
              batch_size=32,
              epochs=50,
              optimizer='adam',
              loss='mse',
              patience=10,
              shuffle: bool = False,
              seed: int = 42):
        set_seed(seed)

        # Custom max-diff loss function
        def max_diff_loss(y_true, y_pred):
            return tf.reduce_max(tf.abs(y_true - y_pred), axis=[1, 2])

        # Determine which loss function to use
        loss_function = 'mse' if loss == 'mse' else max_diff_loss

        # Compile the model
        self.model.compile(optimizer=optimizer, loss=loss_function)

        # Early stopping
        early_stopping = callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True
        )

        # Train the model
        history = self.model.fit(
            self.train_data_window, self.train_data_window,
            batch_size=batch_size,
            shuffle=shuffle,
            validation_split=0.1,
            epochs=epochs,
            verbose=1,
            callbacks=[early_stopping]
        )

        self.losses['train'] = [float(l) for l in history.history['loss']]
        self.losses['valid'] = [float(l) for l in history.history['val_loss']]

    def evaluate(self, batch_size=32, loss='mse'):
        """
        Evaluate the model on self.test_data_window.
        Sets self.anomaly_preds, self.anomaly_errors, self.predictions.
        """
        if self.test_data_window is None or len(self.test_data_window) == 0:
            print("No test windows available for evaluation.")
            return

        length = self.test_data.shape[0]
        self.compute_threshold()

        # Generate predictions for the test data windows
        self.predictions_windows = self.model.predict(self.test_data_window, batch_size=batch_size)

        # Compute reconstruction errors
        if loss == 'mse':
            errors = np.mean(np.square(self.test_data_window - self.predictions_windows), axis=(1, 2))
        else:
            # If using max_diff_loss
            errors = np.max(np.abs(self.test_data_window - self.predictions_windows), axis=(1, 2))

        # Expand window errors to match original time steps
        M = errors.shape[0]
        timestep_errors = np.zeros(length)
        counts = np.zeros(length)

        for i in range(M):
            start = i
            end = i + self.timesteps - 1
            timestep_errors[start:end + 1] += errors[i]
            counts[start:end + 1] += 1

        counts[counts == 0] = 1  # Avoid division by zero
        timestep_errors /= counts  # Average overlapping windows

        self.anomaly_preds = (timestep_errors > self.threshold).astype(int)
        self.anomaly_errors = timestep_errors

        # Compute predictions (averaged across windows)
        counts = np.zeros(length)
        self.predictions = np.zeros(length)
        for i in range(M):
            for j in range(self.timesteps):
                timestep_index = i + j
                if timestep_index < length:
                    self.predictions[timestep_index] += self.predictions_windows[i, j]
                    counts[timestep_index] += 1

        # Avoid division by zero
        for i in range(length):
            if counts[i] > 0:
                self.predictions[i] /= counts[i]

        self.predictions = np.nan_to_num(self.predictions)

    def get_latent(self, x):
        """
        Returns latent representation from the encoder part of the model.
        """
        encoder_model = models.Model(inputs=self.model.input,
                                    outputs=self.model.get_layer('latent').output)
        latent_representations = encoder_model.predict(x)
        return latent_representations

    def save_model(self, model_path: str = "model.h5"):
        """
        Save the Keras model to disk.
        """
        if self.model is not None:
            self.model.save(model_path)
            print(f"Model saved to {model_path}")
        else:
            print("No model to save.")

    def load_model(self, model_path: str, train_path: str, test_path: str, label_path: str):
        """
        Load the Keras model from the specified file paths and set
        self.train_data, self.test_data, and self.labels accordingly.
        """
        self.model = models.load_model(model_path, compile=False)
        self.model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_squared_error'])

        # Load data
        self.train_data = np.load(train_path)
        self.test_data = np.load(test_path)
        self.labels = np.load(label_path)

        # Recreate windows
        self.train_data_window = create_windows(self.train_data, self.timesteps)
        self.test_data_window = create_windows(self.test_data, self.timesteps)

        print(f"Loaded model from {model_path} and data from {train_path}, {test_path}, {label_path}.")

    def plot_results(self, save_path=None, file_format='html',size=800):
        """
        Plot test data, predictions, anomaly errors, and highlight
        labeled anomalies and predicted anomalies.
        """
        # Flatten arrays
        test_data = self.test_data.ravel()
        anomaly_preds = self.anomaly_preds
        anomaly_errors = self.anomaly_errors
        predictions = self.predictions
        labels = self.labels.ravel()

        if not (len(test_data) == len(labels) == len(anomaly_preds) == len(anomaly_errors) == len(predictions)):
            raise ValueError("All input arrays must have the same length.")

        plot_width = max(size, len(test_data) // 10)

        fig = go.Figure()
        # Test Data
        fig.add_trace(go.Scatter(x=list(range(len(test_data))),
                                 y=test_data,
                                 mode='lines',
                                 name='Test Data',
                                 line=dict(color='blue')))
        # Predictions
        fig.add_trace(go.Scatter(x=list(range(len(predictions))),
                                 y=predictions,
                                 mode='lines',
                                 name='Predictions',
                                 line=dict(color='purple')))
        # Anomaly Errors
        fig.add_trace(go.Scatter(x=list(range(len(anomaly_errors))),
                                 y=anomaly_errors,
                                 mode='lines',
                                 name='Anomaly Errors',
                                 line=dict(color='red')))

        # Labeled anomalies
        label_indices = [i for i in range(len(labels)) if labels[i] == 1]
        if label_indices:
            fig.add_trace(go.Scatter(x=label_indices,
                                     y=[test_data[i] for i in label_indices],
                                     mode='markers',
                                     name='Labels on Test Data',
                                     marker=dict(color='orange', size=10)))

        # Predicted anomalies
        anomaly_pred_indices = [i for i in range(len(anomaly_preds)) if anomaly_preds[i] == 1]
        if anomaly_pred_indices:
            fig.add_trace(go.Scatter(x=anomaly_pred_indices,
                                     y=[predictions[i] for i in anomaly_pred_indices],
                                     mode='markers',
                                     name='Anomaly Predictions',
                                     marker=dict(color='green', size=10)))

        fig.update_layout(title='Test Data, Predictions, and Anomalies',
                          xaxis_title='Time Steps',
                          yaxis_title='Value',
                          legend=dict(x=0, y=1, traceorder='normal', orientation='h'),
                          template='plotly',
                          width=plot_width)

        # Optionally save the figure
        if save_path is not None:
            # Ensure directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            if file_format.lower() == 'html':
                # Save as interactive HTML
                fig.write_html(save_path)
            else:
                # Save as static image (requires kaleido or orca)
                fig.write_image(save_path, format=file_format)

            print(f"Plot saved to: {save_path}")
        
        fig.show()

    def plot_losses(self, save_path=None):
        """
        Plot training and validation losses.
        """
        plt.figure(figsize=(10, 6))
        plt.plot(self.losses['train'], label='Training Loss')
        plt.plot(self.losses['valid'], label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Epochs')
        plt.legend()
        plt.grid(True)

        if save_path:
            # Ensure directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
            # Save the figure
            plt.savefig(save_path, bbox_inches='tight')
    
        plt.show()

## LSTM SAE

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import os
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from tqdm import trange
from tensorflow.keras.utils import custom_object_scope

class StationaryLoss(layers.Layer):
    def call(self, latent, mean_coef: float = 1.0, std_coef: float = 1.0):
        # Calculate the average of the latent space
        latent_avg = tf.reduce_mean(latent, axis=0)
        mse_loss = tf.reduce_mean(tf.abs(latent_avg))
        self.add_loss(mean_coef * mse_loss)
        
        # Calculate the standard deviation of the latent space
        latent_std = tf.math.reduce_std(latent, axis=0)
        std_loss = tf.reduce_mean(tf.abs(latent_std - 1.0))
        self.add_loss(std_coef * std_loss)
        
        # Store the losses separately for logging
        self.mse_loss = mean_coef * mse_loss
        self.std_loss = std_coef * std_loss
        
        return latent


class StationaryLSTMAutoencoder:
    def __init__(self, train_data, test_data, labels, timesteps: int = 128, features: int = 1, latent_dim: int = 32, lstm_units: int = 64, step_size: int = 1, threshold_sigma=2.0, seed: int = 0):
        self.train_data = train_data
        self.test_data = test_data
        self.train_data_window = create_windows(self.train_data, timesteps, step_size)
        self.test_data_window = create_windows(self.test_data, timesteps, 1)
        self.timesteps = timesteps
        self.features = features
        self.latent_dim = latent_dim
        self.lstm_units = lstm_units
        self.model = None  # Model is not built yet.
        self.threshold_sigma = threshold_sigma
        self.threshold = 0
        self.predictions_windows = np.zeros(len(self.test_data_window))
        self.anomaly_preds = np.zeros(len(self.test_data))
        self.anomaly_errors = np.zeros(len(self.test_data))
        self.predictions = np.zeros(len(self.test_data))
        self.labels = labels
        self.name = "LSTM_SAE"
        self.losses = {"mse": [], "mean": [], "std": []}
        set_seed(seed)
        self._build_model()

    def _build_model(self):
        # Encoder
        inputs = tf.keras.Input(shape=(self.timesteps, self.features))
        x = layers.LSTM(self.lstm_units, return_sequences=True)(inputs)
        x = layers.LSTM(self.latent_dim, return_sequences=False)(x)
        latent = layers.Dense(self.latent_dim)(x)

        # Apply custom loss to the latent space
        latent_with_loss = StationaryLoss()(latent, mean_coef=1.0, std_coef=1.0)

        # Decoder
        x = layers.RepeatVector(self.timesteps)(latent_with_loss)
        x = layers.LSTM(self.latent_dim, return_sequences=True)(x)
        x = layers.LSTM(self.lstm_units, return_sequences=True)(x)
        outputs = layers.TimeDistributed(layers.Dense(self.features))(x)

        # DAE Model
        self.model = models.Model(inputs, outputs)  # Return only the outputs (no KL divergence in this case)

    def train(self, batch_size: int = 32, epochs: int = 50, optimizer: str = 'adam', patience: int = 15, seed: int = 42, shuffle: bool = False):
        set_seed(seed)
        # Ensure the optimizer is set up correctly
        if isinstance(optimizer, str):
            optimizer = tf.keras.optimizers.get(optimizer)  # Get optimizer by name
        elif not isinstance(optimizer, tf.keras.optimizers.Optimizer):
            raise ValueError("Optimizer must be a string or a tf.keras.optimizers.Optimizer instance.")

        # Loss function
        mse_loss_fn = tf.keras.losses.MeanSquaredError()

        # Track losses
        mse_loss_tracker = tf.keras.metrics.Mean(name="mse_loss")
        mean_loss_tracker = tf.keras.metrics.Mean(name="mean_loss")
        std_loss_tracker = tf.keras.metrics.Mean(name="std_loss")

        # Early stopping variables 
        best_epoch_loss = float('inf') 
        patience_counter = 0

        # Training loop
        for epoch in (pbar := trange(epochs)):
            mse_loss_tracker.reset_state()
            mean_loss_tracker.reset_state()
            std_loss_tracker.reset_state()

            if shuffle:
                np.random.shuffle(self.train_data_window)

            for step in trange(0, len(self.train_data_window), batch_size, leave=False):
                batch_data = self.train_data_window[step:step + batch_size]
                epoch_loss = 0 # Should be changed to val loss later
                with tf.GradientTape() as tape:
                    # Forward pass
                    reconstructed = self.model(batch_data, training=True)

                    # Compute reconstruction loss
                    mse_loss = mse_loss_fn(batch_data, reconstructed)

                    # Get custom losses from the model
                    mean_loss = tf.reduce_mean([layer.mse_loss for layer in self.model.layers if isinstance(layer, StationaryLoss)])
                    std_loss = tf.reduce_mean([layer.std_loss for layer in self.model.layers if isinstance(layer, StationaryLoss)])

                    # Total loss
                    total_loss = mse_loss + mean_loss + std_loss

                # Compute gradients and update weights
                gradients = tape.gradient(total_loss, self.model.trainable_weights)
                optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))

                # Track losses
                mse_loss_tracker.update_state(mse_loss)
                mean_loss_tracker.update_state(mean_loss)
                std_loss_tracker.update_state(std_loss)

                epoch_loss += total_loss

            # Log losses after each epoch
            self.losses['mse'].append(float(mse_loss_tracker.result().numpy()))
            self.losses['mean'].append(float(mean_loss_tracker.result().numpy()))
            self.losses['std'].append(float(std_loss_tracker.result().numpy()))
            pbar.set_description(
                f"MSE Loss = {self.losses['mse'][-1]:.4f}, Mean Loss = {self.losses['mean'][-1]:.4f}, STD Loss = {self.losses['std'][-1]:.4f}"
            )

            # Early stopping logic 
            if epoch_loss < best_epoch_loss: 
                best_epoch_loss = epoch_loss 
                patience_counter = 0  
            else: 
                patience_counter += 1 
                if patience_counter >= patience: 
                    print(f"Early stopping triggered after {epoch + 1} epochs.") 
                    break

    def compute_threshold(self):
        rec = self.model.predict(self.train_data_window, verbose=0)
        mse = np.mean(np.square(self.train_data_window - rec), axis=(1, 2))
        self.threshold = np.mean(mse) + self.threshold_sigma * np.std(mse)

    def evaluate(self, batch_size=32):
        length = self.test_data.shape[0]
        self.compute_threshold()
        # Generate predictions for the test data windows
        self.predictions_windows = self.model.predict(self.test_data_window, batch_size=batch_size)
        mse = np.mean(np.square(self.test_data_window - self.predictions_windows), axis=(1, 2))

        # Expand errors to original length
        M = mse.shape[0]
        timestep_errors = np.zeros(length)
        counts = np.zeros(length)

        # Each window i covers timesteps [i, i+window_size-1]
        for i in range(M):
            start = i
            end = i + self.timesteps - 1
            timestep_errors[start:end + 1] += mse[i]
            counts[start:end + 1] += 1

        counts[counts == 0] = 1  # Avoid division by zero
        timestep_errors /= counts  # Average overlapping windows

        # Generate anomaly predictions based on the threshold
        self.anomaly_preds = (timestep_errors > self.threshold).astype(int)
        self.anomaly_errors = timestep_errors

        counts = np.zeros(length)
        for i in range(M):
            for j in range(self.timesteps):
                timestep_index = i + j  # This is the index in the timestep corresponding to the current prediction
                if timestep_index < length:  # Ensure we don't go out of bounds
                    self.predictions[timestep_index] += self.predictions_windows[i, j]  # Accumulate each prediction appropriately
                    counts[timestep_index] += 1

        # Divide by counts to get the average prediction
        for i in range(length):
            if counts[i] > 0:
                self.predictions[i] /= counts[i]

        self.predictions = np.nan_to_num(self.predictions)

    def get_latent(self, x):
        encoder_model = models.Model(inputs=self.model.input, outputs=self.model.get_layer('latent').output)
        latent_representations = encoder_model.predict(x)
        return latent_representations


    def save_model(self, model_path: str = "model.h5"):
        
        # Save the Keras model
        if self.model is not None:
            self.model.save(model_path)
            print(f"Model saved to {model_path}")
        else:
            print("No model to save.")

    
    def load_model(self, model_path: str, train_path: str, test_path: str, label_path: str):
        # Use custom_object_scope for the custom layer
        with custom_object_scope({'StationaryLoss': StationaryLoss}):
            self.model = models.load_model(
                model_path,
                compile=False  # Avoid recompiling until the model is fully loaded
            )
    
        # Compile the model for evaluation or retraining
        self.model.compile(
            optimizer='adam',
            loss='mean_squared_error',
            metrics=['mean_squared_error']
        )
    
        # Load data
        self.train_data = np.load(train_path)
        self.test_data = np.load(test_path)
        self.labels = np.load(label_path)
    
        # Recreate the windows with the newly loaded data
        self.train_data_window = create_windows(self.train_data, self.timesteps)
        self.test_data_window = create_windows(self.test_data, self.timesteps)
    
    def plot_results(self, size=800, save_path=None, file_format='html'):
        # Flattening arrays to ensure they are 1D
        test_data = self.test_data.ravel()  # Convert to 1D array
        anomaly_preds = self.anomaly_preds  # Already 1D
        anomaly_errors = self.anomaly_errors  # Already 1D
        predictions = self.predictions  # Already 1D
        labels = self.labels.ravel()  # Convert to 1D array

        plot_width = max(size, len(test_data) // 10)  # Ensure a minimum width of 800, scale with data length

        # Check if all inputs have the same length
        if not (len(test_data) == len(labels) == len(anomaly_preds) == len(anomaly_errors) == len(predictions)):
            raise ValueError("All input arrays must have the same length.")

        # Create a figure
        fig = go.Figure()

        # Add traces for test data, predictions, and anomaly errors
        fig.add_trace(go.Scatter(x=list(range(len(test_data))),
                                 y=test_data,
                                 mode='lines',
                                 name='Test Data'))

        fig.add_trace(go.Scatter(x=list(range(len(predictions))),
                                 y=predictions,
                                 mode='lines',
                                 name='Predictions'))

        fig.add_trace(go.Scatter(x=list(range(len(anomaly_errors))),
                                 y=anomaly_errors,
                                 mode='lines',
                                 name='Anomaly Errors'))

        # Highlight points in test_data where label is 1
        label_indices = [i for i in range(len(labels)) if labels[i] == 1]
        if label_indices:
            fig.add_trace(go.Scatter(x=label_indices,
                                     y=[test_data[i] for i in label_indices],
                                     mode='markers',
                                     name='Labels on Test Data',
                                     marker=dict(color='orange', size=10)))

        # Set the layout
        fig.update_layout(title='Test Data, Predictions, and Anomalies',
                          xaxis_title='Time Steps', yaxis_title='Value',
                          legend=dict(x=0, y=1, traceorder='normal', orientation='h'),
                          template='plotly', width=plot_width)
        # Optionally save the figure
        if save_path is not None:
            # Ensure directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            if file_format.lower() == 'html':
                # Save as interactive HTML
                fig.write_html(save_path)
            else:
                # Save as static image (requires kaleido or orca)
                fig.write_image(save_path, format=file_format)

            print(f"Plot saved to: {save_path}")

        # Show the figure
        fig.show()

    def plot_losses(self, save_path=None):
        # Plot the loss values
        plt.figure(figsize=(10, 6))
        plt.plot(self.losses['mse'], label='MSE Reconstruction Loss')
        plt.plot(self.losses['mean'], label='Latent Mean Loss')
        plt.plot(self.losses['std'], label='Latent Standard Deviation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Epochs')
        plt.legend()
        plt.grid(True)
        if save_path:
            # Ensure directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
            # Save the figure
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()
    
def set_seed(seed):
    np.random.seed(seed)
    tf.random.set_seed(seed)