
# T3 Annotation Instructions (From Assignment PDF)

## Step 1: Identify the Pearl Level
- **L1 (Association):** Observational relationships only (P(Y | X)).
- **L2 (Intervention):** Explicit or implicit interventions (do(X)).
- **L3 (Counterfactual):** Hypothetical alternatives (“What would have happened if X had not occurred?”).

## Step 2: Decide the Label
- **YES:** Claim is supported as stated.
- **NO:** Claim is invalid due to a causal trap.
- **AMBIGUOUS:** Insufficient information to evaluate the claim.

## Step 3: Assign Trap Type (NO cases only)
Exactly **one** trap type must be assigned, using this strict order:
1. Confounding
2. Reverse Causation
3. Selection Bias / Collider Bias
4. Simpson’s Paradox
5. Regression to the Mean
6. Goodhart’s Law
7. Feedback Loops
8. Preemption (L3 only)

## Subtype Rules
- Subtypes are optional.
- Use only when clearly applicable.
- If unsure, leave subtype empty.

## Global Constraints
- One instance → one Pearl level → one trap type (if NO).
- Prefer the explanation requiring the **minimal causal graph**.
- AMBIGUOUS cases must have `trap = NONE`.


In [None]:

def convert_to_final_format(case_data: dict, pearl_level: str, trap_type: str, case_id: str) -> dict:
    PEARL_LEVEL_NAME = {
        "L1": "Association",
        "L2": "Intervention",
        "L3": "Counterfactual",
    }

    TRAP_TYPE_MAP = {
        "NONE": ("NONE", "None"),
        "CONFOUNDING": ("CONF", "Confounding"),
        "REVERSE": ("REVERSE", "Reverse Causation"),
        "SELECTION": ("SELECTION", "Selection Bias"),
        "COLLIDER": ("COLLIDER", "Collider Bias"),
        "SIMPSONS": ("SIMPSONS", "Simpson’s Paradox"),
        "REGRESSION": ("REGRESSION", "Regression to the Mean"),
        "SURVIVORSHIP": ("SURVIVORSHIP", "Survivorship Bias"),
        "GOODHART": ("GOODHART", "Goodhart’s Law"),
        "BASE_RATE": ("BASE_RATE", "Base-rate Neglect"),
        "FEEDBACK": ("FEEDBACK", "Feedback Loops"),
        "PREEMPTION": ("PREEMPTION", "Preemption"),
        "CONFOUNDER_MEDIATOR_ERROR": ("CONF-MED", "Confounder–Mediator Error"),
    }

    trap_code, trap_name = TRAP_TYPE_MAP.get(
        trap_type, (trap_type, trap_type.replace("_", " ").title())
    )

    label = str(case_data.get("label", "NO")).upper().strip()
    if label not in {"YES", "NO", "AMBIGUOUS"}:
        label = "NO"

    LABEL_NAME = {"YES": "SUPPORTED", "NO": "FLAWED", "AMBIGUOUS": "UNCERTAIN"}
    is_ambiguous = (label == "AMBIGUOUS") or bool(case_data.get("is_ambiguous", False))

    vars_in = case_data.get("variables", {}) or {}
    x_name = vars_in.get("X", "") if isinstance(vars_in, dict) else ""
    y_name = vars_in.get("Y", "") if isinstance(vars_in, dict) else ""
    z_name = vars_in.get("Z", []) if isinstance(vars_in, dict) else []

    Z_list = []
    if isinstance(z_name, list):
        for zn in z_name:
            if str(zn).strip():
                Z_list.append({"name": str(zn).strip(), "role": "common_cause"})
    else:
        if str(z_name).strip():
            Z_list.append({"name": str(z_name).strip(), "role": "common_cause"})

    final_case = {
        "id": f"T3-BucketLarge-E-{case_id}",
        "bucket": "BucketLarge-E",
        "case_id": case_id,

        "pearl_level": pearl_level,
        "pearl_level_name": PEARL_LEVEL_NAME.get(pearl_level, ""),

        "domain_id": "D1",
        "domain_name": "Daily Life",

        "scenario": str(case_data.get("scenario", "") or ""),
        "claim": str(case_data.get("claim", "") or ""),

        "variables": {
            "X": {"name": str(x_name or ""), "role": "exposure"},
            "Y": {"name": str(y_name or ""), "role": "outcome"},
            "Z": Z_list,
        },

        "trap": {
            "type": trap_code,
            "type_name": trap_name,
            "subtype": str(case_data.get("trap_subtype", "") or ""),
            "subtype_name": "",
        },

        "difficulty": str(case_data.get("difficulty", "Medium") or "Medium"),

        "subdomain": str(case_data.get("subdomain", "") or ""),
        "causal_structure": str(case_data.get("causal_structure", "") or ""),
        "key_insight": str(case_data.get("key_insight", "") or ""),

        "hidden_timestamp": {"question": "", "answer": ""},

        "conditional_answers": case_data.get("conditional_answers", []) or [
            {"condition": "", "answer": "", "rationale": ""}
        ],

        "wise_refusal": str(case_data.get("wise_refusal", "") or ""),

        "label": label,
        "label_name": LABEL_NAME.get(label, ""),
        "is_ambiguous": bool(is_ambiguous),

        "hidden_structure": {
            "dag_edges": [["X", "Y"]],
            "notes": "",
        },

        "gold_rationale": str(case_data.get("gold_rationale", "") or ""),

        "source": {
            "origin": "generated",
            "generator": "llm_draft_human_verified",
            "seed_case_ref": "",
        },

        "annotation": {
            "num_annotators": 2,
            "agreement": "ai_generated",
            "adjudicated": False,
        },
    }
    return final_case


# T³ Benchmark Case Generator

