# Edge Classifier Snippet Method – Ensemble / 3-Judge Classifier

**Purpose**

* Classify each `(:Case)-[:CITES_TO]->(:Case)` edge as **Positive / Neutral / Negative / Unknown** using:
  * **Pre-extracted snippets** stored on the relation (`snippet_1 … snippet_N`)
  * Short **case summaries** for citing and cited cases
* Use **three independent LLM “judges”** (Mistral, Llama, Claude) and combine them with a **majority-vote ensemble**.
* Write **per-model** and **global** results back to Neo4j and (optionally) export CSVs for analysis / QA.

---

## What this notebook uses

* **Snippets from the relation**  
  * `snippet_1, snippet_2, …` on each `[:CITES_TO]` relation (created by the Snippet Retriever).
* **Case metadata** (for both source = citing, target = cited):
  * `name`, `decision_date`, `citation_pipe`, optional `opinion_summary`
* **Short citation**  
  * First element from target’s `citation_pipe` (used to anchor the LLMs).
* **Models via AWS Bedrock**:
  * **Mistral 7B Instruct** (`mistral.mistral-7b-instruct-v0:2`)
  * **Llama 3 70B Instruct** (`meta.llama3-70b-instruct-v1:0`)
  * **Claude 3.5 Sonnet** (`anthropic.claude-3-5-sonnet-20240620-v1:0`)

---

## How it works (high level)

1. **Page through edges**  
   * Query `Q_PAGE_REL` to get `(:Case)-[r:CITES_TO]->(:Case)` in batches by `id(r)`.

2. **Gather snippets**  
   * Read `properties(r)` and pull all keys matching `snippet_\d+`.
   * Sort by number and build a labeled block:

     ```text
     snippet 1:

     [text]

     snippet 2:

     [text]
     ```

3. **Build the prompt (shared across models)**  
   Each judge sees the same input:
   * Citing **case name** + **summary**
   * Cited **case name** + **summary**
   * **Cited short citation** (first from `citation_pipe`)
   * **Snippets block**

4. **Call the three judges independently**

   * `classify_with_bedrock_mistral(...)`
   * `classify_with_bedrock_llama(...)`
   * `classify_with_bedrock_claude(...)`

   Each judge returns (or fails to return):

   ```json
   {
     "classification": "Positive|Neutral|Negative",
     "rationale": "Four sentences with one direct quote."
   }


If a judge fails (too long / JSON error / empty output, etc.), that model’s label is set to **`"Unknown"`** with a failure message in its per-model rationale.

5. **Ensemble / majority vote**

   * Take the three model labels and compute a **final global label**.
   * Combine the selected judges’ rationales into a single **global rationale**.

6. **Write back to Neo4j**

   For each edge, set:

   * Per-model:

     * `r.mistral_treatment_label`
     * `r.mistral_treatment_rationale`
     * `r.llama_treatment_label`
     * `r.llama_treatment_rationale`
     * `r.claude_treatment_label`
     * `r.claude_treatment_rationale`
   * Global (ensemble):

     * `r.treatment_label`          (final label)
     * `r.treatment_rationale`      (explanation + selected judges’ rationales)
     * `r.treatment_snippet`        (joined snippets)
     * `r.model_used`               (list of model display names)
     * `r.updated_at_utc`           (Neo4j `datetime()`)

7. **Already-labeled edges**

   * If `force=False` and `r.treatment_label` exists and is not `"Unknown"`, the edge is **skipped**.

8. **No snippets**

   * If the relation has **no `snippet_1..N`**, all judges are treated as `Unknown`.
   * The global label is set to **`"Unknown"`** with a rationale explaining that no snippets were available.

---

## Ensemble decision logic

Let **M**, **L**, **C** be the three model labels.

1. **Strict majority** (any label appears 2 or 3 times)

   * The global label is that majority label.

2. **All three different** (e.g., `Positive`, `Neutral`, `Negative`):

   * If **any** judge is `"Unknown"` → global label = **`"Unknown"`**.
   * If **none** are `"Unknown"` (e.g., one Positive, one Neutral, one Negative) → global label = **`"Neutral"`** by rule.

3. **Unknown-dominated**

   * If at least **two** judges return `"Unknown"` → global label = **`"Unknown"`**.
   * If exactly one judge is `"Unknown"` and the other two **disagree** → global label = **`"Unknown"`**.

4. **Global rationale content**

   * Starts with a short explanation of **why** the final label was chosen (majority vs tie-break rule).
   * Then includes **one or more judges’ rationales**:

     * If global label is **not** `"Unknown"` and not the all-different case:
       → include only the judges that voted for the final label.
     * If global label is `"Unknown"` or all three disagree:
       → include **all three** judges’ rationales (so you can inspect their views).

---

## Token budgeting & overflow handling

Each model has its own token budgeting step:

* Uses a shared tokenizer to estimate **system + user** tokens.
* If the prompt exceeds the context limit:

  * Compute the overhead of system + empty user prompt.
  * **Trim only the snippet block** down to the remaining allowance.
  * Rebuild the user prompt and recheck.
* If the prompt is **still too long** after trimming:

  * That model returns status **`"too_long"`** and is treated as `"Unknown"` for that edge.
  * The edge can still be labeled globally if the other two judges succeed.

Summaries and metadata are preserved; only snippets are shortened.

---

## Robust JSON parsing & retries (per model)

For each judge:

* Up to **3 attempts** per edge with **increasingly strict** system instructions:

  1. Normal instructions.
  2. “Output JSON only.”
  3. “Output only a JSON object with keys `classification` and `rationale`.”

* Parsing pipeline:

  * Clean output:

    * Strip backticks / ```json fences.
    * Normalize whitespace and smart quotes.
    * Remove trailing commas.
    * Optionally fix unescaped quotes inside the `rationale` string.
  * Try direct JSON parsing of the **whole string**.
  * If that fails, search for the **largest `{...}` block** and parse that.
  * As a last resort, use a regex-style extraction (for Llama) to salvage `classification` and `rationale`.

* Status codes used:

  * `"ok"`
  * `"json_parse_failed"`
  * `"bad_keys_or_values"`
  * `"too_long"`
  * `"api_error"`, `"api_error_throttled"`, `"api_response_error"`
  * `"empty_response"`

These statuses drive both the per-model `Unknown` handling and the `failed_citations.csv` log.

---

## Neo4j I/O

**Reads**

* `Q_PAGE_REL`:

  * Source and target IDs, names, decision dates, `citation_pipe`
  * `src_summary`, `tgt_summary`
  * Source URL (CourtListener or case URL)
  * `properties(r)` (for snippets and any existing fields)
  * Existing `r.treatment_label` (used to skip unless `force=True`)

* Snippet extraction:

  * Look for keys matching `snippet_\d+`, sort them numerically, and build the ordered list.

**Writes**

* `Q_WRITE_REL_ANNOT`:

  * Per-model labels / rationales (Mistral, Llama, Claude)
  * Global `treatment_label` + `treatment_rationale`
  * `treatment_snippet` (joined snippets)
  * `model_used` (list of model display names)
  * `updated_at_utc` timestamp

---

## CSV outputs (optional)

### Results CSV

With `results_csv=True`, the pipeline writes (by default):

* `edge_classifications_ensemble.csv`

Each row includes:

* Source / Target:

  * IDs, names, decision dates, `citation_pipe`
  * Summaries (source/target)
* Snippets and labels:

  * Joined `Opinion Snippet`
  * `Mistral Citation Evaluation`
  * `LLama Citation Evaluation`
  * `Claude Citation Evaluation`
  * `Global Citation Evaluation`
* Rationales:

  * `Mistral Rationale`
  * `LLama Rationale`
  * `Claude Rationale`
  * `Global Rationale`
* URL:

  * Source case URL (CourtListener if available)

You can choose to export only rows from the current run or replace `df_results` with **all labeled relations** via `show_all_labels_in_output_csv=True`.

### Failed classifications CSV

With `failed_csv=True`, the pipeline writes:

* `failed_citations.csv`

Each row captures:

* The first 11 columns (source/target metadata and `Opinion Snippet`).
* `Model Name`
* `Attempt` (1–3)
* `Input to LLM` (full prompt used for that attempt)
* `LLM Output` (raw text)
* `Type of Error` (human-readable status from `_describe_error`)

This file is useful to debug JSON issues, timeouts, and token overflows.

### Optional labeled dataset merge

If you pass a labeled CSV (human ground truth) via `append_to_labeled_dataset_csv`:

* The script:

  * Reads your labeled dataset (with expected columns like `Source ID`, `Target ID`, `Chunk`, `Label`, `Rationale`).
  * Joins it with the model outputs on `Source Case ID` and `Target Case ID`.
  * Writes a **model comparison** CSV (default: `"<stem> - model comparison.csv"`), which includes:

    * Human label vs **global model label** and per-model labels.
    * Human rationale vs per-model and global rationales.
    * Both the labeled chunk and the pipeline’s retrieved chunk.

---

## Key parameters

```python
label_all_citations(
    results_csv: bool = False,
    results_csv_filename: str = "edge_classifications_ensemble.csv",
    batch_size: int = 200,
    echo: bool = True,
    force: bool = False,
    append_to_labeled_dataset_csv: Optional[str] = None,
    labeled_output_csv: Optional[str] = None,   # default: "<labeled_csv_stem> - model comparison.csv"
    show_all_labels_in_output_csv: bool = False,
    failed_csv: bool = False,
)
```

* `results_csv` / `results_csv_filename`
  Control whether to write the main ensemble results CSV.
* `batch_size`
  Number of edges per Neo4j page.
* `echo`
  If `True`, prints per-edge summary lines and diagnostics.
