# Edge Classifier Snippet Method – Claude 3.5 Sonnet

**Purpose**

* Classify each `(:Case)-[:CITES_TO]->(:Case)` edge as **Positive / Neutral / Negative / Unknown** using:
  * **Pre-extracted snippets** on the relation (`snippet_1 … snippet_N`)
  * Short **case summaries** for citing and cited cases
* Write results back to the relation and (optionally) produce a CSV for analysis and QA.

---

## What this notebook uses

* **Snippets from the relation**  
  * `snippet_1, snippet_2, …` already created by the Snippet Retriever.
* **Case metadata** (for both source = citing, target = cited):
  * `name`, `decision_date`, `citation_pipe`, optional `opinion_summary`
* **Short citation**  
  * First element from target’s `citation_pipe` (used to anchor the model).
* **Model**  
  * AWS Bedrock **Claude 3.5 Sonnet** (`anthropic.claude-3-5-sonnet-20240620-v1:0`).

---

## How it works (high level)

1. **Page edges**  
   * Use `Q_PAGE_REL` to page through `(:Case)-[r:CITES_TO]->(:Case)` relations in batches by `id(r)`.

2. **Gather snippets**  
   * Read `properties(r)` and collect keys matching `snippet_\d+`.
   * Sort by numeric suffix and build a labeled block:

     ```text
     snippet 1:

     [text]

     snippet 2:

     [text]
     ```

3. **Build the prompt**

   Claude sees:

   * Citing case **name** and **summary**
   * Cited case **name** and **summary**
   * **Cited short citation** (first from `citation_pipe`)
   * **Snippets block**

