# Version 2 — Population field → gravity-based city nuclei

**Goal:** Start from a global population total and a spatial density field. Let density peaks “attract” nearby population into city nuclei until only cities with ≥ `min_city_pop` remain.

**Method (high-level):**

1. **Density field:** mixture of Gaussians + uniform baseline + low-frequency noise.
2. **Allocate people:** multinomial draw over grid cells ⇒ exact `total_population`.
3. **Detect peaks:** local maxima above a percentile, with non-maximum suppression (min separation).
4. **Gravity assignment:** each cell’s people go to peak `j` with attractiveness
   `A_j = w_j / (dist + eps)^γ`, where `w_j` scales with the peak’s density.
5. **Prune & reassign:** iteratively remove cities with pop `< min_city_pop`, reassign their cells.
6. **Enforce city count:** if after pruning we have `< n_cities`, relax thresholds and/or split the largest city into two viable parts (both ≥ `min_city_pop`) and reassign, until `n_cities` is met or no further split is possible.
7. **City center:** population-weighted centroid of its assigned cells.

**Outputs:**

* `nodes.csv` → `id, x_km, y_km, pop, n_cells_assigned, radius_km`
* `meta.json` → full config, metrics, generator info, hashes
* `preview.png` → scatter (size/color = population, top-3 annotated)
* `population_heatmap.png` → heatmap of allocated people per grid cell

---

## Key parameters (in `V2Config`)

* **Global/population**

  * `total_population` (e.g., `5_000_000`)
  * `seed`
* **Region/discretization**

  * `bbox_km = (minx, miny, maxx, maxy)`
  * `grid_res_km` (cell size; smaller ⇒ more cells/finer detail)
* **Density field**

  * `n_centers` (mixture components)
  * `center_sigma_km_min`, `center_sigma_km_max` (spread per center)
  * `baseline_frac` (uniform floor)
  * `noise_amp`, `noise_grid` (low-freq variability)
* **Peak detection**

  * `peaks_percentile` (initial threshold)
  * `min_peak_separation_km`
* **Gravity assignment**

  * `gamma` (distance exponent), `eps_km` (softening)
* **City pruning / constraints**

  * `min_city_pop` (e.g., `1_000`)
  * `n_cities` (minimum number of cities to produce)
  * `peaks_percentile_floor`, `separation_shrink`, `max_relax_iters` (how we relax detection to grow city count)

---

## Validation

* **Population conservation:** `sum(nodes.pop) == total_population`.
* **Threshold compliance:** `min(nodes.pop) ≥ min_city_pop`.
* **City count:** `len(nodes) ≥ n_cities` (otherwise raise with guidance).
* **Sanity metrics (stored in meta):** pop percentiles, number of peaks used, relaxation summary, grid info.

---

## Tuning tips

* Too few cities? Try **lower** `min_city_pop`, **finer** `grid_res_km`, **lower** `peaks_percentile`, **smaller** `min_peak_separation_km`, or **increase** `n_centers`/reduce `center_sigma`.
* Cities too fragmented? **Increase** `min_peak_separation_km` or **raise** `peaks_percentile`.
* Over-dominant mega-city? **Increase** `gamma` (stronger distance penalty) or **reduce** center sigmas near that core.

---

## Usage

* Run the V2 notebook cell; adjust `V2Config` (especially `total_population`, `grid_res_km`, `min_city_pop`, `n_cities`).
* Artifacts are written to `out_dir`.