* `force`
  If `True`, re-labels edges even if `treatment_label` already exists and is not `"Unknown"`.
* `append_to_labeled_dataset_csv` / `labeled_output_csv`
  Merge with a human-labeled dataset and export a comparison file.
* `show_all_labels_in_output_csv`
  If `True`, replace `df_results` with **all labeled edges** currently in Neo4j before writing `results_csv`.
* `failed_csv`
  If `True`, write `failed_citations.csv` with all non-OK LLM attempts.

---

## Environment

* Loads `../.env`. Required:

  * `NEO4J_URI`, `NEO4J_USERNAME`, `NEO4J_PASSWORD`
    (and optionally `NEO4J_DATABASE`, default `"neo4j"`)
  * `BEDROCK_REGION` (default `"us-east-1"`)

The model IDs for Mistral, Llama, and Claude are set at the top of the script.

---

## Quick start

```python
label_all_citations(
    force=False,
    echo=True,
    results_csv=True,
    results_csv_filename="edge_classifications_ensemble.csv",
    failed_csv=True,
)
```

This will:

* Page through all `CITES_TO` edges that are not already labeled (unless `force=True`).
* Run all **three** judges for each edge.
* Write per-model and global labels back to Neo4j.
* Export both:

  * `edge_classifications_ensemble.csv`
  * `failed_citations.csv` (for debugging)

---

## Troubleshooting

* **Many “Unknown” labels**

  * Check `failed_citations.csv` to see if failures come from:

    * Missing snippets (`missing_snippets`)
    * Token limits (`too_long`)
    * JSON parsing issues (`json_parse_failed`, `bad_keys_or_values`)
  * Verify that the Snippet Retriever has populated `snippet_1..N`.

* **Prompt too long**

  * Snippets are already auto-trimmed, but if you still hit `"too_long"`:

    * Reduce number/length of stored snippets per edge, or
    * Lower `max_new_tokens` for one or more models.

* **Frequent JSON parse errors**

  * Set `echo=True` and review the raw outputs.
  * Consider simplifying case summaries or snippet size if the models are over-responding.

* **Disagreement across models**

  * Use the global rationale plus per-model labels/rationales in the CSV to:

    * Inspect where models disagree.
    * Decide if you want to trust the majority rule or override with human review for those edges.


In [1]:
# Install (if applicable)
! pip install neo4j



In [2]:
! pip install tiktoken

Collecting tiktoken
  Using cached tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.7 kB)
Using cached tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl (1.2 MB)
Installing collected packages: tiktoken
Successfully installed tiktoken-0.12.0


In [3]:
import os, re, time, json, pathlib, logging, datetime as dt, random
from typing import List, Dict, Any, Tuple, Optional
from collections import Counter

import boto3
import pandas as pd
from neo4j import GraphDatabase
from dotenv import load_dotenv
from botocore.exceptions import ClientError, BotoCoreError
import tiktoken

# Quiet noisy logs (incl. Neo4j notifications/deprecations)
for _n in ("neo4j", "neo4j.notifications", "neo4j.work.simple"):
    logging.getLogger(_n).setLevel(logging.ERROR)
os.environ.setdefault("NEO4J_DRIVER_LOG_LEVEL", "ERROR")

# =========================
# Config / ENV
# =========================
BEDROCK_REGION        = os.getenv("BEDROCK_REGION", "us-east-1")
MISTRAL_MODEL_ID      = "mistral.mistral-7b-instruct-v0:2"
LLAMA_MODEL_ID        = "meta.llama3-70b-instruct-v1:0"
CLAUDE_MODEL_ID       = "anthropic.claude-3-5-sonnet-20240620-v1:0"

MISTRAL_DISPLAY_NAME  = "Mistral 7B Instruct v0.2 (mistral.mistral-7b-instruct-v0:2)"
LLAMA_DISPLAY_NAME    = "Llama 3 70B Instruct (meta.llama3-70b-instruct-v1:0)"
CLAUDE_DISPLAY_NAME   = "Claude 3.5 Sonnet (anthropic.claude-3-5-sonnet-20240620-v1:0)"

# .env is always one level up from this notebook
load_dotenv("../.env", override=True)
NEO4J_URI       = os.getenv("NEO4J_URI")
NEO4J_USERNAME  = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD  = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE  = os.getenv("NEO4J_DATABASE", "neo4j")

if not (NEO4J_URI and NEO4J_USERNAME and NEO4J_PASSWORD):
    raise RuntimeError("Missing Neo4j connection settings. Check ../.env for NEO4J_URI/NEO4J_USERNAME/NEO4J_PASSWORD.")

# =========================
# Tokenizer for prompt sizing (Claude-style; used for all models)
# =========================
_tok = tiktoken.get_encoding("cl100k_base")

def _count_tokens(text: str) -> int:
    if not text:
        return 0
    return len(_tok.encode(text))

def _trim_to_tokens(text: str, max_tokens: int) -> str:
    if max_tokens <= 0:
        return ""
    token_ids = _tok.encode(text)
    if len(token_ids) <= max_tokens:
        return text
    return _tok.decode(token_ids[:max_tokens])

def _max_ctx_tokens_for_bedrock(max_new_tokens: int) -> int:
    """
    Claude 3.5 Sonnet supports ~200k tokens of context.
    Use that as a safe upper bound for all three models and keep a buffer.
    """
    max_context = 200_000
    buffer_for_overhead = 2_000
    reserve_for_output  = max_new_tokens + 128
    return max_context - buffer_for_overhead - reserve_for_output

# =========================
# Bedrock client (lazy)
# =========================
_bedrock_client = None

def _bedrock():
    global _bedrock_client
    if _bedrock_client is None:
        _bedrock_client = boto3.client("bedrock-runtime", region_name=BEDROCK_REGION)
    return _bedrock_client

## Edge Classification Prompt (Snippet Method)

In [4]:
# =========================
# Prompts
# =========================
SYSTEM_PROMPT = """Role: You are an experienced lawyer specializing in legal citation analysis. 
Goal: Your goal is to classify how a citing case treats a cited case (e.g., Positive, Neutral, Negative) when given:
1.        Citing Case Name: the name of the citing case.
2.        Citing Case Summary: the summary of the citing case.
3.        Cited Case Name: the name of the cited case
4.        Cited Case Summary: the summary of the cited case.
5.        Snippets (from citing opinion where the cited case appears): List of text snippets where the cited case is cited in the citing opinion text 

CRITICAL RULES (must follow):
1. **If the citing case uses the cited case to establish or support ANY legal rule, doctrine, standard, test, or conclusion, classify as POSITIVE.**
   - This includes situations where the case:
     • recites a standard from the case,
     • cites the case as part of a string cite supporting a legal principle,
     • uses the case as an example consistent with its reasoning,
     • applies reasoning from the cited case.

   - IMPORTANT: The opinion **does NOT need to use words like “follow,” “adopt,” “agree,” or “apply.”**  
     Any supportive or explanatory use counts as Positive.

2. Classify as NEUTRAL **only when the case is mentioned without being used as authority**.
   - Examples:
     • descriptive background
     • illustrating a factual distinction
     • quoting language without using it to support a rule
     • noting procedural history

3. Classify as NEGATIVE when the citing court:
   - criticizes, limits, distinguishes, rejects, or declines to follow the cited case.

4. When writing the rationale:
   - Include ONE short direct quote from the citing opinion.
   - The quote must illustrate why the treatment is Positive, Neutral, or Negative.
   - Provide EXACTLY four sentences.

OUTPUT FORMAT (strict):
{
  "classification": "Positive|Neutral|Negative",
  "rationale": "Four sentences with one direct quote."
}
"""

USER_PROMPT_TMPL = """Now, classify how the citing case treats the cited case. Return your answer as a json format: {{classification: "", rationale: ""}} where label is one of Positive, Neutral, or Negative and rationale is a four-sentence explanation that justifies your classification
using evidence from the cited paragraph and case summaries.

Input:
Citing Case Name: {citing_case_name}
Citing Case Summary: {citing_case_summary}
Cited Case Name: {cited_case_name}
Cited Case Citation: {cited_case_citation}
Cited Case Summary: {cited_case_summary}

Snippets (from citing opinion where the cited case appears):
{snippet_block}
"""

In [5]:
def _inst(sys: str, usr: str) -> str:
    """Mistral-style chat wrapper."""
    return f"<s>[INST]{sys}\n{usr}[/INST]"

def _llama_prompt(sys: str, usr: str) -> str:
    """Llama 3 chat wrapper, Bedrock-compatible."""
    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
        f"{sys}\n"
        "<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
        f"{usr}\n"
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    )

def _format_snippet_block_labeled(snippets: List[str]) -> str:
    """
    Render as:
    snippet 1:

    [text]

    snippet 2:

    [text]
    """
    blocks = []
    for i, s in enumerate(snippets, 1):
        ss = (s or "").strip()
        if not ss:
            continue
        blocks.append(f"snippet {i}:\n\n{ss}")
    return "\n\n".join(blocks) if blocks else "snippet 1:\n\n[N/A]"

## Cypher Queries

In [6]:
# =========================
# Cypher queries
# =========================
COL_URL_HEADER = "Source Case CourtListener URL (may be incorrect for some cases)"

Q_PAGE_REL = """
MATCH (s:Case)-[r:CITES_TO]->(t:Case)
WHERE id(r) > $after_id
RETURN id(r) AS rel_id,
       s.id AS src_id, coalesce(s.name,'') AS src_name, s.decision_date AS src_date, s.citation_pipe AS src_cite,
       coalesce(s.court_listener_url, s.url, '') AS src_url,
       coalesce(s.opinion_summary,'') AS src_summary,
       t.id AS tgt_id, coalesce(t.name,'') AS tgt_name, t.decision_date AS tgt_date, t.citation_pipe AS tgt_cite,
       coalesce(t.opinion_summary,'') AS tgt_summary,
       properties(r) AS rel_props,
       r.treatment_label AS existing_label
ORDER BY rel_id
LIMIT $limit
"""

