# ML Lab 05: Drift Analysis Notebook

This notebook queries Prometheus metrics collected during the drift simulation
and performs post-hoc analysis to understand how drift manifests in model behavior.

**Prerequisites:** Run the drift simulator first (`docker compose --profile simulator up drift-simulator`)
and ensure Prometheus is running at http://localhost:9090.

## Section 1: Query Prometheus Metrics

We'll query the Prometheus HTTP API to pull time-series data for our drift metrics.

In [None]:
import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

PROMETHEUS_URL = "http://localhost:9090"


def query_prometheus_range(query: str, start: str = "now-15m", end: str = "now", step: str = "15s") -> pd.DataFrame:
    """Query Prometheus range API and return a DataFrame."""
    # Convert relative times to absolute
    now = datetime.now()
    if start == "now-15m":
        start_ts = (now - timedelta(minutes=15)).timestamp()
    elif start == "now-30m":
        start_ts = (now - timedelta(minutes=30)).timestamp()
    elif start == "now-1h":
        start_ts = (now - timedelta(hours=1)).timestamp()
    else:
        start_ts = float(start)
    end_ts = now.timestamp() if end == "now" else float(end)

    resp = requests.get(
        f"{PROMETHEUS_URL}/api/v1/query_range",
        params={"query": query, "start": start_ts, "end": end_ts, "step": step},
    )
    resp.raise_for_status()
    data = resp.json()

    if data["status"] != "success" or not data["data"]["result"]:
        print(f"No data returned for query: {query}")
        return pd.DataFrame()

    frames = []
    for series in data["data"]["result"]:
        metric_label = series["metric"].get("__name__", "")
        label = ", ".join(f"{k}={v}" for k, v in series["metric"].items() if k != "__name__")
        name = f"{metric_label} ({label})" if label else metric_label

        values = series["values"]
        timestamps = [datetime.fromtimestamp(float(v[0])) for v in values]
        vals = [float(v[1]) for v in values]
        df = pd.DataFrame({"timestamp": timestamps, name: vals})
        df.set_index("timestamp", inplace=True)
        frames.append(df)

    if frames:
        return pd.concat(frames, axis=1)
    return pd.DataFrame()


print("Prometheus query helper loaded.")
print(f"Querying: {PROMETHEUS_URL}")

## Section 2: Confidence Distribution Over Time

Prediction confidence is one of the strongest drift signals. When the model receives
out-of-distribution data, it becomes less certain about its predictions, and the average
confidence drops.

In [None]:
# Query average confidence over time
confidence_query = (
    "rate(model_prediction_confidence_sum[1m]) / "
    "rate(model_prediction_confidence_count[1m])"
)

df_confidence = query_prometheus_range(confidence_query, start="now-30m")

if not df_confidence.empty:
    fig, ax = plt.subplots(figsize=(12, 5))
    df_confidence.plot(ax=ax)
    ax.set_title("Average Prediction Confidence Over Time", fontsize=14)
    ax.set_ylabel("Confidence")
    ax.set_xlabel("Time")
    ax.set_ylim(0.4, 1.0)
    ax.axhline(y=0.75, color="red", linestyle="--", alpha=0.7, label="Alert threshold (0.75)")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print("\nConfidence Statistics:")
    print(df_confidence.describe())
else:
    print("No confidence data available. Run the drift simulator first.")

## Section 3: Text Length Distribution Shifts

Changes in input text length can indicate distribution shift. Different categories of
text (politics vs. baseball) may have characteristically different lengths.

In [None]:
# Query average text length over time
text_length_query = (
    "rate(model_input_text_length_sum[1m]) / "
    "rate(model_input_text_length_count[1m])"
)

df_text_length = query_prometheus_range(text_length_query, start="now-30m")

if not df_text_length.empty:
    fig, ax = plt.subplots(figsize=(12, 5))
    df_text_length.plot(ax=ax, color="purple")
    ax.set_title("Average Input Text Length Over Time", fontsize=14)
    ax.set_ylabel("Characters")
    ax.set_xlabel("Time")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print("\nText Length Statistics:")
    print(df_text_length.describe())