This notebook generates 230 new causal reasoning cases for your T³ benchmark assignment.

**Distribution:**
- L1 (Association): ~50 cases (11%)
- L2 (Intervention): ~288 cases (64%)
- L3 (Counterfactual): ~112 cases (24%)

**Estimated Time:** 6-8 hours  
**Cost:** ~$2-3 (Claude Sonnet 4 API)

## Setup

In [None]:
# Install required packages
# !pip install anthropic tqdm -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/388.2 kB[0m [31m?[0m eta [36m-:--:--[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m378.9/388.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:01[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m378.9/388.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.2/388.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import json
import anthropic
import os
import random
from typing import Dict, List
from tqdm.notebook import tqdm
import time
from IPython.display import display, JSON, Markdown

## Configuration

⚠️ **Set your Anthropic API key below:**

In [None]:
from openai import AzureOpenAI
from pydantic import BaseModel, Field
from dotenv import load_dotenv

load_dotenv(dotenv_path='/content/genie-worksheets/.env')

client = AzureOpenAI(
    api_version="",
    azure_endpoint="",
    api_key="",
)

response = client.chat.completions.create(
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant.",
        },
        {
            "role": "user",
            "content": "Say 'hi.'",
        }
    ],
    model="gpt-4.1"
)

print(response.choices[0].message.content)

Hi.


In [None]:
# SET YOUR API KEY HERE
API_KEY = ""  # Replace with your actual key

# Or use environment variable
if API_KEY == "":
    API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")

if not API_KEY:
    print("⚠️ WARNING: API key not set!")
else:
    print("✓ API key configured")
    client = anthropic.Anthropic(api_key=API_KEY)



## Helper Functions

In [None]:

import re
from typing import Dict, List, Tuple

# ----------------------------
# Trap types (depend on Pearl level) — from T3 cheat sheet (Table 3/5)
# ----------------------------
TRAP_TYPES_BY_PEARL: Dict[str, List[str]] = {
    "L1": [
        "CONFOUNDING",
        "REVERSE",
        "SELECTION",
        "COLLIDER",
        "SIMPSONS",
        "REGRESSION",
        "SURVIVORSHIP",
        "BASE_RATE",
        "GOODHART",
    ],
    "L2": [
        "CONFOUNDING",
        "REVERSE",
        "SELECTION",
        "COLLIDER",
        "CONFOUNDER_MEDIATOR_ERROR",
        "SIMPSONS",
        "GOODHART",
        "FEEDBACK",
    ],
    "L3": [
        "PREEMPTION",
        "CONFOUNDING",
        "REVERSE",
        "CONFOUNDER_MEDIATOR_ERROR",
        "FEEDBACK",
        "SELECTION",
    ],
}

# ----------------------------
# Trap guides (short reminders used in prompts)
# ----------------------------
TRAP_GUIDES: Dict[str, str] = {
    "NONE": "No causal trap applies (YES/AMBIGUOUS).",
    "CONFOUNDING": "A common cause Z influences both X and Y, creating a misleading association/effect.",
    "REVERSE": "The outcome Y (or its causes) influences X (Y→X), so associational evidence is misread as X→Y.",
    "SELECTION": "Inference is distorted because we condition on a selected/filtered subset (who is observed depends on variables).",
    "COLLIDER": "Conditioning on a common effect of X and Y (X→Z←Y) induces spurious association.",
    "SIMPSONS": "Aggregated trends reverse within subgroups; the overall association differs from stratified associations.",
    "REGRESSION": "Extreme observations naturally revert toward the mean due to noise/variation, not a causal effect.",
    "SURVIVORSHIP": "Failures/non-survivors are missing; only those that remain observed bias the conclusion.",
    "BASE_RATE": "Ignoring priors/base rates or confusing conditional probabilities (P(A|B) vs P(B|A)) invalidates the conclusion.",
    "GOODHART": "Optimizing a proxy metric breaks its correlation with the true target; the metric stops measuring what matters.",
    "FEEDBACK": "Bidirectional/adaptive causation: actions change outcomes which in turn change future actions (a loop).",
    "CONFOUNDER_MEDIATOR_ERROR": "Incorrectly adjusting/fixing a mediator or post-treatment variable breaks causal interpretation.",
    "PREEMPTION": "In counterfactuals, an alternative cause would have produced the outcome anyway, undermining the stated cause.",
}

# ----------------------------
# Optional subtypes (use ONLY when clearly applicable)
# ----------------------------
TRAP_SUBTYPES: Dict[Tuple[str, str], List[str]] = {
    # L1
    ("L1", "CONFOUNDING"): ["Confounding_by_Indication", "Omitted_Variable", "Socioeconomic"],
    ("L1", "REVERSE"): ["Outcome-driven_Selection", "Policy_Endogeneity"],
    ("L1", "SELECTION"): ["Sampling-on-the-Outcome", "Attrition_Bias", "Case-Control_Sampling"],
    ("L1", "COLLIDER"): ["Conditioning_on_Participation"],
    ("L1", "SIMPSONS"): ["Aggregation_Bias", "Imbalanced_Group_Composition"],
    ("L1", "REGRESSION"): ["Extreme-Group_Selection", "Noise-Induced_Extremes"],
    ("L1", "SURVIVORSHIP"): ["Selective_Observation", "Historical_Filtering"],
    ("L1", "BASE_RATE"): ["Prior_Ignorance", "Conditional_Fallacy"],
    ("L1", "GOODHART"): ["Static_Metric_Gaming", "Proxy_Drift"],

    # L2
    ("L2", "CONFOUNDING"): ["Unblocked_Backdoor", "Time-varying_Confounding"],
    ("L2", "REVERSE"): ["Reactive_Intervention"],
    ("L2", "SELECTION"): ["Post-intervention_Selection"],
    ("L2", "COLLIDER"): ["Conditioning_on_Compliance"],
    ("L2", "CONFOUNDER_MEDIATOR_ERROR"): ["Mediator_Adjustment_Error"],
    ("L2", "SIMPSONS"): ["Stratified_Intervention_Reversal"],
    ("L2", "GOODHART"): ["Policy_Target_Gaming"],
    ("L2", "FEEDBACK"): ["Policy–Response_Loop"],

    # L3
    ("L3", "PREEMPTION"): ["Early_Preemption", "Late_Preemption"],
    ("L3", "CONFOUNDING"): ["Cross-world_Confounder"],
    ("L3", "REVERSE"): ["Outcome-dependent_Worlds"],
    ("L3", "CONFOUNDER_MEDIATOR_ERROR"): ["Mediator_Fixing_Error"],
    ("L3", "FEEDBACK"): ["Dynamic_World_Divergence"],
    ("L3", "SELECTION"): ["Counterfactual_Conditioning"],
}