Q_WRITE_REL_ANNOT = """
MATCH (s:Case {id:$src_id})-[r:CITES_TO]->(t:Case {id:$tgt_id})
SET r.mistral_treatment_label     = $mistral_label,
    r.mistral_treatment_rationale = $mistral_rationale,
    r.llama_treatment_label       = $llama_label,
    r.llama_treatment_rationale   = $llama_rationale,
    r.claude_treatment_label      = $claude_label,
    r.claude_treatment_rationale  = $claude_rationale,
    r.treatment_label             = $label,
    r.treatment_rationale         = $rationale,
    r.treatment_snippet           = $snippet_joined,
    r.model_used                  = $model_used,
    r.updated_at_utc              = datetime()
RETURN count(r) AS updated
"""

Q_COUNT_LABELS = """
MATCH ()-[r:CITES_TO]->()
RETURN
  sum(CASE WHEN r.treatment_label = 'Positive' THEN 1 ELSE 0 END) AS pos_total,
  sum(CASE WHEN r.treatment_label = 'Neutral'  THEN 1 ELSE 0 END) AS neu_total,
  sum(CASE WHEN r.treatment_label = 'Negative' THEN 1 ELSE 0 END) AS neg_total,
  sum(CASE WHEN r.treatment_label = 'Unknown'  THEN 1 ELSE 0 END) AS unk_total
"""

Q_ALL_LABELED_REL = """
MATCH (s:Case)-[r:CITES_TO]->(t:Case)
WHERE r.treatment_label IS NOT NULL
RETURN
  s.id AS src_id,
  coalesce(s.name,'') AS src_name,
  s.decision_date AS src_date,
  s.citation_pipe AS src_cite,
  coalesce(s.court_listener_url, s.url, '') AS src_url,
  coalesce(s.opinion_summary,'') AS src_summary,
  t.id AS tgt_id,
  coalesce(t.name,'') AS tgt_name,
  t.decision_date AS tgt_date,
  t.citation_pipe AS tgt_cite,
  coalesce(t.opinion_summary,'') AS tgt_summary,
  coalesce(r.treatment_snippet,'') AS snippet_joined,
  coalesce(r.mistral_treatment_label,'') AS mistral_label,
  coalesce(r.llama_treatment_label,'')  AS llama_label,
  coalesce(r.claude_treatment_label,'') AS claude_label,
  r.treatment_label AS treatment_label,
  coalesce(r.mistral_treatment_rationale,'') AS mistral_rationale,
  coalesce(r.llama_treatment_rationale,'')  AS llama_rationale,
  coalesce(r.claude_treatment_rationale,'') AS claude_rationale,
  coalesce(r.treatment_rationale,'')        AS treatment_rationale
"""


## Helpers

In [7]:
# =========================
# Helpers
# =========================
def _first_citation(citation_pipe: Optional[str]) -> str:
    if not citation_pipe:
        return ""
    parts = [p.strip() for p in re.split(r'[|;]+', citation_pipe) if p.strip()]
    return parts[0] if parts else ""

def _extract_numbered_snippets(rel_props: Dict[str, Any]) -> List[str]:
    """
    Pull snippet_1, snippet_2, ... from properties(r) and return as a list
    sorted by numeric suffix. Skip empty/whitespace-only.
    """
    out: List[Tuple[int, str]] = []
    for k, v in (rel_props or {}).items():
        m = re.fullmatch(r"snippet_(\d+)", k)
        if not m:
            continue
        try:
            idx = int(m.group(1))
        except ValueError:
            continue
        text = str(v or "").strip()
        if text:
            out.append((idx, text))
    out.sort(key=lambda p: p[0])
    return [t for _, t in out]

def _read_csv_fallback(path: str) -> pd.DataFrame:
    encodings = ["utf-8", "utf-8-sig", "cp1252", "latin-1"]
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    try:
        with open(path, "r", encoding="cp1252", errors="replace") as fh:
            return pd.read_csv(fh)
    except Exception:
        pass
    raise last_err if last_err else RuntimeError("Failed to read CSV with fallback encodings")

def _ordinal_word(n: int) -> str:
    return {1: "first", 2: "second", 3: "third"}.get(n, f"{n}th")

In [8]:
# =========================
# JSON cleaning / normalization (improved)
# =========================
_BACKTICKS_RE = re.compile(r"^```(?:json)?|```$", re.MULTILINE)

def _clean_json(s: str) -> str:
    """
    Clean common formatting issues:
      - remove ```json fences
      - normalize smart quotes
      - strip a single outer pair of quotes if the whole thing is one big JSON string
      - remove trailing commas before } or ]
    """
    if not isinstance(s, str):
        s = str(s)

    s = _BACKTICKS_RE.sub("", s).strip()
    s = s.replace("“", '"').replace("”", '"').replace("’", "'")

    # If the whole thing is wrapped in one pair of quotes, strip them
    if len(s) >= 2 and s[0] == s[-1] == '"':
        inner = s[1:-1].strip()
        if ("{" in inner and "}" in inner) or ("[" in inner and "]" in inner):
            s = inner

    # Remove trailing commas before object/array close
    s = re.sub(r",\s*([}\]])", r"\1", s)
    return s

def _normalize_from_js(js: Any) -> Tuple[Optional[str], Optional[str]]:
    """
    Given a parsed JSON value (dict or string), try to extract:
      classification ∈ {"Positive","Neutral","Negative"}
      rationale: non-empty string
    Handles nested or stringified JSON inside the 'rationale' field.
    """
    valid_labels = {"Positive", "Neutral", "Negative"}

    # Case 1: dict at top level
    if isinstance(js, dict):
        c = js.get("classification") or js.get("label")
        r = js.get("rationale")

        # If rationale itself is another dict, try inner fields
        if isinstance(r, dict):
            inner_c = r.get("classification") or r.get("label")
            inner_r = r.get("rationale")
            if inner_c in valid_labels and isinstance(inner_r, str) and inner_r.strip():
                return inner_c, inner_r.strip()

        # If rationale is a string that looks like JSON, try to parse it
        if isinstance(r, str):
            r_str = r.strip()
            if len(r_str) >= 2 and r_str[0] == r_str[-1] == '"':
                r_str = r_str[1:-1].strip()
            if r_str.startswith("{") and "classification" in r_str and "rationale" in r_str:
                try:
                    inner_js = json.loads(_clean_json(r_str))
                    inner_c, inner_r = _normalize_from_js(inner_js)
                    if inner_c and inner_r:
                        return inner_c, inner_r
                except Exception:
                    pass

        # Normal case: classification + plain rationale
        if c in valid_labels and isinstance(r, str) and r.strip():
            return c, r.strip()

        return None, None

    # Case 2: top-level string containing JSON
    if isinstance(js, str):
        txt = js.strip()
        if len(txt) >= 2 and txt[0] == txt[-1] == '"':
            txt = txt[1:-1].strip()
        if txt.startswith("{") and "classification" in txt and "rationale" in txt:
            try:
                inner_js = json.loads(_clean_json(txt))
                return _normalize_from_js(inner_js)
            except Exception:
                return None, None

    return None, None

def _extract_rationale_between_markers(s: str) -> Optional[str]:
    """
    Heuristic rationale extractor that ignores JSON validity.

    Strategy:
      - Find `"rationale"` key
      - From the colon, find the first `"` that starts the string
      - Then take everything up to the LAST `"` before the closing `}`

    This allows inner unescaped " characters inside the rationale, since
    we are not trying to parse the JSON string literally.
    """
    s_lower = s.lower()
    key_idx = s_lower.find('"rationale"')
    if key_idx == -1:
        return None

    colon_idx = s.find(":", key_idx)
    if colon_idx == -1:
        return None

    # First quote after the colon = start of rationale string
    first_quote = s.find('"', colon_idx)
    if first_quote == -1:
        return None

    # Try to limit search to this JSON object: last " before the last }
    last_brace = s.rfind("}")
    if last_brace == -1:
        search_end = len(s)
    else:
        search_end = last_brace

    last_quote = s.rfind('"', first_quote + 1, search_end)
    if last_quote == -1:
        # Fallback: last quote in the whole string
        last_quote = s.rfind('"', first_quote + 1)
        if last_quote == -1:
            return None

    rationale = s[first_quote + 1:last_quote]

    # Simple cleanup of common escape sequences
    rationale = rationale.replace('\\"', '"').replace("\\n", " ").replace("\\r", " ")
    rationale = rationale.strip()
    return rationale or None

def _extract_classification_and_rationale(raw_text: str) -> Tuple[Optional[str], Optional[str], str]:
    """
    Extract classification and rationale from Claude output.

    New behavior:
      - Classification is read via regex from "classification" / "label".
      - Rationale is extracted heuristically as everything between
        `"rationale": "` and the closing `}` of the JSON object, so it
        is not cut off by inner unescaped double quotes.

    Fallback:
      - If this fails, we still try a simple json.loads() on the cleaned
        text and read classification / rationale from there.
    """
    if not raw_text or not str(raw_text).strip():
        return None, None, "empty_response"

    # Normalize markdown fences, newlines, smart quotes, etc.
    s = _clean_json(str(raw_text).strip())
    valid_labels = {"Positive", "Neutral", "Negative"}

    # ---- 1) Classification via regex (safe; label string has no inner quotes) ----
    label_match = re.search(
        r'"(?:classification|label)"\s*:\s*"(?P<label>Positive|Neutral|Negative)"',
        s,
        flags=re.IGNORECASE,
    )
    label: Optional[str] = None
    if label_match:
        label = label_match.group("label").capitalize()
        if label not in valid_labels:
            label = None

    # ---- 2) Rationale via heuristic between "rationale" and closing brace ----
    rationale = _extract_rationale_between_markers(s)

    if label and rationale:
        return label, rationale, "ok"

    # ---- 3) Simple JSON fallback (for weird but valid JSON cases) ----
    try:
        js = json.loads(s)
        c = js.get("classification") or js.get("label")
        r = js.get("rationale")
        if c in valid_labels and isinstance(r, str) and r.strip():
            return c, r.strip(), "ok"
        return None, None, "bad_keys_or_values"
    except Exception:
        return None, None, "json_parse_failed"


