In [None]:
# TU Q091-A · Equilibrium climate sensitivity range reasoning (MVP, fixed OpenAI client)
#
# This notebook is designed as a single-cell Colab script:
# 1. It defines a small synthetic item bank about equilibrium climate sensitivity (ECS).
# 2. It asks a chat model to estimate ECS and explain the estimate.
# 3. It probes the explanation and computes a scalar tension observable T_ECS_range.
#
# You can:
# - Read the code and printed tables without running any live API calls.
# - Optionally enter an OpenAI API key when prompted to reproduce the experiment.
#   The key is requested via a hidden prompt and is NOT stored in the notebook text.

import sys
import subprocess
import math
import json
import textwrap
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

import os
from getpass import getpass

import pandas as pd
import matplotlib.pyplot as plt

# ---------------------------------------------------------------------
# 0. OpenAI client setup (new-style API)
# ---------------------------------------------------------------------

def ensure_openai_installed() -> None:
    try:
        from openai import OpenAI  # noqa: F401
    except ImportError:
        print("[setup] openai package not found. Installing openai>=1.0.0 ...")
        subprocess.run([sys.executable, "-m", "pip", "install", "-U", "openai"], check=False)


def build_openai_client() -> Optional["OpenAI"]:
    """
    Ask the user whether to run live OpenAI calls.
    If yes, request an API key via hidden input and build a client.
    If the user skips the key, stay in offline mode.
    """
    ensure_openai_installed()
    from openai import OpenAI  # type: ignore

    choice = input(
        "\n[Q091-A] Run with live OpenAI calls? (y/n) "
        "(n = offline mode, just show item bank and exit): "
    ).strip().lower()

    if choice != "y":
        print("\n[mode] Offline mode selected. "
              "No OpenAI calls will be made. "
              "You can still inspect the item bank and code.")
        return None

    api_key = getpass("\n[secure] Enter your OpenAI API key (input is hidden). "
                      "Press Enter to cancel and stay offline: ").strip()

    if not api_key:
        print("\n[mode] No API key provided. Staying in offline mode.")
        return None

    # Build client with explicit API key (not printed anywhere).
    client = OpenAI(api_key=api_key)
    print("\n[setup] OpenAI client initialised. "
          "Model calls will use your key inside this runtime only.")
    return client


# ---------------------------------------------------------------------
# 1. Data structures and synthetic item bank
# ---------------------------------------------------------------------

ECS_GLOBAL_MIN = 0.5
ECS_GLOBAL_MAX = 6.0

@dataclass
class ECSItem:
    item_id: str
    title: str
    bucket_true: str
    ecs_min_true: float
    ecs_max_true: float


def build_item_bank() -> List[ECSItem]:
    """
    Synthetic item bank for equilibrium climate sensitivity reasoning.
    These items are not real datasets. They are stylised descriptions
    used to probe consistency at the effective layer.
    """
    items = [
        ECSItem(
            item_id="C01",
            title="Historical warming with multi-line evidence (medium ECS)",
            bucket_true="MEDIUM",
            ecs_min_true=2.0,
            ecs_max_true=4.0,
        ),
        ECSItem(
            item_id="C02",
            title="Paleoclimate strong-response case (high ECS)",
            bucket_true="HIGH",
            ecs_min_true=4.0,
            ecs_max_true=6.0,
        ),
        ECSItem(
            item_id="C03",
            title="Energy-balance study with weak feedbacks (low ECS)",
            bucket_true="LOW",
            ecs_min_true=1.0,
            ecs_max_true=2.0,
        ),
        ECSItem(
            item_id="C04",
            title="Multi-source constraint narrowing around 2.5–3.0°C",
            bucket_true="MEDIUM",
            ecs_min_true=2.5,
            ecs_max_true=3.5,
        ),
        ECSItem(
            item_id="C05",
            title="High-end ensemble member emphasising strong positive feedbacks",
            bucket_true="HIGH",
            ecs_min_true=4.0,
            ecs_max_true=5.5,
        ),
        ECSItem(
            item_id="C06",
            title="Historical-only fit with cautious priors (medium ECS)",
            bucket_true="MEDIUM",
            ecs_min_true=1.8,
            ecs_max_true=4.2,
        ),
        ECSItem(
            item_id="C07",
            title="Short-term variability emphasised, but multiple lines of evidence",
            bucket_true="MEDIUM",
            ecs_min_true=2.0,
            ecs_max_true=3.5,
        ),
        ECSItem(
            item_id="C08",
            title="Hypothetical strong-stabilising feedback world (very low ECS)",
            bucket_true="LOW",
            ecs_min_true=0.5,
            ecs_max_true=1.5,
        ),
    ]
    return items