else:
    print("No text length data available. Run the drift simulator first.")

## Section 4: Population Stability Index (PSI)

PSI measures how much a distribution has shifted compared to a reference distribution.
It is commonly used in production ML systems to quantify drift.

**Interpretation:**
- PSI < 0.1: No significant shift
- 0.1 <= PSI < 0.2: Moderate shift (investigate)
- PSI >= 0.2: Significant shift (action required)

We'll compute PSI on the confidence distribution, comparing the first window (baseline)
to subsequent windows.

In [None]:
def calculate_psi(reference: np.ndarray, current: np.ndarray, bins: int = 10) -> float:
    """Calculate Population Stability Index between two distributions.

    PSI = sum( (current_pct - reference_pct) * ln(current_pct / reference_pct) )

    Args:
        reference: Array of values from the reference (baseline) distribution.
        current: Array of values from the current distribution.
        bins: Number of bins for discretization.

    Returns:
        PSI value (float). Higher = more drift.
    """
    # Create bins from the reference distribution
    breakpoints = np.linspace(
        min(reference.min(), current.min()),
        max(reference.max(), current.max()),
        bins + 1,
    )

    # Compute bin proportions
    ref_counts = np.histogram(reference, bins=breakpoints)[0]
    cur_counts = np.histogram(current, bins=breakpoints)[0]

    # Convert to proportions with small epsilon to avoid division by zero
    eps = 1e-6
    ref_pct = (ref_counts + eps) / (ref_counts.sum() + eps * bins)
    cur_pct = (cur_counts + eps) / (cur_counts.sum() + eps * bins)

    # PSI formula
    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return float(psi)


# Demonstrate PSI with synthetic data
np.random.seed(42)

# Simulate baseline confidence (high, centered around 0.92)
baseline_confidence = np.random.beta(20, 2, size=200)  # skewed high

# Simulate drifted confidence (lower, more spread)
slight_drift = np.random.beta(10, 3, size=200)
moderate_drift = np.random.beta(5, 4, size=200)
heavy_drift = np.random.beta(3, 5, size=200)

# Calculate PSI for each phase
psi_slight = calculate_psi(baseline_confidence, slight_drift)
psi_moderate = calculate_psi(baseline_confidence, moderate_drift)
psi_heavy = calculate_psi(baseline_confidence, heavy_drift)

print("PSI Values (Confidence Distribution):")
print(f"  Baseline vs Baseline:  {calculate_psi(baseline_confidence, baseline_confidence):.4f} (should be ~0)")
print(f"  Baseline vs Slight:    {psi_slight:.4f}")
print(f"  Baseline vs Moderate:  {psi_moderate:.4f}")
print(f"  Baseline vs Heavy:     {psi_heavy:.4f}")
print()
print("Interpretation:")
print(f"  PSI < 0.1  = No significant shift")
print(f"  PSI 0.1-0.2 = Moderate shift (investigate)")
print(f"  PSI >= 0.2  = Significant shift (action required)")

In [None]:
# Visualize the distributions and PSI
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

distributions = [
    ("Baseline", baseline_confidence, 0.0),
    ("Slight Drift", slight_drift, psi_slight),
    ("Moderate Drift", moderate_drift, psi_moderate),
    ("Heavy Drift", heavy_drift, psi_heavy),
]

for ax, (name, dist, psi) in zip(axes.flat, distributions):
    ax.hist(baseline_confidence, bins=20, alpha=0.5, label="Baseline", color="blue", density=True)
    ax.hist(dist, bins=20, alpha=0.5, label=name, color="orange", density=True)
    ax.set_title(f"{name} (PSI = {psi:.4f})", fontsize=12)
    ax.set_xlabel("Confidence")
    ax.set_ylabel("Density")
    ax.legend()
    ax.set_xlim(0, 1)
    ax.grid(True, alpha=0.3)

    # Color code the PSI
    if psi < 0.1:
        ax.set_facecolor("#f0fff0")  # light green
    elif psi < 0.2:
        ax.set_facecolor("#fffff0")  # light yellow
    else:
        ax.set_facecolor("#fff0f0")  # light red