4. **Model output (single judge)**

   Claude is instructed to return strict JSON:

   ```json
   {
     "classification": "Positive|Neutral|Negative",
     "rationale": "Four sentences with one short direct quote."
   }

* If the call succeeds and the JSON can be parsed → classification and rationale are stored.
* If the call fails (too long / JSON error / empty response, etc.) → this edge is labeled **`"Unknown"`** with an error message recorded in the rationale.

5. **Write back to the edge**

   For each relation:

   * `r.treatment_label`          (Claude’s label or `"Unknown"`)
   * `r.treatment_rationale`      (four-sentence rationale or failure message)
   * `r.treatment_snippet`        (all snippets joined)
   * `r.model_used`               (e.g., `"Claude 3.5 Sonnet (…)"`)
   * `r.updated_at_utc`           (timestamp from `datetime()`)

6. **Already-labeled edges**

   * If `force=False` and `r.treatment_label` exists and is not `"Unknown"`, the edge is **skipped**.
   * Set `force=True` if you want to overwrite existing labels.

7. **No snippets**

   * If the relation has **no `snippet_1..N`**, it is labeled **`"Unknown"`** and the rationale explains that there were no snippets to analyze.

---

## Token budgeting & overflow handling

* Uses a shared tokenizer (based on `cl100k_base`) to estimate total tokens for:

  * System prompt + user prompt (summaries + snippets).
* If the prompt exceeds the context budget:

  * Compute the overhead for system + an empty user prompt.
  * Use the remaining budget only for **snippets**.
  * Trim the snippet block down to the allowed token count (summaries are preserved).
* If the prompt is **still too long** after trimming:

  * The call is flagged as **`"too_long"`**.
  * The edge is labeled **`"Unknown"`** with a rationale explaining the overflow.

---

## Robust JSON parsing & retries

Each call to Claude:

* Retries up to **3 times** with increasing strictness in the system prompt:

  1. Normal instructions.
  2. “Output JSON only. No prose, no backticks.”
  3. “Output ONLY a JSON object with `classification` and `rationale`. Double quotes. No trailing commas. No markdown.”

* JSON cleaning and parsing:

  * Strip `json / ` fences and extra whitespace.
  * Normalize smart quotes.
  * Remove trailing commas before `}` or `]`.
  * If needed, apply a heuristic to fix **unescaped double quotes** inside the `rationale` string.
  * Try parsing the entire output as JSON.
  * If that fails, search for the **largest `{...}` block** and parse that.
  * Support both:

    * Direct JSON object, and
    * JSON nested or stringified inside the `rationale` field.

* Status codes:

  * `"ok"`
  * `"json_parse_failed"`
  * `"bad_keys_or_values"`
  * `"too_long"`
  * `"empty_response"`
  * `"api_error"`, `"api_error_throttled"`, `"api_response_error"`

If the final status is not `"ok"`, the edge is labeled `"Unknown"` and the status is described in the rationale.

---

## Neo4j I/O

**Reads**

* `Q_PAGE_REL` returns:

  * Source case: `id`, `name`, `decision_date`, `citation_pipe`, `opinion_summary`, URL
  * Target case: `id`, `name`, `decision_date`, `citation_pipe`, `opinion_summary`
  * `properties(r)` (snippets and any previous fields)
  * Existing `r.treatment_label` (for `force` behavior)

* Snippets are pulled from `properties(r)` as all keys matching `snippet_\d+` and sorted numerically.

**Writes**

* `Q_WRITE_REL_ANNOT` sets on the relation:

  * `treatment_label`
  * `treatment_rationale`
  * `treatment_snippet`
  * `model_used`
  * `updated_at_utc`

> Already-labeled edges are **skipped** unless `force=True`.

---

## CSV outputs (optional)

With `results_csv=True`, the notebook writes (by default):

* `edge_classifications_claude.csv`

Each row includes:

* Source / Target:

  * IDs, names, decision dates, `citation_pipe`
  * Summaries (source and target)
* Snippet and label:

  * Joined `Opinion Snippet`
  * **Citation Evaluation** (Claude label or `"Unknown"`)
  * **LLM Rationale** (four-sentence explanation or error message)
* URL:

  * Source case URL (CourtListener if available)

### Optional merges

* **Append to labeled dataset** (`append_to_labeled_dataset_csv`)
  Produces a **model comparison** CSV joining your human-labeled dataset with Claude’s labels and rationales.
* **Compare with previous model CSV** (`compare_with_previous_csv`)
  Optional side-by-side diff of an older model CSV vs. the current Claude outputs.

---

## Key parameters

```python
label_all_citations(
    results_csv: bool = False,
    results_csv_filename: str = "edge_classifications_claude.csv",
    batch_size: int = 200,
    echo: bool = True,
    force: bool = False,
    append_to_labeled_dataset_csv: Optional[str] = None,
    labeled_output_csv: Optional[str] = None,
    compare_with_previous_csv: Optional[str] = None,
    comparison_output_csv: Optional[str] = None,
)
```

* `results_csv` / `results_csv_filename`
  Control whether and where to write the output CSV.
* `batch_size`
  Number of edges processed per Neo4j page.
* `echo`
  If `True`, prints `"<Source> → <Target>: <Label>"` plus batch-level summaries.
* `force`
  If `True`, re-labels edges even if `treatment_label` already exists and is not `"Unknown"`.
* `append_to_labeled_dataset_csv`, `compare_with_previous_csv`
  Optional CSV comparison / benchmarking utilities.

---

## Environment

* Loads `../.env`. Required:

  * `NEO4J_URI`, `NEO4J_USERNAME`, `NEO4J_PASSWORD`
    (and optionally `NEO4J_DATABASE`, default `"neo4j"`)
  * `BEDROCK_REGION` (default `"us-east-1"`)

---

## Quick start

```python
label_all_citations(
    force=False,
    echo=True,
    results_csv=True,
    results_csv_filename="edge_classifications_claude.csv",
)
```

This will:

* Page through all `CITES_TO` edges that are not already labeled (unless `force=True`).
* Call **Claude 3.5 Sonnet** once per edge.
* Write `treatment_label` / `treatment_rationale` / `treatment_snippet` and metadata back to Neo4j.
* Export `edge_classifications_claude.csv` for review.

---

## Troubleshooting

* **“Unknown” due to missing snippets**
  Run the Snippet Retriever so edges have `snippet_1..N`.

* **“Unknown (too_long)”**
  Snippets are auto-trimmed, but if you still hit token limits:

  * Reduce the number or length of snippets per edge, or
  * Lower `max_new_tokens` to give more room to the input.

* **Frequent JSON parse / key errors**
  Retries and JSON cleaning are already enabled. If errors continue:

  * Set `echo=True` and inspect raw Claude outputs.
  * Consider simplifying summaries or prompts if the model is adding extra prose or nested structures.

```
::contentReference[oaicite:0]{index=0}
```


In [1]:
# Install (if applicable)
! pip install neo4j



In [2]:
! pip install tiktoken

Collecting tiktoken
  Using cached tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.7 kB)
Using cached tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl (1.2 MB)
Installing collected packages: tiktoken
Successfully installed tiktoken-0.12.0


In [3]:
import os, re, time, json, pathlib, logging, datetime as dt, random
from typing import List, Dict, Any, Tuple, Optional

import boto3
import pandas as pd
from neo4j import GraphDatabase
from dotenv import load_dotenv
from botocore.exceptions import ClientError, BotoCoreError
import tiktoken

# Quiet noisy logs (incl. Neo4j notifications/deprecations)
for _n in ("neo4j", "neo4j.notifications", "neo4j.work.simple"):
    logging.getLogger(_n).setLevel(logging.ERROR)
os.environ.setdefault("NEO4J_DRIVER_LOG_LEVEL", "ERROR")

# =========================
# Config / ENV
# =========================
BEDROCK_REGION   = os.getenv("BEDROCK_REGION", "us-east-1")
BEDROCK_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"

# .env is always one level up from this notebook
load_dotenv("../.env", override=True)
NEO4J_URI       = os.getenv("NEO4J_URI")
NEO4J_USERNAME  = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD  = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE  = os.getenv("NEO4J_DATABASE", "neo4j")

if not (NEO4J_URI and NEO4J_USERNAME and NEO4J_PASSWORD):
    raise RuntimeError("Missing Neo4j connection settings. Check ../.env for NEO4J_URI/USERNAME/PASSWORD.")

# =========================
# Tokenizer for precise prompt sizing (Claude 3 / 3.5)
# =========================
_tok = tiktoken.get_encoding("cl100k_base")

def _count_tokens(text: str) -> int:
    """Approximate Claude token count."""
    if not text:
        return 0
    return len(_tok.encode(text))

def _trim_to_tokens(text: str, max_tokens: int) -> str:
    """Trim text to max_tokens using Claude-like tokenization."""
    token_ids = _tok.encode(text)
    if len(token_ids) <= max_tokens:
        return text
    return _tok.decode(token_ids[:max_tokens])

def _max_ctx_tokens_for_bedrock(max_new_tokens: int) -> int:
    """
    Claude 3.5 Sonnet supports ~200k tokens of context.
    Reserve budget for response + overhead.
    """
    max_context = 200_000
    buffer_for_overhead = 2_000
    reserve_for_output  = max_new_tokens + 128
    return max_context - buffer_for_overhead - reserve_for_output

# =========================
# Bedrock client (lazy)
# =========================
_bedrock_client = None

def _bedrock():
    global _bedrock_client
    if _bedrock_client is None:
        _bedrock_client = boto3.client("bedrock-runtime", region_name=BEDROCK_REGION)
    return _bedrock_client

## Edge Classification Prompt (Snippet Method)

In [4]:
# =========================
# Prompts
# =========================
SYSTEM_PROMPT = """Role: You are an experienced lawyer specializing in legal citation analysis. 
Goal: Your goal is to classify how a citing case treats a cited case (e.g., Positive, Neutral, Negative) when given:
1.        Citing Case Name: the name of the citing case.
2.        Citing Case Summary: the summary of the citing case.
3.        Cited Case Name: the name of the cited case
4.        Cited Case Summary: the summary of the cited case.
5.        Snippets (from citing opinion where the cited case appears): List of text snippets where the cited case is cited in the citing opinion text 

CRITICAL RULES (must follow):
1. **If the citing case uses the cited case to establish or support ANY legal rule, doctrine, standard, test, or conclusion, classify as POSITIVE.**
   - This includes situations where the case:
     • recites a standard from the case,
     • cites the case as part of a string cite supporting a legal principle,
     • uses the case as an example consistent with its reasoning,
     • applies reasoning from the cited case.

   - IMPORTANT: The opinion **does NOT need to use words like “follow,” “adopt,” “agree,” or “apply.”**  
     Any supportive or explanatory use counts as Positive.

2. Classify as NEUTRAL **only when the case is mentioned without being used as authority**.
   - Examples:
     • descriptive background
     • illustrating a factual distinction
     • quoting language without using it to support a rule
     • noting procedural history

3. Classify as NEGATIVE when the citing court:
   - criticizes, limits, distinguishes, rejects, or declines to follow the cited case.

4. When writing the rationale:
   - Include ONE short direct quote from the citing opinion.
   - The quote must illustrate why the treatment is Positive, Neutral, or Negative.
   - Provide EXACTLY four sentences.

OUTPUT FORMAT (strict):
{
  "classification": "Positive|Neutral|Negative",
  "rationale": "Four sentences with one direct quote."
}
"""
USER_PROMPT_TMPL = """Classify how the citing case treats the cited case. Return your answer as a json format: {{classification: "", rationale: ""}} where label is one of Positive, Neutral, or Negative and rationale is a four-sentence explanation that justifies your classification
using evidence from the cited paragraph and case summaries.

Input:
Citing Case Name: {citing_case_name}
Citing Case Summary: {citing_case_summary}
Cited Case Name: {cited_case_name}
Cited Case Citation: {cited_case_citation}
Cited Case Summary: {cited_case_summary}

Snippets (from citing opinion where the cited case appears):
{snippet_block}
"""

In [5]:
def _inst(sys: str, usr: str) -> str:
    return f"<s>[INST]{sys}\n{usr}[/INST]"

def _format_snippet_block_labeled(snippets: List[str]) -> str:
    """
    Render as:
    snippet 1:

    [text]

    snippet 2:

    [text]
    """
    blocks = []
    for i, s in enumerate(snippets, 1):
        ss = (s or "").strip()
        if not ss:
            continue
        blocks.append(f"snippet {i}:\n\n{ss}")
    return "\n\n".join(blocks) if blocks else "snippet 1:\n\n[N/A]"

## Cypher Queries

In [6]:
# =========================
# Cypher queries
# =========================
COL_URL_HEADER = "Source Case CourtListener URL (may be incorrect for some cases)"

# Paginated fetch of relations for classification
Q_PAGE_REL = """
MATCH (s:Case)-[r:CITES_TO]->(t:Case)
WHERE id(r) > $after_id
RETURN id(r) AS rel_id,
       s.id AS src_id, coalesce(s.name,'') AS src_name, s.decision_date AS src_date, s.citation_pipe AS src_cite,
       coalesce(s.court_listener_url, s.url, '') AS src_url,
       coalesce(s.opinion_summary,'') AS src_summary,
       t.id AS tgt_id, coalesce(t.name,'') AS tgt_name, t.decision_date AS tgt_date, t.citation_pipe AS tgt_cite,
       coalesce(t.opinion_summary,'') AS tgt_summary,
       properties(r) AS rel_props,
       r.treatment_label AS existing_label
ORDER BY rel_id
LIMIT $limit
"""

Q_WRITE_REL_ANNOT = """
MATCH (s:Case {id:$src_id})-[r:CITES_TO]->(t:Case {id:$tgt_id})
SET r.treatment_label     = $label,
    r.treatment_rationale = $rationale,
    r.treatment_snippet   = $snippet_joined,
    r.model_used          = $model_used,
    r.updated_at_utc      = datetime()
RETURN count(r) AS updated
"""

Q_COUNT_LABELS = """
MATCH ()-[r:CITES_TO]->()
RETURN
  sum(CASE WHEN r.treatment_label = 'Positive' THEN 1 ELSE 0 END) AS pos_total,
  sum(CASE WHEN r.treatment_label = 'Neutral'  THEN 1 ELSE 0 END) AS neu_total,
  sum(CASE WHEN r.treatment_label = 'Negative' THEN 1 ELSE 0 END) AS neg_total,
  sum(CASE WHEN r.treatment_label = 'Unknown'  THEN 1 ELSE 0 END) AS unk_total
"""

# Full dataset of currently labeled edges (for show_all_labels_in_output_csv=True)
Q_ALL_LABELED_REL = """
MATCH (s:Case)-[r:CITES_TO]->(t:Case)
WHERE r.treatment_label IS NOT NULL
RETURN
  s.id AS src_id,
  coalesce(s.name,'') AS src_name,
  s.decision_date AS src_date,
  s.citation_pipe AS src_cite,
  coalesce(s.court_listener_url, s.url, '') AS src_url,
  coalesce(s.opinion_summary,'') AS src_summary,
  t.id AS tgt_id,
  coalesce(t.name,'') AS tgt_name,
  t.decision_date AS tgt_date,
  t.citation_pipe AS tgt_cite,
  coalesce(t.opinion_summary,'') AS tgt_summary,
  coalesce(r.treatment_snippet,'') AS snippet_joined,
  r.treatment_label AS treatment_label,
  coalesce(r.treatment_rationale,'') AS treatment_rationale
"""

## Helpers

In [7]:
# =========================
# Helpers
# =========================
def _first_citation(citation_pipe: Optional[str]) -> str:
    if not citation_pipe:
        return ""
    parts = [p.strip() for p in re.split(r'[|;]+', citation_pipe) if p.strip()]
    return parts[0] if parts else ""

def _extract_numbered_snippets(rel_props: Dict[str, Any]) -> List[str]:
    """
    Pull snippet_1, snippet_2, ... from properties(r) and return as a list
    sorted by the numeric suffix. Empty/whitespace-only are skipped.
    """
    out: List[Tuple[int, str]] = []
    for k, v in (rel_props or {}).items():
        m = re.fullmatch(r"snippet_(\d+)", k)
        if not m:
            continue
        try:
            idx = int(m.group(1))
        except ValueError:
            continue
        text = str(v or "").strip()
        if text:
            out.append((idx, text))
    out.sort(key=lambda p: p[0])
    return [t for _, t in out]

def _read_csv_fallback(path: str) -> pd.DataFrame:
    """
    Read CSV trying a few common encodings in order.
    Returns a DataFrame or raises the last exception.
    """
    encodings = ["utf-8", "utf-8-sig", "cp1252", "latin-1"]
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    # As a last resort, decode with cp1252 and replace invalid bytes to avoid hard fail
    try:
        with open(path, "r", encoding="cp1252", errors="replace") as fh:
            return pd.read_csv(fh)
    except Exception:
        pass
    raise last_err if last_err else RuntimeError("Failed to read CSV with fallback encodings")

In [8]:
# =========================
# JSON cleaning / normalization
# =========================
_BACKTICKS_RE = re.compile(r"^```(?:json)?|```$", re.MULTILINE)

def _clean_json(s: str) -> str:
    """
    Normalize LLM-style JSON:
      - strip ```json fences
      - collapse newlines into spaces (multi-line strings become single-line)
      - normalize smart quotes
      - unwrap a single outer quote pair if it surrounds a JSON object/array
      - drop dangling commas before } or ]
    """
    if not isinstance(s, str):
        s = str(s)

    # Strip markdown fences
    s = _BACKTICKS_RE.sub("", s)

    # Normalize newlines and carriage returns to spaces so raw line breaks
    # inside strings do not break JSON parsing.
    s = s.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")

    s = s.strip()

    # Normalize smart quotes / apostrophes
    s = s.replace("“", '"').replace("”", '"').replace("’", "'")

    # If the entire payload is quoted as a big JSON string, unwrap it once
    if len(s) >= 2 and s[0] == s[-1] == '"':
        inner = s[1:-1].strip()
        if ("{" in inner and "}" in inner) or ("[" in inner and "]" in inner):
            s = inner

    # Remove trailing commas before object/array close
    s = re.sub(r",\s*([}\]])", r"\1", s)
    return s


def _fix_unescaped_quotes_in_rationale(payload: str) -> str:
    """
    Heuristic: inside the rationale string, replace any *unescaped* " with '.
    This recovers JSON where the model used double quotes for a direct quote
    but did not escape them.
    """
    try:
        key = '"rationale"'
        i = payload.find(key)
        if i == -1:
            return payload

        colon = payload.find(":", i)
        if colon == -1:
            return payload

        # First " after the colon starts the rationale string
        first_q = payload.find('"', colon + 1)
        # Assume rationale is the *last* string value: last " in the payload
        last_q = payload.rfind('"')

        if first_q == -1 or last_q == -1 or last_q <= first_q:
            return payload

        body = payload[first_q + 1:last_q]

        # Replace any unescaped " inside the rationale body with '
        out_chars = []
        for idx, ch in enumerate(body):
            if ch == '"' and (idx == 0 or body[idx - 1] != "\\"):
                out_chars.append("'")
            else:
                out_chars.append(ch)
        fixed_body = "".join(out_chars)

        return payload[:first_q + 1] + fixed_body + payload[last_q:]
    except Exception:
        # On any error, return original unchanged
        return payload


def _normalize_from_js(js: Any) -> Tuple[Optional[str], Optional[str]]:
    """
    Given a parsed JSON value (dict or string), extract:
      classification ∈ {"Positive","Neutral","Negative"}
      rationale: non-empty string

    Handles nested or stringified JSON inside the 'rationale' field.
    """
    valid_labels = {"Positive", "Neutral", "Negative"}

    # Case 1: dict at top level
    if isinstance(js, dict):
        c = js.get("classification") or js.get("label")
        r = js.get("rationale")

        # If rationale itself is another dict, try inner fields
        if isinstance(r, dict):
            inner_c = r.get("classification") or r.get("label")
            inner_r = r.get("rationale")
            if inner_c in valid_labels and isinstance(inner_r, str) and inner_r.strip():
                return inner_c, inner_r.strip()

        # If rationale is a string that looks like JSON, try to parse it
        if isinstance(r, str):
            r_str = r.strip()
            if len(r_str) >= 2 and r_str[0] == r_str[-1] == '"':
                r_str = r_str[1:-1].strip()
            if r_str.startswith("{") and "classification" in r_str and "rationale" in r_str:
                try:
                    inner_js = json.loads(_clean_json(r_str))
                    inner_c, inner_r = _normalize_from_js(inner_js)
                    if inner_c and inner_r:
                        return inner_c, inner_r
                except Exception:
                    pass

        # Normal case: classification + plain rationale
        if c in valid_labels and isinstance(r, str) and r.strip():
            return c, r.strip()

        return None, None

    # Case 2: top-level is a string that itself contains JSON
    if isinstance(js, str):
        txt = js.strip()
        if len(txt) >= 2 and txt[0] == txt[-1] == '"':
            txt = txt[1:-1].strip()
        if txt.startswith("{") and "classification" in txt and "rationale" in txt:
            try:
                inner_js = json.loads(_clean_json(txt))
                return _normalize_from_js(inner_js)
            except Exception:
                return None, None

    return None, None
    
def _extract_rationale_between_markers(s: str) -> Optional[str]:
    """
    Heuristic rationale extractor that ignores JSON validity.

    Strategy:
      - Find `"rationale"` key
      - From the colon, find the first `"` that starts the string
      - Then take everything up to the LAST `"` before the closing `}`

    This allows inner unescaped " characters inside the rationale, since
    we are not trying to parse the JSON string literally.
    """
    s_lower = s.lower()
    key_idx = s_lower.find('"rationale"')
    if key_idx == -1:
        return None

    colon_idx = s.find(":", key_idx)
    if colon_idx == -1:
        return None

    # First quote after the colon = start of rationale string
    first_quote = s.find('"', colon_idx)
    if first_quote == -1:
        return None

    # Try to limit search to this JSON object: last " before the last }
    last_brace = s.rfind("}")
    if last_brace == -1:
        search_end = len(s)
    else:
        search_end = last_brace

    last_quote = s.rfind('"', first_quote + 1, search_end)
    if last_quote == -1:
        # Fallback: last quote in the whole string
        last_quote = s.rfind('"', first_quote + 1)
        if last_quote == -1:
            return None

    rationale = s[first_quote + 1:last_quote]

    # Simple cleanup of common escape sequences
    rationale = rationale.replace('\\"', '"').replace("\\n", " ").replace("\\r", " ")
    rationale = rationale.strip()
    return rationale or None


def _extract_classification_and_rationale(raw_text: str) -> Tuple[Optional[str], Optional[str], str]:
    """
    Extract classification and rationale from Claude output.

    New behavior:
      - Classification is read via regex from "classification" / "label".
      - Rationale is extracted heuristically as everything between
        `"rationale": "` and the closing `}` of the JSON object, so it
        is not cut off by inner unescaped double quotes.

    Fallback:
      - If this fails, we still try a simple json.loads() on the cleaned
        text and read classification / rationale from there.
    """
    if not raw_text or not str(raw_text).strip():
        return None, None, "empty_response"

    # Normalize markdown fences, newlines, smart quotes, etc.
    s = _clean_json(str(raw_text).strip())
    valid_labels = {"Positive", "Neutral", "Negative"}

    # ---- 1) Classification via regex (safe; label string has no inner quotes) ----
    label_match = re.search(
        r'"(?:classification|label)"\s*:\s*"(?P<label>Positive|Neutral|Negative)"',
        s,
        flags=re.IGNORECASE,
    )
    label: Optional[str] = None
    if label_match:
        label = label_match.group("label").capitalize()
        if label not in valid_labels:
            label = None

    # ---- 2) Rationale via heuristic between "rationale" and closing brace ----
    rationale = _extract_rationale_between_markers(s)

    if label and rationale:
        return label, rationale, "ok"

    # ---- 3) Simple JSON fallback (for weird but valid JSON cases) ----
    try:
        js = json.loads(s)
        c = js.get("classification") or js.get("label")
        r = js.get("rationale")
        if c in valid_labels and isinstance(r, str) and r.strip():
            return c, r.strip(), "ok"
        return None, None, "bad_keys_or_values"
    except Exception:
        return None, None, "json_parse_failed"


def _describe_error(status: str) -> str:
    """
    Map internal status codes to a human-readable error description.
    """
    mapping = {
        "too_long": "too_long: prompt exceeded maximum context tokens",
        "json_parse_failed": "json_parse_failed: could not parse LLM output as JSON",
        "bad_keys_or_values": "bad_keys_or_values: missing or invalid classification/rationale in JSON",
        "api_error_throttled": "api_error_throttled: Bedrock throttled the request",
        "api_error": "api_error: generic AWS/Bedrock client or service error",
        "api_response_error": "api_response_error: malformed or unexpected API response",
        "empty_response": "empty_response: model returned empty text",
        "missing_snippets": "missing_snippets: no snippets available on the relation (snippet_1..N).",
        "ok": "ok: successful classification",
    }
    return mapping.get(status, status)


## LLM Classify

In [9]:
# =========================
# LLM classify (Claude 3.5 Sonnet on Bedrock)
# =========================
def classify_with_bedrock(
    *,
    citing_case_name: str,
    citing_case_summary: str,
    cited_case_name: str,
    cited_case_citation: str,
    cited_case_summary: str,
    snippet_block_labeled: str,
    max_new_tokens: int = 1024,
    retries: int = 3,
    echo: bool = False
) -> Tuple[
    Optional[str], Optional[str], str, str, int,
    List[Dict[str, Any]], str
]:
    """
    Returns:
      classification: Optional[str]
      rationale: Optional[str]
      status: final status code ("ok", "json_parse_failed", etc.)
      raw_output: final raw output (string)
      attempt_used: which attempt index produced the final result (1-based)
      attempt_logs: list of per-attempt dicts:
          {
            "attempt": int,
            "status": str,
            "raw": str,
            "input": str  # user prompt string (no system wrapper)
          }
      user_input: the user prompt string sent to the LLM

    status: "ok" | "json_parse_failed" | "bad_keys_or_values" | "too_long"
            | "api_error" | "api_error_throttled" | "api_response_error" | "empty_response"
    """
    client = _bedrock()

    user = USER_PROMPT_TMPL.format(
        citing_case_name=(citing_case_name or "").strip(),
        citing_case_summary=(citing_case_summary or "").strip(),
        cited_case_name=(cited_case_name or "").strip(),
        cited_case_citation=(cited_case_citation or "").strip(),
        cited_case_summary=(cited_case_summary or "").strip(),
        snippet_block=snippet_block_labeled
    )

    attempt_logs: List[Dict[str, Any]] = []

    # -------- Token budget (system + user combined) --------
    max_in = _max_ctx_tokens_for_bedrock(max_new_tokens)

    # First pass usage
    used = _count_tokens(SYSTEM_PROMPT + "\n\n" + user)

    if used > max_in:
        # Compute overhead for system + empty-user
        empty_user = USER_PROMPT_TMPL.format(
            citing_case_name="",
            citing_case_summary="",
            cited_case_name="",
            cited_case_citation="",
            cited_case_summary="",
            snippet_block=""
        )
        overhead_tokens = _count_tokens(SYSTEM_PROMPT + "\n\n" + empty_user)

        # Allowance for snippets inside the user prompt
        allowance_for_snippets = max(512, max_in - overhead_tokens)

        snippet_tokens = _count_tokens(snippet_block_labeled)
        target_snippet_tokens = max(128, min(snippet_tokens, allowance_for_snippets))
        trimmed = _trim_to_tokens(snippet_block_labeled, target_snippet_tokens)

        user = USER_PROMPT_TMPL.format(
            citing_case_name=(citing_case_name or "").strip(),
            citing_case_summary=(citing_case_summary or "").strip(),
            cited_case_name=(cited_case_name or "").strip(),
            cited_case_citation=(cited_case_citation or "").strip(),
            cited_case_summary=(cited_case_summary or "").strip(),
            snippet_block=trimmed
        )

        used = _count_tokens(SYSTEM_PROMPT + "\n\n" + user)
        if used > max_in:
            if echo:
                print(f"      · prompt still too long after trim ({used} > {max_in}); giving up")
            attempt_logs.append({
                "attempt": 1,
                "status": "too_long",
                "raw": "",
                "input": user,
            })
            return None, None, "too_long", "", 1, attempt_logs, user

    last_err = "json_parse_failed"
    last_raw = ""

    # -------- Call Bedrock with retries --------
    for attempt in range(1, retries + 1):
        # Progressive strictness in system prompt
        if attempt == 1:
            sys = SYSTEM_PROMPT
        elif attempt == 2:
            sys = SYSTEM_PROMPT + "\n\nSTRICT: Output JSON only. No prose, no backticks."
        else:
            sys = SYSTEM_PROMPT + (
                '\n\nSTRICT: Output ONLY a JSON object with keys '
                '"classification" and "rationale". Use double quotes. '
                "No trailing commas. No markdown."
            )

        if echo:
            print(f"      · attempt {attempt}/{retries}")

        # Small base delay to avoid hammering the API
        time.sleep(0.75)

        try:
            resp = client.invoke_model(
                modelId=BEDROCK_MODEL_ID,
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": max_new_tokens,
                    "temperature": 0.0,
                    "system": sys,
                    "messages": [
                        {"role": "user", "content": user}
                    ]
                }),
                contentType="application/json",
                accept="application/json",
            )

            raw_body = resp["body"].read()
            try:
                body = json.loads(raw_body)
            except Exception as e:
                if echo:
                    print(f"      · failed to decode response JSON: {e}")
                last_err = "api_response_error"
                last_raw = (
                    raw_body.decode("utf-8", errors="replace")
                    if isinstance(raw_body, (bytes, bytearray))
                    else str(raw_body)
                )

                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.5 * attempt)
                continue

            # Claude response format
            try:
                text = body["content"][0]["text"]
            except (KeyError, IndexError, TypeError) as e:
                if echo:
                    print(f"      · unexpected response structure: {e}")
                last_err = "api_response_error"
                last_raw = json.dumps(body)

                attempt_logs.append({
                    "attempt": attempt,
                    "status": last_err,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.5 * attempt)
                continue

            last_raw = (text or "").strip()
            if not last_raw:
                if echo:
                    print("      · empty response from model")
                status = "empty_response"
                last_err = status
                attempt_logs.append({
                    "attempt": attempt,
                    "status": status,
                    "raw": last_raw,
                    "input": user,
                })
                time.sleep(1.5 * attempt)
                continue

            # ---- Centralized parsing / normalization ----
            c, r, status = _extract_classification_and_rationale(last_raw)
            attempt_logs.append({
                "attempt": attempt,
                "status": status,
                "raw": last_raw,
                "input": user,
            })

            if status == "ok" and c and r:
                return c, r, "ok", last_raw, attempt, attempt_logs, user

            # Logical / JSON parsing failure
            last_err = status
            time.sleep(1.5 * attempt)
            continue

        except ClientError as e:
            # AWS / Bedrock explicit error with code + message
            code = e.response.get("Error", {}).get("Code", "")
            msg = e.response.get("Error", {}).get("Message", "")

            if echo:
                print(f"      · API ClientError on attempt {attempt}/{retries}: {code} - {msg}")

            msg_lower = msg.lower() if isinstance(msg, str) else ""
            if "token" in msg_lower and ("length" in msg_lower or "limit" in msg_lower):
                last_err = "too_long"
            elif code in ("ThrottlingException", "TooManyRequestsException") or "rate exceeded" in msg_lower:
                last_err = "api_error_throttled"
            else:
                last_err = "api_error"

            last_raw = msg
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })

            # More conservative backoff on throttling: exponential + small jitter
            if last_err == "api_error_throttled":
                backoff = min(60.0, float(2 ** attempt))
                jitter = random.uniform(0.0, 0.5)
                if echo:
                    print(f"      · throttled, backing off for {backoff + jitter:.2f} seconds")
                time.sleep(backoff + jitter)
            else:
                time.sleep(1.5 * attempt)
            continue

        except BotoCoreError as e:
            if echo:
                print(f"      · BotoCoreError on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.5 * attempt)
            continue

        except Exception as e:
            if echo:
                print(f"      · API error on attempt {attempt}/{retries}: {e}")
            last_err = "api_error"
            last_raw = str(e)
            attempt_logs.append({
                "attempt": attempt,
                "status": last_err,
                "raw": last_raw,
                "input": user,
            })
            time.sleep(1.5 * attempt)
            continue

    # After all retries
    return None, None, last_err, last_raw, attempt, attempt_logs, user

## Full-dataset fetch for CSV Outputs

In [10]:
# =========================
# Full-dataset fetch for CSV outputs
# =========================
def _fetch_all_labeled_results(driver) -> pd.DataFrame:
    """
    Load all currently labeled CITES_TO edges from Neo4j
    and shape them like rows_out / df_results.
    """
    with driver.session(database=NEO4J_DATABASE) as s_all:
        data = s_all.run(Q_ALL_LABELED_REL, {}).data()

    rows_all: List[Dict[str, Any]] = []
    for row in data:
        rows_all.append({
            "Source Case ID": row["src_id"],
            "Source Case Name": row["src_name"] or "",
            "Source Case Decision Date": row.get("src_date") or "",
            "Source Case citation_pipe": row.get("src_cite") or "",
            "Source Case Summary": row.get("src_summary") or "",
            "Target Case ID": row["tgt_id"],
            "Target Case Name": row["tgt_name"] or "",
            "Target Case Decision Date": row.get("tgt_date") or "",
            "Target Case citation_pipe": row.get("tgt_cite") or "",
            "Target Case Summary": row.get("tgt_summary") or "",
            "Opinion Snippet": row.get("snippet_joined") or "",
            "Citation Evaluation": row.get("treatment_label") or "",
            "LLM Rationale": row.get("treatment_rationale") or "",
            COL_URL_HEADER: row.get("src_url") or "",
        })
    return pd.DataFrame(rows_all)


## Batch Driver

In [11]:
# =========================
# Batch driver
# =========================
def label_all_citations(
    *,
    results_csv: bool = False,
    results_csv_filename: str = "edge_classifications_claude.csv",
    batch_size: int = 200,
    echo: bool = True,
    force: bool = False,
    append_to_labeled_dataset_csv: Optional[str] = None,
    labeled_output_csv: Optional[str] = None,   # default = "<labeled_csv_stem> - model comparison.csv"
    compare_with_previous_csv: Optional[str] = None,
    comparison_output_csv: Optional[str] = None, # default = "<prev_csv_stem> - model comparison.csv"
    show_all_labels_in_output_csv: bool = False,
    failed_csv: bool = False,
):
    """
    Classifies CITES_TO edges using snippets stored on the relation as snippet_1, snippet_2, ...
    Prints only: "<Source Name> → <Target Name>: <Label>".

    Writes r.treatment_label / r.treatment_rationale / r.treatment_snippet (joined) to the relation.

    CSV behavior:
      - If show_all_labels_in_output_csv=False (default):
          df_results contains ONLY edges touched in this run.
      - If show_all_labels_in_output_csv=True:
          df_results is replaced with ALL currently labeled edges from Neo4j,
          regardless of whether they were modified in this run.

    Failed CSV:
      - If failed_csv=True:
          writes "failed_citations.csv" with:
            first 11 columns (same as results_csv),
            plus "Attempt", "Input to LLM", "LLM Output", "Type of Error"
          for all edges that ended with label "Unknown" in this run,
          with one row per LLM attempt (for those edges where the LLM was called).
    """
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

    # Best-effort: disable notifications in the session (Neo4j 5+ driver)
    try:
        session = driver.session(database=NEO4J_DATABASE, notifications_min_severity="OFF")
    except TypeError:
        session = driver.session(database=NEO4J_DATABASE)

    rows_out: List[Dict[str, Any]] = []
    failed_rows: List[Dict[str, Any]] = []
    after = -1
    processed = 0
    t0 = time.time()

    # diagnostics counters
    ok_rel = 0
    fail_too_long = fail_json = fail_bad = fail_api = fail_other = 0
    missing_snippets = 0

    # classification tallies (this run)
    pos_cnt = neu_cnt = neg_cnt = unk_cnt = 0

    # First 11 columns to reuse in failed_citations.csv
    first_11_cols = [
        "Source Case ID",
        "Source Case Name",
        "Source Case Decision Date",
        "Source Case citation_pipe",
        "Source Case Summary",
        "Target Case ID",
        "Target Case Name",
        "Target Case Decision Date",
        "Target Case citation_pipe",
        "Target Case Summary",
        "Opinion Snippet",
    ]

    with session as s:
        while True:
            rels = s.run(Q_PAGE_REL, {"after_id": after, "limit": batch_size}).data()
            if not rels:
                if echo:
                    print("All done. No more relations.")
                break

            if echo:
                print(f"\nBatch after rel_id {after}: {len(rels)} relation(s)")

            for row in rels:
                after = row["rel_id"]

                # --- force=False: skip ONLY edges that already have a non-Unknown label ---
                existing_label = (row.get("existing_label") or "").strip()
                if (not force) and existing_label and existing_label.lower() != "unknown":
                    continue

                src_name = row["src_name"] or ""
                tgt_name = row["tgt_name"] or ""
                src_sum  = row.get("src_summary") or ""
                tgt_sum  = row.get("tgt_summary") or ""
                tgt_cit  = _first_citation(row.get("tgt_cite") or "")

                # Gather snippet_1, snippet_2, ... from relation properties
                snippets = _extract_numbered_snippets(row.get("rel_props") or {})
                if not snippets:
                    # No snippets available — mark Unknown and move on
                    reason_text = "No snippets available on the relation (snippet_1..N)."
                    s.run(Q_WRITE_REL_ANNOT, {
                        "src_id": row["src_id"],
                        "tgt_id": row["tgt_id"],
                        "label": "Unknown",
                        "rationale": reason_text,
                        "snippet_joined": "",
                        "model_used": BEDROCK_MODEL_ID
                    })
                    missing_snippets += 1
                    unk_cnt += 1
                    if echo:
                        print(f"{src_name} → {tgt_name}: Unknown")

                    common_row = {
                        "Source Case ID": row["src_id"],
                        "Source Case Name": src_name,
                        "Source Case Decision Date": row.get("src_date") or "",
                        "Source Case citation_pipe": row.get("src_cite") or "",
                        "Source Case Summary": src_sum,
                        "Target Case ID": row["tgt_id"],
                        "Target Case Name": tgt_name,
                        "Target Case Decision Date": row.get("tgt_date") or "",
                        "Target Case citation_pipe": row.get("tgt_cite") or "",
                        "Target Case Summary": tgt_sum,
                        "Opinion Snippet": "",
                    }

                    # record a row anyway for results CSV parity
                    rows_out.append({
                        **common_row,
                        "Citation Evaluation": "Unknown",
                        "LLM Rationale": reason_text,
                        COL_URL_HEADER: row.get("src_url") or "",
                    })

                    # record failed row (Attempt=0 since LLM was not called)
                    failed_rows.append({
                        **common_row,
                        "Attempt": 0,
                        "Input to LLM": "",
                        "LLM Output": "",
                        "Type of Error": _describe_error("missing_snippets"),
                    })

                    processed += 1
                    continue

                # Build labeled block for the LLM and joined snippet for storage/CSV
                snippet_block = _format_snippet_block_labeled(snippets)
                joined_for_store = "\n\n".join(snippets)

                lab, rat, status, raw, attempt_used, attempt_logs, user_input = classify_with_bedrock(
                    citing_case_name=src_name, citing_case_summary=src_sum,
                    cited_case_name=tgt_name, cited_case_citation=tgt_cit, cited_case_summary=tgt_sum,
                    snippet_block_labeled=snippet_block,
                    max_new_tokens=340,
                    retries=3,
                    echo=echo
                )

                common_row = {
                    "Source Case ID": row["src_id"],
                    "Source Case Name": src_name,
                    "Source Case Decision Date": row.get("src_date") or "",
                    "Source Case citation_pipe": row.get("src_cite") or "",
                    "Source Case Summary": src_sum,
                    "Target Case ID": row["tgt_id"],
                    "Target Case Name": tgt_name,
                    "Target Case Decision Date": row.get("tgt_date") or "",
                    "Target Case citation_pipe": row.get("tgt_cite") or "",
                    "Target Case Summary": tgt_sum,
                    "Opinion Snippet": joined_for_store,
                }

                if lab:
                    # Write clean classification + rationale to Neo4j
                    s.run(Q_WRITE_REL_ANNOT, {
                        "src_id": row["src_id"],
                        "tgt_id": row["tgt_id"],
                        "label": lab,
                        "rationale": rat,
                        "snippet_joined": joined_for_store,
                        "model_used": BEDROCK_MODEL_ID
                    })
                    ok_rel += 1
                    if lab == "Positive":
                        pos_cnt += 1
                    elif lab == "Neutral":
                        neu_cnt += 1
                    elif lab == "Negative":
                        neg_cnt += 1

                    if echo:
                        print(f"{src_name} → {tgt_name}: {lab}")

                    # Results CSV row (schema like the old pipeline)
                    rows_out.append({
                        **common_row,
                        "Citation Evaluation": lab,
                        "LLM Rationale": rat,
                        COL_URL_HEADER: row.get("src_url") or "",
                    })
                else:
                    # Unknown with reason
                    if status == "too_long":
                        fail_too_long += 1
                    elif status == "json_parse_failed":
                        fail_json += 1
                    elif status == "bad_keys_or_values":
                        fail_bad += 1
                    elif status in ("api_error", "api_response_error", "api_error_throttled"):
                        fail_api += 1
                    else:
                        fail_other += 1

                    unk_cnt += 1

                    rationale_text = f"Classification failed: {status}"
                    s.run(Q_WRITE_REL_ANNOT, {
                        "src_id": row["src_id"],
                        "tgt_id": row["tgt_id"],
                        "label": "Unknown",
                        "rationale": rationale_text,
                        "snippet_joined": joined_for_store,
                        "model_used": BEDROCK_MODEL_ID
                    })
                    if echo:
                        print(f"{src_name} → {tgt_name}: Unknown")

                    rows_out.append({
                        **common_row,
                        "Citation Evaluation": "Unknown",
                        "LLM Rationale": rationale_text,
                        COL_URL_HEADER: row.get("src_url") or "",
                    })

                    # Record one failed row per attempt
                    for log in attempt_logs:
                        failed_rows.append({
                            **common_row,
                            "Attempt": log.get("attempt", 0),
                            "Input to LLM": log.get("input", ""),
                            "LLM Output": log.get("raw", ""),
                            "Type of Error": _describe_error(log.get("status", "")),
                        })

                processed += 1

    # Build df_results from this run
    df_results = pd.DataFrame(rows_out)

    # Optionally replace with ALL currently labeled edges from Neo4j
    if show_all_labels_in_output_csv:
        df_results = _fetch_all_labeled_results(driver)
        if echo:
            print(f"\nshow_all_labels_in_output_csv=True → df_results replaced with full labeled dataset (rows: {len(df_results)})")
    else:
        if echo:
            print(f"\nshow_all_labels_in_output_csv=False → df_results only has rows from this run (rows: {len(df_results)})")

    # -------- Optional CSV of results --------
    if results_csv:
        df_results.to_csv(results_csv_filename, index=False)
        print(f"\nWrote {len(df_results)} rows → {results_csv_filename}")
    else:
        print(f"\nSkipping results CSV write (results_csv=False). Rows buffered in df_results: {len(df_results)}")

    # -------- Optional CSV of failed classifications --------
    if failed_csv:
        df_failed = pd.DataFrame(failed_rows)
        if not df_failed.empty:
            # Ensure column order: first 11 columns, then Attempt, Input to LLM, LLM Output, Type of Error
            ordered_cols = first_11_cols + [
                c for c in ["Attempt", "Input to LLM", "LLM Output", "Type of Error"]
                if c in df_failed.columns
            ]
            # Add any missing columns as empty
            for c in ordered_cols:
                if c not in df_failed.columns:
                    df_failed[c] = ""
            df_failed = df_failed[ordered_cols]
            df_failed.to_csv("failed_citations.csv", index=False)
            print(f"Wrote {len(df_failed)} failed rows → failed_citations.csv")
        else:
            print("No failed classifications in this run; failed_citations.csv not written.")
    else:
        if echo:
            print(f"Skipping failed CSV write (failed_csv=False). Failed rows in memory: {len(failed_rows)}")

    # -------- Optional labeled dataset join (no duplicate case-name columns) --------
    if append_to_labeled_dataset_csv:
        if labeled_output_csv is None:
            stem = pathlib.Path(append_to_labeled_dataset_csv).stem
            labeled_output_csv = f"{stem} - model comparison.csv"

        try:
            df_label = _read_csv_fallback(append_to_labeled_dataset_csv)
            print(f"Loaded labeled dataset with fallback reader: {append_to_labeled_dataset_csv}")
        except Exception as e:
            df_label = None
            print(f"Could not read labeled dataset CSV: {e}")

        if df_label is not None:
            # Labeled dataset must have these columns:
            # ["Source ID","Source Name","Target ID","Target Name","Chunk","Label","Rationale"]
            expected = ["Source ID","Source Name","Target ID","Target Name","Chunk","Label","Rationale"]
            missing = [c for c in expected if c not in df_label.columns]
            if missing:
                print(f"Labeled CSV missing required columns: {missing}. Skipping append.")
            else:
                # 1) Standardize labeled columns (avoid later header collisions)
                df_label_std = df_label.rename(columns={
                    "Source ID":   "Source Case ID",
                    "Target ID":   "Target Case ID",
                    "Source Name": "Source Case Name (Labeled)",
                    "Target Name": "Target Case Name (Labeled)",
                    "Rationale":   "Rational"  # keep user's original header for labeled rationale
                })

                # 2) Prepare join frame from current df_results with MODEL-suffixed columns
                df_join = df_results.rename(columns={
                    "Source Case Name":        "Source Case Name (Model)",
                    "Target Case Name":        "Target Case Name (Model)",
                    "Source Case Summary":     "Source Case Summary (Model)",
                    "Target Case Summary":     "Target Case Summary (Model)",
                    "Opinion Snippet":         "Chunk (Pulled by Pipeline)",
                    "Citation Evaluation":     "Label (Model)",
                    "LLM Rationale":           "Rational (Model)",
                })

                needed_from_results = [
                    "Source Case ID", "Target Case ID",
                    "Source Case Name (Model)", "Target Case Name (Model)",
                    "Source Case Decision Date", "Source Case citation_pipe",
                    "Source Case Summary (Model)",
                    "Target Case Decision Date", "Target Case citation_pipe",
                    "Target Case Summary (Model)",
                    "Chunk (Pulled by Pipeline)",
                    "Label (Model)", "Rational (Model)",
                    COL_URL_HEADER,
                ]
                for c in needed_from_results:
                    if c not in df_join.columns:
                        df_join[c] = ""

                # 3) Merge labeled ↔ model outputs on IDs
                merged = df_label_std.merge(
                    df_join[needed_from_results],
                    on=["Source Case ID", "Target Case ID"],
                    how="inner"
                )

                # 4) Build unified case-name columns (prefer labeled, else model)
                def _prefer_labeled(lbl, mdl):
                    lbl_series = merged.get(lbl)
                    mdl_series = merged.get(mdl)
                    if lbl_series is None and mdl_series is None:
                        return pd.Series([], dtype="object")
                    if lbl_series is None:
                        return mdl_series
                    if mdl_series is None:
                        return lbl_series
                    # treat empty/whitespace as missing
                    lbl_clean = lbl_series.fillna("").astype(str)
                    use_lbl = ~lbl_clean.str.fullmatch(r"\s*")
                    out = mdl_series.copy()
                    out[use_lbl] = lbl_series[use_lbl]
                    return out

                merged["Source Case Name"] = _prefer_labeled("Source Case Name (Labeled)", "Source Case Name (Model)")
                merged["Target Case Name"] = _prefer_labeled("Target Case Name (Labeled)", "Target Case Name (Model)")

                # 5) Assemble final ordered columns (single 'Source/Target Case Name')
                final_cols = [
                    "Source Case ID",
                    "Source Case Name",
                    "Source Case Decision Date",
                    "Source Case citation_pipe",
                    "Source Case Summary (Model)",
                    "Target Case ID",
                    "Target Case Name",
                    "Target Case Decision Date",
                    "Target Case citation_pipe",
                    "Target Case Summary (Model)",
                    "Chunk",                        # from labeled dataset
                    "Chunk (Pulled by Pipeline)",   # from model/pipeline
                    "Label",                        # from labeled dataset
                    "Label (Model)",                # from model
                    "Rational",                     # from labeled dataset
                    "Rational (Model)",             # from model
                    COL_URL_HEADER,
                ]
                for c in final_cols:
                    if c not in merged.columns:
                        merged[c] = ""

                # 6) Drop the temporary name columns and write CSV
                drop_cols = [
                    "Source Case Name (Labeled)", "Source Case Name (Model)",
                    "Target Case Name (Labeled)", "Target Case Name (Model)"
                ]
                merged = merged[[c for c in final_cols if c in merged.columns]]
                for dc in drop_cols:
                    if dc in merged.columns:
                        merged = merged.drop(columns=[dc])

                merged.to_csv(labeled_output_csv, index=False)
                print(f"Wrote labeled+model comparison CSV → {labeled_output_csv} (rows: {len(merged)})")

    # -------- Optional previous vs. recent comparison (schema like before) --------
    if compare_with_previous_csv:
        if comparison_output_csv is None:
            stem = pathlib.Path(compare_with_previous_csv).stem
            comparison_output_csv = f"{stem} - model comparison.csv"

        try:
            df_prev = pd.read_csv(compare_with_previous_csv)
        except Exception as e:
            df_prev = None
            print(f"Could not read previous model CSV: {e}")

        if df_prev is not None:
            prev_cols_map = {
                "Source Case Summary": "Source Case Summary (Previous Model)",
                "Target Case Summary": "Target Case Summary (Previous Model)",
                "Opinion Snippet": "Opinion Snippet (Previous Model)",
                "Citation Evaluation": "Citation Evaluation (Previous Model)",
                "LLM Rationale": "LLM Rationale (Previous Model)",
            }
            cur_cols_map = {
                "Source Case Summary": "Source Case Summary (Recent Model)",
                "Target Case Summary": "Target Case Summary (Recent Model)",
                "Opinion Snippet": "Opinion Snippet (Recent Model)",
                "Citation Evaluation": "Citation Evaluation (Recent Model)",
                "LLM Rationale": "LLM Rationale (Recent Model)",
            }

            prev_ren = df_prev.rename(columns=prev_cols_map)
            cur_ren  = df_results.rename(columns=cur_cols_map)

            join_keys = ["Source Case ID", "Target Case ID"]
            for k in join_keys:
                if k not in prev_ren.columns and k in df_prev.columns:
                    prev_ren[k] = df_prev[k]
                if k not in cur_ren.columns and k in df_results.columns:
                    cur_ren[k] = df_results[k]

            comp = prev_ren.merge(cur_ren, on=join_keys, how="outer")

            ordered = [
                "Source Case ID",
                "Source Case Name",
                "Source Case Decision Date",
                "Source Case citation_pipe",
                "Source Case Summary (Previous Model)",
                "Source Case Summary (Recent Model)",
                "Target Case ID",
                "Target Case Name",
                "Target Case Decision Date",
                "Target Case citation_pipe",
                "Target Case Summary (Previous Model)",
                "Target Case Summary (Recent Model)",
                "Opinion Snippet (Previous Model)",
                "Opinion Snippet (Recent Model)",
                "Citation Evaluation (Previous Model)",
                "Citation Evaluation (Recent Model)",
                "LLM Rationale (Previous Model)",
                "LLM Rationale (Recent Model)",
                COL_URL_HEADER,
            ]
            for c in ordered:
                if c not in comp.columns:
                    comp[c] = ""
            comp = comp[ordered]
            comp.to_csv(comparison_output_csv, index=False)
            print(f"Wrote previous vs recent model comparison CSV → {comparison_output_csv} (rows: {len(comp)})")

    # -------- Global label totals (after this run) --------
    with driver.session(database=NEO4J_DATABASE) as s_counts:
        row_counts = s_counts.run(Q_COUNT_LABELS, {}).single()
        total_pos = row_counts["pos_total"] or 0
        total_neu = row_counts["neu_total"] or 0
        total_neg = row_counts["neg_total"] or 0
        total_unk = row_counts["unk_total"] or 0

    # Summary
    dt_min = (time.time() - t0)/60
    print(f"\nProcessed relations: {processed} | Elapsed: {dt_min:.1f} min")
    print("=== Diagnostics ===")
    print(f"  Successful (classified): {ok_rel}")
    print(f"  Missing snippets:        {missing_snippets}")
    print(f"  Fail (too long):         {fail_too_long}")
    print(f"  Fail (JSON parse):       {fail_json}")
    print(f"  Fail (bad keys/values):  {fail_bad}")
    print(f"  Fail (API error):        {fail_api}")
    print(f"  Fail (other):            {fail_other}")
    print("=== Label counts (this run) ===")
    print(f"  Positive: {pos_cnt}")
    print(f"  Neutral : {neu_cnt}")
    print(f"  Negative: {neg_cnt}")
    print(f"  Unknown : {unk_cnt}")
    print("=== Dataset label totals (after run) ===")
    print(f"  Positive: {total_pos}")
    print(f"  Neutral : {total_neu}")
    print(f"  Negative: {total_neg}")
    print(f"  Unknown : {total_unk}")


# =========================
# Example call
# =========================
# label_all_citations(
#     force=False,
#     echo=True,
#     results_csv=True,
#     results_csv_filename="edge_classifications_claude.csv",
#     append_to_labeled_dataset_csv="Phase One Final Labeled Dataset from WK.csv",
#     labeled_output_csv=None,
#     compare_with_previous_csv=None,
#     comparison_output_csv=None,
#     show_all_labels_in_output_csv=True,
#     failed_csv=True,
# )

## Example Run

In [12]:
# Assumes relations already have snippet_1, snippet_2, ... properties.
label_all_citations(force=True,
                    echo=True,
                    results_csv=True, results_csv_filename="edge_classifications_snippet_method - Claude.csv",
                    show_all_labels_in_output_csv=True,
                    failed_csv=True)


Batch after rel_id -1: 68 relation(s)
      · attempt 1/3
Hastings v. Papillion-LaVista School District → Wisbey v. City of Lincoln, Neb.: Positive
      · attempt 1/3
Kirkeberg v. Canadian Pacific Railway → Equal Employment Opportunity Commission, and Judith Keane, Intervening v. Sears, Roebuck & Company: Positive
      · attempt 1/3
Connie M. Gretillat v. Care Initiatives → Ellen Fjellestad v. Pizza Hut of America, Inc.: Neutral
      · attempt 1/3
Keane, Judith v. Sears Roebuck → US Airways, Inc. v. Barnett: Positive
      · attempt 1/3
Equal Employment Opportunity Commission, and Judith Keane, Intervening v. Sears, Roebuck & Company → US Airways, Inc. v. Barnett: Positive
      · attempt 1/3
      · API ClientError on attempt 1/3: ThrottlingException - Too many requests, please wait before trying again.
      · throttled, backing off for 2.32 seconds
      · attempt 2/3
Kiphart v. Saturn Corp. → Roger Monette and Doris Monette v. Electronic Data Systems Corporation: Positive
     

## Compare with labeled dataset

In [13]:
# label_all_citations(force=True,
#                     echo=True,
#                     append_to_labeled_dataset_csv="Phase One Final Labeled Dataset from WK.csv",
#                     labeled_output_csv="WK Labeled vs Snippet Method Model Comparison - Claude.csv",
#                     show_all_labels_in_output_csv=True,
#                     failed_csv=True)