In [None]:
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, LogFormatterMathtext, NullLocator

import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq

from setup_plot import setup_local, get_colors, get_markers


# SET1 contains: JobID, hostname, gpu_id, ts_ns, and the *_active counters
SET1_PARQUET = "ldms_set1.parquet"

# SET2 contains:
#   nersc_ldms_dcgm_total_energy_consumption
SET2_PARQUET = "ldms_set2.parquet"

SACCT_CSV = "slurm.csv"
NODES_80GB_FILE = "nodes_80gb.txt"

LABELS_FILE = "job_label_fractions_fp64only.parquet"
JOB_METRICS_FILE = "job_metrics.parquet"
ENERGY_PER_JOB = "energy_per_job.parquet"

BATCH_SET1 = 1_000_000
BATCH_SET2 = 1_000_000

SACCT_TZ = "US/Pacific"

DCGM_PREFIX = "nersc_ldms_dcgm_"
ACTIVE_COLS = [
    f"{DCGM_PREFIX}fp16_active",
    f"{DCGM_PREFIX}fp32_active",
    f"{DCGM_PREFIX}fp64_active",
    f"{DCGM_PREFIX}tensor_active",
    f"{DCGM_PREFIX}dram_active",
]

# A100 peaks (FLOP/s)
PEAK_FLOPS_FP16        = 78.0e12
PEAK_FLOPS_FP32        = 19.5e12
PEAK_FLOPS_FP64_VECTOR = 9.7e12
PEAK_FLOPS_FP64_TENSOR = 19.5e12

# HBM peaks (B/s)
HBM_40 = 1.555e12
HBM_80 = 2.039e12


def load_nodes_from_file(path: str) -> set[str]:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"80GB node list not found: {p.resolve()}")
    nodes = set()
    for ln in p.read_text(encoding="utf-8", errors="replace").splitlines():
        ln = ln.strip()
        if not ln or ln.startswith("#"):
            continue
        for tok in ln.replace(",", " ").split():
            tok = tok.strip()
            if tok and not tok.startswith("#"):
                nodes.add(tok)
    return nodes


def read_sacct_start_end(path: str) -> pd.DataFrame:
    """
    Read sacct file and return per-JobID.
    """
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"sacct file not found: {p.resolve()}")

    with p.open("r", encoding="utf-8", errors="replace") as f:
        header = f.readline()

    if "|" in header:
        names = header.rstrip("\n").split("|")
        ncols = len(names)
        rows = []
        with p.open("r", encoding="utf-8", errors="replace") as f:
            next(f)
            for ln in f:
                parts = ln.rstrip("\n").split("|", ncols - 1)
                if len(parts) < ncols:
                    parts += [""] * (ncols - len(parts))
                rows.append(parts)
        df = pd.DataFrame(rows, columns=names)
    else:
        df = pd.read_csv(p, low_memory=False)

    for c in ["JobID", "Start", "End"]:
        if c not in df.columns:
            raise KeyError(f"sacct file missing required column: {c}")

    df["JobID"] = df["JobID"].astype(str)

    start = pd.to_datetime(df["Start"], errors="coerce")
    end   = pd.to_datetime(df["End"],   errors="coerce")

    start = start.dt.tz_localize(SACCT_TZ, ambiguous="NaT", nonexistent="NaT").dt.tz_convert("UTC")
    end   = end.dt.tz_localize(SACCT_TZ, ambiguous="NaT", nonexistent="NaT").dt.tz_convert("UTC")

    out = pd.DataFrame({"JobID": df["JobID"], "start_time": start, "end_time": end})
    out = out.dropna(subset=["start_time", "end_time"])

    out = (out.groupby("JobID", as_index=False)
              .agg(start_time=("start_time", "min"),
                   end_time=("end_time", "max")))
    return out


