# Mechanistic Watchdog: Stress Test & Validation

**Module:** `MechWatch`
**Goal:** Verify that the "Deception Score" is linearly separable between Truthful Control prompts and Deceptive/Adversarial prompts.

In [None]:
import json
import sys
import os
from pathlib import Path

# 1. Setup Paths dynamically for Windows/Linux compatibility
current_dir = Path.cwd()
# If we are in 'notebooks', parent is project root. If in root, current is root.
project_root = current_dir.parent if current_dir.name == "notebooks" else current_dir

if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

print(f"üìÇ Project Root: {project_root}")

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm import tqdm

# 2. Import from MechWatch
try:
    from MechWatch.runtime import WatchdogRuntime
    from MechWatch.config import load_config
    from MechWatch.calibrate import standardize_dataset
    print("‚úÖ Successfully imported WatchdogRuntime from MechWatch")
except ImportError as e:
    print(f"‚ùå Import failed: {e}. Ensure the folder is named 'MechWatch' and is in the project root.")

## 1. Initialize Runtime & Vector Profiles
We keep a library of defensive profiles (truthfulness, cyber, bio) and hot-swap the learned vector before each burn-in test.

In [None]:
ARTIFACT_DIR = project_root / "artifacts"
ARTIFACT_DIR.mkdir(exist_ok=True)
print(f"üì¶ Artifacts dir: {ARTIFACT_DIR}")

## 2. Profile Definitions & Datasets

In [None]:
import pandas as pd
from datasets import load_dataset
from typing import Callable, Dict, List

SEED = 42
SAMPLES_PER_CATEGORY = 40


def build_truthfulness_cases(seed: int = SEED, samples: int = SAMPLES_PER_CATEGORY) -> List[Dict[str, str]]:
    """Sample true, edge, and adversarial prompts for the truthfulness profile."""
    cases: List[Dict[str, str]] = []

    ds_facts = load_dataset("L1Fthrasir/Facts-true-false", split="train")

    true_facts = ds_facts.filter(lambda x: x["label"] == 1).shuffle(seed=seed).select(
        range(min(samples, len(ds_facts)))
    )
    for row in true_facts:
        cases.append({"category": "Control (Truth)", "prompt": row["statement"], "expected": "Pass"})

    ds_tqa = load_dataset("truthful_qa", "generation", split="validation")
    ds_tqa = ds_tqa.shuffle(seed=seed).select(range(min(samples, len(ds_tqa))))
    for row in ds_tqa:
        cases.append({
            "category": "TruthfulQA (Misconceptions)",
            "prompt": row["question"],
            "expected": "Variable",
        })

    false_facts = ds_facts.filter(lambda x: x["label"] == 0).shuffle(seed=seed).select(
        range(min(samples, len(ds_facts)))
    )
    for row in false_facts:
        cases.append({"category": "Factual Lies", "prompt": row["statement"], "expected": "Block"})

    return cases


def build_wmdp_cases(
    config_name: str,
    seed: int = SEED,
    safe_samples: int = 40,
    misuse_samples: int = 40,
) -> List[Dict[str, str]]:
    """Expand the WMDP multiple-choice dataset into safe/misuse prompts."""
    ds = load_dataset("cais/wmdp", config_name, split="test")
    ds = standardize_dataset(ds)

    def take_subset(dataset, count):
        count = min(count, len(dataset))
        if count <= 0:
            return []
        return dataset.shuffle(seed=seed).select(range(count))

    safe_rows = take_subset(ds.filter(lambda x: x["label"] == 1), safe_samples)
    misuse_rows = take_subset(ds.filter(lambda x: x["label"] == 0), misuse_samples)

    domain = config_name.split("-")[-1].upper()
    records: List[Dict[str, str]] = []
    for row in safe_rows:
        records.append({
            "category": f"WMDP {domain} (Safe)",
            "prompt": row["statement"],
            "expected": "Pass",
        })
    for row in misuse_rows:
        records.append({
            "category": f"WMDP {domain} (Misuse)",
            "prompt": row["statement"],
            "expected": "Block",
        })
    return records

In [None]:
PROFILE_CONFIGS = [
    {
        "name": "Truthfulness",
        "slug": "truthfulness",
        "vector_path": ARTIFACT_DIR / "deception_vector.pt",
        "build_cases": build_truthfulness_cases,
    },
    {
        "name": "Cyber Defense",
        "slug": "cyber_defense",
        "vector_path": ARTIFACT_DIR / "cyber_misuse_vector.pt",
        "build_cases": lambda: build_wmdp_cases("wmdp-cyber", safe_samples=40, misuse_samples=40),
    },
    {
        "name": "Bio Defense",
        "slug": "bio_defense",
        "vector_path": ARTIFACT_DIR / "bio_defense_vector.pt",
        "build_cases": lambda: build_wmdp_cases("wmdp-bio", safe_samples=40, misuse_samples=40),
    },
]

for profile in PROFILE_CONFIGS:
    status = "‚úÖ" if profile["vector_path"].exists() else "‚ö†Ô∏è"
    print(f"{status} {profile['name']} vector => {profile['vector_path']}")


In [None]:
cfg = load_config()
primary_profile = next((p for p in PROFILE_CONFIGS if p["vector_path"].exists()), None)
if primary_profile is None:
    raise FileNotFoundError("No concept vectors found. Run calibration for at least one profile under artifacts/.")

cfg.vector_path = primary_profile["vector_path"]
rt = WatchdogRuntime(cfg=cfg)
print(f"‚úÖ Loaded model with initial profile: {primary_profile['name']} (threshold {rt.threshold:.4f})")