# ----------------------------
# Canonicalization (prevents KeyError from alias names)
# ----------------------------
_TRAP_ALIASES: Dict[str, str] = {
    # regression
    "REGRESSION TO MEAN": "REGRESSION",
    "REGRESSION TO THE MEAN": "REGRESSION",
    "REGRESSION_TO_MEAN": "REGRESSION",
    "REGRESSION_TO_THE_MEAN": "REGRESSION",

    # base rate
    "BASE RATE NEGLECT": "BASE_RATE",
    "BASE-RATE NEGLECT": "BASE_RATE",
    "BASE_RATE_NEGLECT": "BASE_RATE",

    # confounder–mediator
    "CONF-MED": "CONFOUNDER_MEDIATOR_ERROR",
    "CONF MED": "CONFOUNDER_MEDIATOR_ERROR",
    "CONF_MED": "CONFOUNDER_MEDIATOR_ERROR",
    "CONFOUND_MEDIATOR_ERROR": "CONFOUNDER_MEDIATOR_ERROR",
    "CONF-MED ERROR": "CONFOUNDER_MEDIATOR_ERROR",

    # simpsons
    "SIMPSON'S": "SIMPSONS",
    "SIMPSONS PARADOX": "SIMPSONS",

    # misc
    "REVERSE CAUSATION": "REVERSE",
    "SELECTION BIAS": "SELECTION",
    "COLLIDER BIAS": "COLLIDER",
    "SURVIVORSHIP BIAS": "SURVIVORSHIP",
    "GOODHART'S LAW": "GOODHART",
    "GOODHARTS LAW": "GOODHART",

    # people sometimes put "COUNTERFACTUAL" as a trap by mistake
    "COUNTERFACTUAL": "PREEMPTION",
    "NO TRAP": "NONE",
    "NONE": "NONE",
}

def canonical_trap_type(trap_type: str) -> str:
    t = (trap_type or "").strip().upper()
    t = t.replace("’", "'").replace("–", "-").replace("—", "-")
    t = re.sub(r"\s+", " ", t)
    return _TRAP_ALIASES.get(t, t)


In [None]:

import json
from typing import List, Optional

_NO_STYLE_EXEMPLARS = [
    {
        "statistical_structure": (
            "The Statistical Structure. The extreme initial outcome naturally moves toward average on repeat measurement."
        ),
        "correct_reasoning": (
            "Correct Reasoning. The apparent change is driven by regression to the mean (or selection on extremes), not the claimed cause."
        ),
        "wise_refusal": (
            "Wise Refusal. \"This pattern can arise from regression to the mean: extreme cases often improve even without the intervention.\""
        ),
    },
    {
        "statistical_structure": (
            "The Statistical Structure. Raw counts can rise because exposure/denominator is larger, even if the rate is unchanged."
        ),
        "correct_reasoning": (
            "Correct Reasoning. Compare rates (per unit exposure), not totals; otherwise you may mistake base-rate differences for a causal effect."
        ),
        "wise_refusal": (
            "Wise Refusal. \"Check the base rate (e.g., per mile/per user). A higher count may just mean more exposure, not higher risk.\""
        ),
    },
    {
        "statistical_structure": (
            "The Statistical Structure. The dataset includes only a selected subset, so the observed relationship is distorted by who is included."
        ),
        "correct_reasoning": (
            "Correct Reasoning. Selection (or collider conditioning) can create or flip associations; you need the missing cases to evaluate the claim."
        ),
        "wise_refusal": (
            "Wise Refusal. \"This could be selection bias: if inclusion depends on variables related to both X and Y, the observed link can be misleading.\""
        ),
    },
]

