# Case Labeler – Continuous Time Weights

**Purpose**

* Compute a **case-level label** (`"Good"`, `"Bad"`, `"Moderate"`, or `"Unknown"`) for each `(:Case)` node based on:

  * Incoming `(:Case)-[:CITES_TO]->(:Case)` edges
  * **Court level** of the citing case
  * **Continuous, recency-based time weights** on citations
  * Optionally, **jurisdiction-specific weights** on citations
  * **Configurable thresholds** and **label priority**
* Write the label, decision level, and a **human-readable rationale** back to the `Case` node.
* Optionally, produce a **CSV** with detailed per-court-level metrics.

---

## What this script uses

**From Neo4j**

* `(:Case)` nodes:

  * `c.id` (internal case ID used in the graph)
  * `c.name` (case name, may be empty string)
* `(:Case)-[:CITES_TO]->(:Case)` edges:

  * `r.treatment_label` (edge-level label: Positive / Neutral / Negative / Unknown)
  * `src.decision_date` (decision date of the citing case)
* `(:Case)-[:HEARD_IN]->(:Court)` edges:

  * `ct.court_level` (integer 1–5, where lower usually means higher court)
* `(:Case)-[:UNDER_JURISDICTION]->(:Jurisdiction)` edges:

  * `j.jurisdiction_name` (used for optional jurisdiction weights)

**Environment**

* Loaded from `../.env`:

  * `NEO4J_URI`
  * `NEO4J_USERNAME`
  * `NEO4J_PASSWORD`
  * `NEO4J_DATABASE` (optional, default `"neo4j"`)

---

## How it works (high level)

1. **Collect global time statistics**

   * Query all citing cases with:

     ```cypher
     MATCH (s:Case)-[:CITES_TO]->(:Case)
     WHERE s.decision_date IS NOT NULL
     RETURN DISTINCT s.decision_date AS decision_date
     ```

   * Convert `decision_date` to Python ordinals (integers).

   * Compute:

     * Lower quartile $Q_1$
     * Median $Q_2$
     * Upper quartile $Q_3$
     * Global minimum and maximum dates.

   * By default (`default_tmin_tmax=True`):

     * $t_{\min} = Q_1$
     * $t_{\max} = \max(\text{decision date})$ over citing cases.

2. **Page through Case nodes**

   * Use `Q_PAGE_CASES` to page `Case` nodes by `id(c)` in batches of `_CASE_BATCH_SIZE` (default 200).
   * Only include cases where:

     * `force=True`, or
     * `c.case_label IS NULL`.

3. **Collect incoming citations per case**

   * For each case:

     ```cypher
     MATCH (src:Case)-[r:CITES_TO]->(tgt:Case {id:$case_id})
     OPTIONAL MATCH (src)-[:HEARD_IN]->(ct:Court)
     OPTIONAL MATCH (src)-[:UNDER_JURISDICTION]->(j:Jurisdiction)
     RETURN r.treatment_label   AS label,
            src.decision_date   AS decision_date,
            ct.court_level      AS court_level,
            j.jurisdiction_name AS jurisdiction_name
     ```

   * Group edges by `court_level` (1–5).
   * Skip edges with missing or invalid `court_level`.

4. **Compute metrics per court level**

   For each court level $L$:

   * Normalize `r.treatment_label` into one of:

     * `"Positive"`, `"Negative"`, `"Neutral"`, `"Unknown"`

   * For each edge $e$ at level $L$:

     * Compute a **continuous time weight** $\alpha(e)$ based on its citing `decision_date`
       (and any configured jurisdiction weight for its `jurisdiction_name`).
     * Add that weight to the appropriate label’s bucket.

   * Compute, per label:

     * `counts[label]` = number of edges
     * `weights[label]` = sum of time + jurisdiction weights
     * `proportions[label]` = weighted share in the denominator

   * Store in:

     * `per_level_metrics[level]` (counts, weights, proportions, denom)
     * `per_level_counts[level]` (raw counts, total)

5. **Decide a label for each court level**

   * For each level, compare **proportions** to **configured thresholds**:

     * `Pos_p` for `"Positive"`
     * `Neg_p` for `"Negative"`
     * `Neu_p` for `"Neutral"`
     * `Unk_p` for `"Unknown"` (if `include_unknown=True`)

   * A label becomes a **candidate** if:

     $$ p_{\text{label}} ;\ge; \text{threshold}_{\text{label}} $$

   * If multiple labels are candidates, apply **label priority** to pick a single **driver label**:

     * Default priority: `Unknown > Negative > Neutral > Positive`
     * Or a custom order such as `["pos", "neg", "neu", "unk"]`.

   * Map the driver label to a **case label**:

     * `"Positive"` $\to$ `"Good"`
     * `"Negative"` $\to$ `"Bad"`
     * `"Neutral"`  $\to$ `"Moderate"`
     * `"Unknown"`  $\to$ `"Unknown"`

6. **Walk court levels (highest to lowest)**

   * Identify all levels that have any citations.
   * Determine the **highest level with citations** (numerically smallest).
   * If `lower_level_court=False`:

     * Only use the highest level.
     * If no label at that level passes its threshold, classify as `"Moderate"`.
   * If `lower_level_court=True`:

     * Walk levels from **highest to lowest**.
     * At each level, try to decide a label.
     * Stop at the first level where a label passes its threshold.
     * If no level yields a label, default to `"Moderate"` at the lowest level with citations.

7. **Build a human-readable rationale**

   For each case, `label_rationale` includes:

   * Case name + **total number of incoming citations**.
   * A **per-court-level citation count** summary, e.g.
     `Supreme Court: X, Court of Appeals: Y, …`.
   * The **decision court level** and a breakdown at that level:

     * Weighted proportions (Positive / Negative / Neutral, and optionally Unknown).
     * Raw counts for each label.
   * A short explanation of **why** that label was chosen:

     * Which label was dominant and met its threshold.
     * Or that no label met thresholds, so the case is `"Moderate"`.
   * If a lower court was used:

     * A note that the highest level had **mixed / balanced signals**, so the decision moved down.

   The rationale also reports:

   * The **configured thresholds** $(\text{Pos}_p, \text{Neg}_p, \text{Neu}_p, \text{Unk}_p)$.
   * The **label priority order** used at the decision level.

   When `include_unknown=False`:

   * Unknown is **not** included in the main share and count breakdown.
   * Threshold explanation focuses on Positive / Negative / Neutral.

8. **Write results back to Neo4j**

   * For each case, run:

     ```cypher
     MATCH (c:Case {id:$case_id})
     SET c.case_label                      = $case_label,
         c.court_level_case_label_decision = $decision_level,
         c.label_rationale                 = $label_rationale,
         c.updated_at_utc                  = datetime()
     RETURN c.id AS case_id
     ```

   * Fields:

     * `c.case_label` ∈ `"Good" | "Bad" | "Moderate" | "Unknown"`.
     * `c.court_level_case_label_decision` = **human-readable court name**
       (e.g. `"Supreme Court"`, `"Court of Appeals"`) where the decision was made,
       or `null`/empty if there is no decision court.
     * `c.label_rationale` = long explanatory text.
     * `c.updated_at_utc` = timestamp.

9. **Optional CSV export**

   * If `results_csv=True`, build one row per case with:

     * Case metadata (`Case ID`, `Case Name`, `Total Number of Citations`).
     * For each court level 1–5:

       * Counts per label.
       * Time + jurisdiction weighted sums per label.
       * Proportions per label.
     * `Court Level Decision`
     * `Case Label`
     * `Rationale`

   * Written to `results_csv_filename` (default `"case_labeled_results.csv"`).

---

## Continuous time weighting

### Recency window

Let $t_e$ be the citing decision date (converted to ordinal) for edge $e$.

The code defines a **recency window** $[t_{\min}, t_{\max}]$.

* **Default** (`default_tmin_tmax=True`):

  * $t_{\min}$ = lower quartile $Q_1$ of all citing decision dates
  * $t_{\max}$ = maximum citing decision date in the dataset

* **Custom** (`default_tmin_tmax=False`):

  * You pass `tmin_tmax = [t_{min}, t_{max}]`
  * Both are parsed into ordinals; must satisfy $t_{\max} > t_{\min}$.

### Normalized recency

For an edge $e$, with ordinal date $t_e$, the code computes a **normalized recency**:

$$
r_e = \frac{t_e - t_{\min}}{t_{\max} - t_{\min}}
$$

Then it clamps $r_e$ to the range $[0, 1]$:

* If $t_e \le t_{\min}$, then $r_e = 0$
* If $t_e \ge t_{\max}$, then $r_e = 1$

If the recency window is not valid (missing dates or $t_{\max} \le t_{\min}$), it falls back to no recency effect, i.e. all edges get base weight $1$ (before any jurisdiction weight).

### Time weight range

The **base time-weight component** is always in the range $[1, \text{MAX WEIGHT}]$.

* **Default** (`default_time_weight=True`):

  $$ \text{MAX WEIGHT} = 2.5 $$

  So the base component is $\in [1, 2.5]$, and a **fully recent** edge ($r_e = 1$) has base weight $2.5$.

* **Custom** (`default_time_weight=False`):

  * You pass:

    ```python
    time_weight = [1.0, MAX_WEIGHT]
    ```

  * The first element must be $1.0$ (minimum base weight).

  * The second element sets $\text{MAX WEIGHT} \ge 1.0$.

