# Segments V1 (Edges only) — batch-aware greedy MST

**Goal**  
Create a **greedy MST (Prim)** per dataset and save it under a **solutions** subfolder
inside each dataset. Before creating a solution, the notebook **checks if it already exists**
and **skips** that dataset.

**What counts as a dataset?**  
A folder that contains a top-level `nodes.csv` (from your nodes-only generator).

**Outputs** (per dataset, inside `<dataset>/solutions/<solution_name>/`)
- `edges.csv` — columns: `u, v, length_km`
- `adjacency.csv` — `NxN`, 1 if edge exists, else 0
- `cost.csv` — `NxN`, edge length in km (`0` on diagonal; `inf` if no edge)
- `segments_preview.png`
- `solution_meta.json`

**Usage**
- **Option A (batch):** set `dataset_root` to your `00_Datasets` root.
- **Option B (single dataset):** leave `dataset_root=None` and set `nodes_csv_path`.

Only **new** solutions are created; existing ones are **skipped**.


In [1]:
"""
Jupyter notebook cell — Segments V1 (Edges only) — batch-aware greedy MST

Goal
-----
Create a **greedy MST (Prim)** per dataset and save it under a **solutions subfolder
inside each dataset**. Before creating a solution, the notebook **checks if it already exists**
and **skips** that dataset.

What counts as a dataset?
- A folder that contains a top-level `nodes.csv` (from your nodes-only generator).

Outputs (per dataset, inside `<dataset>/solutions/<solution_name>/`)
- `edges.csv`          : u, v, length_km
- `adjacency.csv`      : NxN, 1/0 connectivity
- `cost.csv`           : NxN, length_km (0 diag, inf if no edge)
- `segments_preview.png`
- `solution_meta.json`

Usage
-----
Option A (batch over many datasets): set `dataset_root` to your 00_Datasets root.
Option B (single dataset): leave `dataset_root=None` and set `nodes_csv_path`.

Only **new** solutions are created; existing ones are **skipped**.
"""
from __future__ import annotations

import os
import json
import time
from dataclasses import dataclass
from typing import List, Tuple, Dict, Iterable

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ------------------------------
# Config
# ------------------------------
@dataclass
class SegV1Config:
    seed: int = 42  # not used by Prim, kept for symmetry/metadata

    # Batch mode (set this to your datasets root; e.g., "...\\00_Datasets")
    dataset_root: str | None = None

    # Single-dataset mode (used only if dataset_root is None)
    nodes_csv_path: str | None = None

    # Where to write solutions inside each dataset
    solutions_subdir: str = "solutions"
    solution_name: str | None = None  # None → "v1_greedy_mst"

    # Plot settings
    label_lengths: bool = True


