# Comparison of Replay Strategies in pyCLAD

This notebook compares the performance of `ReplayEnhancedStrategy` and `BalancedReservoirSamplingStrategy` from the `pyCLAD` library.
The comparison is performed across four datasets: Energy, NSL-KDD, UNSW, and Wind.
For each dataset, three different data scenarios are used: `random_anomalies`, `clustered_with_closest_assignment`, and `clustered_with_random_assignment`.

The notebook is divided into three main parts:
1.  **Setup**: Imports necessary libraries and defines the configuration for datasets, strategies, and output directories.
2.  **Run Experiments**: Executes the experiments for each combination of dataset, scenario, and strategy, saving the results to JSON files.
3.  **Analyze Results**: Loads the saved results, generates heatmaps for each experiment, and creates comparison heatmaps to visualize the performance difference between the two strategies.

In [1]:
import pathlib
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# Datasets
from pyclad.data.datasets.unsw_dataset import UnswDataset
from pyclad.data.datasets.nsl_kdd_dataset import NslKddDataset
from pyclad.data.datasets.wind_energy_dataset import WindEnergyDataset
from pyclad.data.datasets.energy_plants_dataset import EnergyPlantsDataset

# Scenarios
from pyclad.scenarios.concept_aware import ConceptAwareScenario

from pyclad.metrics.continual.average_continual import ContinualAverage
from pyclad.metrics.continual.backward_transfer import BackwardTransfer
from pyclad.metrics.continual.forward_transfer import ForwardTransfer
# Models
from pyclad.models.adapters.pyod_adapters import LocalOutlierFactorAdapter

# Strategies
from pyclad.strategies.replay.replay import ReplayEnhancedStrategy
from pyclad.strategies.replay.watch import WatchStrategy

# Additional imports for replay strategies
from pyclad.strategies.replay.buffers.adaptive_balanced import (
    AdaptiveBalancedReplayBuffer,
)
from pyclad.strategies.replay.selection.random import RandomSelection

# Callback and metrics
from pyclad.callbacks.evaluation.concept_metric_evaluation import ConceptMetricCallback
from pyclad.callbacks.evaluation.memory_usage import MemoryUsageCallback
from pyclad.callbacks.evaluation.time_evaluation import TimeEvaluationCallback
from pyclad.metrics.base.roc_auc import RocAuc
from pyclad.output.json_writer import JsonOutputWriter
import time

# Configuration
DATASETS = {
    "energy": EnergyPlantsDataset,
    "nsl-kdd": NslKddDataset,
    "unsw": UnswDataset,
    "wind": WindEnergyDataset,
}

DATASET_TYPES = [
    "random_anomalies",
    "clustered_with_closest_assignment",
    "clustered_with_random_assignment",
]

STRATEGIES = {
    "replay_enhanced": lambda model: ReplayEnhancedStrategy(
        model,
        AdaptiveBalancedReplayBuffer(selection_method=RandomSelection(), max_size=max_size),
    ),
    "watch_percentile": lambda model: CandiStrategy(
        model, max_buffer_size=max_size, threshold_ratio=0.5, warm_up_period=2, threshold_cal_index=2, resize_new_regime=True
    ),
}

RESULTS_DIR = pathlib.Path("comparison_results")
PLOTS_DIR = pathlib.Path("comparison_plots")
RESULTS_DIR.mkdir(exist_ok=True)
PLOTS_DIR.mkdir(exist_ok=True)

print("Setup complete.")

  from .autonotebook import tqdm as notebook_tqdm


Setup complete.


### Run Experiments

The following cell runs the experiments. It iterates through each dataset and dataset type, and for each, it runs both the `ReplayEnhancedStrategy` and the `BalancedReservoirSamplingStrategy`. The results are saved in the `comparison_results` directory.

**Note:** This process can be time-consuming.

In [2]:
import time


def run_experiments():
    for dataset_name, dataset_class in DATASETS.items():
        for dataset_type in DATASET_TYPES:
            print(f"Running experiments for {dataset_name} - {dataset_type}")
            try:
                dataset = dataset_class(dataset_type=dataset_type)
            except Exception as e:
                print(f"Could not load dataset {dataset_name} with type {dataset_type}. Error: {e}")
                continue

            for strategy_name, strategy_builder in STRATEGIES.items():
                start_time = time.time()
                print(f"  with strategy: {strategy_name}")
                start_time = time.time()
                model = LocalOutlierFactorAdapter()
                strategy = strategy_builder(model)

                callbacks = [
                    ConceptMetricCallback(
                        base_metric=RocAuc(),
                        metrics=[ContinualAverage(), BackwardTransfer(), ForwardTransfer()]
                    ),
                    TimeEvaluationCallback(),
                    MemoryUsageCallback(),
                ]
                scenario = ConceptAwareScenario(dataset, strategy=strategy, callbacks=callbacks)
                scenario.run()

                callbacks[0].print_continual_average()
                end_time = time.time()
                print(f"  Time taken: {end_time - start_time:.2f} seconds")

            print("-" * 20)
            
        print(f"Finished experiments for {dataset_name}.")
        print("=" * 40)

run_experiments()
print("All experiments finished.")

Running experiments for wind - random_anomalies
  with strategy: replay_enhanced