def item_bank_dataframe(items: List[ECSItem]) -> pd.DataFrame:
    rows = [
        {
            "item_id": it.item_id,
            "title": it.title,
            "bucket_true": it.bucket_true,
            "ecs_min_true": it.ecs_min_true,
            "ecs_max_true": it.ecs_max_true,
        }
        for it in items
    ]
    return pd.DataFrame(rows)


# ---------------------------------------------------------------------
# 2. Helper functions for JSON parsing and scoring
# ---------------------------------------------------------------------

def extract_json_from_text(text: str) -> Dict[str, Any]:
    """
    Extract a JSON object from a chat completion.
    The model is told to respond with pure JSON, but we defensively
    strip extra text or Markdown fences if present.
    """
    first = text.find("{")
    last = text.rfind("}")
    if first == -1 or last == -1 or last <= first:
        raise ValueError("No JSON object found in model output.")
    json_str = text[first : last + 1]
    return json.loads(json_str)


def normalise_bucket(raw: Optional[str]) -> Optional[str]:
    if raw is None:
        return None
    s = raw.strip().upper()
    if "LOW" in s:
        return "LOW"
    if "HIGH" in s:
        return "HIGH"
    if "MED" in s:
        return "MEDIUM"
    return None


def compute_range_plausibility(
    ecs_low: Optional[float],
    ecs_high: Optional[float],
    global_min: float,
    global_max: float,
) -> float:
    if ecs_low is None or ecs_high is None:
        return 0.0
    if ecs_high <= ecs_low:
        return 0.0
    width = ecs_high - ecs_low
    overlap_low = max(ecs_low, global_min)
    overlap_high = min(ecs_high, global_max)
    overlap = max(0.0, overlap_high - overlap_low)
    if overlap <= 0.0:
        return 0.0
    ratio = overlap / width
    return max(0.0, min(1.0, ratio))


def compute_bucket_correctness(
    bucket_true: str,
    bucket_estimate: Optional[str],
    bucket_from_explanation: Optional[str],
) -> float:
    score = 0.0
    if bucket_estimate is not None and bucket_estimate == bucket_true:
        score += 0.5
    if bucket_from_explanation is not None and bucket_from_explanation == bucket_true:
        score += 0.5
    return score


def compute_self_consistency(
    bucket_estimate: Optional[str],
    bucket_from_explanation: Optional[str],
) -> float:
    if bucket_estimate is None or bucket_from_explanation is None:
        return 0.0
    return 1.0 if bucket_estimate == bucket_from_explanation else 0.0


def compute_sharpness(
    ecs_low: Optional[float],
    ecs_high: Optional[float],
    global_min: float,
    global_max: float,
) -> float:
    if ecs_low is None or ecs_high is None:
        return 0.0
    if ecs_high <= ecs_low:
        return 0.0
    band_width = ecs_high - ecs_low
    global_width = max(1e-6, global_max - global_min)
    ratio = min(1.0, band_width / global_width)
    return max(0.0, 1.0 - ratio)


def combine_tension(
    range_plausibility: float,
    bucket_correctness: float,
    self_consistency: float,
) -> float:
    """
    Combine three scores into a scalar tension T_ECS_range in [0, 1].
    Higher T means more tension (worse behavior).
    """
    w_plaus = 0.4
    w_bucket = 0.4
    w_self = 0.2

    t = (
        w_plaus * (1.0 - range_plausibility)
        + w_bucket * (1.0 - bucket_correctness)
        + w_self * (1.0 - self_consistency)
    ) / (w_plaus + w_bucket + w_self)

    return max(0.0, min(1.0, t))


def is_effectively_coherent(
    range_plausibility: float,
    bucket_correctness: float,
    self_consistency: float,
    t_ecs_range: float,
) -> bool:
    return (
        range_plausibility >= 0.7
        and bucket_correctness >= 0.7
        and self_consistency >= 0.7
        and t_ecs_range <= 0.4
    )


# ---------------------------------------------------------------------
# 3. Prompt builders and model calls
# ---------------------------------------------------------------------