def build_fp64only_job_fractions_and_ngpus(
    set1_path: str,
    nodes80: set[str],
    out_labels: str,
    batch_size: int = BATCH_SET1,
):
    need = ["JobID", "hostname", "gpu_id", "ts_ns"] + ACTIVE_COLS
    pf = pq.ParquetFile(set1_path)
    present = set(pf.schema_arrow.names)
    missing = [c for c in need if c not in present]
    if missing:
        raise KeyError(f"Missing required columns in {set1_path}: {missing}")

    nodes80 = set(nodes80)

    state = {} 
    job_time = defaultdict(lambda: {"Compute-intensive": 0, "Memory-intensive": 0, "Other": 0})
    job_total = defaultdict(int)
    job_samples = defaultdict(int)

    rows_in = 0
    dropped_gt1 = 0
    dropped_all_fp_zero = 0

    for batch in pf.iter_batches(columns=need, batch_size=batch_size):
        df = batch.to_pandas()
        if df.empty:
            continue
        rows_in += len(df)

        job = df["JobID"].astype(str)
        host = df["hostname"].astype(str)
        gpu = pd.to_numeric(df["gpu_id"], errors="coerce")
        ts  = pd.to_numeric(df["ts_ns"], errors="coerce")

        fp16 = pd.to_numeric(df[f"{DCGM_PREFIX}fp16_active"], errors="coerce")
        fp32 = pd.to_numeric(df[f"{DCGM_PREFIX}fp32_active"], errors="coerce")
        fp64 = pd.to_numeric(df[f"{DCGM_PREFIX}fp64_active"], errors="coerce")
        tens = pd.to_numeric(df[f"{DCGM_PREFIX}tensor_active"], errors="coerce")
        dram = pd.to_numeric(df[f"{DCGM_PREFIX}dram_active"], errors="coerce")

        missing_any = fp16.isna() | fp32.isna() | fp64.isna() | tens.isna() | dram.isna()

        bad_gt1 = (fp16 > 1.0) | (fp32 > 1.0) | (fp64 > 1.0) | (tens > 1.0) | (dram > 1.0)
        dropped_gt1 += int(bad_gt1.sum())

        all_fp_zero = (fp16.eq(0.0)) & (fp32.eq(0.0)) & (fp64.eq(0.0)) & (tens.eq(0.0))
        dropped_all_fp_zero += int(all_fp_zero.sum())

        keep = ~(bad_gt1 | all_fp_zero)
        if not keep.any():
            continue

        job = job.loc[keep].to_numpy()
        host = host.loc[keep].to_numpy()
        gpu = gpu.loc[keep].to_numpy()
        ts  = ts.loc[keep].to_numpy()

        fp16 = fp16.loc[keep].to_numpy(dtype=float, copy=False)
        fp32 = fp32.loc[keep].to_numpy(dtype=float, copy=False)
        fp64 = fp64.loc[keep].to_numpy(dtype=float, copy=False)
        tens = tens.loc[keep].to_numpy(dtype=float, copy=False)
        dram = dram.loc[keep].to_numpy(dtype=float, copy=False)
        missing_any = missing_any.loc[keep].to_numpy(dtype=bool, copy=False)

        is80 = np.fromiter((h in nodes80 for h in host), dtype=bool, count=len(host))
        peak_hbm = np.where(is80, HBM_80, HBM_40)
        achieved_hbm = dram * peak_hbm

        ridge_fp64 = PEAK_FLOPS_FP64_VECTOR / peak_hbm
        achieved_fp64 = fp64 * PEAK_FLOPS_FP64_VECTOR

        ai_fp64 = np.full(len(job), np.nan, dtype=float)
        np.divide(achieved_fp64, achieved_hbm, out=ai_fp64, where=(achieved_hbm > 0))

        any_compute = (fp16 > 0) | (fp32 > 0) | (fp64 > 0) | (tens > 0)
        eligible = ~missing_any

        labels = np.full(len(job), "Other", dtype=object)

        mask_inf = eligible & any_compute & (achieved_hbm <= 0)
        labels[mask_inf] = "Compute-intensive"

        mask_pos = eligible & (achieved_hbm > 0) & np.isfinite(ai_fp64) & np.isfinite(ridge_fp64)
        mem = mask_pos & (ai_fp64 < ridge_fp64)
        comp = mask_pos & ~mem
        labels[mem] = "Memory-intensive"
        labels[comp] = "Compute-intensive"

        for j, h, g, t, lab in zip(job, host, gpu, ts, labels):
            job_samples[j] += 1

            if not np.isfinite(g) or not np.isfinite(t):
                continue
            key = (j, h, int(g))
            t = int(t)

            if key in state:
                last_ts, last_lab = state[key]
                dt = t - last_ts
                if dt > 0:
                    job_time[j][last_lab] += dt
                    job_total[j] += dt

            state[key] = (t, lab)

    job_ngpus = defaultdict(int)
    for (j, _h, _g) in state.keys():
        job_ngpus[j] += 1

    rows = []
    for j, tot in job_total.items():
        if tot <= 0:
            continue
        c = job_time[j]["Compute-intensive"]
        m = job_time[j]["Memory-intensive"]
        o = job_time[j]["Other"]
        rows.append({
            "JobID": j,
            "time_seconds": float(tot) / 1e9,
            "frac_time_compute_fp64only": c / tot,
            "frac_time_memory_fp64only":  m / tot,
            "frac_time_other_fp64only":   o / tot,
            "sample_count": int(job_samples.get(j, 0)),
        })

    out = pd.DataFrame(rows).sort_values("JobID")
    out.to_parquet(out_labels, index=False)

    return job_ngpus