In [11]:
"""
Jupyter notebook cell — Version 2 (Nodes only)
Population density → city nuclei via gravity-like clustering

Goal: Start from a total population spread over space with higher/lower density areas,
then let density peaks attract nearby population into city nuclei until only cities with
≥ min_city_pop remain (e.g., 1,000 inhabitants).

- Field: Mixture-of-Gaussians density + low‑frequency noise + uniform baseline
- Allocation: Multinomial over grid cells (so total people is exact)
- Peaks: Local maxima of the density with non‑maximum suppression (min separation)
- Assignment: Each cell’s population goes to the peak with max attractiveness
             A_j = weight_j / (distance + eps)^gamma, weight_j ∝ peak density
- Pruning: Iteratively remove cities with pop < min_city_pop and reassign their cells
- City center: Population‑weighted centroid of assigned cells
- Outputs: nodes.csv, meta.json, preview.png (color = population; top‑3 annotated)

Usage: run this cell. Edit `V2Config` at the end.
"""
from __future__ import annotations

import json
import os
import time
import hashlib
from dataclasses import dataclass
from typing import Tuple, Dict, Any, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import StrMethodFormatter


# ------------------------------
# Config
# ------------------------------
@dataclass
class V2Config:
    seed: int = 42
    total_population: int = 5_000_000
    bbox_km: Tuple[float, float, float, float] = (0.0, 0.0, 200.0, 200.0)  # (minx, miny, maxx, maxy)

    # Grid
    grid_res_km: float = 2.0  # cell size (km). 200 km / 2 km ⇒ 100×100 grid

    # Density mixture (centers auto if None)
    n_centers: int = 4
    center_sigma_km_min: float = 12.0
    center_sigma_km_max: float = 35.0
    baseline_frac: float = 0.05   # baseline level as a fraction of mean Gaussian field

    # Low‑frequency noise (coarse grid upsampled bilinearly)
    noise_amp: float = 0.30       # multiplicative amplitude; 0.3 ⇒ ×(1±0.3)
    noise_grid: Tuple[int, int] = (10, 10)  # (rows, cols) of coarse noise grid

    # Peak detection (on density field)
    peaks_percentile: float = 92.0  # keep local maxima above this percentile
    min_peak_separation_km: float = 8.0

    # Gravity assignment
    gamma: float = 1.7             # distance exponent in attractiveness
    eps_km: float = 0.5            # small distance softening (km)

    # City pruning
    min_city_pop: int = 1_000

    n_cities: int = 8                    # target minimum number of cities
    peaks_percentile_floor: float = 60.0 # how far we can relax the threshold
    separation_shrink: float = 0.85      # shrink factor for min peak separation per relax step
    max_relax_iters: int = 10            # max relax attempts to reach n_cities


    # Output & metadata
    out_dir: str = "maps/sv1.2/dv0.1_v2_density_cities"
    crs: str = "EPSG:3857"
    schema_version: str = "1.2"  # optional new cols (`radius_km`, `n_cells_assigned`)
    dataset_version: str = "0.1"


# ------------------------------
# Helpers
# ------------------------------

def set_seed(seed: int) -> None:
    np.random.seed(seed)


def _bbox_arrays(cfg: V2Config):
    minx, miny, maxx, maxy = cfg.bbox_km
    W, H = maxx - minx, maxy - miny
    nx = int(np.ceil(W / cfg.grid_res_km))
    ny = int(np.ceil(H / cfg.grid_res_km))
    x = minx + (np.arange(nx) + 0.5) * cfg.grid_res_km
    y = miny + (np.arange(ny) + 0.5) * cfg.grid_res_km
    X, Y = np.meshgrid(x, y)  # shape (ny, nx)
    return X, Y, x, y, nx, ny


def _dirichlet_weights(k: int) -> np.ndarray:
    w = np.random.rand(k)
    w = w + 0.01  # avoid zeros
    return w / w.sum()


def _upsample_bilinear(coarse: np.ndarray, ny: int, nx: int) -> np.ndarray:
    """Simple bilinear upsample using two 1D interpolations (no SciPy)."""
    cy, cx = coarse.shape
    x_old = np.linspace(0.0, 1.0, cx)
    x_new = np.linspace(0.0, 1.0, nx)
    # interp along x for each row
    tmp = np.array([np.interp(x_new, x_old, coarse[i, :]) for i in range(cy)])  # (cy, nx)
    y_old = np.linspace(0.0, 1.0, cy)
    y_new = np.linspace(0.0, 1.0, ny)
    # interp along y for each column
    out = np.array([np.interp(y_new, y_old, tmp[:, j]) for j in range(nx)]).T  # (ny, nx)
    return out


