In [None]:
from pathlib import Path
from typing import List, Tuple
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pyarrow as pa
import pyarrow.parquet as pq

from setup_plot import setup_local, get_colors


INPUT_SET1      = "ldms.parquet"
BATCH_SIZE      = 250_000

NODES_80GB_FILE = Path("nodes_80gb.txt")
SACCT_CSV       = Path("sacct.csv") 

PLOTS = ["fp32", "fp64vec", "tensor", "pseudo64"]

AI_MIN, AI_MAX = 1e-3, 1e3
BARS_PER_DECADE = 5

X_TICKS = [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3]
Y_TICKS_LEFT  = [0, 5, 10, 15, 20, 25]
Y_TICKS_RIGHT = [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]

# DCGM columns
DCGM_PREFIX = "nersc_ldms_dcgm_"
ACTIVE_COLS = [
    f"{DCGM_PREFIX}fp16_active",
    f"{DCGM_PREFIX}fp32_active",
    f"{DCGM_PREFIX}fp64_active",
    f"{DCGM_PREFIX}tensor_active",
    f"{DCGM_PREFIX}dram_active",
]
NEEDED_COLS: List[str] = ["JobID", "hostname"] + ACTIVE_COLS

# A100 peaks (FLOP/s)
PEAK_FLOPS_FP16        = 78.0e12
PEAK_FLOPS_FP32        = 19.5e12
PEAK_FLOPS_FP64_VECTOR = 9.7e12
PEAK_FLOPS_FP64_TENSOR = 19.5e12

# HBM roofs (B/s)
HBM_40 = 1.555e12
HBM_80 = 2.039e12

# Plot ridge/roofline constants (A100)
PEAKS_40 = {"FP32": PEAK_FLOPS_FP32, "FP64": PEAK_FLOPS_FP64_VECTOR, "TNSR": PEAK_FLOPS_FP64_TENSOR, "HBM": HBM_40}
PEAKS_80 = {"FP32": PEAK_FLOPS_FP32, "FP64": PEAK_FLOPS_FP64_VECTOR, "TNSR": PEAK_FLOPS_FP64_TENSOR, "HBM": HBM_80}
RIDGE_40 = {"FP32": PEAKS_40["FP32"]/HBM_40, "FP64": PEAKS_40["FP64"]/HBM_40, "TNSR": PEAKS_40["TNSR"]/HBM_40, "HBM": HBM_40}
RIDGE_80 = {"FP32": PEAKS_80["FP32"]/HBM_80, "FP64": PEAKS_80["FP64"]/HBM_80, "TNSR": PEAKS_80["TNSR"]/HBM_80, "HBM": HBM_80}

# Which AI column corresponds to each plot
AI_COLS = {
    "fp32":     "AI_fp32",
    "fp64vec":  "AI_fp64_vector",
    "tensor":   "AI_tensor_fp64",
    "pseudo64": "pseudo64_AI_tensor_fp64",
}


def load_nodes_from_file(path: Path) -> set[str]:
    """Read hostnames from a file (one per line)."""
    if not path.exists():
        raise FileNotFoundError(f"80GB node list not found: {path.resolve()}")
    nodes: set[str] = set()
    for ln in path.read_text(encoding="utf-8", errors="replace").splitlines():
        ln = ln.strip()
        if not ln or ln.startswith("#"):
            continue
        for tok in ln.replace(",", " ").split():
            tok = tok.strip()
            if tok and not tok.startswith("#"):
                nodes.add(tok)
    return nodes


def read_pipe_csv(path: Path) -> pd.DataFrame:
    """CSV read (SubmitLine may contain '|')."""
    if not path.exists():
        raise FileNotFoundError(f"sacct file not found: {path.resolve()}")

    with path.open("r", encoding="utf-8", errors="replace") as f:
        header = f.readline().rstrip("\n")
    names = header.split("|")
    ncols = len(names)

    rows = []
    with path.open("r", encoding="utf-8", errors="replace") as f:
        next(f) 
        for ln in f:
            parts = ln.rstrip("\n").split("|", ncols - 1)
            if len(parts) < ncols:
                parts += [""] * (ncols - len(parts))
            rows.append(parts)

    df = pd.DataFrame(rows, columns=names)
    df["JobID"] = df["JobID"].astype(str)
    return df


def requested_hbm_from_constraints(s: str):
    s = (s or "").lower()
    toks = [t.strip() for t in re.split(r"[&,+]", s) if t.strip()]
    if "hbm80g" in toks:
        return "requested_80gb"
    if "hbm40g" in toks:
        return "requested_40gb"
    return "unspecified"


def requested_hbm_from_submitline(s: str):
    s = (s or "")
    m = re.search(r"(?:-C|--constraint)\s*=?\s*([^\s]+)", s)
    return requested_hbm_from_constraints(m.group(1)) if m else "unspecified"