When **no jurisdiction weights** are configured, the final $\alpha(e)$ equals this base component.  
When jurisdiction weights are configured, $\alpha(e)$ can be **larger than** `MAX_WEIGHT` (see below).

### Linear vs non-linear recency

Let $\text{MAX WEIGHT}$ be the maximum base weight.

1. **Linear recency** (`non_linear_recency_effect=False`, default)

   Base component:

   $$ \alpha_{\text{base}}(e) = 1 + (\text{MAX WEIGHT} - 1)\, r_e $$

2. **Quadratic recency** (`non_linear_recency_effect=True`)

   Base component:

   $$ \alpha_{\text{base}}(e) = 1 + (\text{MAX WEIGHT} - 1)\, r_e^2 $$

   For $0 < r_e < 1$, $r_e^2 < r_e$, so mid-range dates get **less** boost and very recent edges get a **stronger relative boost**.

In both cases, the final weight is:

$$
\alpha(e) = \alpha_{\text{base}}(e) + J_i
$$

where $J_i$ is an optional **jurisdiction weight** (default $J_i = 0$).  
If the recency window is invalid or the date is missing, the code uses $\alpha_{\text{base}}(e) = 1$ for that edge (so $\alpha(e) = 1 + J_i$).

---

## Weights and proportions (per court level)

Fix a single court level $L$. Let:

* $\mathcal{E}_{\text{Pos}}$, $\mathcal{E}_{\text{Neg}}$, $\mathcal{E}_{\text{Neu}}$, $\mathcal{E}_{\text{Unk}}$
  be the sets of incoming `:CITES_TO` edges at level $L$ with labels
  **Positive**, **Negative**, **Neutral**, **Unknown**, respectively.

Let $\alpha(e)$ be the full weight for edge $e$ (time + jurisdiction).

### Weighted sums

For each label:

$$
w_{\text{Pos}} = \sum_{e \in \mathcal{E}_{\text{Pos}}} \alpha(e), \qquad
w_{\text{Neg}} = \sum_{e \in \mathcal{E}_{\text{Neg}}} \alpha(e),
$$

$$
w_{\text{Neu}} = \sum_{e \in \mathcal{E}_{\text{Neu}}} \alpha(e), \qquad
w_{\text{Unk}} = \sum_{e \in \mathcal{E}_{\text{Unk}}} \alpha(e).
$$

If `include_unknown=False`, the code **still** counts Unknown edges, but **does not** add their weights into the denominator.

### Proportions

If `include_unknown=True`, the denominator is:

$$
D = w_{\text{Pos}} + w_{\text{Neg}} + w_{\text{Neu}} + w_{\text{Unk}},
$$

and the proportions are:

$$
p_{\text{Pos}} = \frac{w_{\text{Pos}}}{D}, \quad
p_{\text{Neg}} = \frac{w_{\text{Neg}}}{D}, \quad
p_{\text{Neu}} = \frac{w_{\text{Neu}}}{D}, \quad
p_{\text{Unk}} = \frac{w_{\text{Unk}}}{D}.
$$

If `include_unknown=False`, the denominator excludes Unknown:

$$
D = w_{\text{Pos}} + w_{\text{Neg}} + w_{\text{Neu}},
$$

and the proportions are:

$$
p_{\text{Pos}} = \frac{w_{\text{Pos}}}{D}, \quad
p_{\text{Neg}} = \frac{w_{\text{Neg}}}{D}, \quad
p_{\text{Neu}} = \frac{w_{\text{Neu}}}{D},
$$

while $p_{\text{Unk}}$ is treated as $0$ for scoring (though Unknown counts are still reported in CSV and rationale when relevant).

---

## Label thresholds and priorities

### Thresholds

Configured via `label_thresholds`:

```python
label_thresholds = {
    "Pos_p": 0.55,  # Positive proportion threshold
    "Neg_p": 0.55,  # Negative proportion threshold
    "Neu_p": 0.55,  # Neutral proportion threshold
    "Unk_p": 0.55,  # Unknown proportion threshold
}
```

* Default (if not provided): $0.55$ for all four.
* Per court level:

  * Compute $(p_{\text{Pos}}, p_{\text{Neg}}, p_{\text{Neu}}, p_{\text{Unk}})$.
  * A label is a **candidate** if $p_{\text{label}} \ge \text{threshold}_{\text{label}}$.

### Label priority

Used to break ties when **multiple labels pass their thresholds** at the same level.

**Default** (`default_label_priority=True`):

* Priority (from highest to lowest):

  ```text
  Unknown > Negative > Neutral > Positive
  ```

* Internally:

  ```python
  ["unk", "neg", "neu", "pos"]
  ```

* If `include_unknown=False`, `"Unknown"` is removed from the effective order.

**Custom** (`default_label_priority=False`):

* You provide `label_priority`, e.g.:

  ```python
  label_priority = ["pos", "neg", "neu", "unk"]
  ```

* Accepted values (case-insensitive):

  * `pos`, `positive`, `good`
  * `neg`, `negative`, `bad`
  * `neu`, `neutral`, `moderate`, `mod`
  * `unk`, `unknown`

* These are mapped to canonical labels:

  * `"Positive"`, `"Negative"`, `"Neutral"`, `"Unknown"`

The script then:

* Removes duplicates.
* Drops `"Unknown"` if `include_unknown=False`.
* Uses this ordered list to pick **one driver label** from the candidate set at each level.

---

## Jurisdiction weights (optional)

You can give extra weight to citations from specific jurisdictions via the `jurisdictions` argument:

```python
jurisdictions = {
    "California": 1.0,      # fixed extra weight J_i
    "Alabama":   "Default", # interpreted as MAX_WEIGHT / 2
}
```

* Keys must match `j.jurisdiction_name` in Neo4j.
* Values can be:

  * A float (extra weight $J_i$ added to every edge from that jurisdiction), or
  * The string `"Default"`, which the script interprets as `MAX_WEIGHT / 2`.

For each edge $e$ from a jurisdiction with configured weight $J_i$:

* The final weight is $\alpha(e) = \alpha_{\text{base}}(e) + J_i$.

If `jurisdictions` is `None` or empty, all $J_i = 0$ and weights are purely time-based.

---

## Key parameters (continuous time version)

```python
label_all_cases(
    *,
    force: bool = False,
    echo: bool = False,
    lower_level_court: bool = True,
    include_unknown: bool = True,
    label_thresholds: Optional[Dict[str, float]] = None,
    default_label_priority: bool = True,
    label_priority: Optional[List[str]] = None,
    default_tmin_tmax: bool = True,
    tmin_tmax: Optional[List[Any]] = None,
    default_time_weight: bool = True,
    time_weight: Optional[List[float]] = None,
    non_linear_recency_effect: bool = False,
    jurisdictions: Optional[Dict[str, Any]] = None,
    results_csv: bool = False,
    results_csv_filename: str = "case_labeled_results.csv",
)
```

* **Recency window**

  * `default_tmin_tmax=True`:

    * Uses $t_{\min} = Q_1$, $t_{\max} = \max(\text{decision date})$.
  * `default_tmin_tmax=False`:

    * You must set `tmin_tmax = [tmin, tmax]`.