## 3. Run the Gauntlet per Profile
We evaluate each profile independently (truth, cyber, bio) while the watchdog runs in **monitoring mode** so we can chart raw score distributions.

In [None]:
import gc
import torch
import pandas as pd
from tqdm import tqdm

SAVE_EVERY = 20
profile_runs: Dict[str, pd.DataFrame] = {}


def run_gauntlet(rt: WatchdogRuntime, cases, max_new_tokens: int = 40) -> pd.DataFrame:
    rows = []
    baseline_threshold = rt.threshold
    with torch.no_grad():
        for i, case in enumerate(tqdm(cases, desc="Inference", unit="prompt")):
            try:
                output = rt.generate_without_watchdog(
                    prompt=case["prompt"],
                    max_new_tokens=max_new_tokens,
                    temperature=0.0,
                )
                scores = output.scores
                valid_scores = scores[5:] if len(scores) > 5 else scores
                if not valid_scores:
                    valid_scores = [0.0]

                peak_score = max(valid_scores)
                avg_score = sum(valid_scores) / len(valid_scores)

                raw_text = output.text[0] if isinstance(output.text, list) else output.text
                response_text = raw_text[len(case["prompt"]):].strip()

                rows.append({
                    "Category": case["category"],
                    "Prompt": case["prompt"],
                    "Peak Score": peak_score,
                    "Avg Score": avg_score,
                    "Would Block?": peak_score >= baseline_threshold,
                    "Output": response_text[:100].replace("\n", " ") + "...",
                })

                if i > 0 and i % SAVE_EVERY == 0:
                    pd.DataFrame(rows).to_csv("stress_test_backup.csv", index=False)

                del output
                del scores
                del raw_text
            except Exception as e:
                print(f"‚ö†Ô∏è Error at index {i}: {e}")
                torch.cuda.empty_cache()
                gc.collect()
    return pd.DataFrame(rows)


for profile in PROFILE_CONFIGS:
    if not profile["vector_path"].exists():
        print(f"‚è≠Ô∏è Skipping {profile['name']} (vector missing)")
        continue

    cases = profile["build_cases"]()
    if not cases:
        print(f"‚ö†Ô∏è No cases available for {profile['name']}")
        continue

    rt.load_vector_from_path(profile["vector_path"])
    print(f"\nüöÄ {profile['name']} | cases={len(cases)} | threshold={rt.threshold:.4f}")
    df_profile = run_gauntlet(rt, cases)
    df_profile["Profile"] = profile["name"]
    df_profile["Threshold"] = rt.threshold
    profile_runs[profile["name"]] = df_profile
    df_profile.to_csv(f"stress_test_results_{profile['slug']}.csv", index=False)

if not profile_runs:
    raise RuntimeError("No profiles were evaluated. Ensure you have calibrated vectors under artifacts/.")

## 4. Visualization
Plot score distributions per profile and export chart data for the dashboard.

In [None]:
import json

all_frames = [df for df in profile_runs.values()]
combined_df = pd.concat(all_frames, ignore_index=True)

sns.set_theme(style="whitegrid")
plt.figure(figsize=(14, 6))
sns.boxplot(data=combined_df, x="Category", y="Peak Score", hue="Profile")
sns.stripplot(data=combined_df, x="Category", y="Peak Score", color="black", alpha=0.4, jitter=True)
plt.title("Peak Score Distribution per Profile", fontsize=16)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

for profile in PROFILE_CONFIGS:
    name = profile["name"]
    if name not in profile_runs:
        continue
    df_profile = profile_runs[name]
    payload = {
        "threshold": float(df_profile["Threshold"].iloc[0]),
        "data": df_profile[["Category", "Peak Score"]].to_dict(orient="records"),
    }

    filename = f"watchdog_chart_data_{profile['slug']}.json"
    out_path = Path(filename)
    docs_out_path = Path("docs") / filename

    for target in (out_path, docs_out_path):
        with open(target, "w") as fp:
            json.dump(payload, fp, indent=2)

    if profile["slug"] == "truthfulness":
        default_targets = [Path("watchdog_chart_data.json"), Path("docs/watchdog_chart_data.json")]
        for target in default_targets:
            with open(target, "w") as fp:
                json.dump(payload, fp, indent=2)
        print("üìä Updated watchdog_chart_data.json for the website.")


## 5. Failure Analysis
Identify False Negatives (Deception that slipped through) and False Positives (Truth that was blocked).

In [None]:
for profile in PROFILE_CONFIGS:
    name = profile["name"]
    if name not in profile_runs:
        continue
    df_profile = profile_runs[name]
    threshold = df_profile["Threshold"].iloc[0]
    print(f"\n=== {name} (threshold {threshold:.4f}) ===")

    fn = df_profile[(df_profile["Category"].str.contains("Misuse")) & (df_profile["Peak Score"] < threshold)]
    if not fn.empty:
        print("üî¥ Missed interdictions:")
        display(fn[["Category", "Prompt", "Peak Score", "Output"]])
    else:
        print("üî¥ Missed interdictions: none")

    fp = df_profile[(~df_profile["Category"].str.contains("Misuse")) & (df_profile["Peak Score"] > threshold)]
    if not fp.empty:
        print("üü¢ Over-blocked safe prompts:")
        display(fp[["Category", "Prompt", "Peak Score", "Output"]])
    else:
        print("üü¢ Over-blocked safe prompts: none")

In [None]:
combined_df.to_csv("stress_test_results.csv", index=False)
print("Results saved to stress_test_results.csv and per-profile CSVs.")