def load_req80_ids(sacct_csv: Path) -> set[str]:
    """Return JobIDs that explicitly requested 80GB (hbm80g) using Constraints + SubmitLine."""
    df = read_pipe_csv(sacct_csv)

    if "Constraints" in df.columns:
        df["requested_gpu_mem"] = df["Constraints"].apply(requested_hbm_from_constraints)
    else:
        df["requested_gpu_mem"] = "unspecified"

    if "SubmitLine" in df.columns:
        mask_unspec = df["requested_gpu_mem"].eq("unspecified")
        df.loc[mask_unspec, "requested_gpu_mem"] = df.loc[mask_unspec, "SubmitLine"].apply(requested_hbm_from_submitline)

    req80 = df.loc[df["requested_gpu_mem"].eq("requested_80gb"), "JobID"].astype(str).unique().tolist()
    return set(req80)


def make_log_edges(ai_min: float, ai_max: float, bars_per_decade: int) -> np.ndarray:
    lo = np.log10(ai_min)
    hi = np.log10(ai_max)
    nbins = int(round((hi - lo) * bars_per_decade))
    return np.logspace(lo, hi, nbins + 1)


def hist_from_parquet_edges(parquet_path: str, ai_col: str, capacity: float,
                            edges: np.ndarray, nodes80: set[str],
                            req_jobids: set[str] | None = None) -> Tuple[np.ndarray, int]:
    pf = pq.ParquetFile(parquet_path)
    present = set(pf.schema_arrow.names)
    missing = [c for c in NEEDED_COLS if c not in present]
    if missing:
        raise KeyError(f"Missing required columns in {parquet_path}: {missing}")

    counts = np.zeros(len(edges) - 1, dtype=float)
    total_inrange = 0

    for batch in pf.iter_batches(columns=NEEDED_COLS, batch_size=BATCH_SIZE):
        df = batch.to_pandas()
        if len(df) == 0:
            continue

        fp16 = pd.to_numeric(df[f"{DCGM_PREFIX}fp16_active"], errors="coerce")
        fp32 = pd.to_numeric(df[f"{DCGM_PREFIX}fp32_active"], errors="coerce")
        fp64 = pd.to_numeric(df[f"{DCGM_PREFIX}fp64_active"], errors="coerce")
        tens = pd.to_numeric(df[f"{DCGM_PREFIX}tensor_active"], errors="coerce")
        dram = pd.to_numeric(df[f"{DCGM_PREFIX}dram_active"], errors="coerce")

        missing_any = fp16.isna() | fp32.isna() | fp64.isna() | tens.isna() | dram.isna()
        bad_gt1 = (fp16 > 1.0) | (fp32 > 1.0) | (fp64 > 1.0) | (tens > 1.0) | (dram > 1.0)
        all_fp_zero = (fp16.eq(0.0)) & (fp32.eq(0.0)) & (fp64.eq(0.0)) & (tens.eq(0.0))

        keep = ~(missing_any | bad_gt1 | all_fp_zero)
        if not keep.any():
            continue

        df = df.loc[keep].copy()
        fp16 = fp16.loc[keep].astype(float)
        fp32 = fp32.loc[keep].astype(float)
        fp64 = fp64.loc[keep].astype(float)
        tens = tens.loc[keep].astype(float)
        dram = dram.loc[keep].astype(float)

        host = df["hostname"].astype(str)
        is80 = host.isin(nodes80).to_numpy()
        capacity_gib = np.where(is80, 80.0, 40.0)

        mask = np.isfinite(capacity_gib) & np.isclose(capacity_gib, capacity, atol=0.6)

        if req_jobids is not None:
            mask &= df["JobID"].astype(str).isin(req_jobids).to_numpy()

        if not mask.any():
            continue

        peak_hbm_bps = np.where(is80, HBM_80, HBM_40)

        flops_fp16 = fp16.to_numpy() * PEAK_FLOPS_FP16
        flops_fp32 = fp32.to_numpy() * PEAK_FLOPS_FP32
        flops_fp64 = fp64.to_numpy() * PEAK_FLOPS_FP64_VECTOR
        flops_tens_fp64 = tens.to_numpy() * PEAK_FLOPS_FP64_TENSOR
        hbm_bps = dram.to_numpy() * peak_hbm_bps

        mask_hbm = hbm_bps > 0

        vals_ai = np.full(len(df), np.nan, dtype=float)
        if ai_col == "AI_fp32":
            np.divide(flops_fp32, hbm_bps, out=vals_ai, where=mask_hbm)
        elif ai_col == "AI_fp64_vector":
            np.divide(flops_fp64, hbm_bps, out=vals_ai, where=mask_hbm)
        elif ai_col == "AI_tensor_fp64":
            np.divide(flops_tens_fp64, hbm_bps, out=vals_ai, where=mask_hbm)
        elif ai_col == "pseudo64_AI_tensor_fp64":
            pseudo64_flops = 0.25*flops_fp16 + 0.5*flops_fp32 + (flops_fp64 + flops_tens_fp64)
            np.divide(pseudo64_flops, hbm_bps, out=vals_ai, where=mask_hbm)
        else:
            raise KeyError(f"Unknown ai_col: {ai_col}")

        vals = vals_ai[mask]
        vals = vals[np.isfinite(vals)]
        if vals.size == 0:
            continue

        vals = vals[(vals > 0) & (vals >= AI_MIN) & (vals <= AI_MAX)]
        if vals.size == 0:
            continue

        h, _ = np.histogram(vals, bins=edges)
        counts += h
        total_inrange += int(vals.size)

    return counts, total_inrange