def build_estimate_messages(item: ECSItem) -> List[Dict[str, str]]:
    description = textwrap.dedent(
        f"""
        You are an assistant reasoning about equilibrium climate sensitivity (ECS).

        Item id: {item.item_id}
        Title: {item.title}

        Think of ECS as the long-run global mean surface temperature change (in °C)
        after a doubling of atmospheric CO₂.

        Step 1: Read the description and imagine a stylised evidence pattern.
        Step 2: Give a point estimate ECS_est in °C per CO₂ doubling.
        Step 3: Give a 66% confidence interval [ECS_low, ECS_high].
        Step 4: Assign a qualitative bucket:
            - "LOW" if ECS is roughly below 2°C,
            - "MEDIUM" if ECS is roughly between about 2–4.5°C,
            - "HIGH" if ECS is clearly above about 4.5°C.

        Respond ONLY with a JSON object, no Markdown, with keys:
            "ecs_estimate": float,
            "ecs_low": float,
            "ecs_high": float,
            "bucket_estimate": "LOW" | "MEDIUM" | "HIGH",
            "explanation": string

        The explanation should be 3–6 sentences in plain English.
        """
    ).strip()

    return [
        {"role": "system", "content": "You are a careful climate reasoning assistant."},
        {"role": "user", "content": description},
    ]


def build_probe_messages(explanation: str) -> List[Dict[str, str]]:
    instruction = textwrap.dedent(
        f"""
        You are checking the internal consistency of a climate sensitivity explanation.

        Below is an explanation that discusses equilibrium climate sensitivity (ECS).
        Your job is to infer which qualitative ECS bucket it supports.

        Explanation:
        \"\"\"{explanation}\"\"\"

        Buckets:
            - "LOW"    : ECS clearly below about 2°C per CO₂ doubling.
            - "MEDIUM" : ECS in the broad 2–4.5°C range.
            - "HIGH"   : ECS clearly above about 4.5°C.

        Respond ONLY with a JSON object, no Markdown, with keys:
            "bucket_from_explanation": "LOW" | "MEDIUM" | "HIGH",
            "confidence": float between 0 and 1

        The confidence is how sure you are that the explanation points to that bucket.
        """
    ).strip()

    return [
        {"role": "system", "content": "You are a strict consistency checker."},
        {"role": "user", "content": instruction},
    ]


def call_chat_json(
    client: "OpenAI",
    model: str,
    messages: List[Dict[str, str]],
    temperature: float = 0.2,
) -> Dict[str, Any]:
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
    )
    content = response.choices[0].message.content
    return extract_json_from_text(content)


# ---------------------------------------------------------------------
# 4. Main experiment loop
# ---------------------------------------------------------------------

