# MSSD x RL: Full Experiment Pipeline

This notebook runs the **complete MSSD (Multi-Scale Sequential Shift Detector)** pipeline end-to-end:

1. **Install dependencies** and verify GPU availability
2. **Train RL agents** (Tabular Q-learning for CliffWalking, DQN for CartPole)
3. **Collect reference observations** from the trained agents under nominal conditions
4. **Run the unit tests** to validate all components
5. **Run a smoke test** (quick 2-trial sweep) to verify the pipeline works end-to-end
6. **Run full experiment sweeps** (50 trials x 3 shift types x 3 severities per environment)
7. **Generate figures and LaTeX tables** for the report

---

## Project Overview

MSSD wraps a deployed RL agent with a **runtime monitor** that diagnoses *which aspect* of the observation distribution has shifted:

| Probe | What it detects | Statistic |
|-------|----------------|----------|
| **Body** | Mean/variance shift | MMD (d>4) or max-KS (d<=4) |
| **Tail** | Rare dangerous states | CVaR_0.95 difference |
| **Structure** | Correlation breakdown | Frobenius norm of correlation diff |

Each probe's test statistic is converted to an **e-value** via block-bootstrap permutation, then accumulated in a **product martingale**. An alarm fires when the combined log-wealth exceeds `log(1/alpha)`, and the probe with the highest log-wealth provides the **diagnosis**.

## 0. Setup

Make sure you are running this notebook from the project root directory, or adjust the paths below.

If you have a **conda environment** with GPU-enabled PyTorch (e.g., `torch_5070`), activate it before launching Jupyter:
```bash
conda activate torch_5070
jupyter notebook
```

In [1]:
import os
import sys
from pathlib import Path

# Set the project root (adjust if needed)
PROJECT_ROOT = Path(os.getcwd()).parent if Path(os.getcwd()).name == "notebooks" else Path(os.getcwd())
os.chdir(PROJECT_ROOT)
sys.path.insert(0, str(PROJECT_ROOT / "src"))

print(f"Project root: {PROJECT_ROOT}")
print(f"Working directory: {os.getcwd()}")

Project root: d:\loic\ETUDE\IMPERIAL\cours\Spring Term\ML for Safety Critical Decision-Making - ELEC70122\Safety Critical Group project\final_project_idea
Working directory: d:\loic\ETUDE\IMPERIAL\cours\Spring Term\ML for Safety Critical Decision-Making - ELEC70122\Safety Critical Group project\final_project_idea


In [2]:
# Install the package in editable mode (run once)
!pip install -e . -q