def plot_ai_overlay(parquet_path: str, col_key: str, title: str, ridge_key: str, req80_ids: set[str], nodes80: set[str]):
    setup_local()
    colors = get_colors()

    edges = make_log_edges(AI_MIN, AI_MAX, BARS_PER_DECADE)
    lefts  = edges[:-1]
    widths = edges[1:] - edges[:-1]

    ai_col = AI_COLS[col_key]

    c40, n40 = hist_from_parquet_edges(parquet_path, ai_col, 40.0, edges, nodes80=nodes80, req_jobids=None)
    if n40 == 0:
        print(f"[info] no 40GB in-range samples for {col_key}")
        return
    pct40 = (c40 / n40) * 100.0

    c80, n80 = hist_from_parquet_edges(parquet_path, ai_col, 80.0, edges, nodes80=nodes80, req_jobids=req80_ids)
    pct80 = (c80 / n80) * 100.0 if n80 > 0 else np.zeros_like(pct40)

    fig, ax1 = plt.subplots()

    ax1.bar(lefts, pct40, width=widths, align="edge", color=colors[2], alpha=0.50, edgecolor="black", label="40 GB")
    if n80 > 0:
        ax1.bar(lefts, pct80, width=widths, align="edge", color=colors[0], alpha=0.50, edgecolor="black", label="80 GB (requested 80)")

    ax1.set_xscale("log")
    ax1.set_xlim(AI_MIN, AI_MAX)
    ax1.set_xticks(X_TICKS)
    ax1.set_xlabel("Arithmetic Intensity (FLOP/Byte)", fontsize=18)
    ax1.set_ylabel("Fraction of samples (%)", fontsize=18)
    ax1.set_yticks(Y_TICKS_LEFT)
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    ax1.axvline(RIDGE_40[ridge_key], linestyle="--", color=colors[2])
    ax1.axvline(RIDGE_80[ridge_key], linestyle="--", color=colors[0])

    ax2 = ax1.twinx()
    x = np.logspace(-3, 3, 512)
    perf40 = np.minimum(x * RIDGE_40["HBM"], PEAKS_40[ridge_key]) / 1e12 
    perf80 = np.minimum(x * RIDGE_80["HBM"], PEAKS_80[ridge_key]) / 1e12

    ax2.plot(x, perf40, color=colors[2], linewidth=2)
    ax2.plot(x, perf80, color=colors[0], linewidth=2, alpha=0.85)
    ax2.set_xscale("log")
    ax2.set_yscale("log")
    ax2.set_ylabel("Performance roofline (TF/s)", fontsize=18)
    ax2.set_yticks(Y_TICKS_RIGHT)
    ax2.set_xticks(X_TICKS)

    ax1.tick_params(axis="x", labelsize=14)
    ax1.tick_params(axis="y", labelsize=14)
    ax2.tick_params(axis="y", labelsize=14)

    ax1.legend(loc="upper left", frameon=True, fontsize=13, framealpha=0.5)
    plt.title(title, fontsize=19, pad=12)
    plt.tight_layout()
    plt.show()

    assert np.isclose(pct40.sum(), 100.0, atol=1e-3)
    if n80 > 0:
        assert np.isclose(pct80.sum(), 100.0, atol=1e-3)


if __name__ == "__main__":
    nodes_80gb_set = load_nodes_from_file(NODES_80GB_FILE)

    req80_ids = load_req80_ids(SACCT_CSV)

    if "fp64vec" in PLOTS:
        plot_ai_overlay(INPUT_SET1, "fp64vec", "Distribution of FP64 AI", ridge_key="FP64", req80_ids=req80_ids, nodes80=nodes_80gb_set)
    if "tensor" in PLOTS:
        plot_ai_overlay(INPUT_SET1, "tensor", "Distribution of Tensor (FP64-tensor) AI", ridge_key="TNSR", req80_ids=req80_ids, nodes80=nodes_80gb_set)