def _describe_error(status: str) -> str:
    mapping = {
        "too_long": "too_long: prompt exceeded maximum context tokens",
        "json_parse_failed": "json_parse_failed: could not parse LLM output as JSON",
        "bad_keys_or_values": "bad_keys_or_values: missing or invalid classification/rationale in JSON",
        "api_error_throttled": "api_error_throttled: Bedrock throttled the request",
        "api_error": "api_error: generic AWS/Bedrock client or service error",
        "api_response_error": "api_response_error: malformed or unexpected API response",
        "empty_response": "empty_response: model returned empty text",
        "missing_snippets": "missing_snippets: no snippets available on the relation (snippet_1..N).",
        "ok": "ok: successful classification",
    }
    return mapping.get(status, status)

## LLM Classify

In [9]:
# =========================
# Per-model classify functions
# =========================
def classify_with_bedrock_claude(
    *,
    citing_case_name: str,
    citing_case_summary: str,
    cited_case_name: str,
    cited_case_citation: str,
    cited_case_summary: str,
    snippet_block_labeled: str,
    max_new_tokens: int = 512,
    retries: int = 3,
    echo: bool = False,
) -> Tuple[
    Optional[str], Optional[str], str, str, int,
    List[Dict[str, Any]], str
]:
    """
    Claude judge.
    Returns classification, rationale, status, raw_output, attempt_used, attempt_logs, user_input.
    status: "ok", "json_parse_failed", "bad_keys_or_values", "too_long",
            "api_error", "api_error_throttled", "api_response_error", "empty_response"
    """
    client = _bedrock()

    user = USER_PROMPT_TMPL.format(
        citing_case_name=(citing_case_name or "").strip(),
        citing_case_summary=(citing_case_summary or "").strip(),
        cited_case_name=(cited_case_name or "").strip(),
        cited_case_citation=(cited_case_citation or "").strip(),
        cited_case_summary=(cited_case_summary or "").strip(),
        snippet_block=snippet_block_labeled,
    )

    attempt_logs: List[Dict[str, Any]] = []

    # Token budget
    max_in = _max_ctx_tokens_for_bedrock(max_new_tokens)
    used = _count_tokens(SYSTEM_PROMPT + "\n\n" + user)

    if used > max_in:
        overage = used - max_in
        snippet_tokens = _count_tokens(snippet_block_labeled)
        target_snippet_tokens = max(512, snippet_tokens - int(overage * 1.2))
        trimmed = _trim_to_tokens(snippet_block_labeled, target_snippet_tokens)

        user = USER_PROMPT_TMPL.format(
            citing_case_name=(citing_case_name or "").strip(),
            citing_case_summary=(citing_case_summary or "").strip(),
            cited_case_name=(cited_case_name or "").strip(),
            cited_case_citation=(cited_case_citation or "").strip(),
            cited_case_summary=(cited_case_summary or "").strip(),
            snippet_block=trimmed,
        )
        used = _count_tokens(SYSTEM_PROMPT + "\n\n" + user)
        if used > max_in:
            if echo:
                print(f"      · [Claude] prompt still too long after trim ({used} > {max_in}); giving up")
            attempt_logs.append({
                "attempt": 1,
                "status": "too_long",
                "raw": "",
                "input": user,
            })
            return None, None, "too_long", "", 1, attempt_logs, user

    last_err = "json_parse_failed"
    last_raw = ""

    for attempt in range(1, retries + 1):
        if attempt == 1:
            sys = SYSTEM_PROMPT
        elif attempt == 2:
            sys = SYSTEM_PROMPT + "\n\nSTRICT: Output JSON only. No prose, no backticks."
        else:
            sys = SYSTEM_PROMPT + '\n\nSTRICT: Output ONLY a JSON object with keys "classification" and "rationale". Use double quotes. No trailing commas. No markdown.'

        if echo:
            print(f"      · [Claude] attempt {attempt}/{retries}")

        time.sleep(0.5)

        try:
            resp = client.invoke_model(
                modelId=CLAUDE_MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": max_new_tokens,
                    "temperature": 0.0,
                    "system": sys,
                    "messages": [
                        {"role": "user", "content": user}
                    ],
                }),
                contentType="application/json",
                accept="application/json",
            )

            raw_body = resp["body"].read()
            try:
                body = json.loads(raw_body)
            except Exception as e:
                if echo:
                    print(f"      · [Claude] failed to decode response JSON: {e}")
                last_err = "api_response_error"
                last_raw = raw_body.decode("utf-8", errors="replace") if isinstance(raw_body, (bytes, bytearray)) else str(raw_body)
                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.0 * attempt)
                continue

            try:
                text = body["content"][0]["text"]
            except (KeyError, IndexError, TypeError) as e:
                if echo:
                    print(f"      · [Claude] unexpected response structure: {e}")
                last_err = "api_response_error"
                last_raw = json.dumps(body)
                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.0 * attempt)
                continue

            last_raw = (text or "").strip()
            if not last_raw:
                if echo:
                    print("      · [Claude] empty response from model")
                status = "empty_response"
                last_err = status
                attempt_logs.append({
                    "attempt": attempt,
                    "status": status,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.0 * attempt)
                continue

            c, r, status = _extract_classification_and_rationale(last_raw)
            attempt_logs.append({
                "attempt": attempt,
                "status": status,
                "raw": last_raw,
                "input": user,
            })
            if status == "ok" and c and r:
                return c, r, "ok", last_raw, attempt, attempt_logs, user

            last_err = status
            time.sleep(1.0 * attempt)
            continue

        except ClientError as e:
            code = e.response.get("Error", {}).get("Code", "")
            msg = e.response.get("Error", {}).get("Message", "")
            if echo:
                print(f"      · [Claude] API ClientError on attempt {attempt}/{retries}: {code} - {msg}")

            msg_lower = msg.lower() if isinstance(msg, str) else ""
            if "token" in msg_lower and ("length" in msg_lower or "limit" in msg_lower):
                last_err = "too_long"
            elif code in ("ThrottlingException", "TooManyRequestsException") or "rate exceeded" in msg_lower:
                last_err = "api_error_throttled"
            else:
                last_err = "api_error"

            last_raw = msg
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })

            if last_err == "api_error_throttled":
                backoff = min(60.0, float(2 ** attempt))
                jitter = random.uniform(0.0, 0.5)
                if echo:
                    print(f"      · [Claude] throttled, backing off for {backoff + jitter:.2f} seconds")
                time.sleep(backoff + jitter)
            else:
                time.sleep(1.5 * attempt)
            continue

        except BotoCoreError as e:
            if echo:
                print(f"      · [Claude] BotoCoreError on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.0 * attempt)
            continue

        except Exception as e:
            if echo:
                print(f"      · [Claude] API error on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.0 * attempt)
            continue

    return None, None, last_err, last_raw, attempt, attempt_logs, user


def classify_with_bedrock_mistral(
    *,
    citing_case_name: str,
    citing_case_summary: str,
    cited_case_name: str,
    cited_case_citation: str,
    cited_case_summary: str,
    snippet_block_labeled: str,
    max_new_tokens: int = 320,
    retries: int = 3,
    echo: bool = False,
) -> Tuple[
    Optional[str], Optional[str], str, str, int,
    List[Dict[str, Any]], str
]:
    """
    Mistral judge (text-generation format).
    """
    client = _bedrock()

    user = USER_PROMPT_TMPL.format(
        citing_case_name=(citing_case_name or "").strip(),
        citing_case_summary=(citing_case_summary or "").strip(),
        cited_case_name=(cited_case_name or "").strip(),
        cited_case_citation=(cited_case_citation or "").strip(),
        cited_case_summary=(cited_case_summary or "").strip(),
        snippet_block=snippet_block_labeled,
    )

    attempt_logs: List[Dict[str, Any]] = []

    prompt = _inst(SYSTEM_PROMPT, user)
    max_in = _max_ctx_tokens_for_bedrock(max_new_tokens)
    used = _count_tokens(prompt)

    if used > max_in:
        empty_prompt = USER_PROMPT_TMPL.format(
            citing_case_name="",
            citing_case_summary="",
            cited_case_name="",
            cited_case_citation="",
            cited_case_summary="",
            snippet_block="",
        )
        allowance_for_user = max_in - _count_tokens(_inst(SYSTEM_PROMPT, empty_prompt))
        allowance_for_user = max(512, allowance_for_user)
        trimmed = _trim_to_tokens(snippet_block_labeled, allowance_for_user)
        user = USER_PROMPT_TMPL.format(
            citing_case_name=(citing_case_name or "").strip(),
            citing_case_summary=(citing_case_summary or "").strip(),
            cited_case_name=(cited_case_name or "").strip(),
            cited_case_citation=(cited_case_citation or "").strip(),
            cited_case_summary=(cited_case_summary or "").strip(),
            snippet_block=trimmed,
        )
        prompt = _inst(SYSTEM_PROMPT, user)
        used = _count_tokens(prompt)
        if used > max_in:
            if echo:
                print(f"      · [Mistral] prompt still too long after trim ({used} > {max_in}); giving up")
            attempt_logs.append({
                "attempt": 1,
                "status": "too_long",
                "raw": "",
                "input": prompt,
            })
            return None, None, "too_long", "", 1, attempt_logs, user

    base_payload = {
        "max_tokens": max_new_tokens,
        "temperature": 0,
    }

    last_err = "json_parse_failed"
    last_raw = ""

    for attempt in range(1, retries + 1):
        if attempt == 1:
            sys = SYSTEM_PROMPT
        elif attempt == 2:
            sys = SYSTEM_PROMPT + "\n\nSTRICT: Output JSON only. No prose, no backticks."
        else:
            sys = SYSTEM_PROMPT + "\n\nSTRICT: JSON OBJECT ONLY with keys exactly {'classification','rationale'}. Double quotes. No trailing commas."

        eff_prompt = _inst(sys, user)

        if echo:
            print(f"      · [Mistral] attempt {attempt}/{retries}")

        time.sleep(0.5)

        try:
            resp = client.invoke_model(
                modelId=MISTRAL_MODEL_ID,
                body=json.dumps({"prompt": eff_prompt, **base_payload}),
                contentType="application/json",
                accept="application/json",
            )

            raw_body = resp["body"].read()
            try:
                body = json.loads(raw_body)
            except Exception as e:
                if echo:
                    print(f"      · [Mistral] failed to decode response JSON: {e}")
                last_err = "api_response_error"
                last_raw = raw_body.decode("utf-8", errors="replace") if isinstance(raw_body, (bytes, bytearray)) else str(raw_body)
                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": eff_prompt,
                })
                time.sleep(1.0 * attempt)
                continue

            try:
                text = (body.get("outputs") or [{}])[0].get("text", "")
            except Exception as e:
                if echo:
                    print(f"      · [Mistral] unexpected response structure: {e}")
                last_err = "api_response_error"
                last_raw = json.dumps(body)
                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": eff_prompt,
                })
                time.sleep(1.0 * attempt)
                continue

            last_raw = (text or "").strip()
            if not last_raw:
                if echo:
                    print("      · [Mistral] empty response from model")
                status = "empty_response"
                last_err = status
                attempt_logs.append({
                    "attempt": attempt,
                    "status": status,
                    "raw": last_raw,
                    "input": eff_prompt,
                })
                time.sleep(1.0 * attempt)
                continue

            c, r, status = _extract_classification_and_rationale(last_raw)
            attempt_logs.append({
                "attempt": attempt,
                "status": status,
                "raw": last_raw,
                "input": eff_prompt,
            })
            if status == "ok" and c and r:
                return c, r, "ok", last_raw, attempt, attempt_logs, user

            last_err = status
            time.sleep(1.0 * attempt)
            continue

        except ClientError as e:
            code = e.response.get("Error", {}).get("Code", "")
            msg = e.response.get("Error", {}).get("Message", "")
            if echo:
                print(f"      · [Mistral] API ClientError on attempt {attempt}/{retries}: {code} - {msg}")

            msg_lower = msg.lower() if isinstance(msg, str) else ""
            if "token" in msg_lower and ("length" in msg_lower or "limit" in msg_lower):
                last_err = "too_long"
            elif code in ("ThrottlingException", "TooManyRequestsException") or "rate exceeded" in msg_lower:
                last_err = "api_error_throttled"
            else:
                last_err = "api_error"

            last_raw = msg
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": eff_prompt,
            })

            if last_err == "api_error_throttled":
                backoff = min(60.0, float(2 ** attempt))
                jitter = random.uniform(0.0, 0.5)
                if echo:
                    print(f"      · [Mistral] throttled, backing off for {backoff + jitter:.2f} seconds")
                time.sleep(backoff + jitter)
            else:
                time.sleep(1.5 * attempt)
            continue

        except BotoCoreError as e:
            if echo:
                print(f"      · [Mistral] BotoCoreError on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": eff_prompt,
            })
            time.sleep(1.0 * attempt)
            continue

        except Exception as e:
            if echo:
                print(f"      · [Mistral] API error on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": eff_prompt,
            })
            time.sleep(1.0 * attempt)
            continue

    return None, None, last_err, last_raw, attempt, attempt_logs, user