def generate_density_field(cfg: V2Config) -> Tuple[np.ndarray, Dict[str, Any]]:
    X, Y, x, y, nx, ny = _bbox_arrays(cfg)

    # Mixture of Gaussians
    centers = np.column_stack([
        np.random.uniform(x.min(), x.max(), size=cfg.n_centers),
        np.random.uniform(y.min(), y.max(), size=cfg.n_centers),
    ])
    sigmas = np.random.uniform(cfg.center_sigma_km_min, cfg.center_sigma_km_max, size=cfg.n_centers)
    weights = _dirichlet_weights(cfg.n_centers)

    G = np.zeros((ny, nx), dtype=float)
    for (cx, cy), s, w in zip(centers, sigmas, weights):
        G += w * np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * s * s))

    # Baseline
    baseline = cfg.baseline_frac * (G.mean() + 1e-9)

    # Low-frequency noise
    if cfg.noise_amp > 0:
        ngy, ngx = cfg.noise_grid
        coarse = np.random.rand(ngy, ngx)
        noise = _upsample_bilinear(coarse, ny, nx)
        noise = (noise - 0.5) * 2.0  # ~[-1,1]
        field = (G + baseline) * (1.0 + cfg.noise_amp * noise)
        field = np.clip(field, a_min=baseline * 0.1, a_max=None)
    else:
        field = G + baseline

    info = {
        "centers": centers.tolist(),
        "sigmas": sigmas.tolist(),
        "weights": weights.tolist(),
        "baseline": baseline,
        "noise_amp": cfg.noise_amp,
        "noise_grid": cfg.noise_grid,
    }
    return field, info


def _find_local_maxima(field: np.ndarray, percentile: float, min_sep_cells: int) -> List[Tuple[int, int, float]]:
    ny, nx = field.shape
    thr = np.percentile(field, percentile)
    peaks: List[Tuple[int, int, float]] = []
    for i in range(1, ny - 1):
        for j in range(1, nx - 1):
            v = field[i, j]
            if v < thr:
                continue
            nb = field[i-1:i+2, j-1:j+2]
            if v >= nb.max():
                peaks.append((i, j, float(v)))
    # Non-maximum suppression by min_sep_cells (greedy)
    peaks.sort(key=lambda t: t[2], reverse=True)
    accepted: List[Tuple[int, int, float]] = []
    for i, j, v in peaks:
        ok = True
        for ia, ja, _ in accepted:
            if (i - ia) ** 2 + (j - ja) ** 2 < (min_sep_cells ** 2):
                ok = False
                break
        if ok:
            accepted.append((i, j, v))
    return accepted


def _assign_cells_to_peaks(counts: np.ndarray, field: np.ndarray, peaks: List[Tuple[int, int, float]], cfg: V2Config) -> np.ndarray:
    """Return array of shape (ny, nx) with city index per cell (−1 if no peaks)."""
    ny, nx = counts.shape
    if not peaks:
        return -np.ones((ny, nx), dtype=int)

    # Precompute peak weights and coordinates in km
    X, Y, x, y, nx2, ny2 = _bbox_arrays(cfg)
    assert nx2 == nx and ny2 == ny

    px = np.array([x[int(j)] for (_, j, _) in peaks])  # careful: peaks store (row=i, col=j)
    py = np.array([y[int(i)] for (i, _, _) in peaks])

    peak_weight = np.array([v for (_, _, v) in peaks], dtype=float)
    peak_weight = peak_weight / (peak_weight.max() + 1e-12)

    # We compute attractiveness for each peak: w / (dist + eps)^gamma
    eps2 = (cfg.eps_km ** 2)

    # Flatten coordinates for vectorization
    XX = X.reshape(-1)
    YY = Y.reshape(-1)
    counts_flat = counts.reshape(-1)

    # Only consider cells with people
    active_idx = np.where(counts_flat > 0)[0]
    XXa = XX[active_idx][:, None]
    YYa = YY[active_idx][:, None]

    # distances to peaks (active cells × n_peaks)
    dx = XXa - px[None, :]
    dy = YYa - py[None, :]
    dist2 = dx * dx + dy * dy

    attractiveness = peak_weight[None, :] / np.power(dist2 + eps2, cfg.gamma / 2.0)
    best = np.argmax(attractiveness, axis=1)

    assign = -np.ones(XX.shape[0], dtype=int)
    assign[active_idx] = best
    return assign.reshape(ny, nx)