# ------------------------------
# Helpers
# ------------------------------
def _load_nodes(nodes_csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(nodes_csv_path)
    required = {"id", "x_km", "y_km"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"nodes.csv missing columns: {missing}")

    # sort by id and make sure ids are 0..N-1; if not, remap
    df = df.sort_values("id").reset_index(drop=True)
    if not np.array_equal(df["id"].to_numpy(), np.arange(len(df))):
        df.insert(0, "orig_id", df["id"].values)
        df["id"] = np.arange(len(df), dtype=int)
    df = df.set_index("id", drop=False)
    return df


def _pairwise_distances(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    X = x[:, None] - x[None, :]
    Y = y[:, None] - y[None, :]
    return np.sqrt(X * X + Y * Y)


def _solution_dir(dataset_dir: str, solutions_subdir: str, solution_name: str) -> str:
    sol_dir = os.path.join(dataset_dir, solutions_subdir, solution_name)
    os.makedirs(sol_dir, exist_ok=True)
    return sol_dir


def _solution_exists(sol_dir: str) -> bool:
    """Consider it 'exists' if edges.csv is present (lightweight check)."""
    return os.path.exists(os.path.join(sol_dir, "edges.csv"))


def _iter_datasets_with_nodes(dataset_root: str, solutions_subdir: str) -> Iterable[str]:
    """
    Yield dataset directories that contain a top-level nodes.csv.
    Skips any nested 'solutions' subfolders.
    """
    for root, dirs, files in os.walk(dataset_root):
        # do not descend into any 'solutions' directories
        dirs[:] = [d for d in dirs if d != solutions_subdir]
        if "nodes.csv" in files:
            yield root


# ------------------------------
# Core (Segments V1 — Prim's MST)
# ------------------------------
def build_mst_greedy_prim(D: np.ndarray) -> List[Tuple[int, int, float]]:
    """Prim's algorithm over the complete graph with metric distances."""
    n = D.shape[0]
    in_mst = np.zeros(n, dtype=bool)
    parent = -np.ones(n, dtype=int)
    key = np.full(n, np.inf)

    # start from node 0
    key[0] = 0.0
    for _ in range(n):
        # pick min key among nodes not yet in the MST
        u = int(np.argmin(np.where(in_mst, np.inf, key)))
        in_mst[u] = True
        # update keys of neighbors
        for v in range(n):
            if not in_mst[v] and D[u, v] < key[v]:
                key[v] = D[u, v]
                parent[v] = u

    # gather edges (skip root 0)
    edges: List[Tuple[int, int, float]] = []
    for v in range(1, n):
        u = int(parent[v])
        if u < 0:
            raise RuntimeError("Unexpected: MST parent missing — check input.")
        edges.append((u, v, float(D[u, v])))
    return edges


def edges_to_matrices(n: int, edges: List[Tuple[int, int, float]]) -> Tuple[np.ndarray, np.ndarray]:
    A = np.zeros((n, n), dtype=int)
    C = np.full((n, n), np.inf, dtype=float)
    np.fill_diagonal(C, 0.0)
    for u, v, w in edges:
        A[u, v] = A[v, u] = 1
        C[u, v] = C[v, u] = w
    return A, C


def preview_segments(nodes: pd.DataFrame,
                     edges: List[Tuple[int, int, float]],
                     save_path: str,
                     label_lengths: bool = True) -> None:
    minx, miny = nodes[["x_km", "y_km"]].min()
    maxx, maxy = nodes[["x_km", "y_km"]].max()

    plt.figure(figsize=(7, 7))
    for u, v, w in edges:
        x1, y1 = nodes.loc[u, ["x_km", "y_km"]]
        x2, y2 = nodes.loc[v, ["x_km", "y_km"]]
        plt.plot([x1, x2], [y1, y2])
        if label_lengths:
            xm, ym = (x1 + x2) / 2.0, (y1 + y2) / 2.0
            plt.text(xm, ym, f"{w:.1f}", fontsize=7, ha="center", va="center",
                     bbox=dict(boxstyle="round,pad=0.15", fc="white", ec="none", alpha=0.6))
    # nodes
    if "pop" in nodes.columns:
        s = 10 + 90 * np.sqrt(nodes["pop"].to_numpy() / nodes["pop"].max())
    else:
        s = np.full(len(nodes), 40.0)
    plt.scatter(nodes["x_km"], nodes["y_km"], s=s)

    plt.title("Segments — V1 greedy MST (numbers = length km)")
    plt.xlabel("x (km)"); plt.ylabel("y (km)")
    plt.xlim(minx, maxx); plt.ylim(miny, maxy)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.tight_layout(); plt.savefig(save_path, dpi=150); plt.close()


def write_solution(nodes: pd.DataFrame,
                   edges: List[Tuple[int, int, float]],
                   dataset_dir: str,
                   solutions_subdir: str,
                   solution_name: str,
                   label_lengths: bool,
                   seed: int) -> Dict[str, str]:
    sol_dir = _solution_dir(dataset_dir, solutions_subdir, solution_name)

    edges_path = os.path.join(sol_dir, "edges.csv")
    adj_path = os.path.join(sol_dir, "adjacency.csv")
    cost_path = os.path.join(sol_dir, "cost.csv")
    preview_path = os.path.join(sol_dir, "segments_preview.png")
    meta_path = os.path.join(sol_dir, "solution_meta.json")

    # Save artifacts
    pd.DataFrame(edges, columns=["u", "v", "length_km"]).to_csv(edges_path, index=False)
    A, C = edges_to_matrices(len(nodes), edges)
    pd.DataFrame(A).to_csv(adj_path, index=False, header=False)
    pd.DataFrame(C).to_csv(cost_path, index=False, header=False)
    preview_segments(nodes, edges, preview_path, label_lengths=label_lengths)

    meta = {
        "created_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "algorithm": "v1_greedy_mst",
        "seed": seed,
        "dataset_folder": dataset_dir,
        "solution_dir": sol_dir,
        "artifacts": {
            "edges_csv": edges_path,
            "adjacency_csv": adj_path,
            "cost_csv": cost_path,
            "preview_png": preview_path,
        },
        "n_nodes": int(len(nodes)),
        "n_edges": int(len(edges)),
        "total_length_km": float(sum(w for _, _, w in edges)),
    }
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    return {"solution_dir": sol_dir, "edges": edges_path, "adjacency": adj_path,
            "cost": cost_path, "preview": preview_path, "meta": meta_path}


# ------------------------------
# Orchestration
# ------------------------------
def run_for_dataset(dataset_dir: str, cfg: SegV1Config) -> Dict[str, str] | None:
    """Create greedy MST solution for a single dataset folder (if missing)."""
    nodes_csv = os.path.join(dataset_dir, "nodes.csv")
    if not os.path.exists(nodes_csv):
        return None

    solution_name = cfg.solution_name or "v1_greedy_mst"
    sol_dir = os.path.join(dataset_dir, cfg.solutions_subdir, solution_name)
    if _solution_exists(sol_dir):
        print(f"[SKIP] Existing solution found → {sol_dir}")
        return None

    # build solution
    nodes = _load_nodes(nodes_csv)
    D = _pairwise_distances(nodes["x_km"].to_numpy(), nodes["y_km"].to_numpy())
    edges = build_mst_greedy_prim(D)
    paths = write_solution(nodes, edges, dataset_dir, cfg.solutions_subdir, solution_name, cfg.label_lengths, cfg.seed)

    print(f"[OK]   Created solution → {paths['solution_dir']}")
    return paths


def main(cfg: SegV1Config) -> List[Dict[str, str]]:
    created: List[Dict[str, str]] = []

    if cfg.dataset_root:
        # Batch over every dataset under dataset_root
        print(f"Scanning datasets under: {cfg.dataset_root}")
        count = 0
        for dataset_dir in _iter_datasets_with_nodes(cfg.dataset_root, cfg.solutions_subdir):
            count += 1
            out = run_for_dataset(dataset_dir, cfg)
            if out:
                created.append(out)
        print(f"\n[Segments V1] Done. Datasets scanned: {count} | New solutions: {len(created)}")
        return created

    # Single dataset mode
    if not cfg.nodes_csv_path:
        raise ValueError("Provide either dataset_root (batch) or nodes_csv_path (single).")

    dataset_dir = os.path.dirname(cfg.nodes_csv_path)
    out = run_for_dataset(dataset_dir, cfg)
    return [out] if out else []


# ------------------------------
# Run
# ------------------------------
_cfg = SegV1Config(
    seed=42,
    # Option A: run over ALL datasets
    dataset_root=r"C:\Users\User\Documents\Code\traffic-optimization\00_Datasets",
    # Option B: single dataset (leave dataset_root=None and set nodes_csv_path)
    nodes_csv_path=None,  # r"...\some_dataset\nodes.csv"
    solutions_subdir="solutions",
    solution_name=None,   # None → "v1_greedy_mst"
    label_lengths=True,
)

_ = main(_cfg)


Scanning datasets under: C:\Users\User\Documents\Code\traffic-optimization\00_Datasets
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds01_seed42_tp3M_nc15\solutions\v1_greedy_mst
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds02_seed43_tp3.5M_nc18\solutions\v1_greedy_mst
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds03_seed44_tp4M_nc20\solutions\v1_greedy_mst
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds04_seed45_tp2.5M_nc22\solutions\v1_greedy_mst
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds05_seed46_tp5M_nc24\solutions\v1_greedy_mst
[OK]   Created solution → C:\Users\User\Documents\Code\traffic-optimization\00_Datasets\sv1.2\dv0.1_ds06_seed47_tp5.5M_nc26\solutions\v1_greedy_mst
[OK]   Created solution → C:\Us