def classify_with_bedrock_llama(
    *,
    citing_case_name: str,
    citing_case_summary: str,
    cited_case_name: str,
    cited_case_citation: str,
    cited_case_summary: str,
    snippet_block_labeled: str,
    max_new_tokens: int = 512,
    retries: int = 3,
    echo: bool = False,
) -> Tuple[
    Optional[str], Optional[str], str, str, int,
    List[Dict[str, Any]], str
]:
    """
    Llama 3 70B judge (Bedrock).
    """
    client = _bedrock()

    user = USER_PROMPT_TMPL.format(
        citing_case_name=(citing_case_name or "").strip(),
        citing_case_summary=(citing_case_summary or "").strip(),
        cited_case_name=(cited_case_name or "").strip(),
        cited_case_citation=(cited_case_citation or "").strip(),
        cited_case_summary=(cited_case_summary or "").strip(),
        snippet_block=snippet_block_labeled,
    )

    attempt_logs: List[Dict[str, Any]] = []

    base_prompt = _llama_prompt(SYSTEM_PROMPT, user)
    max_in = _max_ctx_tokens_for_bedrock(max_new_tokens)
    used = _count_tokens(base_prompt)

    if used > max_in:
        empty_user = USER_PROMPT_TMPL.format(
            citing_case_name="", citing_case_summary="",
            cited_case_name="", cited_case_citation="", cited_case_summary="",
            snippet_block="",
        )
        overhead_tokens = _count_tokens(_llama_prompt(SYSTEM_PROMPT, empty_user))
        allowance_for_user = max(512, max_in - overhead_tokens)

        snippet_tokens = _count_tokens(snippet_block_labeled)
        target_snippet_tokens = max(128, min(snippet_tokens, allowance_for_user))
        trimmed = _trim_to_tokens(snippet_block_labeled, target_snippet_tokens)

        user = USER_PROMPT_TMPL.format(
            citing_case_name=(citing_case_name or "").strip(),
            citing_case_summary=(citing_case_summary or "").strip(),
            cited_case_name=(cited_case_name or "").strip(),
            cited_case_citation=(cited_case_citation or "").strip(),
            cited_case_summary=(cited_case_summary or "").strip(),
            snippet_block=trimmed,
        )
        base_prompt = _llama_prompt(SYSTEM_PROMPT, user)
        used = _count_tokens(base_prompt)
        if used > max_in:
            if echo:
                print(f"      · [Llama] prompt still too long after trim ({used} > {max_in}); giving up")
            attempt_logs.append({
                "attempt": 1,
                "status": "too_long",
                "raw": "",
                "input": user,
            })
            return None, None, "too_long", "", 1, attempt_logs, user

    last_err = "json_parse_failed"
    last_raw = ""

    for attempt in range(1, retries + 1):
        if attempt == 1:
            sys = SYSTEM_PROMPT
        elif attempt == 2:
            sys = SYSTEM_PROMPT + "\n\nSTRICT: Output JSON only. No prose, no backticks."
        else:
            sys = SYSTEM_PROMPT + '\n\nSTRICT: Output ONLY a JSON object with keys "classification" and "rationale". Use double quotes. No trailing commas. No markdown.'

        prompt = _llama_prompt(sys, user)

        if echo:
            print(f"      · [Llama] attempt {attempt}/{retries}")

        time.sleep(0.75)

        try:
            resp = client.invoke_model(
                modelId=LLAMA_MODEL_ID,
                body=json.dumps({
                    "prompt": prompt,
                    "max_gen_len": max_new_tokens,
                    "temperature": 0.0,
                    "top_p": 0.9,
                }),
                contentType="application/json",
                accept="application/json",
            )

            raw_body = resp["body"].read()
            try:
                body = json.loads(raw_body)
            except Exception as e:
                if echo:
                    print(f"      · [Llama] failed to decode response JSON: {e}")
                last_err = "api_response_error"
                last_raw = raw_body.decode("utf-8", errors="replace") if isinstance(raw_body, (bytes, bytearray)) else str(raw_body)
                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.5 * attempt)
                continue

            text = ""
            if isinstance(body, dict):
                if isinstance(body.get("generation"), str):
                    text = body["generation"]
                elif isinstance(body.get("output_text"), str):
                    text = body["output_text"]
                elif isinstance(body.get("outputs"), list) and body["outputs"]:
                    first = body["outputs"][0]
                    if isinstance(first, dict):
                        text = first.get("text") or first.get("output_text", "") or ""
                    elif isinstance(first, str):
                        text = first
                elif isinstance(body.get("output"), str):
                    text = body["output"]

            last_raw = (text or "").strip()
            if not last_raw:
                if echo:
                    print("      · [Llama] empty response from model")
                status = "empty_response"
                last_err = status
                attempt_logs.append({
                    "attempt": attempt,
                    "status": status,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.5 * attempt)
                continue

            c, r, status = _extract_classification_and_rationale(last_raw)
            attempt_logs.append({
                "attempt": attempt,
                "status": status,
                "raw": last_raw,
                "input": user,
            })
            if status == "ok" and c and r:
                return c, r, "ok", last_raw, attempt, attempt_logs, user

            last_err = status
            time.sleep(1.5 * attempt)
            continue

        except ClientError as e:
            code = e.response.get("Error", {}).get("Code", "")
            msg = e.response.get("Error", {}).get("Message", "")
            if echo:
                print(f"      · [Llama] API ClientError on attempt {attempt}/{retries}: {code} - {msg}")

            msg_lower = msg.lower() if isinstance(msg, str) else ""
            if "token" in msg_lower and ("length" in msg_lower or "limit" in msg_lower):
                last_err = "too_long"
            elif code in ("ThrottlingException", "TooManyRequestsException") or "rate exceeded" in msg_lower:
                last_err = "api_error_throttled"
            else:
                last_err = "api_error"

            last_raw = msg
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })

            if last_err == "api_error_throttled":
                backoff = min(60.0, float(2 ** attempt))
                jitter = random.uniform(0.0, 0.5)
                if echo:
                    print(f"      · [Llama] throttled, backing off for {backoff + jitter:.2f} seconds")
                time.sleep(backoff + jitter)
            else:
                time.sleep(2.0 * attempt)
            continue

        except BotoCoreError as e:
            if echo:
                print(f"      · [Llama] BotoCoreError on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.5 * attempt)
            continue

        except Exception as e:
            if echo:
                print(f"      · [Llama] API error on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.5 * attempt)
            continue

    return None, None, last_err, last_raw, attempt, attempt_logs, user

## Full-dataset fetch for CSV Outputs

In [10]:
# =========================
# Full-dataset fetch for CSV outputs
# =========================
def _fetch_all_labeled_results(driver) -> pd.DataFrame:
    """
    Load all currently labeled CITES_TO edges from Neo4j
    and shape them like df_results, with per-model columns.
    """
    with driver.session(database=NEO4J_DATABASE) as s_all:
        data = s_all.run(Q_ALL_LABELED_REL, {}).data()

    rows_all: List[Dict[str, Any]] = []
    for row in data:
        rows_all.append({
            "Source Case ID": row["src_id"],
            "Source Case Name": row["src_name"] or "",
            "Source Case Decision Date": row.get("src_date") or "",
            "Source Case citation_pipe": row.get("src_cite") or "",
            "Source Case Summary": row.get("src_summary") or "",
            "Target Case ID": row["tgt_id"],
            "Target Case Name": row["tgt_name"] or "",
            "Target Case Decision Date": row.get("tgt_date") or "",
            "Target Case citation_pipe": row.get("tgt_cite") or "",
            "Target Case Summary": row.get("tgt_summary") or "",
            "Opinion Snippet": row.get("snippet_joined") or "",
            "Mistral Citation Evaluation": row.get("mistral_label") or "",
            "LLama Citation Evaluation": row.get("llama_label") or "",
            "Claude Citation Evaluation": row.get("claude_label") or "",
            "Global Citation Evaluation": row.get("treatment_label") or "",
            "Mistral Rationale": row.get("mistral_rationale") or "",
            "LLama Rationale": row.get("llama_rationale") or "",
            "Claude Rationale": row.get("claude_rationale") or "",
            "Global Rationale": row.get("treatment_rationale") or "",
            COL_URL_HEADER: row.get("src_url") or "",
        })
    return pd.DataFrame(rows_all)


## Majority vote helper

In [11]:
# =========================
# Majority vote helper
# =========================
def _compute_final_label(m_label: str, l_label: str, c_label: str) -> str:
    labels = [m_label, l_label, c_label]
    cnt = Counter(labels)
    majority_label, majority_count = cnt.most_common(1)[0]

    # If there is a strict majority (2 or 3 of the same), return that value
    if majority_count >= 2:
        return majority_label

    # No majority: all three different
    if "Unknown" in cnt:
        # One model Unknown and other two disagree -> Unknown
        return "Unknown"

    # All three different and none Unknown -> Neutral by rule
    return "Neutral"

## Batch Driver

In [12]:
# =========================
# Batch driver
# =========================
def label_all_citations(
    *,
    results_csv: bool = False,
    results_csv_filename: str = "edge_classifications_ensemble.csv",
    batch_size: int = 200,
    echo: bool = True,
    force: bool = False,
    append_to_labeled_dataset_csv: Optional[str] = None,
    labeled_output_csv: Optional[str] = None,   # default = "<labeled_csv_stem> - model comparison.csv"
    show_all_labels_in_output_csv: bool = False,
    failed_csv: bool = False,
):
    """
    Ensemble Edge Classifier (3 judges: Mistral, Llama, Claude).

    For each CITES_TO relation:
      - Run all three models.
      - Store per-model label/rationale.
      - Compute final treatment_label by majority vote with the rules given.
      - Build treatment_rationale explaining the decision and including
        majority judges' rationales (or all three when needed).
    """
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

    # Session with DB name; disable notifications where supported
    try:
        session = driver.session(database=NEO4J_DATABASE, notifications_min_severity="OFF")
    except TypeError:
        session = driver.session(database=NEO4J_DATABASE)

    rows_out: List[Dict[str, Any]] = []
    failed_rows: List[Dict[str, Any]] = []
    after = -1
    processed = 0
    t0 = time.time()

    # diagnostics (global)
    ok_rel = 0
    missing_snippets = 0

    # classification tallies (global final label)
    pos_cnt = neu_cnt = neg_cnt = unk_cnt = 0

    # per-model error counters (optional diagnostics)
    model_error_counts = {
        "Mistral": {"too_long": 0, "json_parse_failed": 0, "bad_keys_or_values": 0, "api_error": 0, "other": 0},
        "Llama":   {"too_long": 0, "json_parse_failed": 0, "bad_keys_or_values": 0, "api_error": 0, "other": 0},
        "Claude":  {"too_long": 0, "json_parse_failed": 0, "bad_keys_or_values": 0, "api_error": 0, "other": 0},
    }

    # First 11 columns for failed_citations.csv
    first_11_cols = [
        "Source Case ID",
        "Source Case Name",
        "Source Case Decision Date",
        "Source Case citation_pipe",
        "Source Case Summary",
        "Target Case ID",
        "Target Case Name",
        "Target Case Decision Date",
        "Target Case citation_pipe",
        "Target Case Summary",
        "Opinion Snippet",
    ]

    with session as s:
        while True:
            rels = s.run(Q_PAGE_REL, {"after_id": after, "limit": batch_size}).data()
            if not rels:
                if echo:
                    print("All done. No more relations.")
                break

            if echo:
                print(f"\nBatch after rel_id {after}: {len(rels)} relation(s)")

            for row in rels:
                after = row["rel_id"]

                existing_label = (row.get("existing_label") or "").strip()
                if (not force) and existing_label and existing_label.lower() != "unknown":
                    # already globally labeled → skip
                    continue

                src_name = row["src_name"] or ""
                tgt_name = row["tgt_name"] or ""
                src_sum  = row.get("src_summary") or ""
                tgt_sum  = row.get("tgt_summary") or ""
                tgt_cit  = _first_citation(row.get("tgt_cite") or "")

                # Gather snippets
                snippets = _extract_numbered_snippets(row.get("rel_props") or {})
                common_row = {
                    "Source Case ID": row["src_id"],
                    "Source Case Name": src_name,
                    "Source Case Decision Date": row.get("src_date") or "",
                    "Source Case citation_pipe": row.get("src_cite") or "",
                    "Source Case Summary": src_sum,
                    "Target Case ID": row["tgt_id"],
                    "Target Case Name": tgt_name,
                    "Target Case Decision Date": row.get("tgt_date") or "",
                    "Target Case citation_pipe": row.get("tgt_cite") or "",
                    "Target Case Summary": tgt_sum,
                }

                if not snippets:
                    # No snippets: everything Unknown and explanation
                    missing_snippets += 1
                    global_label = "Unknown"
                    explanation = "The final treatment is 'Unknown' because no snippets were available on the relation (snippet_1..N)."

                    mistral_label = llama_label = claude_label = "Unknown"
                    mistral_rat = llama_rat = claude_rat = "No snippets available on the relation (snippet_1..N)."

                    joined_for_store = ""
                    global_rationale = (
                        explanation
                        + f"\n\nThe rationale for the first judge ('{MISTRAL_DISPLAY_NAME}') is:\n\n"
                        + mistral_rat
                        + f"\n\nThe rationale for the second judge ('{LLAMA_DISPLAY_NAME}') is:\n\n"
                        + llama_rat
                        + f"\n\nThe rationale for the third judge ('{CLAUDE_DISPLAY_NAME}') is:\n\n"
                        + claude_rat
                    )

                    # Write to Neo4j
                    s.run(Q_WRITE_REL_ANNOT, {
                        "src_id": row["src_id"],
                        "tgt_id": row["tgt_id"],
                        "mistral_label": mistral_label,
                        "mistral_rationale": mistral_rat,
                        "llama_label": llama_label,
                        "llama_rationale": llama_rat,
                        "claude_label": claude_label,
                        "claude_rationale": claude_rat,
                        "label": global_label,
                        "rationale": global_rationale,
                        "snippet_joined": joined_for_store,
                        "model_used": [MISTRAL_DISPLAY_NAME, LLAMA_DISPLAY_NAME, CLAUDE_DISPLAY_NAME],
                    })

                    unk_cnt += 1
                    if echo:
                        print(f"{src_name} → {tgt_name}: Global=Unknown (missing snippets)")

                    rows_out.append({
                        **common_row,
                        "Opinion Snippet": joined_for_store,
                        "Mistral Citation Evaluation": mistral_label,
                        "LLama Citation Evaluation": llama_label,
                        "Claude Citation Evaluation": claude_label,
                        "Global Citation Evaluation": global_label,
                        "Mistral Rationale": mistral_rat,
                        "LLama Rationale": llama_rat,
                        "Claude Rationale": claude_rat,
                        "Global Rationale": global_rationale,
                        COL_URL_HEADER: row.get("src_url") or "",
                    })

                    # Failed row (no LLM attempts)
                    failed_rows.append({
                        **{k: common_row.get(k, "") for k in first_11_cols},
                        "Model Name": "",
                        "Attempt": 0,
                        "Input to LLM": "",
                        "LLM Output": "",
                        "Type of Error": _describe_error("missing_snippets"),
                    })

                    processed += 1
                    continue

                # Build snippet block for models
                snippet_block = _format_snippet_block_labeled(snippets)
                joined_for_store = "\n\n".join(snippets)

                # ----- Run three judges -----
                # Mistral
                m_lab, m_rat, m_status, m_raw, m_attempt, m_logs, m_user_input = classify_with_bedrock_mistral(
                    citing_case_name=src_name,
                    citing_case_summary=src_sum,
                    cited_case_name=tgt_name,
                    cited_case_citation=tgt_cit,
                    cited_case_summary=tgt_sum,
                    snippet_block_labeled=snippet_block,
                    max_new_tokens=340,
                    retries=3,
                    echo=echo,
                )
                if not m_lab:
                    label_m = "Unknown"
                    rat_m = f"Classification failed: {_describe_error(m_status)}"
                    # bump error counters
                    if m_status in ("too_long", "json_parse_failed", "bad_keys_or_values", "api_error", "api_error_throttled", "api_response_error", "empty_response"):
                        if m_status == "too_long":
                            model_error_counts["Mistral"]["too_long"] += 1
                        elif m_status == "json_parse_failed":
                            model_error_counts["Mistral"]["json_parse_failed"] += 1
                        elif m_status == "bad_keys_or_values":
                            model_error_counts["Mistral"]["bad_keys_or_values"] += 1
                        else:
                            model_error_counts["Mistral"]["api_error"] += 1
                    else:
                        model_error_counts["Mistral"]["other"] += 1
                else:
                    label_m = m_lab
                    rat_m = m_rat

                # Llama
                l_lab, l_rat, l_status, l_raw, l_attempt, l_logs, l_user_input = classify_with_bedrock_llama(
                    citing_case_name=src_name,
                    citing_case_summary=src_sum,
                    cited_case_name=tgt_name,
                    cited_case_citation=tgt_cit,
                    cited_case_summary=tgt_sum,
                    snippet_block_labeled=snippet_block,
                    max_new_tokens=340,
                    retries=3,
                    echo=echo,
                )
                if not l_lab:
                    label_l = "Unknown"
                    rat_l = f"Classification failed: {_describe_error(l_status)}"
                    if l_status in ("too_long", "json_parse_failed", "bad_keys_or_values", "api_error", "api_error_throttled", "api_response_error", "empty_response"):
                        if l_status == "too_long":
                            model_error_counts["Llama"]["too_long"] += 1
                        elif l_status == "json_parse_failed":
                            model_error_counts["Llama"]["json_parse_failed"] += 1
                        elif l_status == "bad_keys_or_values":
                            model_error_counts["Llama"]["bad_keys_or_values"] += 1
                        else:
                            model_error_counts["Llama"]["api_error"] += 1
                    else:
                        model_error_counts["Llama"]["other"] += 1
                else:
                    label_l = l_lab
                    rat_l = l_rat

                # Claude
                c_lab, c_rat, c_status, c_raw, c_attempt, c_logs, c_user_input = classify_with_bedrock_claude(
                    citing_case_name=src_name,
                    citing_case_summary=src_sum,
                    cited_case_name=tgt_name,
                    cited_case_citation=tgt_cit,
                    cited_case_summary=tgt_sum,
                    snippet_block_labeled=snippet_block,
                    max_new_tokens=512,
                    retries=3,
                    echo=echo,
                )
                if not c_lab:
                    label_c = "Unknown"
                    rat_c = f"Classification failed: {_describe_error(c_status)}"
                    if c_status in ("too_long", "json_parse_failed", "bad_keys_or_values", "api_error", "api_error_throttled", "api_response_error", "empty_response"):
                        if c_status == "too_long":
                            model_error_counts["Claude"]["too_long"] += 1
                        elif c_status == "json_parse_failed":
                            model_error_counts["Claude"]["json_parse_failed"] += 1
                        elif c_status == "bad_keys_or_values":
                            model_error_counts["Claude"]["bad_keys_or_values"] += 1
                        else:
                            model_error_counts["Claude"]["api_error"] += 1
                    else:
                        model_error_counts["Claude"]["other"] += 1
                else:
                    label_c = c_lab
                    rat_c = c_rat

                # Collect attempt logs into failed_rows (for failed_csv)
                for model_name, logs in [
                    (MISTRAL_DISPLAY_NAME, m_logs),
                    (LLAMA_DISPLAY_NAME,   l_logs),
                    (CLAUDE_DISPLAY_NAME,  c_logs),
                ]:
                    for log in logs:
                        status = log.get("status", "")
                        if status == "ok":
                            continue
                        failed_rows.append({
                            **{k: common_row.get(k, "") for k in first_11_cols},
                            "Model Name": model_name,
                            "Attempt": log.get("attempt", 0),
                            "Input to LLM": log.get("input", ""),
                            "LLM Output": log.get("raw", ""),
                            "Type of Error": _describe_error(status),
                        })

                # ----- Majority vote and global rationale -----
                final_label = _compute_final_label(label_m, label_l, label_c)
                labels_set = {label_m, label_l, label_c}

                cnt = Counter([label_m, label_l, label_c])
                explanation = ""
                if final_label == "Unknown":
                    if cnt["Unknown"] >= 2:
                        explanation = "The final treatment is 'Unknown' because at least two of the three judges returned 'Unknown' labels."
                    elif "Unknown" in cnt and len(cnt) == 3:
                        explanation = "The final treatment is 'Unknown' because one judge returned 'Unknown' and the other two judges disagreed on the label."
                    else:
                        explanation = "The final treatment is 'Unknown' based on the ensemble voting rules."
                else:
                    if len(cnt) == 3 and "Unknown" not in cnt and final_label == "Neutral":
                        explanation = "The final treatment is 'Neutral' because all three judges disagreed (one Positive, one Neutral, one Negative), so we fall back to 'Neutral' by rule."
                    else:
                        explanation = f"The final treatment is '{final_label}' because it is the majority vote among the three judges."

                # Decide which judges' rationales to include
                judges_info = [
                    ("Mistral", MISTRAL_DISPLAY_NAME, label_m, rat_m),
                    ("Llama",   LLAMA_DISPLAY_NAME,   label_l, rat_l),
                    ("Claude",  CLAUDE_DISPLAY_NAME,  label_c, rat_c),
                ]

                # Default: judges whose label == final_label
                if final_label == "Unknown":
                    judges_to_include = judges_info  # show all in Unknown case
                elif len(cnt) == 3 and "Unknown" not in cnt and final_label == "Neutral":
                    judges_to_include = judges_info  # all three disagreed → include all
                else:
                    judges_to_include = [j for j in judges_info if j[2] == final_label]

                parts = [explanation]
                for idx, (_, display_name, _, rat_text) in enumerate(judges_to_include, 1):
                    parts.append(
                        f"\nThe rationale for the { _ordinal_word(idx) } judge ('{display_name}') is:\n\n{rat_text}"
                    )
                global_rationale = "\n".join(parts).strip()

                # Update global label counts
                if final_label == "Positive":
                    pos_cnt += 1
                elif final_label == "Neutral":
                    neu_cnt += 1
                elif final_label == "Negative":
                    neg_cnt += 1
                else:
                    unk_cnt += 1

                # Write to Neo4j
                s.run(Q_WRITE_REL_ANNOT, {
                    "src_id": row["src_id"],
                    "tgt_id": row["tgt_id"],
                    "mistral_label": label_m,
                    "mistral_rationale": rat_m,
                    "llama_label": label_l,
                    "llama_rationale": rat_l,
                    "claude_label": label_c,
                    "claude_rationale": rat_c,
                    "label": final_label,
                    "rationale": global_rationale,
                    "snippet_joined": joined_for_store,
                    "model_used": [MISTRAL_DISPLAY_NAME, LLAMA_DISPLAY_NAME, CLAUDE_DISPLAY_NAME],
                })

                ok_rel += 1 if final_label != "Unknown" else 0

                if echo:
                    print(
                        f"{src_name} → {tgt_name}: "
                        f"Mistral={label_m}, Llama={label_l}, Claude={label_c} → Global={final_label}"
                    )

                rows_out.append({
                    **common_row,
                    "Opinion Snippet": joined_for_store,
                    "Mistral Citation Evaluation": label_m,
                    "LLama Citation Evaluation": label_l,
                    "Claude Citation Evaluation": label_c,
                    "Global Citation Evaluation": final_label,
                    "Mistral Rationale": rat_m,
                    "LLama Rationale": rat_l,
                    "Claude Rationale": rat_c,
                    "Global Rationale": global_rationale,
                    COL_URL_HEADER: row.get("src_url") or "",
                })

                processed += 1

    # Build df_results from this run
    df_results = pd.DataFrame(rows_out)

    if show_all_labels_in_output_csv:
        df_results = _fetch_all_labeled_results(driver)
        if echo:
            print(f"\nshow_all_labels_in_output_csv=True → df_results replaced with full labeled dataset (rows: {len(df_results)})")
    else:
        if echo:
            print(f"\nshow_all_labels_in_output_csv=False → df_results only has rows from this run (rows: {len(df_results)})")

    # -------- Optional CSV of results --------
    if results_csv:
        df_results.to_csv(results_csv_filename, index=False)
        print(f"\nWrote {len(df_results)} rows → {results_csv_filename}")
    else:
        print(f"\nSkipping results CSV write (results_csv=False). Rows buffered in df_results: {len(df_results)}")

    # -------- Optional CSV of failed classifications --------
    if failed_csv:
        df_failed = pd.DataFrame(failed_rows)
        if not df_failed.empty:
            ordered_cols = first_11_cols + [
                c for c in ["Model Name", "Attempt", "Input to LLM", "LLM Output", "Type of Error"]
                if c in df_failed.columns
            ]
            for c in ordered_cols:
                if c not in df_failed.columns:
                    df_failed[c] = ""
            df_failed = df_failed[ordered_cols]
            df_failed.to_csv("failed_citations.csv", index=False)
            print(f"Wrote {len(df_failed)} failed rows → failed_citations.csv")
        else:
            print("No failed classifications in this run; failed_citations.csv not written.")
    else:
        if echo:
            print(f"Skipping failed CSV write (failed_csv=False). Failed rows in memory: {len(failed_rows)}")

    # -------- Optional labeled dataset join --------
    if append_to_labeled_dataset_csv:
        if labeled_output_csv is None:
            stem = pathlib.Path(append_to_labeled_dataset_csv).stem
            labeled_output_csv = f"{stem} - model comparison.csv"

        try:
            df_label = _read_csv_fallback(append_to_labeled_dataset_csv)
            print(f"Loaded labeled dataset with fallback reader: {append_to_labeled_dataset_csv}")
        except Exception as e:
            df_label = None
            print(f"Could not read labeled dataset CSV: {e}")

        if df_label is not None:
            expected = ["Source ID", "Source Name", "Target ID", "Target Name", "Chunk", "Label", "Rationale"]
            missing = [c for c in expected if c not in df_label.columns]
            if missing:
                print(f"Labeled CSV missing required columns: {missing}. Skipping append.")
            else:
                # Standardize labeled columns
                df_label_std = df_label.rename(columns={
                    "Source ID":   "Source Case ID",
                    "Target ID":   "Target Case ID",
                    "Source Name": "Source Case Name (Labeled)",
                    "Target Name": "Target Case Name (Labeled)",
                    # keep "Rationale" as-is
                })

                # Prepare join frame from current df_results with MODEL-suffixed columns
                df_join = df_results.rename(columns={
                    "Source Case Name":        "Source Case Name (Model)",
                    "Target Case Name":        "Target Case Name (Model)",
                    "Source Case Summary":     "Source Case Summary (Model)",
                    "Target Case Summary":     "Target Case Summary (Model)",
                    "Opinion Snippet":         "Chunk (Pulled by Pipeline)",
                    "Global Citation Evaluation": "Label (Model)",
                    "Global Rationale":           "Rationale (Model)",
                })

                needed_from_results = [
                    "Source Case ID", "Target Case ID",
                    "Source Case Name (Model)", "Target Case Name (Model)",
                    "Source Case Decision Date", "Source Case citation_pipe",
                    "Source Case Summary (Model)",
                    "Target Case Decision Date", "Target Case citation_pipe",
                    "Target Case Summary (Model)",
                    "Chunk (Pulled by Pipeline)",
                    "Label (Model)",
                    "Mistral Citation Evaluation",
                    "LLama Citation Evaluation",
                    "Claude Citation Evaluation",
                    "Rationale (Model)",
                    "Mistral Rationale",
                    "LLama Rationale",
                    "Claude Rationale",
                    COL_URL_HEADER,
                ]
                for c in needed_from_results:
                    if c not in df_join.columns:
                        df_join[c] = ""

                merged = df_label_std.merge(
                    df_join[needed_from_results],
                    on=["Source Case ID", "Target Case ID"],
                    how="inner",
                )

                # Build unified case-name columns (prefer labeled, else model)
                def _prefer_labeled(lbl_col: str, mdl_col: str):
                    lbl_series = merged.get(lbl_col)
                    mdl_series = merged.get(mdl_col)
                    if lbl_series is None and mdl_series is None:
                        return pd.Series([], dtype="object")
                    if lbl_series is None:
                        return mdl_series
                    if mdl_series is None:
                        return lbl_series
                    lbl_clean = lbl_series.fillna("").astype(str)
                    use_lbl = ~lbl_clean.str.fullmatch(r"\s*")
                    out = mdl_series.copy()
                    out[use_lbl] = lbl_series[use_lbl]
                    return out

                merged["Source Case Name"] = _prefer_labeled("Source Case Name (Labeled)", "Source Case Name (Model)")
                merged["Target Case Name"] = _prefer_labeled("Target Case Name (Labeled)", "Target Case Name (Model)")

                # Map per-model label/rationale columns
                merged["Mistral Label"] = merged.get("Mistral Citation Evaluation", "")
                merged["LLama Label"]   = merged.get("LLama Citation Evaluation", "")
                merged["Claude Label"]  = merged.get("Claude Citation Evaluation", "")

                # Final ordered columns (with renamed Rational → Rationale)
                merged = merged.rename(columns={
                    "Rationale": "Rationale",
                    "Rationale (Model)": "Rationale (Model)",
                })

                final_cols = [
                    "Source Case ID",
                    "Source Case Name",
                    "Source Case Decision Date",
                    "Source Case citation_pipe",
                    "Source Case Summary (Model)",
                    "Target Case ID",
                    "Target Case Name",
                    "Target Case Decision Date",
                    "Target Case citation_pipe",
                    "Target Case Summary (Model)",
                    "Chunk",                        # from labeled dataset
                    "Chunk (Pulled by Pipeline)",   # from model/pipeline
                    "Label",                        # from labeled dataset
                    "Label (Model)",                # final model label
                    "Mistral Label",
                    "LLama Label",
                    "Claude Label",
                    "Rationale",                    # from labeled dataset
                    "Mistral Rationale",
                    "LLama Rationale",
                    "Claude Rationale",
                    "Rationale (Model)",            # final model rationale
                    COL_URL_HEADER,
                ]
                for c in final_cols:
                    if c not in merged.columns:
                        merged[c] = ""

                merged = merged[[c for c in final_cols if c in merged.columns]]

                merged.to_csv(labeled_output_csv, index=False)
                print(f"Wrote labeled+model comparison CSV → {labeled_output_csv} (rows: {len(merged)})")

    # -------- Global label totals (after this run) --------
    with driver.session(database=NEO4J_DATABASE) as s_counts:
        row_counts = s_counts.run(Q_COUNT_LABELS, {}).single()
        total_pos = row_counts["pos_total"] or 0
        total_neu = row_counts["neu_total"] or 0
        total_neg = row_counts["neg_total"] or 0
        total_unk = row_counts["unk_total"] or 0

    dt_min = (time.time() - t0) / 60
    print(f"\nProcessed relations: {processed} | Elapsed: {dt_min:.1f} min")
    print("=== Diagnostics (global final label) ===")
    print(f"  Successful (classified, non-Unknown): {ok_rel}")
    print(f"  Missing snippets:                    {missing_snippets}")
    print("=== Label counts (this run, global) ===")
    print(f"  Positive: {pos_cnt}")
    print(f"  Neutral : {neu_cnt}")
    print(f"  Negative: {neg_cnt}")
    print(f"  Unknown : {unk_cnt}")
    print("=== Dataset label totals (after run, global) ===")
    print(f"  Positive: {total_pos}")
    print(f"  Neutral : {total_neu}")
    print(f"  Negative: {total_neg}")
    print(f"  Unknown : {total_unk}")
    print("=== Per-model error counts (this run) ===")
    for model, counts in model_error_counts.items():
        print(f"  [{model}] too_long={counts['too_long']}, json_parse_failed={counts['json_parse_failed']}, "
              f"bad_keys_or_values={counts['bad_keys_or_values']}, api_error={counts['api_error']}, other={counts['other']}")

# =========================
# Example call (adjust paths as needed)
# =========================
# label_all_citations(
#     force=False,
#     echo=True,
#     results_csv=True,
#     results_csv_filename="edge_classifications_ensemble.csv",
#     append_to_labeled_dataset_csv="Phase One Final Labeled Dataset from WK.csv",
#     labeled_output_csv=None,
#     show_all_labels_in_output_csv=True,
#     failed_csv=True,
# )


## Example Run

In [None]:
# Assumes relations already have snippet_1, snippet_2, ... properties.

label_all_citations(
    force=False,
    echo=True
)


Batch after rel_id -1: 200 relation(s)
Niece v. Fitzner → Robert T. McGregor v. Louisiana State University Board of Supervisors: Global=Unknown (missing snippets)
United States v. Agrawal → Coors Brewing Co. v. MENDEZ-TORRES: Global=Unknown (missing snippets)

Batch after rel_id 1152924803141730530: 200 relation(s)

Batch after rel_id 1152924803141730820: 200 relation(s)
Palao v. Fel-Pro., Inc. → Paula S. Skorup v. Modern Door Corporation: Global=Unknown (missing snippets)
Barber Lines A/s v. M/v Donau Maru → Dees v. Austin Travis County Mental Health & Mental Retardation: Global=Unknown (missing snippets)
Conboy v. State → Jensen v. State, Department of Labor & Industry: Global=Unknown (missing snippets)

Batch after rel_id 1152924803141731151: 200 relation(s)
      · [Mistral] attempt 1/3
      · [Llama] attempt 1/3
      · [Llama] API ClientError on attempt 1/3: ValidationException - This model's maximum context length is 8192 tokens. Please reduce the length of the prompt
      · 

## Compare with labeled dataset

In [None]:
# =========================
# Example call
# =========================
# label_all_citations(
#     force=True,
#     echo=True,
#     append_to_labeled_dataset_csv="Phase One Final Labeled Dataset from WK.csv",
#     labeled_output_csv="WK Labeled vs Snippet Method Model Comparison - Ensemble.csv",
#     show_all_labels_in_output_csv=True,
#     failed_csv=True,
# )