def run_q091_a(client: Optional["OpenAI"]) -> None:
    print("\n=== TU Q091-A · Equilibrium climate sensitivity range reasoning (MVP) ===\n")

    items = build_item_bank()
    df_items = item_bank_dataframe(items)

    print("=== TU Q091-A: Synthetic item bank (equilibrium climate sensitivity) ===")
    print(df_items)
    print(
        "\nNote: These items are synthetic and only used to test consistency of reasoning "
        "at the effective layer.\n"
    )

    if client is None:
        print(
            "[info] No OpenAI client available. "
            "Experiment will not run live calls.\n"
            "You can inspect this cell, the item bank, and the README without any API usage.\n"
        )
        return

    model_name = "gpt-4o-mini"
    print(f"[run] Starting TU Q091-A experiment with live OpenAI calls on model '{model_name}' ...\n")

    results: List[Dict[str, Any]] = []

    for item in items:
        print(f"--- Running item {item.item_id} · {item.title} ---")
        ecs_estimate = None
        ecs_low = None
        ecs_high = None
        bucket_estimate = None
        explanation = None
        bucket_from_explanation = None
        range_plausibility = 0.0
        bucket_correctness = 0.0
        self_consistency = 0.0
        sharpness = 0.0
        t_ecs_range = 1.0
        coherent_flag = False

        try:
            # Step 1: estimate ECS and explanation
            est_messages = build_estimate_messages(item)
            est_json = call_chat_json(client, model_name, est_messages, temperature=0.2)

            ecs_estimate = float(est_json.get("ecs_estimate", None))
            ecs_low = float(est_json.get("ecs_low", None))
            ecs_high = float(est_json.get("ecs_high", None))
            bucket_estimate = normalise_bucket(est_json.get("bucket_estimate"))
            explanation = str(est_json.get("explanation", "")).strip()

            # Step 2: probe explanation only
            probe_messages = build_probe_messages(explanation)
            probe_json = call_chat_json(client, model_name, probe_messages, temperature=0.0)
            bucket_from_explanation = normalise_bucket(
                probe_json.get("bucket_from_explanation")
            )

            # Step 3: compute scores
            range_plausibility = compute_range_plausibility(
                ecs_low, ecs_high, ECS_GLOBAL_MIN, ECS_GLOBAL_MAX
            )
            bucket_correctness = compute_bucket_correctness(
                item.bucket_true, bucket_estimate, bucket_from_explanation
            )
            self_consistency = compute_self_consistency(
                bucket_estimate, bucket_from_explanation
            )
            sharpness = compute_sharpness(
                ecs_low, ecs_high, ECS_GLOBAL_MIN, ECS_GLOBAL_MAX
            )
            t_ecs_range = combine_tension(
                range_plausibility, bucket_correctness, self_consistency
            )
            coherent_flag = is_effectively_coherent(
                range_plausibility,
                bucket_correctness,
                self_consistency,
                t_ecs_range,
            )

        except Exception as e:
            print(f"[error] OpenAI call or parsing failed for item {item.item_id}: {e}")
            print("[info] Marking this item as maximal tension (T_ECS_range = 1.0).")

        results.append(
            {
                "item_id": item.item_id,
                "title": item.title,
                "bucket_true": item.bucket_true,
                "ecs_estimate": ecs_estimate,
                "ecs_low": ecs_low,
                "ecs_high": ecs_high,
                "bucket_estimate": bucket_estimate,
                "bucket_from_explanation": bucket_from_explanation,
                "range_plausibility": range_plausibility,
                "bucket_correctness": bucket_correctness,
                "self_consistency": self_consistency,
                "sharpness": sharpness,
                "T_ECS_range": t_ecs_range,
                "is_effective_coherent": coherent_flag,
            }
        )

        print(
            f"    range_plausibility = {range_plausibility:.2f} | "
            f"bucket_correctness = {bucket_correctness:.2f} | "
            f"self_consistency = {self_consistency:.2f} | "
            f"T_ECS_range = {t_ecs_range:.2f} | "
            f"coherent = {coherent_flag}"
        )
        print()

    df = pd.DataFrame(results)

    print("\n=== Summary table (one row per item) ===")
    print(df)

    # Overall statistics
    print("\n=== Overall statistics ===")
    mean_t = float(df["T_ECS_range"].mean())
    median_t = float(df["T_ECS_range"].median())
    coherent_rate = float(df["is_effective_coherent"].mean())
    print(f"mean T_ECS_range   : {mean_t:5.3f}")
    print(f"median T_ECS_range : {median_t:5.3f}")
    print(f"coherent item rate : {coherent_rate:5.3f}")

    # Mean tension by bucket
    print("\n=== Mean tension by ground-truth bucket ===")
    df_bucket = (
        df.groupby("bucket_true")["T_ECS_range"]
        .agg(["mean", "median", "count"])
        .reset_index()
        .sort_values("bucket_true")
    )
    print(df_bucket)

    # -----------------------------------------------------------------
    # 5. Plots
    # -----------------------------------------------------------------
    plt.figure(figsize=(10, 4))
    plt.bar(df["item_id"], df["T_ECS_range"])
    plt.ylim(0.0, 1.0)
    plt.xlabel("Item id")
    plt.ylabel("T_ECS_range")
    plt.title("TU Q091-A: T_ECS_range per item")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(6, 4))
    df_bucket_mean = (
        df.groupby("bucket_true")["T_ECS_range"]
        .mean()
        .reset_index()
        .sort_values("bucket_true")
    )
    plt.bar(df_bucket_mean["bucket_true"], df_bucket_mean["T_ECS_range"])
    plt.ylim(0.0, 1.0)
    plt.xlabel("Ground-truth bucket")
    plt.ylabel("Mean T_ECS_range")
    plt.title("TU Q091-A: mean T_ECS_range by bucket")
    plt.tight_layout()
    plt.show()

    print(
        "\n[done] TU Q091-A run completed. "
        "You can screenshot the tables and plots, or compare with future runs."
    )


# ---------------------------------------------------------------------
# 6. Entry point
# ---------------------------------------------------------------------

if __name__ == "__main__":
    client_obj = build_openai_client()
    run_q091_a(client_obj)