def build_job_metrics_gpu_hours(job_ngpus: dict, sacct_csv: str, out_path: str):
    sacct = read_sacct_start_end(sacct_csv)

    sacct["duration_hours"] = (sacct["end_time"] - sacct["start_time"]).dt.total_seconds() / 3600.0
    sacct["duration_hours"] = sacct["duration_hours"].clip(lower=0)

    ng = pd.Series(job_ngpus, name="ngpus").astype(int)
    ng.index = ng.index.astype(str)
    ng = ng.reset_index().rename(columns={"index": "JobID"})

    df = sacct.merge(ng, on="JobID", how="inner")
    df["gpu_hours"] = df["duration_hours"] * df["ngpus"]

    out = df[["JobID", "gpu_hours"]].copy()
    out.to_parquet(out_path, index=False)


def build_energy_per_job(set2_path: str, out_path: str):
    COL_JOB = "JobID"
    COL_HOST = "hostname"
    COL_GPU = "gpu_id"
    COL_TS = "ts_ns"
    COL_ENE = "nersc_ldms_dcgm_total_energy_consumption"

    dataset = ds.dataset(set2_path, format="parquet")
    scanner = dataset.scanner(
        columns=[COL_JOB, COL_HOST, COL_GPU, COL_TS, COL_ENE],
        batch_size=BATCH_SET2,
        use_threads=True,
    )

    state = {}
    total_rows = 0

    for batch in scanner.to_batches():
        total_rows += batch.num_rows
        pdf = pa.Table.from_batches([batch]).to_pandas(types_mapper=pd.ArrowDtype)

        pdf[COL_JOB]  = pdf[COL_JOB].astype("string")
        pdf[COL_HOST] = pdf[COL_HOST].astype("string")
        pdf[COL_GPU]  = pd.to_numeric(pdf[COL_GPU], errors="coerce").astype("Int64")
        pdf[COL_TS]   = pd.to_numeric(pdf[COL_TS],  errors="coerce").astype("Int64")
        pdf[COL_ENE]  = pd.to_numeric(pdf[COL_ENE], errors="coerce").astype("Float64")

        pdf = pdf.dropna(subset=[COL_TS, COL_ENE, COL_GPU, COL_JOB, COL_HOST])
        if pdf.empty:
            continue

        grp = pdf.groupby([COL_JOB, COL_HOST, COL_GPU], sort=False, observed=True)
        counts = grp.size().reset_index(name="n").set_index([COL_JOB, COL_HOST, COL_GPU])

        first_rows = (
            pdf.drop_duplicates(subset=[COL_JOB, COL_HOST, COL_GPU], keep="first")
               .set_index([COL_JOB, COL_HOST, COL_GPU])[[COL_ENE]]
               .rename(columns={COL_ENE: "first_e"})
        )
        last_rows = (
            pdf.drop_duplicates(subset=[COL_JOB, COL_HOST, COL_GPU], keep="last")
               .set_index([COL_JOB, COL_HOST, COL_GPU])[[COL_ENE]]
               .rename(columns={COL_ENE: "last_e"})
        )

        batch_keys = set((str(i[0]), str(i[1]), int(i[2])) for i in counts.index)

        for (j, h, g) in batch_keys:
            n_in_batch = int(counts.loc[(j, h, g), "n"])
            fe = float(first_rows.loc[(j, h, g), "first_e"])
            le = float(last_rows.loc[(j, h, g), "last_e"])

            key = (j, h, g)
            if key not in state:
                state[key] = {"first_e": fe, "last_e": le, "n": n_in_batch}
            else:
                state[key]["last_e"] = le
                state[key]["n"] += n_in_batch

    rows_gpu = []
    for (j, h, g), st in state.items():
        rows_gpu.append((j, st["last_e"] - st["first_e"]))

    df_gpu = pd.DataFrame(rows_gpu, columns=["JobID", "energy_delta"])
    df_job = df_gpu.groupby("JobID", as_index=False).agg(total_energy_joules=("energy_delta", "sum"))

    df_job.to_parquet(out_path, index=False)