plt.suptitle("Confidence Distribution Shift (PSI Analysis)", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

## Section 5: Building a Drift Alert Threshold

Now let's combine our drift signals into a practical alerting strategy.

In [None]:
# Query entropy over time
entropy_query = (
    "rate(model_prediction_entropy_sum[1m]) / "
    "rate(model_prediction_entropy_count[1m])"
)

df_entropy = query_prometheus_range(entropy_query, start="now-30m")

if not df_entropy.empty and not df_confidence.empty:
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

    # Confidence panel
    df_confidence.plot(ax=ax1, color="green")
    ax1.axhline(y=0.75, color="red", linestyle="--", alpha=0.7, label="Alert: Confidence < 0.75")
    ax1.set_title("Drift Signal: Confidence", fontsize=12)
    ax1.set_ylabel("Confidence")
    ax1.set_ylim(0.4, 1.0)
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Fill red where confidence below threshold
    for col in df_confidence.columns:
        ax1.fill_between(
            df_confidence.index,
            df_confidence[col],
            0.75,
            where=df_confidence[col] < 0.75,
            alpha=0.3,
            color="red",
        )

    # Entropy panel
    df_entropy.plot(ax=ax2, color="orange")
    ax2.axhline(y=0.6, color="red", linestyle="--", alpha=0.7, label="Alert: Entropy > 0.6")
    ax2.set_title("Drift Signal: Entropy", fontsize=12)
    ax2.set_ylabel("Entropy (bits)")
    ax2.set_xlabel("Time")
    ax2.set_ylim(0, 1.0)
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Fill red where entropy above threshold
    for col in df_entropy.columns:
        ax2.fill_between(
            df_entropy.index,
            df_entropy[col],
            0.6,
            where=df_entropy[col] > 0.6,
            alpha=0.3,
            color="red",
        )

    plt.suptitle("Drift Detection: Combined Signals with Alert Thresholds", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()
else:
    print("Insufficient data. Run the drift simulator and try again.")

In [None]:
# Summary: Recommended alert rules
print("=" * 60)
print("RECOMMENDED DRIFT ALERT RULES")
print("=" * 60)
print()
print("1. Confidence Drop Alert")
print("   PromQL: rate(model_prediction_confidence_sum[5m])")
print("           / rate(model_prediction_confidence_count[5m]) < 0.75")
print("   Severity: Warning")
print("   Action: Investigate input data distribution")
print()
print("2. High Entropy Alert")
print("   PromQL: rate(model_prediction_entropy_sum[5m])")
print("           / rate(model_prediction_entropy_count[5m]) > 0.6")
print("   Severity: Warning")
print("   Action: Check for out-of-distribution inputs")
print()
print("3. PSI Threshold (batch job)")
print("   Calculate PSI on confidence distributions")
print("   every hour, comparing to the training baseline.")
print("   PSI > 0.2 => Significant drift, consider retraining")
print()
print("4. Combined Alert (most robust)")
print("   Fire when BOTH confidence < 0.75 AND entropy > 0.6")
print("   for more than 5 minutes. This reduces false positives.")
print("=" * 60)

## Key Takeaways

1. **Confidence drops when models see unfamiliar data** -- this is the simplest and most reliable drift signal.
2. **Entropy increases with uncertainty** -- for binary classification, max entropy is 1.0 (coin flip).
3. **PSI gives you a single number** -- use it to quantify drift and set thresholds for automated alerts.
4. **Combine multiple signals** -- no single metric is perfect; using confidence + entropy together reduces false positives.
5. **Detection is not remediation** -- once you detect drift, you still need a plan: retrain, fall back to a simpler model, or flag for human review.