# Early instability metrics (heuristic)

This notebook computes a few **instability-oriented metrics** on JSONL traces:

- relative latency gaps between consecutive events,
- recovery turn distance (from first `drift_like` to `stability_tag='recovered'`),
- post-correction relapse rate,
- a simple session closure profile.

It mirrors the logic of `scripts/compute_metrics_from_jsonl.py`, but is
kept inline here for interactive exploration.


In [None]:
import json
import statistics
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

DATA_DIR = Path("examples") / "synthetic_traces"
print(f"Using synthetic traces from: {DATA_DIR.resolve()}")


In [None]:
@dataclass
class Event:
    raw: Dict[str, Any]

    @property
    def trace_id(self) -> str:
        return str(self.raw.get("trace_id", ""))

    @property
    def event_type(self) -> str:
        return str(self.raw.get("event_type", ""))

    @property
    def component(self) -> str:
        return str(self.raw.get("component", ""))

    @property
    def payload(self) -> Dict[str, Any]:
        obj = self.raw.get("payload")
        return obj if isinstance(obj, dict) else {}

    @property
    def latency_ms(self) -> Optional[float]:
        val = self.payload.get("latency_ms")
        try:
            return float(val) if val is not None else None
        except (TypeError, ValueError):
            return None

    @property
    def turn(self) -> Optional[int]:
        val = self.payload.get("turn") or self.raw.get("turn")
        try:
            return int(val) if val is not None else None
        except (TypeError, ValueError):
            return None


In [None]:
def load_events(path: Path) -> List[Event]:
    events: List[Event] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            events.append(Event(obj))
    return events

def group_by_trace(events: Iterable[Event]) -> Dict[str, List[Event]]:
    grouped: Dict[str, List[Event]] = defaultdict(list)
    for ev in events:
        grouped[ev.trace_id].append(ev)
    return grouped


In [None]:
def compute_relative_latency_gaps(traces: Dict[str, List[Event]]) -> List[float]:
    gaps: List[float] = []
    for events in traces.values():
        prev: Optional[float] = None
        for ev in events:
            lat = ev.latency_ms
            if lat is None:
                continue
            if prev is not None:
                denom = max(prev, lat, 1.0)
                gaps.append(abs(prev - lat) / denom)
            prev = lat
    return gaps


In [None]:
def compute_recovery_turn_distances(traces: Dict[str, List[Event]]) -> List[int]:
    distances: List[int] = []
    for events in traces.values():
        onset_turn: Optional[int] = None
        for ev in events:
            if onset_turn is None and ev.event_type == "drift_like":
                onset_turn = ev.turn if ev.turn is not None else 0
                continue
            if onset_turn is not None:
                tag = ev.payload.get("stability_tag")
                if tag == "recovered":
                    end_turn = ev.turn if ev.turn is not None else onset_turn
                    distances.append(max(0, end_turn - onset_turn))
                    onset_turn = None
    return distances


In [None]:
def compute_post_correction_relapse_rate(traces: Dict[str, List[Event]]):
    with_correction = 0
    with_relapse = 0
    for events in traces.values():
        had_correction = False
        relapsed = False
        for ev in events:
            if not had_correction and ev.event_type in {"correction", "self_check"}:
                had_correction = True
                continue
            if had_correction and ev.event_type == "drift_like":
                relapsed = True
                break
        if had_correction:
            with_correction += 1
            if relapsed:
                with_relapse += 1
    return with_relapse, with_correction


In [None]:
CLOSURE_LABELS = {
    "ok": "natural_completion",
    "completed_after_correction": "completed_after_correction",
    "corrected": "completed_after_correction",
    "incomplete": "incomplete",
    "error": "forced_stop",
}

def classify_session_closure(events: List[Event]) -> str:
    if not events:
        return "unknown"
    last = events[-1]
    payload = last.payload
    status = str(payload.get("status", "")).strip()
    final_status = str(payload.get("final_status", "")).strip()
    pattern = str(payload.get("pattern", "")).strip()
    for key in (status or None, final_status or None, pattern or None):
        if not key:
            continue
        label = CLOSURE_LABELS.get(key)
        if label:
            return label
    if last.event_type == "session_end":
        return "session_end_generic"
    if last.component == "user":
        return "user_abandonment"
    return "unknown"

def compute_session_closure_profile(traces: Dict[str, List[Event]]) -> Counter:
    counts: Counter = Counter()
    for events in traces.values():
        label = classify_session_closure(events)
        counts[label] += 1
    return counts


In [None]:
def summarize_file(path: Path):
    events = load_events(path)
    traces = group_by_trace(events)
    print(f"\n=== {path} ===")
    print(f"events : {len(events)}")
    print(f"sessions: {len(traces)}\n")

    gaps = compute_relative_latency_gaps(traces)
    if gaps:
        print("[relative-latency-gap]")
        print(f"  samples: {len(gaps)}")
        print(f"  mean   : {statistics.mean(gaps):.3f}")
        print(f"  median : {statistics.median(gaps):.3f}")
        print()

    rtd = compute_recovery_turn_distances(traces)
    if rtd:
        print("[recovery-turn-distance]")
        print(f"  episodes: {len(rtd)}")
        print(f"  mean    : {statistics.mean(rtd):.2f} turns")
        print(f"  median  : {statistics.median(rtd):.2f} turns")
        print()

    relapsed, corrected = compute_post_correction_relapse_rate(traces)
    if corrected:
        rate = (relapsed / corrected) * 100
        print("[post-correction-relapse-rate]")
        print(f"  sessions with correction: {corrected}")
        print(f"  sessions with relapse   : {relapsed}")
        print(f"  relapse rate            : {rate:.1f}%")
        print()

    profile = compute_session_closure_profile(traces)
    if profile:
        print("[session-closure-profile]")
        total = sum(profile.values())
        for label, count in profile.most_common():
            pct = (count / total) * 100
            print(f"  {label:28s}: {count:3d}  ({pct:4.1f}%)")
        print()


In [None]:
# Run on the bundled synthetic traces
for name in ["simple_correction_loop.jsonl", "noisy_mixed_sessions.jsonl"]:
    path = DATA_DIR / name
    summarize_file(path)