def generate_case_prompt(
    pearl_level: str,
    trap_type: str,
    case_number: int,
    *,
    trap_guide: str,
    allowed_subtypes: Optional[List[str]] = None,
    desired_label: Optional[str] = None,
) -> str:
    """
    Strict prompt:
    - Domain fixed to Daily Life
    - Forces JSON-only output with stable keys
    - Enforces Pearl-level semantics
    - Enforces trap_type exactly (and optional subtype from allowed list)
    - Enforces desired label distribution (YES/NO/AMBIGUOUS)
    - Adds NO-style exemplars (Statistical Structure / Correct Reasoning / Wise Refusal)
    """
    pearl_level = (pearl_level or "").strip().upper()
    trap_type = (trap_type or "").strip().upper()
    allowed_subtypes = allowed_subtypes or []
    desired_label = (desired_label or "").strip().upper() if desired_label else None

    # For YES/AMBIGUOUS, there is no trap type (trap = NONE)
    effective_trap_type = trap_type
    if desired_label in {"YES", "AMBIGUOUS"}:
        effective_trap_type = "NONE"

    pearl_instructions = {
        "L1": (
            "Pearl level is L1 (Association). Use ONLY observational/correlational language. "
            "Do NOT use intervention language ('if we do', 'will cause') or counterfactual phrasing."
        ),
        "L2": (
            "Pearl level is L2 (Intervention). The claim MUST be an intervention/causal effect claim (do(X)). "
            "Use action language like 'if we do/assign/increase X'."
        ),
        "L3": (
            "Pearl level is L3 (Counterfactual). The claim MUST be explicitly counterfactual, comparing the actual world "
            "to a hypothetical world ('Had X not occurred...', 'If X had been different...')."
        ),
    }[pearl_level]

    output_schema = {
        "domain": "Daily Life",
        "scenario": "2–4 sentences describing the real-world situation/data",
        "claim": "1 sentence claim consistent with the Pearl level",
        "variables": {
            "X": "exposure or action",
            "Y": "outcome",
            "Z": ["optional list of confounders/selection variables (strings)"]
        },
        "label": "YES | NO | AMBIGUOUS",
        "trap_type": effective_trap_type,
        "trap_subtype": "optional (empty string if none)",
        "gold_rationale": "2–4 sentences explaining the causal reasoning; if NO, explicitly reference the trap mechanism",
        "wise_refusal": "1–3 sentences in plain language explaining why the claim is flawed/uncertain (or why it holds), in the style of the exemplars",
        "difficulty": "Easy | Medium | Hard"
    }

    label_constraint = ""
    if desired_label in {"YES", "NO", "AMBIGUOUS"}:
        label_constraint = f'- label MUST be exactly "{desired_label}".\n'
        if desired_label == "YES":
            label_constraint += (
                "- The scenario MUST directly support the claim as stated (no missing info).\n"
                "- Do NOT describe any causal trap.\n"
            )
        elif desired_label == "AMBIGUOUS":
            label_constraint += (
                "- The scenario MUST be missing a critical piece of information so the claim cannot be verified.\n"
                "- Do NOT describe any causal trap.\n"
            )
        else:  # NO
            label_constraint += (
                "- The scenario MUST make the claim invalid due to the specified trap mechanism.\n"
            )

    # PDF label definitions (must follow)
    label_definitions = """LABEL DEFINITIONS (MUST FOLLOW EXACTLY):

- YES:
  The claim is supported AS STATED by the scenario under the given Pearl level.
  The scenario provides sufficient information, and no causal or statistical
  assumption is violated. Do NOT invoke any causal trap.

- NO:
  The claim is INVALID AS STATED due to a violated causal or statistical assumption.
  The error MUST be explained by the specified causal trap (trap_type).
  There must be exactly one causal failure mode.

- AMBIGUOUS:
  The claim cannot be definitively evaluated given the available information.
  Critical information is missing (e.g., timing, controls, comparison group,
  intervention clarity). Do NOT invoke a causal trap.
""".strip()

    # Add NO-style exemplars only when desired_label == NO
    exemplar_block = ""
    if desired_label == "NO":
        ex = _NO_STYLE_EXEMPLARS
        exemplar_block = (
            "NO-CASE STYLE EXEMPLARS (match this structure/tone):\n"
            "- Use headings in your rationale/writing style, e.g., 'The Statistical Structure', 'Correct Reasoning', 'Wise Refusal'.\n"
            "Examples:\n"
            + "\n".join(
                [
                    f"* {e['statistical_structure']}\n  {e['correct_reasoning']}\n  {e['wise_refusal']}"
                    for e in ex
                ]
            )
        )

    return f"""
You are generating ONE dataset instance.

HARD CONSTRAINTS:
- Output MUST be a single valid JSON object and NOTHING ELSE (no markdown, no commentary).
- Use double quotes for all keys and strings.
- Do NOT include trailing commas.
- The JSON MUST include exactly these top-level keys:
  domain, scenario, claim, variables, label, trap_type, trap_subtype, gold_rationale, wise_refusal, difficulty

DOMAIN (FIXED):
- domain MUST be exactly: "Daily Life"

PEARL LEVEL:
- pearl_level = "{pearl_level}"
- {pearl_instructions}

LABEL CONSTRAINT:
{label_constraint.strip() if label_constraint else "- label can be YES, NO, or AMBIGUOUS."}

{label_definitions}

TRAP REQUIREMENT:
- If label is YES or AMBIGUOUS: trap_type MUST be "NONE" (no trap applies).
- If label is NO: trap_type MUST be exactly: "{effective_trap_type}"
- trap guide: {trap_guide}
- Allowed subtypes (optional): {json.dumps(allowed_subtypes if effective_trap_type!="NONE" else [])}
Rules:
- If label = "NO": the reasoning error MUST be explained by the specified causal trap.
- If label = "YES" or "AMBIGUOUS": set trap_subtype = "".

VARIABLE RULES:
- variables.Z MUST be a JSON array (use [] if none).
- X and Y should be short, concrete phrases.

QUALITY:
- Use realistic, non-sensitive daily-life scenarios (home, work, school, habits, tech use, commuting, etc.).

{exemplar_block}

OUTPUT JSON ONLY matching this template:
{json.dumps(output_schema, ensure_ascii=False, indent=2)}
""".strip()


