In [None]:
from pathlib import Path
from collections import defaultdict
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.compute as pc
import pyarrow.parquet as pq

from setup_plot import setup_local, get_colors, get_markers


SET1_PARQUET     = "ldms.parquet"
SACCT_CSV        = Path("sacct.csv")
NODES_80GB_FILE  = Path("nodes_80gb.txt")

CAP_40_MIB = 40960
CAP_80_MIB = 81920
BATCH_SIZE = 1_000_000


def load_nodes_from_file(path: Path) -> set[str]:
    """
    Reads hostnames from a text file (one hostname per line).
    """
    if not path.exists():
        raise FileNotFoundError(f"80GB node list not found: {path.resolve()}")
    nodes: set[str] = set()
    for ln in path.read_text(encoding="utf-8", errors="replace").splitlines():
        ln = ln.strip()
        if not ln or ln.startswith("#"):
            continue
        for tok in ln.replace(",", " ").split():
            tok = tok.strip()
            if tok and not tok.startswith("#"):
                nodes.add(tok)
    return nodes


def read_pipe_csv(path: Path) -> pd.DataFrame:
    """
    Reads a pipe-delimited sacct CSV.
    """
    if not path.exists():
        raise FileNotFoundError(f"sacct file not found: {path.resolve()}")

    with path.open("r", encoding="utf-8", errors="replace") as f:
        header = f.readline().rstrip("\n")
    names = header.split("|")
    ncols = len(names)

    rows = []
    with path.open("r", encoding="utf-8", errors="replace") as f:
        next(f)
        for ln in f:
            parts = ln.rstrip("\n").split("|", ncols - 1)
            if len(parts) < ncols:
                parts += [""] * (ncols - len(parts))
            rows.append(parts)

    df = pd.DataFrame(rows, columns=names)
    if "JobID" not in df.columns:
        raise KeyError("sacct CSV missing required column: JobID")
    df["JobID"] = df["JobID"].astype(str)
    return df


def requested_hbm_from_constraints(s: str):
    s = (s or "").lower()
    toks = [t.strip() for t in re.split(r"[&,+]", s) if t.strip()]
    if "hbm80g" in toks:
        return "requested_80gb"
    if "hbm40g" in toks:
        return "requested_40gb"
    return "unspecified"


def requested_hbm_from_submitline(s: str):
    s = (s or "")
    m = re.search(r"(?:-C|--constraint)\s*=?\s*([^\s]+)", s)
    return requested_hbm_from_constraints(m.group(1)) if m else "unspecified"


nodes_80gb_set = load_nodes_from_file(NODES_80GB_FILE)
nodes80_arr = pa.array(sorted(nodes_80gb_set), type=pa.string())

_OCC_TABLE = None


def build_job_occ_parquet():
    global _OCC_TABLE

    job_cap_max = defaultdict(int)
    job_occ_max = defaultdict(float)

    dset = ds.dataset(SET1_PARQUET, format="parquet")
    scanner = dset.scanner(
        columns=["JobID", "hostname", "nersc_ldms_dcgm_fb_used"],
        batch_size=BATCH_SIZE,
        use_threads=True,
    )

    nb = 0
    for rb in scanner.to_batches():
        nb += 1

        job  = rb["JobID"]
        host = rb["hostname"]
        used = rb["nersc_ldms_dcgm_fb_used"]

        job_s  = pc.cast(job, pa.string())
        used_f = pc.cast(used, pa.float64())
        used_f = pc.fill_null(used_f, pa.scalar(0.0))

        is80 = pc.is_in(host, value_set=nodes80_arr)
        cap_i64 = pc.if_else(
            is80,
            pa.scalar(CAP_80_MIB, pa.int64()),
            pa.scalar(CAP_40_MIB, pa.int64()),
        )

        occ = pc.divide(used_f, pc.cast(cap_i64, pa.float64()))

        t = pa.table({"JobID": job_s, "cap_nom": cap_i64, "occ_nom": occ})
        gb = t.group_by("JobID").aggregate([("cap_nom", "max"), ("occ_nom", "max")])

        jarr   = gb["JobID"].to_pylist()
        caparr = gb["cap_nom_max"].to_pylist()
        occarr = gb["occ_nom_max"].to_pylist()

        for J, C, O in zip(jarr, caparr, occarr):
            if J is None:
                continue
            J = str(J)

            c = int(C) if C is not None else 0
            o = float(O) if O is not None else float("nan")

            if c > job_cap_max[J]:
                job_cap_max[J] = c
            if o > job_occ_max[J]:
                job_occ_max[J] = o

    jobids = sorted(job_cap_max.keys())
    out_tbl = pa.table({
        "JobID": pa.array(jobids, type=pa.string()),
        "capacity_nom_mib": pa.array([job_cap_max[j] for j in jobids], type=pa.int64()),
        "max_occ_nom": pa.array([job_occ_max.get(j, float("nan")) for j in jobids], type=pa.float64()),
    })

    _OCC_TABLE = out_tbl