def plot_energy_vs_gpu_hours():
    LABELS_FILE_LOCAL = LABELS_FILE
    ENERGY_PER_JOB_LOCAL = ENERGY_PER_JOB
    JOB_METRICS_FILE_LOCAL = JOB_METRICS_FILE

    lab = pd.read_parquet(
        LABELS_FILE_LOCAL,
        columns=["JobID","frac_time_compute_fp64only","frac_time_memory_fp64only"]
    ).astype({"JobID": str})

    comp_mask = lab["frac_time_compute_fp64only"] > lab["frac_time_memory_fp64only"]
    mem_mask  = lab["frac_time_memory_fp64only"]  > lab["frac_time_compute_fp64only"]

    labels = pd.DataFrame({
        "JobID": lab.loc[comp_mask | mem_mask, "JobID"],
        "class": np.where(comp_mask[comp_mask | mem_mask], "Compute-bound", "Memory-bound")
    })

    epj = (pd.read_parquet(ENERGY_PER_JOB_LOCAL, columns=["JobID","total_energy_joules"])
             .astype({"JobID": str})
             .rename(columns={"total_energy_joules":"total_energy_mj"}))
    epj["total_energy_j"] = pd.to_numeric(epj["total_energy_mj"], errors="coerce") / 1000.0
    epj = epj[["JobID","total_energy_j"]]

    jm = (pd.read_parquet(JOB_METRICS_FILE_LOCAL, columns=["JobID","gpu_hours"])
            .astype({"JobID": str}))
    jm["gpu_hours"] = pd.to_numeric(jm["gpu_hours"], errors="coerce")

    df = (labels.merge(epj, on="JobID", how="inner")
                .merge(jm,  on="JobID", how="inner"))


    df = df[(df["total_energy_j"] > 0) & (df["gpu_hours"] >= 0.5)].copy()

    major_ticks = np.array([1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5], dtype=float)
    edges_list = []
    for i in range(len(major_ticks)-1):
        a, b = major_ticks[i], major_ticks[i+1]
        sub_edges = a * (b/a)**(np.linspace(0, 1, 3))
        if i > 0:
            sub_edges = sub_edges[1:]
        edges_list.append(sub_edges)
    edges = np.concatenate(edges_list) 
    widths = edges[1:] - edges[:-1]

    bin_idx = np.digitize(df["gpu_hours"].to_numpy(dtype=float), edges, right=False) - 1
    bin_idx = np.where(bin_idx >= len(edges)-1, len(edges)-2, bin_idx)
    bin_idx = np.where(bin_idx < 0, 0, bin_idx)
    df["gh_bin"] = bin_idx

    def means_per_bin(sub_df: pd.DataFrame, nbins: int):
        grp = sub_df.groupby("gh_bin")["total_energy_j"]
        sum_by_bin   = grp.sum()
        count_by_bin = grp.count()
        sums   = np.array([sum_by_bin.get(i, 0.0) for i in range(nbins)], dtype=float)
        counts = np.array([count_by_bin.get(i, 0) for i in range(nbins)], dtype=float)
        means  = np.divide(sums, counts, out=np.zeros_like(sums), where=counts>0)
        return means, counts

    nbins = len(edges) - 1
    means_comp, counts_comp = means_per_bin(df[df["class"] == "Compute-bound"], nbins)
    means_mem,  counts_mem  = means_per_bin(df[df["class"] == "Memory-bound"],  nbins)

    setup_local()
    colors  = get_colors()
    markers = get_markers()

    fig, ax = plt.subplots()

    def values_per_bin(sub_df: pd.DataFrame, nbins: int):
        return [sub_df.loc[sub_df["gh_bin"] == i, "total_energy_j"].to_numpy() for i in range(nbins)]

    vals_comp = values_per_bin(df[df["class"] == "Compute-bound"], nbins)
    vals_mem  = values_per_bin(df[df["class"] == "Memory-bound"],  nbins)

    centers = np.sqrt(edges[:-1] * edges[1:])
    offset  = widths * 0.2
    pos_comp = centers - offset
    pos_mem  = centers + offset

    box_widths_comp = (widths * 0.40) * (pos_comp / centers)
    box_widths_mem  = (widths * 0.40) * (pos_mem  / centers)

    bp_comp = ax.boxplot(
        vals_comp, positions=pos_comp, widths=box_widths_comp,
        manage_ticks=False, patch_artist=True, showfliers=False,
        whis=[5, 95], showmeans=False,
        meanprops=dict(marker=markers[1], markerfacecolor=colors[5], markeredgecolor=colors[5]),
    )
    bp_mem = ax.boxplot(
        vals_mem, positions=pos_mem, widths=box_widths_mem,
        manage_ticks=False, patch_artist=True, showfliers=False,
        whis=[5, 95], showmeans=False,
        meanprops=dict(marker=markers[1], markerfacecolor=colors[5], markeredgecolor=colors[5]),
    )

    for elem in ["boxes","whiskers","caps","medians"]:
        for artist in bp_comp[elem]:
            artist.set_color("black")
            if elem == "medians":
                artist.set_color("white")
    for patch in bp_comp["boxes"]:
        patch.set_facecolor(colors[4]); patch.set_alpha(1); patch.set_edgecolor("black")

    for elem in ["boxes","whiskers","caps","medians"]:
        for artist in bp_mem[elem]:
            artist.set_color("black")
            if elem == "medians": 
                artist.set_color("white")
    for patch in bp_mem["boxes"]:
        patch.set_facecolor(colors[5]); patch.set_alpha(1); patch.set_edgecolor("black")

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlim(1e-1, 1e5)
    ax.set_ylim(1e5, 1e11)  

    major_ticks_list = [1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5]
    ax.xaxis.set_major_locator(FixedLocator(major_ticks_list))
    ax.xaxis.set_major_formatter(LogFormatterMathtext(base=10))
    ax.xaxis.set_minor_locator(NullLocator())

    ax.set_yticks([1e5, 1e7, 1e9, 1e11])

    ax.set_xlabel("GPU-hours (number of GPUs x duration)", fontsize=17)
    ax.set_ylabel("Total energy per job (J)", fontsize=17)
    ax.set_title("Distribution of total energy consumption", pad=14, fontsize=18)
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)

    ax.legend(
        frameon=True, fontsize=14, loc="upper left",
        handles=[
            plt.Rectangle((0,0),1,1,facecolor=colors[4], alpha=1, edgecolor="black", label="Compute-bound"),
            plt.Rectangle((0,0),1,1,facecolor=colors[5], alpha=1, edgecolor="black", label="Memory-bound"),
        ],
    )

    fig.set_size_inches(7, 3, forward=True)
    fig.subplots_adjust(left=0.12, right=0.98, bottom=0.19, top=0.86)
    plt.show()



if __name__ == "__main__":
    nodes_80gb_set = load_nodes_from_file(NODES_80GB_FILE)

    job_ngpus = build_fp64only_job_fractions_and_ngpus(SET1_PARQUET, nodes_80gb_set, LABELS_FILE)

    build_job_metrics_gpu_hours(job_ngpus, SACCT_CSV, JOB_METRICS_FILE)

    build_energy_per_job(SET2_PARQUET, ENERGY_PER_JOB)

    plot_energy_vs_gpu_hours()