In [None]:
def convert_to_final_format_old(case_data: Dict, pearl_level: str, trap_type: str) -> Dict:
    """Convert generated case to final JSON format."""

    # Determine variable role for Z (best-effort; mostly for downstream visualization)
    if trap_type == "COLLIDER":
        z_role = "collider"
    elif trap_type == "CONFOUND_MEDIATOR_ERROR":
        z_role = "mediator_or_post_treatment"
    else:
        z_role = "common_cause_or_context"

    # Label: these generated cases are intended to be flawed (NO)
    label = "NO"

    # Build a simple illustrative DAG (best-effort)
    dag_edges: List[List[str]] = []
    if trap_type == "CONFOUNDING":
        dag_edges = [["Z", "X"], ["Z", "Y"]]
    elif trap_type == "REVERSE":
        dag_edges = [["Y", "X"]]
    elif trap_type == "SELECTION":
        # Selection is really about conditioning on inclusion; approximate with Z affecting both
        dag_edges = [["Z", "X"], ["Z", "Y"]]
    elif trap_type == "COLLIDER":
        dag_edges = [["X", "Z"], ["Y", "Z"]]
    elif trap_type == "SIMPSONS":
        dag_edges = [["Z", "X"], ["Z", "Y"], ["X", "Y"]]
    elif trap_type == "REGRESSION":
        dag_edges = [["X", "Y"]]
    elif trap_type == "SURVIVORSHIP":
        dag_edges = [["X", "Z"], ["Y", "Z"]]
    elif trap_type == "GOODHART":
        dag_edges = [["X", "Y"]]
    elif trap_type == "BASE_RATE":
        dag_edges = [["X", "Y"]]
    elif trap_type == "FEEDBACK":
        dag_edges = [["X", "Y"], ["Y", "X"]]
    elif trap_type == "CONFOUND_MEDIATOR_ERROR":
        dag_edges = [["X", "Z"], ["Z", "Y"]]
    elif trap_type == "PREEMPTION":
        dag_edges = [["X", "Y"], ["Z", "Y"]]
    else:
        dag_edges = [["X", "Y"]]

    final_case = {
        "id": case_data.get("case_id", f"1.{case_number_from_id(case_data.get('case_id', '1.0'))}"),
        "pearl_level": pearl_level,
        "domain": case_data.get("subdomain", "General"),
        "scenario": case_data.get("scenario", ""),
        "claim": case_data.get("claim", ""),
        "label": label,
        "is_ambiguous": False,
        "trap": {
            "type": trap_type,
            "type_name": trap_type.replace("_", " ").title(),
            "subtype": case_data.get("trap_subtype", ""),
            "subtype_name": case_data.get("trap_subtype", "").replace("_", " ").title()
        },
        "variables": {
            "X": case_data.get("variables", {}).get("X", {}).get("description", ""),
            "Y": case_data.get("variables", {}).get("Y", {}).get("description", ""),
            "Z": [case_data.get("variables", {}).get("Z", {}).get("description", "")]
        },
        "gold_rationale": case_data.get("wise_refusal", ""),
        "metadata": {
            "title": case_data.get("title", ""),
            "difficulty": case_data.get("difficulty", "Easy"),
            "z_role": z_role,
            "dag_edges": dag_edges,
            "key_insight": case_data.get("key_insight", ""),
            "causal_structure": case_data.get("causal_structure", ""),
        },
        "source": {
            "origin": "generated",
            "generator": "t3_case_generator.ipynb",
            "generation_date": time.strftime("%Y-%m-%d")
        },
        "annotation": {
            "num_annotators": 1,
            "agreement": "ai_generated",
            "adjudicated": False
        }
    }

    return final_case


def case_number_from_id(case_id: str) -> int:
    """Extract numeric suffix from an id like '1.23'."""
    try:
        return int(str(case_id).split(".")[-1])
    except Exception:
        return 0


In [None]:

import json
import time

def generate_single_case(
    pearl_level: str,
    trap_type: str,
    case_number: int,
    max_retries: int = 4,
    desired_label: str | None = None,
) -> dict:
    """Generate a single case; returns schema-shaped dict. If desired_label is set, enforce it."""
    pearl_level = (pearl_level or "").strip().upper()
    trap_type = canonical_trap_type(trap_type)
    desired_label = (desired_label or "").strip().upper() if desired_label else None

    # For YES/AMBIGUOUS, there is no trap type
    if desired_label in {"YES", "AMBIGUOUS"}:
        trap_type = "NONE"

    allowed = TRAP_TYPES_BY_PEARL.get(pearl_level, [])
    if trap_type != "NONE" and trap_type not in allowed:
        return {
            "id": f"T3-BucketLarge-E-1.{case_number}",
            "case_id": f"1.{case_number}",
            "error": f"trap_type '{trap_type}' not allowed for pearl_level '{pearl_level}'. Allowed: {allowed}"
        }

    case_id = f"1.{case_number}"
    trap_guide = TRAP_GUIDES[trap_type]
    allowed_subtypes = TRAP_SUBTYPES.get((pearl_level, trap_type), [])

    def _call_model(feedback: str | None = None) -> dict:
        prompt = generate_case_prompt(
            pearl_level=pearl_level,
            trap_type=trap_type,
            case_number=case_number,
            trap_guide=trap_guide,
            allowed_subtypes=allowed_subtypes,
            desired_label=desired_label,
        )
        if feedback:
            prompt = prompt + "\n\nCORRECTION:\n" + feedback

        response = client.chat.completions.create(
            model="gpt-4.1",
            messages=[
                {"role": "system", "content": "Output valid JSON only. No markdown. No extra text."},
                {"role": "user", "content": prompt},
            ],
            max_tokens=1200,
            temperature=0.35,
        )
        text = response.choices[0].message.content or ""
        start = text.find("{")
        end = text.rfind("}")
        if start == -1 or end == -1 or end <= start:
            raise json.JSONDecodeError("No JSON object found", text, 0)
        obj = json.loads(text[start:end+1])
        return obj

    feedback = None
    for attempt in range(max_retries):
        try:
            obj = _call_model(feedback)

            # enforce fixed domain and trap_type
            obj["domain"] = "Daily Life"
            obj["trap_type"] = trap_type

            lab = str(obj.get("label", "")).strip().upper()
            if lab not in {"YES", "NO", "AMBIGUOUS"}:
                lab = "NO"
                obj["label"] = "NO"

            if lab in {"YES", "AMBIGUOUS"}:
                obj["trap_type"] = "NONE"
                obj["trap_subtype"] = ""

            # enforce desired label if requested
            if desired_label in {"YES", "NO", "AMBIGUOUS"} and lab != desired_label:
                feedback = f'Your previous JSON had label="{lab}". Re-generate with label EXACTLY "{desired_label}". Keep trap_type "{trap_type}".'
                time.sleep(0.6)
                continue

            return convert_to_final_format(obj, pearl_level, trap_type, case_id)

        except json.JSONDecodeError:
            feedback = "Your previous output was not valid JSON. Output ONLY one valid JSON object."
            time.sleep(0.8)
            continue
        except Exception as e:
            feedback = f"Exception occurred: {type(e).__name__}: {e}. Output valid JSON only."
            time.sleep(0.8)
            continue

    return {
        "id": f"T3-BucketLarge-E-{case_id}",
        "case_id": case_id,
        "error": f"Failed after retries (desired_label={desired_label})",
    }