def _city_stats_from_assignment(assign: np.ndarray, counts: np.ndarray, cfg: V2Config) -> Tuple[pd.DataFrame, Dict[int, np.ndarray]]:
    """Compute city populations, centroids, radius, and keep cell masks per city."""
    X, Y, *_ = _bbox_arrays(cfg)
    ny, nx = counts.shape

    city_ids = np.unique(assign[assign >= 0])
    masks: Dict[int, np.ndarray] = {}
    rows = []
    for cid in city_ids:
        mask = assign == cid
        pop = int(counts[mask].sum())
        if pop <= 0:
            continue
        masks[cid] = mask
        # Pop-weighted centroid
        w = counts[mask].astype(float)
        xs = X[mask]
        ys = Y[mask]
        x_c = float((w * xs).sum() / w.sum())
        y_c = float((w * ys).sum() / w.sum())
        n_cells = int(mask.sum())
        area_km2 = n_cells * (cfg.grid_res_km ** 2)
        radius_km = float(np.sqrt(area_km2 / np.pi))
        rows.append({"city_id": int(cid), "x_km": x_c, "y_km": y_c, "pop": pop, "n_cells_assigned": n_cells, "radius_km": radius_km})

    df = pd.DataFrame(rows).sort_values("pop", ascending=False).reset_index(drop=True)
    return df, masks


def _prune_and_reassign(assign: np.ndarray, counts: np.ndarray, peaks: List[Tuple[int, int, float]], cfg: V2Config) -> Tuple[np.ndarray, List[int]]:
    """Iteratively remove cities below threshold and reassign their cells to survivors."""
    ny, nx = counts.shape
    while True:
        df, masks = _city_stats_from_assignment(assign, counts, cfg)
        if df.empty:
            raise RuntimeError("No cities formed — check parameters.")
        low = df[df["pop"] < cfg.min_city_pop]
        if low.empty or len(df) == 1:
            # Done
            survivors = df["city_id"].tolist()
            return assign, survivors
        # Remove the smallest city under threshold
        remove_id = int(low.sort_values("pop").iloc[0]["city_id"])
        # Reassign its cells to the best among survivors
        survivors = [int(cid) for cid in df["city_id"].tolist() if cid != remove_id]

        # Build survivors peak subset
        surv_peaks = [peaks[cid] for cid in survivors]
        # Temporarily set these cells to -1 to be reassigned
        rem_mask = masks[remove_id]
        assign[rem_mask] = -1

        # Reassign only the removed cells by recomputing best survivor for those cells
        # Compute attractiveness for survivors
        X, Y, x, y, nx2, ny2 = _bbox_arrays(cfg)
        px = np.array([x[int(j)] for (i, j, v) in surv_peaks])
        py = np.array([y[int(i)] for (i, j, v) in surv_peaks])
        peak_weight = np.array([v for (i, j, v) in surv_peaks], dtype=float)
        peak_weight = peak_weight / (peak_weight.max() + 1e-12)
        eps2 = (cfg.eps_km ** 2)

        idx_cells = np.where(rem_mask.reshape(-1))[0]
        XX = X.reshape(-1)[idx_cells][:, None]
        YY = Y.reshape(-1)[idx_cells][:, None]
        dx = XX - px[None, :]
        dy = YY - py[None, :]
        dist2 = dx * dx + dy * dy
        attractiveness = peak_weight[None, :] / np.power(dist2 + eps2, cfg.gamma / 2.0)
        best = np.argmax(attractiveness, axis=1)
        reassigned = np.array([survivors[b] for b in best], dtype=int)

        # Write back
        flat_assign = assign.reshape(-1)
        flat_assign[idx_cells] = reassigned
        assign = flat_assign.reshape(ny, nx)