* **Time weights**

  * `default_time_weight=True`:

    * Uses base $\alpha_{\text{base}}(e) \in [1, 2.5]`.
  * `default_time_weight=False`:

    * You must set `time_weight = [1.0, MAX_WEIGHT]`.

* **Shape of recency effect**

  * `non_linear_recency_effect=False`:

    * Linear: $\alpha_{\text{base}}(e) = 1 + (\text{MAX WEIGHT} - 1), r_e$.
  * `non_linear_recency_effect=True`:

    * Quadratic: $\alpha_{\text{base}}(e) = 1 + (\text{MAX WEIGHT} - 1), r_e^2$.

* **Jurisdiction weights**

  * `jurisdictions=None`:

    * No jurisdiction boost; $\alpha(e) = \alpha_{\text{base}}(e)$.
  * `jurisdictions={...}`:

    * Adds $J_i$ per jurisdiction: $\alpha(e) = \alpha_{\text{base}}(e) + J_i$.

The rest of the behavior (thresholding, priority, lower court walk, rationale, CSV) matches the described logic, using these continuous, date-based weights (and optional jurisdiction boosts) instead of discrete time buckets.


In [1]:
# Install (if applicable)
! pip install neo4j



In [2]:
# =========================
# Imports
# =========================
import os
import time
import logging
from datetime import date, datetime
from typing import Dict, Any, Optional, List, Tuple
from collections import defaultdict

import pandas as pd
from neo4j import GraphDatabase
from dotenv import load_dotenv

# Quiet noisy logs (incl. Neo4j notifications/deprecations)
for _n in ("neo4j", "neo4j.notifications", "neo4j.work.simple"):
    logging.getLogger(_n).setLevel(logging.ERROR)
os.environ.setdefault("NEO4J_DRIVER_LOG_LEVEL", "ERROR")

# =========================
# Config / ENV
# =========================
# .env is always one level up from this notebook/script
load_dotenv("../.env", override=True)
NEO4J_URI       = os.getenv("NEO4J_URI")
NEO4J_USERNAME  = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD  = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE  = os.getenv("NEO4J_DATABASE", "neo4j")

if not (NEO4J_URI and NEO4J_USERNAME and NEO4J_PASSWORD):
    raise RuntimeError(
        "Missing Neo4j connection settings. "
        "Check ../.env for NEO4J_URI / NEO4J_USERNAME / NEO4J_PASSWORD."
    )

# Internal batch sizes
_CASE_BATCH_SIZE = 200
_EDGE_BATCH_SIZE = 2000  # for paging CITES_TO edges

In [3]:
# =========================
# Jurisdictions and Court Levels
# =========================

VALID_JURISDICTIONS = {
    "Alabama",
    "Alaska",
    "Alaska Court of Appeal",
    "Arizona",
    "Arizona Court of Appeal",
    "Arkansas",
    "Arkansas Court of Appeal",
    "Board of Immigration Appeals",
    "California",
    "California Court of Appeal",
    "Colorado",
    "Colorado Court of Appeal",
    "Connecticut",
    "Delaware",
    "Federal Supreme Court",
    "Florida",
    "Florida Court of Appeal",
    "Georgia",
    "Georgia Court of Appeal",
    "Hawaii",
    "Idaho",
    "Idaho Court of Appeal",
    "Illinois",
    "Indiana",
    "Indiana Court of Appeal",
    "Iowa",
    "Iowa Court of Appeal",
    "Kansas",
    "Kansas Court of Appeal",
    "Kentucky",
    "Kentucky Court of Appeal",
    "Louisiana",
    "Louisiana Court of Appeal",
    "Maine",
    "Maryland",
    "Massachusetts",
    "Merit Systems Protection Board",
    "Michigan",
    "Michigan Court of Appeal",
    "Minnesota",
    "Minnesota Court of Appeal",
    "Mississippi",
    "Mississippi Court of Appeal",
    "Missouri",
    "Missouri Court of Appeal",
    "Montana",
    "Nebraska",
    "Nebraska Court of Appeal",
    "Nevada",
    "New Hampshire",
    "New Jersey",
    "New Jersey Court of Appeal",
    "New Mexico",
    "New Mexico Court of Appeal",
    "New York",
    "North Carolina",
    "North Carolina Court of Appeal",
    "North Dakota",
    "Northern Mariana Islands",
    "Office of Legal Counsel",
    "Ohio",
    "Ohio Court of Appeal",
    "Oklahoma",
    "Oregon",
    "Oregon Court of Appeal",
    "Pennsylvania",
    "Puerto Rico",
    "Rhode Island",
    "South Carolina",
    "South Carolina Court of Appeal",
    "South Dakota",
    "Tennessee",
    "Tennessee Court of Appeal",
    "Texas",
    "U.S. Court of Appeals for the Armed Forces",
    "U.S. Court of Appeals for the D.C. Circuit",
    "U.S. Court of Appeals for the Eighth Circuit",
    "U.S. Court of Appeals for the Eleventh Circuit",
    "U.S. Court of Appeals for the Federal Circuit",
    "U.S. Court of Appeals for the Fifth Circuit",
    "U.S. Court of Appeals for the First Circuit",
    "U.S. Court of Appeals for the Fourth Circuit",
    "U.S. Court of Appeals for the Ninth Circuit",
    "U.S. Court of Appeals for the Second Circuit",
    "U.S. Court of Appeals for the Seventh Circuit",
    "U.S. Court of Appeals for the Sixth Circuit",
    "U.S. Court of Appeals for the Tenth Circuit",
    "U.S. Court of Appeals for the Third Circuit",
    "U.S. Court of Appeals for Veterans Claims",
    "U.S. Court of Federal Claims",
    "U.S. Court of International Trade",
    "U.S. District Court for the Central District of California",
    "U.S. District Court for the District of Colorado",
    "U.S. District Court for the District of Columbia",
    "U.S. District Court for the District of Hawaii",
    "U.S. District Court for the District of Maryland",
    "U.S. District Court for the District of Massachusetts",
    "U.S. District Court for the District of Minnesota",
    "U.S. District Court for the District of New Jersey",
    "U.S. District Court for the District of Oregon",
    "U.S. District Court for the District of the Virgin Islands",
    "U.S. District Court for the Eastern District of California",
    "U.S. District Court for the Eastern District of New York",
    "U.S. District Court for the Middle District of Louisiana",
    "U.S. District Court for the Middle District of Pennsylvania",
    "U.S. District Court for the Northern District of California",
    "U.S. District Court for the Southern District of California",
    "U.S. District Court for the Southern District of New York",
    "U.S. Tax Court",
    "Unknown",
    "Utah",
    "Utah Court of Appeal",
    "Vermont",
    "Virginia",
    "Virginia Court of Appeal",
    "Washington",
    "Washington Court of Appeal",
    "Wisconsin",
    "Wisconsin Court of Appeal",
    "Wyoming",
}

COURT_LEVEL_NAMES: Dict[int, str] = {
    1: "Supreme Court",
    2: "Court of Appeals",
    3: "District Court",
    4: "State Court",
    5: "Unknown Court",
}

## Cypher Queries

In [4]:
# =========================
# Cypher Queries
# =========================

# All decision dates for computing global time stats (using citing cases)
Q_GET_TIME_DATES = """
MATCH (s:Case)-[:CITES_TO]->(:Case)
WHERE s.decision_date IS NOT NULL
RETURN DISTINCT s.decision_date AS decision_date
"""

# Count cases that need labeling (respecting `force`)
Q_COUNT_CASES = """
MATCH (c:Case)
WHERE $force = true OR c.case_label IS NULL
RETURN count(c) AS n
"""

# Page through Case nodes using internal id(c)
Q_PAGE_CASES = """
MATCH (c:Case)
WHERE id(c) > $after_id
  AND ($force = true OR c.case_label IS NULL)
RETURN id(c) AS neo_id,
       c.id   AS case_id,
       coalesce(c.name,'') AS case_name
ORDER BY neo_id
LIMIT $limit
"""

# Incoming citations for a given Case, with court level, citing decision date, and jurisdiction
# Assumes:
#   (src:Case)-[:HEARD_IN]->(ct:Court {level: 1..5})
#   (src:Case)-[:UNDER_JURISDICTION]->(j:Jurisdiction {jurisdiction_name: ...})
Q_INCOMING_EDGES_FOR_CASE = """
MATCH (src:Case)-[r:CITES_TO]->(tgt:Case {id:$case_id})
OPTIONAL MATCH (src)-[:HEARD_IN]->(ct:Court)
OPTIONAL MATCH (src)-[:UNDER_JURISDICTION]->(j:Jurisdiction)
RETURN r.treatment_label      AS label,
       src.decision_date      AS decision_date,
       ct.court_level         AS court_level,
       j.jurisdiction_name    AS jurisdiction_name
"""

# Write final label back to the Case node
Q_WRITE_CASE_LABEL = """
MATCH (c:Case {id:$case_id})
SET c.case_label                       = $case_label,
    c.court_level_case_label_decision  = $decision_level,
    c.label_rationale                  = $label_rationale,
    c.updated_at_utc                   = datetime()
RETURN c.id AS case_id
"""

# Count CITES_TO edges relevant for this run (only into cases we are labeling,
# unless force=True where we use all edges).
Q_COUNT_CITES_EDGES = """
MATCH (src:Case)-[r:CITES_TO]->(tgt:Case)
WHERE $force = true OR tgt.case_label IS NULL
RETURN count(r) AS n
"""

# Page through CITES_TO edges to compute and store recency/alpha scores.
# Only edges into cases we are labeling are included.
Q_PAGE_CITES_EDGES = """
MATCH (src:Case)-[r:CITES_TO]->(tgt:Case)
WHERE id(r) > $after_id
  AND ($force = true OR tgt.case_label IS NULL)
OPTIONAL MATCH (src)-[:UNDER_JURISDICTION]->(j:Jurisdiction)
RETURN id(r)                AS rel_id,
       src.decision_date    AS decision_date,
       j.jurisdiction_name  AS jurisdiction_name
ORDER BY rel_id
LIMIT $limit
"""

# Batch-write recency_re and influence_score_alpha onto edges
Q_WRITE_EDGE_SCORES = """
UNWIND $rows AS row
MATCH ()-[r:CITES_TO]->()
WHERE id(r) = row.rel_id
SET r.recency_re            = row.recency_re,
    r.influence_score_alpha = row.influence_score_alpha
"""


## Helper Functions

In [5]:
# =========================
# Helper Functions
# =========================

def _to_ordinal(date_value: Any) -> Optional[int]:
    """
    Convert a Neo4j date / datetime / ISO string to a Python ordinal (int days).
    Returns None if the value cannot be parsed.
    """
    if date_value is None:
        return None

    # Already a date or datetime
    if isinstance(date_value, date) and not isinstance(date_value, datetime):
        return date_value.toordinal()
    if isinstance(date_value, datetime):
        return date_value.date().toordinal()

    # Fallback: assume ISO string "YYYY-MM-DD" or full datetime
    try:
        txt = str(date_value)
        if "T" in txt:
            return datetime.fromisoformat(txt).date().toordinal()
        return datetime.fromisoformat(txt).toordinal()
    except Exception:
        return None


def _compute_time_quartiles(session) -> Dict[str, Optional[int]]:
    """
    Compute global Q1, Q2 (median), Q3, and min/max over decision_date of citing cases.
    Returns ordinals in a dict:
      {"q1": int|None, "q2": int|None, "q3": int|None,
       "min": int|None, "max": int|None}.
    """
    rows = session.run(Q_GET_TIME_DATES).data()
    ordinals: List[int] = []
    for row in rows:
        ordv = _to_ordinal(row.get("decision_date"))
        if ordv is not None:
            ordinals.append(ordv)

    if not ordinals:
        # No dates available
        return {"q1": None, "q2": None, "q3": None, "min": None, "max": None}

    s = pd.Series(ordinals, dtype="float")
    q1 = int(round(s.quantile(0.25)))
    q2 = int(round(s.quantile(0.50)))
    q3 = int(round(s.quantile(0.75)))
    dmin = int(s.min())
    dmax = int(s.max())
    return {"q1": q1, "q2": q2, "q3": q3, "min": dmin, "max": dmax}


def _normalize_edge_label(raw: Any) -> str:
    """
    Normalize edge treatment labels into one of:
    "Positive", "Neutral", "Negative", "Unknown".
    """
    if raw is None:
        return "Unknown"
    txt = str(raw).strip().lower()
    if txt == "positive":
        return "Positive"
    if txt == "negative":
        return "Negative"
    if txt in ("neutral", "moderate"):
        # Edge-level label "Neutral" is treated as Neutral; if "Moderate" appears, map to Neutral.
        return "Neutral"
    if txt == "unknown":
        return "Unknown"
    return "Unknown"


def _court_level_to_name(level: Optional[int]) -> str:
    """
    Map a numeric court level (1–5) to a human-readable court name.
    Returns "" for None.
    """
    if level is None:
        return ""
    return COURT_LEVEL_NAMES.get(level, f"Court level {level}")


def _compute_normalized_recency(
    decision_date: Any,
    tmin_ord: Optional[int],
    tmax_ord: Optional[int],
) -> Optional[float]:
    """
    Compute normalized recency r_e in [0, 1] based on decision_date
    and the global window [tmin_ord, tmax_ord].

    Returns:
      - float in [0,1] if recency can be computed
      - None if there is no valid date or window.
    """
    ordv = _to_ordinal(decision_date)
    if (
        ordv is None
        or tmin_ord is None
        or tmax_ord is None
        or tmax_ord <= tmin_ord
    ):
        return None

    span = float(tmax_ord - tmin_ord)
    r = (ordv - tmin_ord) / span
    if r <= 0.0:
        return 0.0
    if r >= 1.0:
        return 1.0
    return float(r)


def _compute_alpha(
    decision_date: Any,
    jurisdiction_name: Optional[str],
    tmin_ord: Optional[int],
    tmax_ord: Optional[int],
    max_weight: float,
    non_linear_recency_effect: bool,
    jurisdiction_weights: Optional[Dict[str, float]],
) -> float:
    """
    Compute the weight alpha(e) for one edge:

        alpha(e) = 1 + (MAX_WEIGHT - 1) * r_eff (+ optional Ji)

    where:
      - r_eff = r_e or r_e^2, depending on non_linear_recency_effect
      - r_e is the normalized recency in [0, 1]
      - Ji is an additional jurisdiction weight if the citing jurisdiction
        is in `jurisdiction_weights`. If the user passed "Default" for
        that jurisdiction, Ji = MAX_WEIGHT / 2 by construction.
    """
    base = 1.0
    r = _compute_normalized_recency(decision_date, tmin_ord, tmax_ord)

    if r is not None:
        if non_linear_recency_effect:
            r_eff = r * r
        else:
            r_eff = r
        base = 1.0 + (max_weight - 1.0) * r_eff

    # --- jurisdiction component ---
    Ji = 0.0
    if jurisdiction_weights and jurisdiction_name:
        Ji = jurisdiction_weights.get(str(jurisdiction_name), 0.0)

    return base + Ji


def _compute_level_metrics(
    edges: List[Dict[str, Any]],
    include_unknown: bool,
    tmin_ord: Optional[int],
    tmax_ord: Optional[int],
    max_weight: float,
    non_linear_recency_effect: bool,
    jurisdiction_weights: Optional[Dict[str, float]],
) -> Dict[str, Any]:
    """
    Compute counts, time-weighted sums, and proportions for one court level.

    Returns a dict:
    {
      "counts":      {label -> int},
      "weights":     {label -> float},
      "proportions": {label -> float},
      "denom":       float
    }
    where labels are: "Positive", "Neutral", "Negative", "Unknown".
    """
    labels = ("Positive", "Neutral", "Negative", "Unknown")
    counts = {lab: 0 for lab in labels}
    weights = {lab: 0.0 for lab in labels}

    for e in edges:
        lab = _normalize_edge_label(e.get("label"))
        counts[lab] += 1

        # If we are excluding Unknown from the scoring, skip its weight
        if lab == "Unknown" and not include_unknown:
            continue

        w = _compute_alpha(
            decision_date=e.get("decision_date"),
            jurisdiction_name=e.get("jurisdiction_name"),
            tmin_ord=tmin_ord,
            tmax_ord=tmax_ord,
            max_weight=max_weight,
            non_linear_recency_effect=non_linear_recency_effect,
            jurisdiction_weights=jurisdiction_weights,
        )
        weights[lab] += w

    if include_unknown:
        denom = (
            weights["Positive"]
            + weights["Neutral"]
            + weights["Negative"]
            + weights["Unknown"]
        )
    else:
        denom = (
            weights["Positive"]
            + weights["Neutral"]
            + weights["Negative"]
        )

    if denom > 0.0:
        pos_p = weights["Positive"] / denom
        neu_p = weights["Neutral"] / denom
        neg_p = weights["Negative"] / denom
        unk_p = (weights["Unknown"] / denom) if include_unknown else 0.0
    else:
        pos_p = neu_p = neg_p = unk_p = 0.0

    return {
        "counts": counts,
        "weights": weights,
        "proportions": {
            "Positive": pos_p,
            "Neutral": neu_p,
            "Negative": neg_p,
            "Unknown": unk_p,
        },
        "denom": denom,
    }


def _normalize_priority_list(
    label_priority: List[str],
    include_unknown: bool,
) -> List[str]:
    """
    Normalize user-specified label priority into canonical labels:
    "Positive", "Negative", "Neutral", "Unknown".

    Removes duplicates and, if include_unknown=False, removes "Unknown".
    """
    if not label_priority:
        raise ValueError("label_priority must be a non-empty list.")

    mapping = {
        "pos": "Positive",
        "positive": "Positive",
        "good": "Positive",

        "neg": "Negative",
        "negative": "Negative",
        "bad": "Negative",

        "neu": "Neutral",
        "neutral": "Neutral",
        "moderate": "Neutral",
        "mod": "Neutral",

        "unk": "Unknown",
        "unknown": "Unknown",
    }

    result: List[str] = []
    for item in label_priority:
        if item is None:
            continue
        key = str(item).strip().lower()
        if key not in mapping:
            raise ValueError(
                f"Unrecognized label in label_priority: {item!r}. "
                "Allowed values (case-insensitive) include: "
                "pos/positive/good, neg/negative/bad, "
                "neu/neutral/moderate, unk/unknown."
            )
        canon = mapping[key]
        if canon not in result:
            result.append(canon)

    if not include_unknown:
        result = [lab for lab in result if lab != "Unknown"]

    if not result:
        raise ValueError(
            "Effective label_priority is empty after applying include_unknown setting."
        )

    return result


def _decide_label_from_metrics(
    metrics: Dict[str, Any],
    include_unknown: bool,
    label_thresholds: Dict[str, float],
    priority_order: List[str],
) -> Tuple[Optional[str], Optional[str]]:
    """
    Decide case-level label for a single court level based on proportions.

    label_thresholds keys:
      - "Pos_p"  → threshold for Positive proportion
      - "Neg_p"  → threshold for Negative proportion
      - "Neu_p"  → threshold for Neutral proportion
      - "Unk_p"  → threshold for Unknown proportion

    priority_order is a list of canonical labels (subset of):
      ["Positive", "Negative", "Neutral", "Unknown"]
    in the order of preference to break ties when several labels
    pass their thresholds.

    Returns (case_label, driver_label) where:
    - case_label ∈ {"Good","Bad","Moderate","Unknown"} or None
    - driver_label ∈ {"Positive","Negative","Neutral","Unknown"} or None
    """
    if metrics["denom"] <= 0.0:
        return None, None

    props = metrics["proportions"]

    thr_map = {
        "Positive": label_thresholds["Pos_p"],
        "Negative": label_thresholds["Neg_p"],
        "Neutral":  label_thresholds["Neu_p"],
        "Unknown":  label_thresholds["Unk_p"],
    }

    # Labels we consider for thresholding, depending on include_unknown
    considered_labels = ["Positive", "Negative", "Neutral"]
    if include_unknown:
        considered_labels.append("Unknown")

    # Collect labels that pass their own threshold
    candidates = set()
    for lab in considered_labels:
        if props.get(lab, 0.0) >= thr_map[lab]:
            candidates.add(lab)

    if not candidates:
        return None, None

    # Apply priority order to pick the driver label
    chosen: Optional[str] = None
    for lab in priority_order:
        if lab in candidates:
            chosen = lab
            break

    if chosen is None:
        # No overlap between candidates and priority_order → treat as no decision
        return None, None

    if chosen == "Positive":
        return "Good", "Positive"
    if chosen == "Negative":
        return "Bad", "Negative"
    if chosen == "Neutral":
        return "Moderate", "Neutral"
    # chosen == "Unknown"
    return "Unknown", "Unknown"


def _build_label_rationale(
    case_name: str,
    total_cites: int,
    per_level_counts: Dict[int, Dict[str, int]],
    per_level_metrics: Dict[int, Dict[str, Any]],
    decision_level: Optional[int],
    case_label: str,
    driver_label: Optional[str],
    include_unknown: bool,
    used_lower_level: bool,
    label_thresholds: Dict[str, float],
    priority_order: List[str],
) -> str:
    """
    Build the human-readable rationale string for one case.
    Uses human-readable court names instead of numeric court levels.
    Designed so that lawyers can understand the reasoning without
    needing to know implementation details.
    """
    lines: List[str] = []

    # -----------------------------
    # 1. Case-level label and edge case with no citations
    # -----------------------------
    if decision_level is None or total_cites == 0:
        lines.append(
            f"The case '{case_name}' is labeled '{case_label}'."
        )
        lines.append(
            "The case has no incoming citations with a known court level, "
            "so there is no clear precedential signal to tilt it toward a "
            "favorable or unfavorable treatment."
        )
        return " ".join(lines)

    decision_court_name = COURT_LEVEL_NAMES.get(
        decision_level, f"Court level {decision_level}"
    )

    # Short, human-readable summary tied to the deciding court
    if driver_label == "Positive":
        summary_sentence = (
            f"The case '{case_name}' is labeled '{case_label}' based on citations "
            f"from the {decision_court_name}, where the balance of weighted "
            "citations is predominantly positive."
        )
    elif driver_label == "Negative":
        summary_sentence = (
            f"The case '{case_name}' is labeled '{case_label}' based on citations "
            f"from the {decision_court_name}, where the balance of weighted "
            "citations is predominantly negative."
        )
    elif driver_label == "Neutral":
        summary_sentence = (
            f"The case '{case_name}' is labeled '{case_label}' based on citations "
            f"from the {decision_court_name}, where the weighted signals are "
            "neither strongly positive nor strongly negative and instead cluster "
            "around a neutral treatment."
        )
    elif driver_label == "Unknown":
        summary_sentence = (
            f"The case '{case_name}' is labeled '{case_label}' because most of the "
            f"weighted citations to this case at the {decision_court_name} are "
            "labeled as 'Unknown'."
        )
    else:
        # No single label met its threshold; treated as balanced
        summary_sentence = (
            f"The case '{case_name}' is labeled '{case_label}' because, across the "
            "courts with citations, no single treatment category clearly exceeds "
            "its required share; the precedential signals are balanced."
        )

    lines.append(summary_sentence)

    # -----------------------------
    # 2. How many citations and from which courts
    # -----------------------------
    lines.append(
        f"The case has {total_cites} incoming citation(s)."
    )

    # Counts by court, using court names
    level_fragments = []
    for lvl in range(1, 6):
        lvl_total = per_level_counts.get(lvl, {}).get("total", 0)
        court_name = COURT_LEVEL_NAMES.get(lvl, f"Court level {lvl}")
        level_fragments.append(f"{court_name}: {lvl_total}")
    levels_str = ", ".join(level_fragments)
    lines.append(
        f"By court, citations are distributed as follows: {levels_str}."
    )

    # -----------------------------
    # 3. Details at the deciding court
    # -----------------------------
    dec_metrics = per_level_metrics.get(decision_level, {})
    dec_counts = per_level_counts.get(decision_level, {})
    dec_props = dec_metrics.get("proportions", {})

    dec_total = dec_counts.get("total", 0)
    c_pos = dec_counts.get("Positive", 0)
    c_neg = dec_counts.get("Negative", 0)
    c_neu = dec_counts.get("Neutral", 0)
    c_unk = dec_counts.get("Unknown", 0)

    p_pos = dec_props.get("Positive", 0.0)
    p_neg = dec_props.get("Negative", 0.0)
    p_neu = dec_props.get("Neutral", 0.0)
    p_unk = dec_props.get("Unknown", 0.0)

    if include_unknown:
        share_str = (
            f"Positive={p_pos:.2f}, Negative={p_neg:.2f}, "
            f"Neutral={p_neu:.2f}, Unknown={p_unk:.2f}"
        )
        count_str = (
            f"{c_pos} positive, {c_neg} negative, "
            f"{c_neu} neutral, {c_unk} unknown"
        )
    else:
        share_str = (
            f"Positive={p_pos:.2f}, Negative={p_neg:.2f}, Neutral={p_neu:.2f}"
        )
        count_str = (
            f"{c_pos} positive, {c_neg} negative, {c_neu} neutral"
        )

    if driver_label in {"Positive", "Negative", "Neutral", "Unknown"}:
        decision_clause = (
            f"At the {decision_court_name}, the label is driven by "
            f"{driver_label.lower()} treatment: based on {dec_total} citation(s) "
            f"at this court, the weighted proportions are {share_str}, coming from "
            f"{count_str} citation(s)."
        )
    else:
        decision_clause = (
            f"At the {decision_court_name}, based on {dec_total} citation(s), "
            f"the weighted proportions are {share_str}, coming from {count_str} "
            "citation(s), but no single label reaches its required share, so the "
            "case is treated as 'Moderate' overall."
        )

    lines.append(decision_clause)

    # -----------------------------
    # 4. If we moved down from a higher court, explain why
    # -----------------------------
    levels_with_cites = [
        lvl for lvl in range(1, 6)
        if per_level_counts.get(lvl, {}).get("total", 0) > 0
    ]

    if used_lower_level and levels_with_cites:
        highest_level = min(levels_with_cites)
        if decision_level != highest_level:
            highest_court_name = COURT_LEVEL_NAMES.get(
                highest_level, f"Court level {highest_level}"
            )
            hl_metrics = per_level_metrics.get(highest_level, {})
            hl_props = hl_metrics.get("proportions", {})
            hl_denom = hl_metrics.get("denom", 0.0)

            if hl_denom > 0.0:
                hl_p_pos = hl_props.get("Positive", 0.0)
                hl_p_neg = hl_props.get("Negative", 0.0)
                hl_p_neu = hl_props.get("Neutral", 0.0)
                hl_p_unk = hl_props.get("Unknown", 0.0)

                if include_unknown:
                    hl_share_str = (
                        f"Positive={hl_p_pos:.2f}, Negative={hl_p_neg:.2f}, "
                        f"Neutral={hl_p_neu:.2f}, Unknown={hl_p_unk:.2f}"
                    )
                else:
                    hl_share_str = (
                        f"Positive={hl_p_pos:.2f}, Negative={hl_p_neg:.2f}, "
                        f"Neutral={hl_p_neu:.2f}"
                    )

                lines.append(
                    f"At the {highest_court_name} level (the highest court that cites this case), "
                    f"the weighted proportions are {hl_share_str}. "
                    "Because no single treatment label at that court met its configured threshold, "
                    f"the algorithm looked to the next lower court and ultimately relied on the "
                    f"{decision_court_name}, where the distribution shown above provided a clearer "
                    "signal for the final label."
                )
            else:
                lines.append(
                    f"The highest court with citations is the {highest_court_name}, "
                    "but there were not enough labeled citations at that level to meet any "
                    f"threshold, so the algorithm instead relied on citations from the "
                    f"{decision_court_name} to determine the label."
                )

    # -----------------------------
    # 5. How thresholds and weighting are applied
    # -----------------------------
    pos_thr = label_thresholds["Pos_p"]
    neg_thr = label_thresholds["Neg_p"]
    neu_thr = label_thresholds["Neu_p"]
    unk_thr = label_thresholds["Unk_p"]

    thr_parts = [
        f"Positive >= {pos_thr:.2f}",
        f"Negative >= {neg_thr:.2f}",
        f"Neutral >= {neu_thr:.2f}",
    ]
    if include_unknown:
        thr_parts.append(f"Unknown >= {unk_thr:.2f}")
    thr_str = ", ".join(thr_parts)

    if priority_order:
        priority_str = " > ".join(priority_order)
        lines.append(
            "At each court, the model uses time- and jurisdiction-weighted citation counts "
            "to compute the share of positive, negative, neutral (and, if included, unknown) "
            "treatment. A label can drive the case outcome at that court only if its weighted "
            f"share meets its configured threshold. For this run, the share thresholds are: {thr_str}. "
            f"If more than one label meets its threshold, the priority order "
            f"({priority_str}) is used to select the controlling label."
        )
    else:
        lines.append(
            "At each court, the model uses time- and jurisdiction-weighted citation counts "
            "to compute the share of positive, negative, and neutral treatment. "
            f"A label can drive the case outcome at that court only if its weighted share "
            f"meets its configured threshold. For this run, the share thresholds are: {thr_str}."
        )

    return " ".join(lines)


def _precompute_edge_scores(
    session,
    tmin_ord: Optional[int],
    tmax_ord: Optional[int],
    max_weight: float,
    non_linear_recency_effect: bool,
    jurisdiction_weights: Optional[Dict[str, float]],
    force: bool,
    echo: bool = False,
) -> None:
    """
    Compute and store recency_re and influence_score_alpha for CITES_TO edges
    that matter for this run:

      - If force=True: all (src:Case)-[r:CITES_TO]->(tgt:Case) edges.
      - If force=False: only edges where tgt.case_label IS NULL.

    - recency_re: normalized recency r_e in [0,1] based on the citing case decision_date.
    - influence_score_alpha: alpha(e) = 1 + (MAX_WEIGHT - 1) * r_eff + Ji,
      where r_eff = r_e or r_e^2 (if non_linear_recency_effect=True) and Ji is
      the jurisdiction weight if configured.

    The loop has an explicit safety stop to avoid unbounded scanning.
    """
    # How many edges should we touch?
    edge_count_rows = session.run(
        Q_COUNT_CITES_EDGES, {"force": bool(force)}
    ).data()
    total_expected = edge_count_rows[0]["n"] if edge_count_rows else 0

    if echo:
        print(f"Total CITES_TO edges to score (for this run): {total_expected}")

    if total_expected == 0:
        return

    after = -1
    total_edges = 0
    t0 = time.time()
    last_print = t0

    while True:
        # Safety: if for some reason we have already processed as many edges
        # as expected (or more), stop.
        if total_edges >= total_expected:
            break

        edge_rows = session.run(
            Q_PAGE_CITES_EDGES,
            {
                "after_id": after,
                "limit": _EDGE_BATCH_SIZE,
                "force": bool(force),
            },
        ).data()

        if not edge_rows:
            break

        rows_to_write: List[Dict[str, Any]] = []
        for er in edge_rows:
            rel_id = er["rel_id"]
            decision_date = er.get("decision_date")
            jurisdiction_name = er.get("jurisdiction_name")

            # recency_re
            r = _compute_normalized_recency(decision_date, tmin_ord, tmax_ord)

            # alpha using the same helper as the labeling logic
            alpha = _compute_alpha(
                decision_date=decision_date,
                jurisdiction_name=jurisdiction_name,
                tmin_ord=tmin_ord,
                tmax_ord=tmax_ord,
                max_weight=max_weight,
                non_linear_recency_effect=non_linear_recency_effect,
                jurisdiction_weights=jurisdiction_weights,
            )

            rows_to_write.append(
                {
                    "rel_id": rel_id,
                    "recency_re": float(r) if r is not None else None,
                    "influence_score_alpha": float(alpha),
                }
            )

        if rows_to_write:
            session.run(Q_WRITE_EDGE_SCORES, {"rows": rows_to_write})

        total_edges += len(rows_to_write)
        after = edge_rows[-1]["rel_id"]

        if echo and (time.time() - last_print >= 5.0):
            print(
                f"Computed citation scores for {total_edges} CITES_TO edge(s)..."
            )
            last_print = time.time()

    if echo:
        elapsed_min = (time.time() - t0) / 60.0
        print(
            f"Finished computing citation scores for {total_edges} CITES_TO edge(s) "
            f"in {elapsed_min:.1f} minutes."
        )

## Label All Cases

In [6]:
# =========================
# Core function: label_all_cases
# =========================

def label_all_cases(
    *,
    force: bool = False,
    echo: bool = False,
    lower_level_court: bool = True,
    include_unknown: bool = True,
    label_thresholds: Optional[Dict[str, float]] = None,
    default_label_priority: bool = True,
    label_priority: Optional[List[str]] = None,
    default_tmin_tmax: bool = True,
    tmin_tmax: Optional[List[Any]] = None,
    default_time_weight: bool = True,
    time_weight: Optional[List[float]] = None,
    non_linear_recency_effect: bool = False,
    jurisdictions: Optional[Dict[str, Any]] = None,
    results_csv: bool = False,
    results_csv_filename: str = "case_labeled_results.csv",
):
    """
    Label Case nodes as "Good", "Bad", "Moderate", or "Unknown" based on incoming
    CITES_TO edges, court levels, continuous time-weighted citation counts, and
    optional jurisdiction weights.

    At the start of each run, this function also computes citation scores for
    the relevant CITES_TO edges and stores:
      - r.recency_re            (normalized recency r_e in [0,1])
      - r.influence_score_alpha (alpha(e) consistent with the labeling logic).

    Parameters
    ----------
    force : bool
        If True, overwrite existing case_label values and precompute scores for
        all case->case CITES_TO edges.
        If False, only label cases where case_label is null/missing and only
        precompute scores for edges into those cases.
    echo : bool
        If True, print progress information.
    lower_level_court : bool
        If False: only use the highest court level with citations; if no label
        reaches its threshold there, classify as "Moderate" immediately.
        If True: if the highest court level is mixed (no threshold), repeat the
        procedure at the next lower court level, and so on.
    include_unknown : bool
        If True: include edges with treatment_label="Unknown" in the weights and
        proportions.
        If False: ignore Unknown edges in the weighting and proportions.
    label_thresholds : dict or None
        Per-label thresholds for the proportions. Must contain:
        {"Pos_p","Neg_p","Neu_p","Unk_p"}.
    default_label_priority : bool
        If True, use the default label priority:
            Unknown > Negative > Neutral > Positive
        to break ties when multiple labels pass their thresholds.
    label_priority : list of str or None
        Only used when default_label_priority=False.
    default_tmin_tmax : bool
        If True (default): use dataset-based defaults for the recency window:
            t_min = Q1, t_max = max decision date.
    tmin_tmax : list or None
        Only used when default_tmin_tmax=False.
    default_time_weight : bool
        If True (default): use alpha(e) in [1, 2.5].
    time_weight : list of float or None
        Only used when default_time_weight=False.
    non_linear_recency_effect : bool
        If False (default): use linear recency.
        If True: use quadratic recency (r_e^2).
    jurisdictions : dict or None
        Optional dictionary of jurisdictions that should receive extra weight.
        Keys are jurisdiction names (must match VALID_JURISDICTIONS exactly).
        Values can be:
          - a specific float Ji, or
          - the string "Default", which is interpreted as Ji = MAX_WEIGHT / 2.
    results_csv : bool
        If True: write a CSV file with per-case, per-court-level metrics.
    results_csv_filename : str
        Name of the CSV file to write when results_csv=True.
    """
    # --- label threshold configuration ---
    if label_thresholds is None:
        label_thresholds = {
            "Pos_p": 0.55,
            "Neg_p": 0.55,
            "Neu_p": 0.55,
            "Unk_p": 0.55,
        }
    else:
        required_thr = {"Pos_p", "Neg_p", "Neu_p", "Unk_p"}
        missing_thr = required_thr.difference(label_thresholds.keys())
        if missing_thr:
            raise ValueError(
                f"label_thresholds is missing required keys: {sorted(missing_thr)}"
            )
        # cast to float
        label_thresholds = {
            k: float(label_thresholds[k]) for k in ("Pos_p", "Neg_p", "Neu_p", "Unk_p")
        }

    # --- label priority configuration ---
    if default_label_priority and label_priority is not None:
        raise ValueError(
            "You provided label_priority but default_label_priority=True. "
            "Set default_label_priority=False to use custom label_priority."
        )

    if default_label_priority:
        base_priority = ["unk", "neg", "neu", "pos"]  # Unknown > Negative > Neutral > Positive
    else:
        if label_priority is None:
            raise ValueError(
                "label_priority must be provided when default_label_priority=False."
            )
        base_priority = label_priority

    # --- tmin/tmax configuration (recency window) ---
    if default_tmin_tmax and tmin_tmax is not None:
        raise ValueError(
            "You provided tmin_tmax but default_tmin_tmax=True. "
            "Set default_tmin_tmax=False to use custom tmin_tmax."
        )

    # --- time weight range configuration ---
    if default_time_weight and time_weight is not None:
        raise ValueError(
            "You provided time_weight but default_time_weight=True. "
            "Set default_time_weight=False to use custom time_weight."
        )

    if default_time_weight:
        max_weight = 2.5  # alpha(e) in [1, 2.5]
    else:
        if time_weight is None or len(time_weight) != 2:
            raise ValueError(
                "time_weight must be a 2-element list [1.0, MAX_WEIGHT] "
                "when default_time_weight=False."
            )
        min_w, max_w = time_weight
        min_w = float(min_w)
        max_w = float(max_w)
        if abs(min_w - 1.0) > 1e-8:
            raise ValueError(
                f"time_weight[0] must be 1.0 (got {min_w}). "
                "The minimum alpha(e) is fixed at 1.0."
            )
        if max_w < 1.0:
            raise ValueError(
                f"time_weight[1] must be >= 1.0 (got {max_w})."
            )
        max_weight = max_w

    # --- jurisdiction weights configuration ---
    jurisdiction_weights: Optional[Dict[str, float]] = None
    if jurisdictions is not None:
        if not isinstance(jurisdictions, dict):
            raise ValueError(
                "jurisdictions must be a dict of {jurisdiction_name: value}."
            )

        invalid = [name for name in jurisdictions.keys() if name not in VALID_JURISDICTIONS]
        if invalid:
            raise ValueError(
                "Invalid jurisdiction name(s) in 'jurisdictions': "
                + ", ".join(sorted(invalid))
            )

        jurisdiction_weights = {}
        for name, val in jurisdictions.items():
            if isinstance(val, str) and val.strip().lower() == "default":
                w = max_weight / 2.0
            else:
                try:
                    w = float(val)
                except (TypeError, ValueError):
                    raise ValueError(
                        f"Jurisdiction weight for {name!r} must be a float or 'Default'."
                    )
            jurisdiction_weights[name] = w

    # --- connect to Neo4j ---
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

    results_rows: List[Dict[str, Any]] = []
    case_label_counts = {"Good": 0, "Bad": 0, "Moderate": 0, "Unknown": 0}

    try:
        try:
            session = driver.session(
                database=NEO4J_DATABASE,
                notifications_min_severity="OFF",
            )
        except TypeError:
            # Older drivers may not support notifications_min_severity
            session = driver.session(database=NEO4J_DATABASE)

        with session as s:
            # --- global time stats (for default tmin/tmax) ---
            time_stats = _compute_time_quartiles(s)

            if default_tmin_tmax:
                q1_ord = time_stats.get("q1")
                tmax_global = time_stats.get("max")
                if q1_ord is None or tmax_global is None or tmax_global <= q1_ord:
                    # Degenerate or missing → no recency effect
                    tmin_ord = None
                    tmax_ord = None
                else:
                    tmin_ord = q1_ord
                    tmax_ord = tmax_global
            else:
                if tmin_tmax is None or len(tmin_tmax) != 2:
                    raise ValueError(
                        "tmin_tmax must be a 2-element list [tmin, tmax] "
                        "when default_tmin_tmax=False."
                    )
                tmin_raw, tmax_raw = tmin_tmax
                tmin_ord = _to_ordinal(tmin_raw)
                tmax_ord = _to_ordinal(tmax_raw)
                if tmin_ord is None or tmax_ord is None:
                    raise ValueError(
                        "Could not parse tmin or tmax into valid dates. "
                        "They can be date/datetime objects or ISO date strings."
                    )
                if tmax_ord <= tmin_ord:
                    raise ValueError(
                        f"tmax must be strictly greater than tmin (got tmin={tmin_raw}, tmax={tmax_raw})."
                    )

            # --- normalize label priority now that include_unknown is known ---
            priority_order = _normalize_priority_list(
                base_priority,
                include_unknown=include_unknown,
            )

            # --- count total cases to label ---
            total_cases = s.run(
                Q_COUNT_CASES,
                {"force": bool(force)},
            ).data()[0]["n"]

            if echo:
                print(f"Total cases to label: {total_cases}")
                if tmin_ord is not None and tmax_ord is not None:
                    tmin_dt = datetime.fromordinal(tmin_ord).date()
                    tmax_dt = datetime.fromordinal(tmax_ord).date()
                    print(
                        f"Using recency window tmin={tmin_dt.isoformat()}, "
                        f"tmax={tmax_dt.isoformat()}, max_weight={max_weight}, "
                        f"non_linear_recency_effect={non_linear_recency_effect}."
                    )
                else:
                    print(
                        "No valid recency window; all edges will have alpha(e) = 1.0 "
                        "(before any jurisdiction weight Ji)."
                    )
                if jurisdiction_weights:
                    print(
                        f"Using jurisdiction weights for {len(jurisdiction_weights)} jurisdiction(s)."
                    )

            # --- precompute citation scores on relevant CITES_TO edges ---
            _precompute_edge_scores(
                session=s,
                tmin_ord=tmin_ord,
                tmax_ord=tmax_ord,
                max_weight=max_weight,
                non_linear_recency_effect=non_linear_recency_effect,
                jurisdiction_weights=jurisdiction_weights,
                force=bool(force),
                echo=echo,
            )

            # --- main paging loop over Case nodes ---
            after = -1
            processed_cases = 0
            t0 = time.time()
            last_print = time.time()

            while True:
                case_rows = s.run(
                    Q_PAGE_CASES,
                    {
                        "after_id": after,
                        "limit": _CASE_BATCH_SIZE,
                        "force": bool(force),
                    },
                ).data()

                if not case_rows:
                    break

                for c_row in case_rows:
                    after = c_row["neo_id"]
                    case_id = c_row["case_id"]
                    case_name = c_row["case_name"] or ""

                    # --- fetch incoming CITES_TO edges for this case ---
                    edge_rows = s.run(
                        Q_INCOMING_EDGES_FOR_CASE,
                        {"case_id": case_id},
                    ).data()

                    edges_by_level: Dict[int, List[Dict[str, Any]]] = defaultdict(list)

                    for er in edge_rows:
                        lvl = er.get("court_level")
                        if lvl is None:
                            # If no court level, skip this edge for scoring
                            continue
                        try:
                            lvl_int = int(lvl)
                        except (TypeError, ValueError):
                            continue
                        if not (1 <= lvl_int <= 5):
                            continue

                        edges_by_level[lvl_int].append(
                            {
                                "label": er.get("label"),
                                "decision_date": er.get("decision_date"),
                                "jurisdiction_name": er.get("jurisdiction_name"),
                            }
                        )

                    # --- compute per-level metrics (counts, weights, proportions) ---
                    per_level_metrics: Dict[int, Dict[str, Any]] = {}
                    per_level_counts: Dict[int, Dict[str, int]] = {}

                    total_cites = 0

                    for lvl in range(1, 6):
                        lvl_edges = edges_by_level.get(lvl, [])
                        metrics = _compute_level_metrics(
                            lvl_edges,
                            include_unknown=include_unknown,
                            tmin_ord=tmin_ord,
                            tmax_ord=tmax_ord,
                            max_weight=max_weight,
                            non_linear_recency_effect=non_linear_recency_effect,
                            jurisdiction_weights=jurisdiction_weights,
                        )
                        per_level_metrics[lvl] = metrics

                        counts = metrics["counts"]
                        lvl_total = (
                            counts["Positive"]
                            + counts["Neutral"]
                            + counts["Negative"]
                            + counts["Unknown"]
                        )
                        per_level_counts[lvl] = {
                            "total": lvl_total,
                            "Positive": counts["Positive"],
                            "Neutral": counts["Neutral"],
                            "Negative": counts["Negative"],
                            "Unknown": counts["Unknown"],
                        }
                        total_cites += lvl_total

                    levels_with_cites = [
                        lvl for lvl in range(1, 6)
                        if per_level_counts[lvl]["total"] > 0
                    ]

                    # --- decide label for this case ---
                    used_lower_level = False
                    decision_level: Optional[int] = None  # numeric level 1–5
                    case_label: str
                    driver_label: Optional[str] = None

                    if total_cites == 0:
                        # No signal at all → label as "Unknown" at the case level
                        case_label = "Unknown"
                        decision_level = None
                        driver_label = None
                    else:
                        if not levels_with_cites:
                            # Should not happen if total_cites > 0, but be safe.
                            case_label = "Unknown"
                            decision_level = None
                            driver_label = None
                        else:
                            highest_level = min(levels_with_cites)

                            if not lower_level_court:
                                # Only consider highest-level court; if no threshold,
                                # classify as Moderate immediately.
                                metrics_h = per_level_metrics[highest_level]
                                label_h, driver_h = _decide_label_from_metrics(
                                    metrics_h,
                                    include_unknown=include_unknown,
                                    label_thresholds=label_thresholds,
                                    priority_order=priority_order,
                                )
                                if label_h is None:
                                    case_label = "Moderate"
                                    decision_level = highest_level
                                    driver_label = None
                                else:
                                    case_label = label_h
                                    decision_level = highest_level
                                    driver_label = driver_h
                            else:
                                # Walk from highest level down to lowest level
                                case_label = "Moderate"
                                decision_level = levels_with_cites[-1]  # default
                                driver_label = None

                                for idx, lvl in enumerate(sorted(levels_with_cites)):
                                    lvl_metrics = per_level_metrics[lvl]
                                    lvl_label, lvl_driver = _decide_label_from_metrics(
                                        lvl_metrics,
                                        include_unknown=include_unknown,
                                        label_thresholds=label_thresholds,
                                        priority_order=priority_order,
                                    )
                                    if lvl_label is not None:
                                        case_label = lvl_label
                                        decision_level = lvl
                                        driver_label = lvl_driver
                                        # If idx > 0, we used a lower level than the highest
                                        used_lower_level = (idx > 0)
                                        break
                                else:
                                    # No level reached its threshold; keep "Moderate"
                                    # Set decision_level to the lowest level we checked
                                    decision_level = sorted(levels_with_cites)[-1]
                                    driver_label = None
                                    used_lower_level = (len(levels_with_cites) > 1)

                    # --- build rationale text ---
                    label_rationale = _build_label_rationale(
                        case_name=case_name,
                        total_cites=total_cites,
                        per_level_counts=per_level_counts,
                        per_level_metrics=per_level_metrics,
                        decision_level=decision_level,
                        case_label=case_label,
                        driver_label=driver_label,
                        include_unknown=include_unknown,
                        used_lower_level=used_lower_level,
                        label_thresholds=label_thresholds,
                        priority_order=priority_order,
                    )

                    # --- convert decision level to court name for storage ---
                    decision_level_name = (
                        _court_level_to_name(decision_level)
                        if decision_level is not None
                        else None
                    )

                    # --- write back to Neo4j ---
                    s.run(
                        Q_WRITE_CASE_LABEL,
                        {
                            "case_id": case_id,
                            "case_label": case_label,
                            "decision_level": decision_level_name,
                            "label_rationale": label_rationale,
                        },
                    )

                    # --- update counters ---
                    if case_label in case_label_counts:
                        case_label_counts[case_label] += 1
                    else:
                        case_label_counts[case_label] = 1

                    # --- accumulate CSV row ---
                    if results_csv:
                        row_data: Dict[str, Any] = {}
                        row_data["Case ID"] = case_id
                        row_data["Case Name"] = case_name
                        row_data["Total Number of Citations"] = total_cites

                        for lvl in range(1, 6):
                            counts = per_level_counts[lvl]
                            metrics = per_level_metrics[lvl]
                            weights = metrics["weights"]
                            props = metrics["proportions"]

                            # Counts
                            row_data[
                                f"Number of Citations from Court Level {lvl}"
                            ] = counts["total"]
                            row_data[
                                f"Number of Positive Citations from Court Level {lvl}"
                            ] = counts["Positive"]
                            row_data[
                                f"Number of Neutral Citations from Court Level {lvl}"
                            ] = counts["Neutral"]
                            row_data[
                                f"Number of Negative Citations from Court Level {lvl}"
                            ] = counts["Negative"]
                            row_data[
                                f"Number of Unknown Citations from Court Level {lvl}"
                            ] = counts["Unknown"]

                            # Weights
                            row_data[
                                f"Positive Weight from Court Level {lvl}"
                            ] = float(weights["Positive"])
                            row_data[
                                f"Neutral Weight from Court Level {lvl}"
                            ] = float(weights["Neutral"])
                            row_data[
                                f"Negative Weight from Court Level {lvl}"
                            ] = float(weights["Negative"])
                            # For Unknown, if include_unknown=False this will be 0.0
                            row_data[
                                f"Unknown Weight from Court Level {lvl}"
                            ] = float(weights["Unknown"])

                            # Proportions
                            row_data[
                                f"Positive Proportion from Court Level {lvl}"
                            ] = float(props["Positive"])
                            row_data[
                                f"Neutral Proportion from Court Level {lvl}"
                            ] = float(props["Neutral"])
                            row_data[
                                f"Negative Proportion from Court Level {lvl}"
                            ] = float(props["Negative"])
                            row_data[
                                f"Unknown Proportion from Court Level {lvl}"
                            ] = float(props["Unknown"])

                        row_data["Court Level Decision"] = (
                            decision_level_name if decision_level_name is not None else ""
                        )
                        row_data["Case Label"] = case_label
                        row_data["Rationale"] = label_rationale

                        results_rows.append(row_data)

                    processed_cases += 1

                    # Progress printing every ~5 seconds
                    now = time.time()
                    if echo and (now - last_print >= 5.0):
                        print(
                            f"Labeled {processed_cases} / {total_cases} cases "
                            f"({(processed_cases / max(total_cases, 1)) * 100:.1f}%)."
                        )
                        last_print = now

            # --- final summary ---
            if echo:
                elapsed_min = (time.time() - t0) / 60.0
                print(f"\nCompleted labeling in {elapsed_min:.1f} minutes.")
                print("Case label counts:")
                print(f"  Good     : {case_label_counts.get('Good', 0)}")
                print(f"  Bad      : {case_label_counts.get('Bad', 0)}")
                print(f"  Moderate : {case_label_counts.get('Moderate', 0)}")
                print(f"  Unknown  : {case_label_counts.get('Unknown', 0)}")

            # --- write CSV if requested ---
            if results_csv:
                df = pd.DataFrame(results_rows)

                # Build explicit column order to match your spec
                columns: List[str] = [
                    "Case ID",
                    "Case Name",
                    "Total Number of Citations",
                ]
                for lvl in range(1, 6):
                    columns.extend(
                        [
                            f"Number of Citations from Court Level {lvl}",
                            f"Number of Positive Citations from Court Level {lvl}",
                            f"Number of Neutral Citations from Court Level {lvl}",
                            f"Number of Negative Citations from Court Level {lvl}",
                            f"Number of Unknown Citations from Court Level {lvl}",
                            f"Positive Weight from Court Level {lvl}",
                            f"Neutral Weight from Court Level {lvl}",
                            f"Negative Weight from Court Level {lvl}",
                            f"Unknown Weight from Court Level {lvl}",
                            f"Positive Proportion from Court Level {lvl}",
                            f"Neutral Proportion from Court Level {lvl}",
                            f"Negative Proportion from Court Level {lvl}",
                            f"Unknown Proportion from Court Level {lvl}",
                        ]
                    )
                columns.extend(
                    [
                        "Court Level Decision",
                        "Case Label",
                        "Rationale",
                    ]
                )

                # Ensure all columns exist
                for col in columns:
                    if col not in df.columns:
                        df[col] = ""

                df = df[columns]
                df.to_csv(results_csv_filename, index=False)
                if echo:
                    print(f"\nWrote case-level results CSV → {results_csv_filename}")

    finally:
        driver.close()

# =========================
# Example call (commented)
# =========================
# label_all_cases(
#     force=False,
#     echo=True,
#     lower_level_court=True,
#     include_unknown=True,
#     label_thresholds=None,
#     default_label_priority=True,
#     label_priority=None,
#     default_tmin_tmax=True,
#     tmin_tmax=None,
#     default_time_weight=True,
#     time_weight=None,
#     non_linear_recency_effect=False,
#     jurisdictions={
#         "Alabama": "Default",   # Ji = MAX_WEIGHT / 2
#         "California": 1.0,      # explicit Ji
#     },
#     results_csv=True,
#     results_csv_filename="case_labeled_results_continuous_time_with_jurisdictions.csv",
# )


## Example Run

In [7]:
label_all_cases(
    force=True,
    echo=True,
    lower_level_court=True,
    include_unknown=True,
    label_thresholds=None,
    default_label_priority=True,
    label_priority=None,
    default_tmin_tmax=True,
    tmin_tmax=None,
    default_time_weight=True,
    time_weight=None,
    non_linear_recency_effect=False
)

Total cases to label: 3648
Using recency window tmin=2001-06-14, tmax=2025-10-08, max_weight=2.5, non_linear_recency_effect=False.
Total CITES_TO edges to score (for this run): 5491
Finished computing citation scores for 5491 CITES_TO edge(s) in 0.0 minutes.
Labeled 50 / 3648 cases (1.4%).
Labeled 113 / 3648 cases (3.1%).
Labeled 182 / 3648 cases (5.0%).
Labeled 252 / 3648 cases (6.9%).
Labeled 324 / 3648 cases (8.9%).
Labeled 396 / 3648 cases (10.9%).
Labeled 467 / 3648 cases (12.8%).
Labeled 537 / 3648 cases (14.7%).
Labeled 606 / 3648 cases (16.6%).
Labeled 678 / 3648 cases (18.6%).
Labeled 749 / 3648 cases (20.5%).
Labeled 819 / 3648 cases (22.5%).
Labeled 891 / 3648 cases (24.4%).
Labeled 963 / 3648 cases (26.4%).
Labeled 1036 / 3648 cases (28.4%).
Labeled 1109 / 3648 cases (30.4%).
Labeled 1181 / 3648 cases (32.4%).
Labeled 1251 / 3648 cases (34.3%).
Labeled 1323 / 3648 cases (36.3%).
Labeled 1394 / 3648 cases (38.2%).
Labeled 1465 / 3648 cases (40.2%).
Labeled 1537 / 3648 cases 

In [8]:
# label_all_cases(
#     force=True,
#     echo=True,
#     lower_level_court=True,
#     include_unknown=True,
#     label_thresholds=None,
#     default_label_priority=True,
#     label_priority=None,
#     default_tmin_tmax=True,
#     tmin_tmax=None,
#     default_time_weight=True,
#     time_weight=None,
#     non_linear_recency_effect=False,
#     jurisdictions={
#         "Alabama": "Default",   # Ji = MAX_WEIGHT / 2
#         "U.S. Court of Appeals for the Sixth Circuit": "Default",      # explicit Ji
#     },
#     results_csv=True,
#     results_csv_filename="case_labeled_results_continuous_time.csv",
# )

In [9]:
# thresholds = {
#     "Pos_p": 0.55,  
#     "Neg_p": 0.75,  
#     "Neu_p": 0.75,  
#     "Unk_p": 0.55,  
# }

# label_all_cases(
#     force=True,
#     echo=True,
#     lower_level_court=True,
#     include_unknown=True,
#     label_thresholds=thresholds,
#     default_label_priority=True,
#     label_priority=None,
#     default_tmin_tmax=True,
#     tmin_tmax=None,
#     default_time_weight=True,
#     time_weight=None,
#     non_linear_recency_effect=False,
#     jurisdictions={
#         "Alabama": "Default",   # Ji = MAX_WEIGHT / 2
#         "U.S. Court of Appeals for the Sixth Circuit": "Default",      # explicit Ji
#     },
#     results_csv=True,
#     results_csv_filename="case_labeled_results_continuous_time.csv",
# )