## Part 1: Test Generation (10 Cases)

**Run this first** to verify everything works before generating all 230 cases.

In [None]:

from typing import List, Dict
from tqdm import tqdm
import time

def generate_test_cases(num_cases: int = 10) -> List[Dict]:
    """
    Generate a small sanity-check set with explicit YES/NO/AMBIGUOUS labels.
    IMPORTANT: YES and AMBIGUOUS must have trap_type = NONE.
    """

    print(f"Generating {num_cases} test cases...\n")

    # (Pearl level, trap, desired_label)
    case_specs = [
        ("L1", "REGRESSION", "NO"),
        ("L1", "BASE_RATE", "NO"),
        ("L2", "CONFOUNDER_MEDIATOR_ERROR", "NO"),
        ("L2", "SELECTION", "NO"),

        # YES cases -> trap must be NONE
        ("L1", "NONE", "YES"),
        ("L2", "NONE", "YES"),

        # AMBIGUOUS cases -> trap must be NONE
        ("L1", "NONE", "AMBIGUOUS"),
        ("L2", "NONE", "AMBIGUOUS"),

        # L3 mix
        ("L3", "PREEMPTION", "NO"),
        ("L3", "NONE", "YES"),
    ]

    cases = []
    for i, (level, trap, desired_label) in enumerate(
        tqdm(case_specs[:num_cases], desc="Generating test cases")
    ):
        case = generate_single_case(
            level,
            trap,
            1000 + i,
            desired_label=desired_label
        )

        if case and "error" not in case:
            cases.append(case)
            cid = case.get("case_id") or case.get("id") or f"1.{1000+i}"
            print(f"✓ {cid} [{level} | {desired_label} | trap={case.get('trap',{}).get('type','')}]")
        else:
            err = case.get("error", "") if isinstance(case, dict) else ""
            print(f"✗ Failed to generate case {1000 + i}: {err}")

        time.sleep(0.7)

    return cases


In [None]:
# RUN THIS: Generate 10 test cases
test_cases = generate_test_cases(10)

print(f"\n{'='*60}")
print(f"Generated {len(test_cases)} test cases")
print(f"{'='*60}")

Generating 10 test cases...



Generating test cases:   0%|          | 0/10 [00:00<?, ?it/s]

✓ 1.1000: Star Student Slump [L1 - REGRESSION TO MEAN]
✓ 1.1001: Lucky Lottery Retailer [L1 - BASE RATE NEGLECT]
✓ 1.1002: Coffee Cup Perks [L2 - CONF-MED]
✓ 1.1003: Weekend Tutor Trap [L2 - SELECTION]
✓ 1.1004: Nightlight Anxiety Trap [L2 - REVERSE]
✓ 1.1005: Parking Permit Paradox [L2 - COLLIDER]
✓ 1.1006: Streaming & Sleepless Nights [L2 - CONF-MED]
✓ 1.1007: Free Coffee Promotion Effect [L3 - COUNTERFACTUAL]
✓ 1.1008: Missed Train, Lost Job [L3 - COUNTERFACTUAL]
✓ 1.1009: Frequent Library Visits [L2 - SELECTION]

Generated 10 test cases


In [None]:
# View a sample case
if test_cases:
    print("Sample Case:")
    print("=" * 60)
    sample = test_cases[0]
    print(f"Title: {sample.get('title')}")
    print(f"Level: {sample.get('pearl_level_name')}")
    print(f"Trap: {sample.get('trap', {}).get('type')}")
    print(f"\nScenario: {sample.get('scenario')}")
    print(f"\nWise Refusal: {sample.get('wise_refusal')}")
    print("\nFull JSON:")
    display(JSON(sample))

Sample Case:
Title: Star Student Slump
Level: Association
Trap: REGRESSION TO MEAN

Scenario: In the first semester, Jamie earns nearly perfect grades, much higher than most students. When Jamie later receives lower grades in the second semester, classmates say the tougher subjects caused Jamie to suddenly lose their 'gift.'

Wise Refusal: A wise AI would point out that Jamie's grades may be regressing to the mean and the drop is expected after an exceptionally strong performance. It's incorrect to assume that only new negative factors are responsible for the change without considering normal variation.

Full JSON:


<IPython.core.display.JSON object>

In [None]:
# Save test cases
test_output = "t3_test_cases.json"
with open(test_output, 'w') as f:
    json.dump({
        "metadata": {
            "total_cases": len(test_cases),
            "generation_date": time.strftime("%Y-%m-%d %H:%M:%S"),
            "model": "claude-sonnet-4-20250514"
        },
        "cases": test_cases
    }, f, indent=2)