# ------------------------------
# Main generator / validator / saver
# ------------------------------

def generate_nodes_v2(cfg: V2Config) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Guarantee >= cfg.n_cities with each city >= cfg.min_city_pop by:
      1) detecting density peaks (relaxing thresholds if needed),
      2) assigning by gravity,
      3) pruning tiny cities,
      4) if still short, SPLITTING the largest city into two (by weighted median along the widest axis)
         and repeating assignment until the target count is reached or no city can be split further.
    """
    set_seed(cfg.seed)

    # ----- Density + allocation -----
    field, field_info = generate_density_field(cfg)
    probs = field / field.sum()
    X, Y, x, y, nx, ny = _bbox_arrays(cfg)
    counts = np.random.multinomial(cfg.total_population, probs.reshape(-1)).reshape(ny, nx)

    def find_peaks(percentile: float, sep_cells: int) -> List[Tuple[int, int, float]]:
        return _find_local_maxima(field, percentile, max(1, sep_cells))

    def assign_and_stats(peaks: List[Tuple[int, int, float]]):
        assign = _assign_cells_to_peaks(counts, field, peaks, cfg)
        df_cities, masks = _city_stats_from_assignment(assign, counts, cfg)
        return assign, df_cities, masks

    def prune_below_threshold(assign: np.ndarray, peaks: List[Tuple[int, int, float]]):
        # use existing prune (merges tiny cities into survivors)
        assign2, survivors = _prune_and_reassign(assign.copy(), counts, peaks, cfg)
        df2, masks2 = _city_stats_from_assignment(assign2, counts, cfg)
        return assign2, df2, masks2

    def split_city_mask(mask: np.ndarray) -> List[Tuple[int, int, float]] | None:
        """Split one city into two by weighted median along widest axis; return two new peak tuples (i,j,value)."""
        w = counts[mask].astype(float)
        if w.sum() < 2 * cfg.min_city_pop or mask.sum() < 2:
            return None
        xs = X[mask]; ys = Y[mask]
        # choose axis with larger variance
        varx, vary = np.var(xs, ddof=0), np.var(ys, ddof=0)
        coord = xs if varx >= vary else ys
        order = np.argsort(coord)
        w_sorted = w[order]
        xs_sorted, ys_sorted = xs[order], ys[order]
        csum = np.cumsum(w_sorted)
        # cut near half but ensure both sides >= min_city_pop
        total = csum[-1]
        cut_idx = np.searchsorted(csum, total / 2.0)
        # expand cut to meet threshold
        left_ok = lambda k: csum[k] >= cfg.min_city_pop
        right_ok = lambda k: (total - csum[k]) >= cfg.min_city_pop
        k = int(np.clip(cut_idx, 1, len(w_sorted) - 2))
        moved = True
        while moved:
            moved = False
            if not left_ok(k):
                k += 1; moved = True
            if not right_ok(k):
                k -= 1; moved = True
            if k <= 0 or k >= len(w_sorted) - 1:
                return None  # cannot split with thresholds
        # weighted centroids for two parts
        wL, wR = w_sorted[:k], w_sorted[k:]
        xL = float(np.average(xs_sorted[:k], weights=wL))
        yL = float(np.average(ys_sorted[:k], weights=wL))
        xR = float(np.average(xs_sorted[k:],  weights=wR))
        yR = float(np.average(ys_sorted[k:],  weights=wR))
        # snap to nearest grid cells
        jL, iL = int(np.argmin(np.abs(x - xL))), int(np.argmin(np.abs(y - yL)))
        jR, iR = int(np.argmin(np.abs(x - xR))), int(np.argmin(np.abs(y - yR)))
        pL = (iL, jL, float(field[iL, jL]))
        pR = (iR, jR, float(field[iR, jR]))
        return [pL, pR]

    # ---- 1) initial peaks with relaxation ----
    current_pct = float(getattr(cfg, "peaks_percentile", 92.0))
    current_sep = int(round(cfg.min_peak_separation_km / cfg.grid_res_km))
    pct_floor = float(getattr(cfg, "peaks_percentile_floor", 60.0))
    shrink = float(getattr(cfg, "separation_shrink", 0.85))
    max_relax = int(getattr(cfg, "max_relax_iters", 10))

    peaks = find_peaks(current_pct, current_sep)
    for _ in range(max_relax):
        assign, df, masks = assign_and_stats(peaks)
        # prune tiny cities (merge) only if we still have >= n_cities afterwards
        assign_p, df_p, masks_p = prune_below_threshold(assign, peaks)
        df = df_p; masks = masks_p; assign = assign_p
        if len(df) >= cfg.n_cities:
            break
        # relax detection if we can
        new_pct = max(pct_floor, current_pct - 5.0)
        new_sep = max(1, int(np.ceil(current_sep * shrink)))
        new_peaks = find_peaks(new_pct, new_sep)
        if len(new_peaks) > len(peaks):  # only accept if we actually got more
            peaks, current_pct, current_sep = new_peaks, new_pct, new_sep
        else:
            break  # no further improvement

    # ---- 2) enforce minimum by splitting largest cities if needed ----
    # always recompute with current peaks
    assign, df, masks = assign_and_stats(peaks)
    # prune tiny ones first (merge them up)
    assign, df, masks = prune_below_threshold(assign, peaks)

    # if still short, repeatedly split the largest splittable city
    attempts = 0
    while len(df) < cfg.n_cities:
        attempts += 1
        if attempts > 200:  # safety
            break
        # pick largest city that can be split
        df_sorted = df.sort_values("pop", ascending=False)
        split_done = False
        for _, row in df_sorted.iterrows():
            cid = int(row["city_id"])
            mask = masks[cid]
            new_two = split_city_mask(mask)
            if new_two is None:
                continue
            # replace this city's peak with two new peaks
            peaks = [p for idx, p in enumerate(peaks) if idx != cid] + new_two
            # reassign & prune
            assign, df, masks = assign_and_stats(peaks)
            assign, df, masks = prune_below_threshold(assign, peaks)
            split_done = True
            if len(df) >= cfg.n_cities:
                break
        if not split_done:
            break  # no city can be split further while respecting min_city_pop

    if len(df) < cfg.n_cities:
        raise RuntimeError(f"Could not reach the target number of cities (got {len(df)}, want {cfg.n_cities}). "
                           f"Try decreasing min_city_pop, increasing total_population, or using finer grid_res_km.")

    # ---- finalize nodes ----
    df_cities = df
    assert int(df_cities["pop"].sum()) == int(cfg.total_population)

    df_nodes = df_cities.rename(columns={"city_id": "id"})[
        ["id", "x_km", "y_km", "pop", "n_cells_assigned", "radius_km"]
    ].copy()
    df_nodes["id"] = df_nodes["id"].astype(int)

    extras = {
        "grid": {"nx": nx, "ny": ny, "res_km": cfg.grid_res_km},
        "peaks_count_used": int(len(peaks)),
        "relaxation": {
            "final_percentile": current_pct,
            "final_sep_cells": current_sep,
            "target_n_cities": int(cfg.n_cities),
            "achieved_n_cities": int(len(df_nodes)),
        },
        "field_info": field_info,
        "counts": counts,
        "field": field,
    }
    return df_nodes, extras

def validate_nodes(df_nodes: pd.DataFrame, cfg: V2Config) -> Dict[str, Any]:
    metrics: Dict[str, Any] = {}
    n = len(df_nodes)
    if n == 0:
        raise AssertionError("No cities produced")
    if (df_nodes["pop"] < cfg.min_city_pop).any():
        raise AssertionError("Found a city below min_city_pop after pruning")
    metrics["n_cities"] = int(n)
    metrics["total_population"] = int(df_nodes["pop"].sum())
    metrics["pop_percentiles"] = {q: int(np.percentile(df_nodes["pop"], q)) for q in (5, 25, 50, 75, 90, 95, 99)}
    return metrics

def preview_nodes(df_nodes: pd.DataFrame, cfg: V2Config, save_path: str) -> None:
    minx, miny, maxx, maxy = cfg.bbox_km

    plt.figure(figsize=(6, 6))
    vmax = df_nodes["pop"].max()
    sc = plt.scatter(
        df_nodes["x_km"], df_nodes["y_km"],
        s=10 + 90 * np.sqrt(df_nodes["pop"].values / vmax),
        c=df_nodes["pop"].values.astype(float),
    )
    cbar = plt.colorbar(sc)
    cbar.set_label("Population")
    try:
        cbar.ax.yaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))
    except Exception:
        pass

    # Annotate top‑3 by population
    top3 = df_nodes.nlargest(3, "pop").copy()
    dx = 0.01 * (maxx - minx)
    dy = 0.01 * (maxy - miny)
    for _, row in top3.iterrows():
        label = f"{int(row['pop']):,}"
        plt.text(
            row["x_km"] + dx,
            row["y_km"] + dy,
            label,
            fontsize=8,
            ha="left",
            va="bottom",
            bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
        )

    plt.title("Nodes — V2 (density → gravity cities)")
    plt.xlabel("x (km)")
    plt.ylabel("y (km)")
    plt.xlim(minx, maxx)
    plt.ylim(miny, maxy)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def compute_metrics_hash(metrics: Dict[str, Any]) -> str:
    blob = json.dumps(metrics, sort_keys=True).encode("utf-8")
    return hashlib.sha256(blob).hexdigest()[:16]

def save_artifacts(df_nodes: pd.DataFrame, cfg: V2Config, metrics: Dict[str, Any], extras: Dict[str, Any]) -> Dict[str, str]:
    import matplotlib.pyplot as plt
    from matplotlib.ticker import StrMethodFormatter

    os.makedirs(cfg.out_dir, exist_ok=True)
    nodes_path = os.path.join(cfg.out_dir, "nodes.csv")
    preview_path = os.path.join(cfg.out_dir, "preview.png")
    heatmap_path = os.path.join(cfg.out_dir, "population_heatmap.png")
    meta_path = os.path.join(cfg.out_dir, "meta.json")

    # Save nodes CSV
    df_nodes.to_csv(nodes_path, index=False)

    # Existing scatter preview of cities
    preview_nodes(df_nodes, cfg, preview_path)

    # --- NEW: heatmap of individual distribution (grid cell counts) ---
    counts = extras.get("counts", None)  # expected shape (ny, nx)
    if counts is not None:
        minx, miny, maxx, maxy = cfg.bbox_km
        plt.figure(figsize=(6, 6))
        # imshow with spatial extent to align axes with km coordinates
        plt.imshow(
            counts,
            origin="lower",
            extent=[minx, maxx, miny, maxy],
            aspect="equal",
        )
        cbar = plt.colorbar()
        cbar.set_label("People per cell")
        try:
            cbar.ax.yaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))
        except Exception:
            pass
        plt.title("Population distribution (heatmap)")
        plt.xlabel("x (km)")
        plt.ylabel("y (km)")
        plt.tight_layout()
        plt.savefig(heatmap_path, dpi=150)
        plt.close()

    # Meta
    meta = {
        "schema_version": cfg.schema_version,
        "dataset_version": cfg.dataset_version,
        "crs": cfg.crs,
        "seed": cfg.seed,
        "generator": {
            "name": "nodes_v2_density_to_cities",
            "params": {
                "total_population": cfg.total_population,
                "bbox_km": cfg.bbox_km,
                "grid_res_km": cfg.grid_res_km,
                "n_centers": cfg.n_centers,
                "center_sigma_km_min": cfg.center_sigma_km_min,
                "center_sigma_km_max": cfg.center_sigma_km_max,
                "baseline_frac": cfg.baseline_frac,
                "noise_amp": cfg.noise_amp,
                "noise_grid": cfg.noise_grid,
                "peaks_percentile": cfg.peaks_percentile,
                "min_peak_separation_km": cfg.min_peak_separation_km,
                "gamma": cfg.gamma,
                "eps_km": cfg.eps_km,
                "min_city_pop": cfg.min_city_pop,
            },
        },
        "extras_summary": {
            "grid": extras.get("grid"),
            "peaks_count": extras.get("peaks_count"),
        },
        "artifacts": {
            "nodes_csv": nodes_path,
            "preview_png": preview_path,
            "population_heatmap_png": heatmap_path if counts is not None else None,
        },
        "metrics": metrics,
        "metrics_hash": compute_metrics_hash(metrics),
        "created_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    }
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    return {"nodes": nodes_path, "preview": preview_path, "heatmap": (heatmap_path if counts is not None else None), "meta": meta_path}

# ------------------------------
# Orchestration
# ------------------------------

def main(cfg: V2Config | None = None) -> pd.DataFrame:
    cfg = cfg or V2Config()
    set_seed(cfg.seed)

    df_nodes, extras = generate_nodes_v2(cfg)
    metrics = validate_nodes(df_nodes, cfg)
    paths = save_artifacts(df_nodes, cfg, metrics, extras)

    print("\n[Nodes V2] Build complete:\n" + "-" * 40)
    print(f"Cities: {len(df_nodes)} | Total pop: {metrics['total_population']:,}")
    print(f"Saved: nodes → {paths['nodes']}\n       preview → {paths['preview']}\n       meta → {paths['meta']}")
    print(f"Metrics hash: {compute_metrics_hash(metrics)}")
    return df_nodes


# ------------------------------
# Run
# ------------------------------
_cfg = V2Config(
    seed=42,
    total_population=5_000_000,
    bbox_km=(0.0, 0.0, 200.0, 200.0),
    grid_res_km=2.0,
    n_centers=4,
    center_sigma_km_min=12.0,
    center_sigma_km_max=35.0,
    baseline_frac=0.05,
    noise_amp=0.30,
    noise_grid=(10, 10),
    peaks_percentile=92.0,
    min_peak_separation_km=8.0,
    gamma=1.7,
    eps_km=0.5,
    min_city_pop=1_000,
    n_cities=20,
    peaks_percentile_floor=60.0,
    separation_shrink=0.85,
    max_relax_iters=10,
    out_dir="maps/sv1.2/dv0.1_v2_density_cities",
)

_ = main(_cfg)


[Nodes V2] Build complete:
----------------------------------------
Cities: 20 | Total pop: 5,000,000
Saved: nodes → maps/sv1.2/dv0.1_v2_density_cities\nodes.csv
       preview → maps/sv1.2/dv0.1_v2_density_cities\preview.png
       meta → maps/sv1.2/dv0.1_v2_density_cities\meta.json
Metrics hash: ab376ddd767689b9