def plot_fb_used_80_req80():
    T = _OCC_TABLE.combine_chunks()
    occ = T.to_pandas()
    occ["JobID"] = occ["JobID"].astype(str)

    job_cap     = occ.groupby("JobID", as_index=True)["capacity_nom_mib"].max()
    job_max_occ = occ.groupby("JobID", as_index=True)["max_occ_nom"].max()

    sacct_df = read_pipe_csv(SACCT_CSV)

    if "Constraints" in sacct_df.columns:
        sacct_df["requested_gpu_mem"] = sacct_df["Constraints"].apply(requested_hbm_from_constraints)
    else:
        sacct_df["requested_gpu_mem"] = "unspecified"

    if "SubmitLine" in sacct_df.columns:
        mask_unspec = sacct_df["requested_gpu_mem"].eq("unspecified")
        sacct_df.loc[mask_unspec, "requested_gpu_mem"] = sacct_df.loc[mask_unspec, "SubmitLine"].apply(requested_hbm_from_submitline)

    req80_ids = set(
        sacct_df.loc[sacct_df["requested_gpu_mem"].eq("requested_80gb"), "JobID"].astype(str)
    )

    mask_80placed = job_cap.eq(CAP_80_MIB)
    keep_ids = job_cap.index[mask_80placed & job_cap.index.isin(req80_ids)]
    vals = job_max_occ.loc[keep_ids].values

    vals_pct = np.clip(vals * 100.0, 0, 100)

    edges = np.linspace(0, 100, 11)
    counts, edges = np.histogram(vals_pct, bins=edges)

    cdf_vals = (np.cumsum(counts) / counts.sum() * 100.0) if counts.sum() > 0 else np.zeros_like(counts, dtype=float)
    bin_left  = edges[:-1]
    bin_width = np.diff(edges)
    bin_cent  = bin_left + bin_width / 2.0

    setup_local()
    colors  = get_colors()
    markers = get_markers()

    fig, ax1 = plt.subplots()

    ax1.bar(bin_left, counts, width=bin_width, align="edge",
            color=colors[2], edgecolor="black", label="Number of jobs")

    ax1.set_xlabel("Peak HBM usage (% of capacity)", fontsize=17)
    ax1.set_ylabel("Number of jobs", fontsize=17)
    ax1.set_xlim(0, 100)
    ax1.set_xticks(np.arange(0, 101, 10))
    ax1.tick_params(axis="x", labelsize=16)
    ax1.tick_params(axis="y", labelsize=16)
    ax1.grid(axis="y", linestyle="--", alpha=0.7)
    ax1.set_title("Distribution of jobs by HBM_USED (80 GB)", fontsize=18, pad=18)

    ax2 = ax1.twinx()
    ax2.plot(bin_cent, cdf_vals, color=colors[0], marker=markers[2],
             label="CDF (number of jobs)", linewidth=2, clip_on=False)
    ax2.set_ylabel("Cumulative percentage (%)", fontsize=17)
    ax2.set_ylim(0, 100)
    ax2.set_yticks([0, 20, 40, 60, 80, 100])
    ax2.tick_params(axis="y", labelsize=16)

    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, loc="best", fontsize=16, framealpha=0.5)

    plt.tight_layout()
    plt.show()


build_job_occ_parquet()
plot_fb_used_80_req80()