print(f"✓ Saved to: {test_output}")

## Part 2: Full Generation (230 Cases)

⚠️ **This will take 6-8 hours and cost ~$2-3**

Only run this after reviewing the test cases above!

In [None]:
def generate_all_cases(target_total: int = 230) -> List[Dict]:
    """Generate all cases maintaining distribution"""

    # Calculate distribution
    l1_count = int(target_total * 0.11)  # 11%
    l2_count = int(target_total * 0.64)  # 64%
    l3_count = target_total - l1_count - l2_count

    print(f"Generating {target_total} cases:")
    print(f"  L1 (Association): {l1_count}")
    print(f"  L2 (Intervention): {l2_count}")
    print(f"  L3 (Counterfactual): {l3_count}")
    print()

    all_cases = []
    case_counter = 50  # Start after existing 45 cases

    # Generate L1 cases
    print("\n=== Generating L1 (Association) Cases ===")
    l1_trap_types = TRAP_TYPES["L1"]
    for i in tqdm(range(l1_count), desc="L1 Cases"):
        trap_type = random.choice(l1_trap_types)
        case = generate_single_case("L1", trap_type, case_counter)
        if "error" not in case:
            all_cases.append(case)
        case_counter += 1
        time.sleep(0.5)

    # Generate L2 cases
    print("\n=== Generating L2 (Intervention) Cases ===")
    l2_trap_types = TRAP_TYPES["L2"]
    for i in tqdm(range(l2_count), desc="L2 Cases"):
        trap_type = random.choice(l2_trap_types)
        case = generate_single_case("L2", trap_type, case_counter)
        if "error" not in case:
            all_cases.append(case)
        case_counter += 1
        time.sleep(0.5)

    # Generate L3 cases
    print("\n=== Generating L3 (Counterfactual) Cases ===")
    for i in tqdm(range(l3_count), desc="L3 Cases"):
        case = generate_single_case("L3", "COUNTERFACTUAL", case_counter)
        if "error" not in case:
            all_cases.append(case)
        case_counter += 1
        time.sleep(0.5)

    return all_cases

In [None]:
# RUN THIS: Generate all 230 cases (takes 6-8 hours!)
# Comment out if you don't want to run yet

# all_cases = generate_all_cases(target_total=230)

# print(f"\n{'='*60}")
# print(f"✓ Successfully generated {len(all_cases)} cases")
# print(f"{'='*60}")

In [None]:
# Save all cases
# Uncomment when you've generated all cases

# output_file = "t3_generated_cases_full.json"
# with open(output_file, 'w') as f:
#     json.dump({
#         "metadata": {
#             "total_cases": len(all_cases),
#             "generation_date": time.strftime("%Y-%m-%d %H:%M:%S"),
#             "model": "claude-sonnet-4-20250514",
#             "bucket": "BucketLarge-E",
#             "domain": "Daily Life & Psychology"
#         },
#         "cases": all_cases
#     }, f, indent=2)

# # Print statistics
# l1_count = sum(1 for c in all_cases if c.get("pearl_level") == "L1")
# l2_count = sum(1 for c in all_cases if c.get("pearl_level") == "L2")
# l3_count = sum(1 for c in all_cases if c.get("pearl_level") == "L3")

# print(f"\nFinal Distribution:")
# print(f"  L1: {l1_count} ({l1_count/len(all_cases)*100:.1f}%)")
# print(f"  L2: {l2_count} ({l2_count/len(all_cases)*100:.1f}%)")
# print(f"  L3: {l3_count} ({l3_count/len(all_cases)*100:.1f}%)")
# print(f"\n✓ Saved to: {output_file}")

## Part 3: Batch Generation with Checkpoints (RECOMMENDED)

This version generates cases in batches of 50 and saves progress after each batch.  
**Safer than generating all 230 at once!**

In [None]:
# Checkpoint file
CHECKPOINT_FILE = "t3_checkpoint.json"
BATCH_SIZE = 50

In [None]:
def load_checkpoint() -> List[Dict]:
    """Load previously generated cases"""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            data = json.load(f)
            return data.get("cases", [])
    return []

def save_checkpoint(cases: List[Dict]):
    """Save progress"""
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump({
            "metadata": {
                "total_cases": len(cases),
                "last_updated": time.strftime("%Y-%m-%d %H:%M:%S")
            },
            "cases": cases
        }, f, indent=2)
    print(f"✓ Checkpoint saved: {len(cases)} total cases")

In [None]:
def generate_batch(pearl_level: str, trap_types: List[str], start_number: int, count: int) -> List[Dict]:
    """Generate a batch of cases"""

    batch = []
    for i in range(count):
        trap_type = random.choice(trap_types)
        case_number = start_number + i

        case = generate_single_case(pearl_level, trap_type, case_number)
        if case and "error" not in case:
            batch.append(case)
            print(f"  ✓ {case['case_id']}: {case.get('title', 'Untitled')}")

        time.sleep(0.7)

    return batch

In [None]:
# Check existing progress
existing_cases = load_checkpoint()
print(f"Found {len(existing_cases)} existing cases in checkpoint")

In [None]:

from tqdm import tqdm
import random
import time