[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: C:\Program Files\Python311\python.exe -m pip install --upgrade pip
ERROR: Exception:
Traceback (most recent call last):
  File "C:\Users\loicb\AppData\Roaming\Python\Python311\site-packages\pip\_internal\cli\base_command.py", line 107, in _run_wrapper
    status = _inner_run()
             ^^^^^^^^^^^^
  File "C:\Users\loicb\AppData\Roaming\Python\Python311\site-packages\pip\_internal\cli\base_command.py", line 98, in _inner_run
    return self.run(options, args)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\loicb\AppData\Roaming\Python\Python311\site-packages\pip\_internal\cli\req_command.py", line 85, in wrapper
    return func(self, options, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\loicb\AppData\Roaming\Python\Python311\site-packages\pip\_internal\commands\install.py", line 388, in run
    requirement_set = resolver.resolve(
                      ^^^^^^^^^^^^^^^^^
  File "C:

### Check GPU availability

DQN training for CartPole benefits from GPU acceleration. The `device: "auto"` setting in `configs/defaults.yaml` will automatically use CUDA if available.

In [3]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("No GPU detected - training will use CPU (slower for DQN, fine for Tabular Q).")

PyTorch version: 2.10.0.dev20251204+cu130
CUDA available: True
GPU: NVIDIA GeForce RTX 5070 Ti Laptop GPU
CUDA version: 13.0


### Create output directories

All artifacts (trained agents, reference data, experiment results, figures) are stored under `artifacts/`.

In [None]:
from pathlib import Path

dirs = [
    "artifacts/agents/cliffwalking",
    "artifacts/agents/cartpole",
    "artifacts/reference",
    "artifacts/results",
    "artifacts/figures",
    "artifacts/tables",
    "artifacts/logs",
]
for d in dirs:
    Path(d).mkdir(parents=True, exist_ok=True)

print("Output directories created.")

---

## 1. Train RL Agents

We train two agents:

### 1a. Tabular Q-learning for CliffWalking-v1

- **Environment**: `CliffWalking-v1` wrapped with `CliffWalkingContinuousObs` to produce 3D observations `(row, col, cliff_distance)` instead of a single integer state.
- **Algorithm**: Tabular Q-learning with epsilon-greedy exploration.
- **Episodes**: 5,000 training episodes.
- **Output**: Q-table saved as `.npy` file.

In [None]:
!python scripts/train_agent.py --config configs/envs/cliffwalking.yaml --output artifacts/agents/cliffwalking/q_table.npy --seed 42

### 1b. DQN for CartPole-v1

- **Environment**: `CartPole-v1` with 4D native observations `(x, x_dot, theta, theta_dot)`.
- **Algorithm**: Deep Q-Network with experience replay and target network.
- **Episodes**: 500 training episodes.
- **Device**: Automatically uses GPU if available (set via `--device auto`).
- **Output**: PyTorch checkpoint saved as `.pt` file.

In [None]:
!python scripts/train_agent.py --config configs/envs/cartpole.yaml --output artifacts/agents/cartpole/dqn_checkpoint.pt --seed 42 --device auto

---

## 2. Collect Reference Observations

We deploy each trained agent in the **nominal (unshifted)** environment and collect observations. These serve as the **reference distribution** against which the MSSD monitor will compare incoming data during deployment.

By default, we collect observations from **100 episodes** per environment.

### 2a. CliffWalking reference observations

In [None]:
!python scripts/collect_reference.py --config configs/envs/cliffwalking.yaml --agent artifacts/agents/cliffwalking/q_table.npy --output artifacts/reference/cliffwalking_ref.npz --n-episodes 100 --seed 42

### 2b. CartPole reference observations

In [None]:
!python scripts/collect_reference.py --config configs/envs/cartpole.yaml --agent artifacts/agents/cartpole/dqn_checkpoint.pt --output artifacts/reference/cartpole_ref.npz --n-episodes 100 --seed 42

### Inspect reference data

Let's take a quick look at the reference observation shapes to make sure everything is correct.

In [None]:
import numpy as np

for env_name in ["cliffwalking", "cartpole"]:
    ref_path = f"artifacts/reference/{env_name}_ref.npz"
    data = np.load(ref_path)
    obs = data["observations"]
    print(f"{env_name}: {obs.shape[0]} observations, {obs.shape[1]}D")
    print(f"  Mean: {obs.mean(axis=0).round(3)}")
    print(f"  Std:  {obs.std(axis=0).round(3)}")
    print()

---

## 3. Run Unit Tests

Before running experiments, let's verify that all components (probes, martingale, metrics, wrappers, shifts) work correctly.

In [None]:
!python -m pytest tests/unit/ -v --tb=short

All 37 unit tests should pass. If any fail, fix the underlying issue before proceeding.

---

## 4. Smoke Test (Quick Pipeline Validation)

Run a **minimal 2-trial sweep** with only body shifts to verify the entire pipeline works end-to-end before committing to the full experiment.

This uses `configs/experiments/quick_smoke.yaml`, which overrides defaults for speed:
- Only 2 trials (instead of 50)
- Only 20 permutations (instead of 200)
- Only 200 monitoring steps (instead of 2,000)
- Only body shifts (instead of all 3 types)

### 4a. Smoke test on CliffWalking

In [None]:
!python scripts/run_sweep.py --config configs/experiments/quick_smoke.yaml --agent artifacts/agents/cliffwalking/q_table.npy --reference artifacts/reference/cliffwalking_ref.npz

### 4b. Smoke test on CartPole

In [None]:
!python scripts/run_sweep.py --config configs/experiments/quick_smoke.yaml --agent artifacts/agents/cartpole/dqn_checkpoint.pt --reference artifacts/reference/cartpole_ref.npz

If the smoke tests complete without errors, the pipeline is ready for the full experiments.

---

## 5. Run a Single Trial (Optional)

You can run individual trials for debugging or closer inspection. This is useful to understand what happens step-by-step.

**Parameters**:
- `--shift-type`: one of `body`, `tail`, `structure`, or `none`
- `--severity`: shift magnitude (higher = stronger shift)
- `--trial-id`: integer identifier for this trial
- `--seed`: random seed for reproducibility

In [None]:
# Example: run a single body-shift trial on CliffWalking
!python scripts/run_experiment.py --config configs/envs/cliffwalking.yaml --defaults configs/defaults.yaml --agent artifacts/agents/cliffwalking/q_table.npy --reference artifacts/reference/cliffwalking_ref.npz --shift-type body --severity 0.6 --trial-id 0 --seed 1000 --output artifacts/results/single_trial_cliff_body.npz

In [None]:
# Example: run a single tail-shift trial on CartPole
!python scripts/run_experiment.py --config configs/envs/cartpole.yaml --defaults configs/defaults.yaml --agent artifacts/agents/cartpole/dqn_checkpoint.pt --reference artifacts/reference/cartpole_ref.npz --shift-type tail --severity 1.5 --trial-id 0 --seed 2000 --output artifacts/results/single_trial_cartpole_tail.npz

In [None]:
# Inspect the result
import numpy as np

result = dict(np.load("artifacts/results/single_trial_cliff_body.npz", allow_pickle=True))
print("Single trial result:")
for k, v in result.items():
    print(f"  {k}: {v}")

---

## 6. Full Experiment Sweeps

Now we run the **complete experiments** required for the paper. Each sweep runs:

- **50 trials** per condition
- **3 shift types**: body, tail, structure
- **3 severity levels** per shift type (defined in the env config)
- **50 null (no-shift) trials** for false alarm rate estimation
- **Total per environment**: 50 x 3 x 3 + 50 = **500 trials**

The sweep runs with **4 parallel workers** by default.

> **Note**: Full sweeps can take a significant amount of time depending on your hardware. The CliffWalking sweep is faster (tabular agent); the CartPole sweep is heavier (DQN inference on every step).

### 6a. CliffWalking full sweep

Shift types and severities (from `configs/envs/cliffwalking.yaml`):
- **Body** (coordinate offset): severities `[0.3, 0.6, 1.0]`
- **Tail** (hazard state injection): severities `[1.0, 2.0, 3.0]`
- **Structure** (feature decorrelation): severities `[0.5, 1.0, 2.0]`

In [None]:
!python scripts/run_sweep.py --config configs/experiments/sweep_cliffwalking.yaml --agent artifacts/agents/cliffwalking/q_table.npy --reference artifacts/reference/cliffwalking_ref.npz

### 6b. CartPole full sweep

Shift types and severities (from `configs/envs/cartpole.yaml`):
- **Body** (pole mass drift): severities `[0.1, 0.3, 0.5]`
- **Tail** (extreme-angle injection): severities `[1.0, 1.5, 2.0]`
- **Structure** (x-theta correlation broken): severities `[0.3, 0.6, 1.0]`

In [None]:
!python scripts/run_sweep.py --config configs/experiments/sweep_cartpole.yaml --agent artifacts/agents/cartpole/dqn_checkpoint.pt --reference artifacts/reference/cartpole_ref.npz

---

## 7. Generate Figures

Once the sweeps are complete, generate the paper figures:

- **Orthogonality heatmap** (3x3 matrix): shows which probe fires for each shift type. A diagonal-dominant heatmap confirms that MSSD correctly diagnoses the shift source.
- **ADD comparison bar chart**: compares Average Detection Delay between MSSD and the global MMD baseline.
- **Log-wealth trajectory plots**: shows how the product martingale accumulates evidence over time.

### 7a. Figures from CliffWalking results

In [None]:
!python scripts/generate_figures.py --results-dir artifacts/results/sweep_cliffwalking --output-dir artifacts/figures

### 7b. Figures from CartPole results

In [None]:
!python scripts/generate_figures.py --results-dir artifacts/results/sweep_cartpole --output-dir artifacts/figures

### Display generated figures

In [None]:
from IPython.display import Image, display
from pathlib import Path

figures_dir = Path("artifacts/figures")
for fig_path in sorted(figures_dir.glob("*.png")):
    print(f"\n--- {fig_path.name} ---")
    display(Image(filename=str(fig_path), width=600))

---

## 8. Generate LaTeX Tables

Generate publication-ready LaTeX tables summarizing the results:
- Average Detection Delay (ADD) for MSSD vs. baseline
- Probe discrimination accuracy
- False alarm rate (FAR)

In [None]:
# Generate tables from CliffWalking results
!python scripts/generate_tables.py --results-dir artifacts/results/sweep_cliffwalking --output-dir artifacts/tables

# Generate tables from CartPole results
!python scripts/generate_tables.py --results-dir artifacts/results/sweep_cartpole --output-dir artifacts/tables

In [None]:
# Display the generated LaTeX table
from pathlib import Path

for tex_path in sorted(Path("artifacts/tables").glob("*.tex")):
    print(f"\n=== {tex_path.name} ===")
    print(tex_path.read_text())

---

## 9. Quick Results Summary (In-Notebook Analysis)

For a quick look at the results directly in the notebook, we can load and analyze them using the built-in evaluation module.

In [None]:
import numpy as np
from pathlib import Path
from mssd.evaluation.metrics import TrialResult, compute_add, compute_far, compute_discrimination_accuracy

def load_results(results_dir):
    """Load all .npz result files into TrialResult objects."""
    results = []
    for npz_path in Path(results_dir).rglob("*.npz"):
        data = dict(np.load(npz_path, allow_pickle=True))
        alarm_step = int(data["mssd_alarm_step"])
        baseline_step = int(data["baseline_alarm_step"])
        results.append(TrialResult(
            env_name=str(data["env_name"]),
            shift_type=str(data["shift_type"]),
            severity=float(data["severity"]),
            trial_id=int(data["trial_id"]),
            seed=int(data["seed"]),
            mssd_alarm_fired=bool(data["mssd_alarm_fired"]),
            mssd_alarm_step=alarm_step if alarm_step >= 0 else None,
            mssd_diagnosed_probe=str(data["mssd_diagnosed_probe"]) if str(data["mssd_diagnosed_probe"]) != "none" else None,
            mssd_log_wealth={},
            baseline_alarm_fired=bool(data["baseline_alarm_fired"]),
            baseline_alarm_step=baseline_step if baseline_step >= 0 else None,
            shift_injection_step=int(data["shift_injection_step"]),
            total_steps=int(data["total_steps"]),
        ))
    return results


for env in ["sweep_cliffwalking", "sweep_cartpole"]:
    results_dir = f"artifacts/results/{env}"
    if not Path(results_dir).exists():
        print(f"No results found for {env} - run the sweep first.")
        continue
    
    results = load_results(results_dir)
    if not results:
        print(f"No results found in {results_dir}")
        continue
    
    shift_results = [r for r in results if r.shift_type != "none"]
    null_results = [r for r in results if r.shift_type == "none"]
    
    print(f"\n{'='*60}")
    print(f"  {env.upper()}")
    print(f"{'='*60}")
    print(f"Total trials: {len(results)} ({len(shift_results)} shifted, {len(null_results)} null)")
    
    # Detection rate
    if shift_results:
        detected = sum(1 for r in shift_results if r.mssd_alarm_fired)
        print(f"Detection rate: {detected}/{len(shift_results)} = {detected/len(shift_results):.1%}")
        
        # ADD
        add = compute_add(shift_results)
        print(f"Average Detection Delay (MSSD): {add:.1f} steps")
    
    # FAR
    if null_results:
        far = compute_far(null_results)
        print(f"False Alarm Rate: {far:.3f} (target: < 0.05)")
    
    # Discrimination accuracy per shift type
    if shift_results:
        for shift_type in ["body", "tail", "structure"]:
            typed = [r for r in shift_results if r.shift_type == shift_type]
            if typed:
                acc = compute_discrimination_accuracy(typed, shift_type)
                print(f"  {shift_type} discrimination accuracy: {acc:.1%}")

---

## Summary

### Expected outcomes

If everything works correctly, you should see:

1. **Orthogonality**: The 3x3 heatmap should be **diagonal-dominant**, meaning each probe primarily fires for its corresponding shift type (body probe detects body shifts, tail probe detects tail shifts, structure probe detects structure shifts).

2. **Low false alarm rate**: FAR should be below the significance level alpha = 0.05.

3. **Competitive detection delay**: MSSD should achieve ADD comparable to or better than the global MMD baseline, while additionally providing a **diagnosis** of the shift type.

4. **Severity monotonicity**: Higher severity shifts should be detected faster (lower ADD).

### Troubleshooting

- **Agent training fails**: Check that gymnasium is installed and the env name is correct (`CliffWalking-v1`, `CartPole-v1`).
- **GPU not detected**: Ensure your conda env has CUDA-enabled PyTorch. Run `torch.cuda.is_available()` to check.
- **Sweep takes too long**: Reduce `n_trials` in the sweep config, or increase `parallel_workers`.
- **Tests fail**: Run `pytest tests/unit/ -v` and fix any failures before proceeding.