Continual Average: 0.9720743220173798
  Time taken: 5.41 seconds
  with strategy: watch_percentile
Fitting model with 2000 samples
Fitting model with 4000 samples
Fitting model with 4001 samples
Fitting model with 4000 samples
Fitting model with 4000 samples
Continual Average: 0.968982689145171
  Time taken: 3.27 seconds
--------------------
Running experiments for wind - clustered_with_closest_assignment
  with strategy: replay_enhanced
Continual Average: 0.9495798783267121
  Time taken: 7.76 seconds
  with strategy: watch_percentile
Fitting model with 2000 samples
Fitting model with 4000 samples
Fitting model with 4001 samples
Fitting model with 4000 samples
Fitting model with 4000 samples
Fitting model with 4004 samples
Fitting model with 4002 samples
Fitting model with 4000 samples
Fitting model with 4007 samples
Fitting model with 4000 samples
Continual Average: 0.9479149027481414
  Time taken: 5.93 s

### Analyze Results and Generate Plots

This cell analyzes the results from the experiments. For each dataset and scenario, it generates:
1.  A heatmap of the ROC-AUC scores for each strategy individually.
2.  A comparison heatmap showing the difference in ROC-AUC scores between the `ReplayEnhancedStrategy` and the `BalancedReservoirSamplingStrategy`.

All plots are saved in the `comparison_plots` directory.

In [3]:
def analyze_results():
    for dataset_name in DATASETS.keys():
        for dataset_type in DATASET_TYPES:
            print(f"Analyzing results for {dataset_name} - {dataset_type}")
            
            results_files = {
                strategy_name: RESULTS_DIR / f"{dataset_name}_{dataset_type}_{strategy_name}.json"
                for strategy_name in STRATEGIES.keys()
            }

            if not all(f.exists() for f in results_files.values()):
                print(f"  Skipping, not all result files found for {dataset_name} - {dataset_type}")
                continue

            dataframes = {}
            concepts = None
            for strategy_name, file_path in results_files.items():
                with open(file_path, "r") as f:
                    results = json.load(f)
                
                metric_key = next((k for k in results if k.startswith("concept_metric_callback")), None)
                if not metric_key:
                    print(f"  Could not find metric callback in {file_path}")
                    continue

                metric_data = results[metric_key]
                if concepts is None:
                    concepts = metric_data["concepts_order"]
                matrix = metric_data["metric_matrix"]
                df = pd.DataFrame(matrix, index=concepts, columns=concepts)
                dataframes[strategy_name] = df

                # Plot individual heatmap
                plt.figure(figsize=(8, 6))
                sns.heatmap(df.where(np.triu(np.ones(df.shape), k=0).astype(bool)), annot=True, fmt=".2f", cmap="viridis")
                plt.title(f"ROC-AUC: {dataset_name} - {dataset_type}\\n({strategy_name})")
                plt.ylabel("Trained Concepts")
                plt.xlabel("Testing Concepts")
                plt.tight_layout()
                plot_filename = PLOTS_DIR / f"{dataset_name}_{dataset_type}_{strategy_name}_heatmap.png"
                plt.savefig(plot_filename)
                plt.close()
                print(f"  Saved heatmap to {plot_filename}")

            if len(dataframes) == 2:
                # Plot comparison heatmap
                df1 = dataframes["replay_enhanced"]
                df2 = dataframes["watch"]
                comparison_df = df1 - df2
                
                plt.figure(figsize=(8, 6))
                sns.heatmap(comparison_df.where(np.triu(np.ones(comparison_df.shape), k=0).astype(bool)), annot=True, fmt=".2f", cmap="coolwarm", center=0)
                plt.title(f"ROC-AUC Comparison: Replay Enhanced vs Watch\\n{dataset_name} - {dataset_type}")
                plt.ylabel("Trained Concepts")
                plt.xlabel("Testing Concepts")
                plt.tight_layout()
                comp_plot_filename = PLOTS_DIR / f"{dataset_name}_{dataset_type}_comparison_heatmap.png"
                plt.savefig(comp_plot_filename)
                plt.close()
                print(f"  Saved comparison heatmap to {comp_plot_filename}")

analyze_results()
print("Analysis finished. All plots are saved.")

Analyzing results for energy - random_anomalies
  Saved heatmap to comparison_plots/energy_random_anomalies_replay_enhanced_heatmap.png
  Saved heatmap to comparison_plots/energy_random_anomalies_watch_heatmap.png
  Saved comparison heatmap to comparison_plots/energy_random_anomalies_comparison_heatmap.png
Analyzing results for energy - clustered_with_closest_assignment
  Saved heatmap to comparison_plots/energy_clustered_with_closest_assignment_replay_enhanced_heatmap.png
  Saved heatmap to comparison_plots/energy_clustered_with_closest_assignment_watch_heatmap.png
  Saved comparison heatmap to comparison_plots/energy_clustered_with_closest_assignment_comparison_heatmap.png
Analyzing results for energy - clustered_with_random_assignment
  Saved heatmap to comparison_plots/energy_clustered_with_random_assignment_replay_enhanced_heatmap.png
  Saved heatmap to comparison_plots/energy_clustered_with_random_assignment_watch_heatmap.png
  Saved comparison heatmap to comparison_plots/energy_