def batch_generate_all(
    target_total: int = 230,
    out_path: str = "t3_bucketlarge_e.jsonl",
    seed: int = 13,
) -> None:
    """
    Generate target_total examples with YES/NO/AMBIGUOUS in a 1:1:1 ratio.
    Writes JSONL.
    """
    random.seed(seed)

    # Build desired label schedule: equal counts as close as possible
    base = target_total // 3
    rem = target_total % 3
    labels = (["YES"] * base) + (["NO"] * base) + (["AMBIGUOUS"] * base)
    # distribute remainder
    if rem >= 1: labels.append("YES")
    if rem >= 2: labels.append("NO")
    random.shuffle(labels)

    # choose pearl levels uniformly
    pearl_levels = ["L1", "L2", "L3"]

    all_cases = []
    with open(out_path, "w", encoding="utf-8") as f:
        for idx, desired_label in enumerate(tqdm(labels, desc="Generating cases")):
            case_number = idx + 1

            level = random.choice(pearl_levels)
            if desired_label in {"YES", "AMBIGUOUS"}:
                trap = "NONE"
            else:
                trap = random.choice(TRAP_TYPES_BY_PEARL[level])

            case = generate_single_case(level, trap, case_number, desired_label=desired_label)

            if isinstance(case, dict) and "error" not in case:
                all_cases.append(case)
                f.write(json.dumps(case, ensure_ascii=False) + "\n")
            else:
                # still write error rows for debugging? comment out if you prefer
                err = case.get("error") if isinstance(case, dict) else "Unknown error"
                f.write(json.dumps({"case_id": f"1.{case_number}", "error": err}, ensure_ascii=False) + "\n")

            time.sleep(0.4)

    print(f"\nWrote {len(all_cases)} valid cases to {out_path}")
    # quick ratio report
    counts = {"YES": 0, "NO": 0, "AMBIGUOUS": 0}
    for c in all_cases:
        lab = str(c.get("label","")).upper()
        if lab in counts: counts[lab] += 1
    print("Label counts:", counts)


In [None]:
# RUN THIS: Generate all cases in batches (RECOMMENDED METHOD)
# This will save progress after every 50 cases

final_cases = batch_generate_all(target_total=230)

print(f"\n{'='*60}")
print(f"✓ COMPLETE: {len(final_cases)} total cases")
print(f"{'='*60}")

In [None]:
# Save final output
final_output = "t3_generated_cases_final.json"
with open(final_output, 'w') as f:
    json.dump({
        "metadata": {
            "total_cases": len(final_cases),
            "generation_date": time.strftime("%Y-%m-%d %H:%M:%S"),
            "model": "claude-sonnet-4-20250514",
            "bucket": "BucketLarge-E",
            "domain": "Daily Life & Psychology"
        },
        "cases": final_cases
    }, f, indent=2)

# Print final statistics
l1_final = sum(1 for c in final_cases if c.get("pearl_level") == "L1")
l2_final = sum(1 for c in final_cases if c.get("pearl_level") == "L2")
l3_final = sum(1 for c in final_cases if c.get("pearl_level") == "L3")

print(f"\nFinal Distribution:")
print(f"  L1: {l1_final} ({l1_final/len(final_cases)*100:.1f}%)")
print(f"  L2: {l2_final} ({l2_final/len(final_cases)*100:.1f}%)")
print(f"  L3: {l3_final} ({l3_final/len(final_cases)*100:.1f}%)")
print(f"\n✓ Saved to: {final_output}")

## Analysis & Quality Check

In [None]:
# Load generated cases for analysis
# Change filename as needed
with open('t3_test_cases.json', 'r') as f:
    data = json.load(f)
    cases_to_analyze = data['cases']

print(f"Analyzing {len(cases_to_analyze)} cases...")

In [None]:
# Distribution analysis
from collections import Counter

pearl_levels = Counter(c.get('pearl_level') for c in cases_to_analyze)
trap_types = Counter(c.get('trap', {}).get('type') for c in cases_to_analyze)
difficulties = Counter(c.get('difficulty') for c in cases_to_analyze)
subdomains = Counter(c.get('subdomain') for c in cases_to_analyze)

print("Pearl Level Distribution:")
for level, count in pearl_levels.most_common():
    print(f"  {level}: {count} ({count/len(cases_to_analyze)*100:.1f}%)")

print("\nTrap Type Distribution:")
for trap, count in trap_types.most_common():
    print(f"  {trap}: {count}")

print("\nDifficulty Distribution:")
for diff, count in difficulties.most_common():
    print(f"  {diff}: {count}")

print("\nTop Subdomains:")
for subdomain, count in subdomains.most_common(10):
    print(f"  {subdomain}: {count}")

In [None]:
# Sample cases from each level
print("Sample L1 Case:")
l1_cases = [c for c in cases_to_analyze if c.get('pearl_level') == 'L1']
if l1_cases:
    sample = l1_cases[0]
    print(f"Title: {sample.get('title')}")
    print(f"Scenario: {sample.get('scenario')}")
    print(f"Trap: {sample.get('trap', {}).get('type')}")
    print(f"Wise Refusal: {sample.get('wise_refusal')}\n")

print("\nSample L2 Case:")
l2_cases = [c for c in cases_to_analyze if c.get('pearl_level') == 'L2']
if l2_cases:
    sample = l2_cases[0]
    print(f"Title: {sample.get('title')}")
    print(f"Scenario: {sample.get('scenario')}")
    print(f"Trap: {sample.get('trap', {}).get('type')}")
    print(f"Wise Refusal: {sample.get('wise_refusal')}\n")

print("\nSample L3 Case:")
l3_cases = [c for c in cases_to_analyze if c.get('pearl_level') == 'L3']
if l3_cases:
    sample = l3_cases[0]
    print(f"Title: {sample.get('title')}")
    print(f"Scenario: {sample.get('scenario')}")
    print(f"Wise Refusal: {sample.get('wise_refusal')}")

## Summary

You now have:
1. ✅ Test cases (10) to verify quality
2. ✅ Full generation capability (230 cases)
3. ✅ Batch generation with checkpoints (safest method)
4. ✅ Analysis tools

**Next Steps:**
1. Review test cases for quality
2. Run batch generation (recommended)
3. Analyze distribution
4. Edit any low-quality cases
5. Write your quality analysis report
6. Submit!