In [None]:
import os
import duckdb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pyarrow.parquet as pq
from collections import defaultdict
from pathlib import Path

from setup_plot import setup_local


SET1 = "ldms_set1.parquet"
NODES_80GB_FILE = "nodes_80gb.txt"

OUT_MEANS  = "job_means_all.parquet"
OUT_METRICS = "job_metrics.parquet"
OUT_LABELS = "job_label_fractions_fp64only.parquet"

UTIL_COL = "nersc_ldms_dcgm_gpu_utilization"

def build_job_means_all():
    COUNTERS = [
        "nersc_ldms_dcgm_fp64_active",
        "nersc_ldms_dcgm_dram_active",
    ]
    def avg_expr(cols):
        return ",\n       ".join([f"avg({c}) AS {c}" for c in cols])

    con = duckdb.connect()
    con.execute("PRAGMA memory_limit='15GB';")
    con.execute(f"PRAGMA threads={os.cpu_count() or 1};")

    con.execute(f"""
    COPY (
      SELECT
        JobID,
        {avg_expr(COUNTERS)}
      FROM (
        SELECT
          JobID::VARCHAR AS JobID,
          hostname,
          gpu_id,
          {avg_expr(COUNTERS)}
        FROM parquet_scan('{SET1}')
        GROUP BY JobID, hostname, gpu_id
      )
      GROUP BY JobID
    )
    TO '{OUT_MEANS}' (FORMAT PARQUET, COMPRESSION 'SNAPPY');
    """)

    con.close()

def build_job_metrics_mean_util():
    con = duckdb.connect()
    con.execute("PRAGMA memory_limit='15GB';")
    con.execute(f"PRAGMA threads={os.cpu_count() or 1};")

    con.execute(f"""
    COPY (
      SELECT
        JobID,
        AVG(gpu_mean) AS mean_utilization
      FROM (
        SELECT
          JobID::VARCHAR AS JobID,
          hostname,
          gpu_id,
          AVG({UTIL_COL}) AS gpu_mean
        FROM parquet_scan('{SET1}')
        GROUP BY JobID, hostname, gpu_id
      )
      GROUP BY JobID
    )
    TO '{OUT_METRICS}' (FORMAT PARQUET, COMPRESSION 'SNAPPY');
    """)

    con.close()

def load_nodes_80gb(path: str) -> set[str]:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Missing {p.resolve()}")
    nodes = set()
    for ln in p.read_text(encoding="utf-8", errors="replace").splitlines():
        ln = ln.strip()
        if not ln or ln.startswith("#"):
            continue
        for tok in ln.replace(",", " ").split():
            tok = tok.strip()
            if tok and not tok.startswith("#"):
                nodes.add(tok)
    return nodes

def build_fp64only_label_fractions():
    DCGM_PREFIX = "nersc_ldms_dcgm_"
    ACTIVE_COLS = [
        f"{DCGM_PREFIX}fp16_active",
        f"{DCGM_PREFIX}fp32_active",
        f"{DCGM_PREFIX}fp64_active",
        f"{DCGM_PREFIX}tensor_active",
        f"{DCGM_PREFIX}dram_active",
    ]

    PEAK_FLOPS_FP64_VECTOR = 9.7e12
    HBM_40 = 1.555e12
    HBM_80 = 2.039e12
    BATCH = 1_000_000

    nodes80 = load_nodes_80gb(NODES_80GB_FILE)

    need_cols = ["JobID", "hostname", "gpu_id", "ts_ns"] + ACTIVE_COLS
    pf = pq.ParquetFile(SET1)
    present = set(pf.schema_arrow.names)
    missing = [c for c in need_cols if c not in present]
    if missing:
        raise KeyError(f"Missing required columns in {SET1}: {missing}")

    state = {} 
    job_time = defaultdict(lambda: {"Compute-intensive": 0, "Memory-intensive": 0, "Other": 0})
    job_total = defaultdict(int)
    job_samples = defaultdict(int)

    for batch in pf.iter_batches(columns=need_cols, batch_size=BATCH):
        df = batch.to_pandas()
        if df.empty:
            continue

        job  = df["JobID"].astype(str)
        host = df["hostname"].astype(str)
        gpu  = pd.to_numeric(df["gpu_id"], errors="coerce")
        ts   = pd.to_numeric(df["ts_ns"], errors="coerce")

        fp16 = pd.to_numeric(df[f"{DCGM_PREFIX}fp16_active"], errors="coerce")
        fp32 = pd.to_numeric(df[f"{DCGM_PREFIX}fp32_active"], errors="coerce")
        fp64 = pd.to_numeric(df[f"{DCGM_PREFIX}fp64_active"], errors="coerce")
        tens = pd.to_numeric(df[f"{DCGM_PREFIX}tensor_active"], errors="coerce")
        dram = pd.to_numeric(df[f"{DCGM_PREFIX}dram_active"], errors="coerce")

        missing_any = fp16.isna() | fp32.isna() | fp64.isna() | tens.isna() | dram.isna()
        bad_gt1 = (fp16 > 1.0) | (fp32 > 1.0) | (fp64 > 1.0) | (tens > 1.0) | (dram > 1.0)
        all_fp_zero = (fp16.eq(0.0)) & (fp32.eq(0.0)) & (fp64.eq(0.0)) & (tens.eq(0.0))

        keep = ~(bad_gt1 | all_fp_zero)
        if not keep.any():
            continue

        job  = job.loc[keep].to_numpy()
        host = host.loc[keep].to_numpy()
        gpu  = gpu.loc[keep].to_numpy()
        ts   = ts.loc[keep].to_numpy()

        fp16 = fp16.loc[keep].to_numpy(dtype=float, copy=False)
        fp32 = fp32.loc[keep].to_numpy(dtype=float, copy=False)
        fp64 = fp64.loc[keep].to_numpy(dtype=float, copy=False)
        tens = tens.loc[keep].to_numpy(dtype=float, copy=False)
        dram = dram.loc[keep].to_numpy(dtype=float, copy=False)
        missing_any = missing_any.loc[keep].to_numpy(dtype=bool, copy=False)

        is80 = np.fromiter((h in nodes80 for h in host), dtype=bool, count=len(host))
        peak_hbm = np.where(is80, HBM_80, HBM_40)
        achieved_hbm = dram * peak_hbm

        ridge_fp64 = PEAK_FLOPS_FP64_VECTOR / peak_hbm
        achieved_fp64 = fp64 * PEAK_FLOPS_FP64_VECTOR

        ai_fp64 = np.full(len(job), np.nan, dtype=float)
        np.divide(achieved_fp64, achieved_hbm, out=ai_fp64, where=(achieved_hbm > 0))

        any_compute = (fp16 > 0) | (fp32 > 0) | (fp64 > 0) | (tens > 0)
        eligible = ~missing_any

        labels = np.full(len(job), "Other", dtype=object)

        mask_inf = eligible & any_compute & (achieved_hbm <= 0)
        labels[mask_inf] = "Compute-intensive"

        mask_pos = eligible & (achieved_hbm > 0) & np.isfinite(ai_fp64) & np.isfinite(ridge_fp64)
        mem = mask_pos & (ai_fp64 < ridge_fp64)
        comp = mask_pos & ~mem
        labels[mem] = "Memory-intensive"
        labels[comp] = "Compute-intensive"

        for j, h, g, t, lab in zip(job, host, gpu, ts, labels):
            job_samples[j] += 1
            if not np.isfinite(g) or not np.isfinite(t):
                continue
            key = (j, h, int(g))
            t = int(t)

            if key in state:
                last_ts, last_lab = state[key]
                dt = t - last_ts
                if dt > 0:
                    job_time[j][last_lab] += dt
                    job_total[j] += dt

            state[key] = (t, lab)

    rows = []
    for j, tot in job_total.items():
        if tot <= 0:
            continue
        c = job_time[j]["Compute-intensive"]
        m = job_time[j]["Memory-intensive"]
        o = job_time[j]["Other"]
        rows.append({
            "JobID": j,
            "time_seconds": float(tot) / 1e9,
            "frac_time_compute_fp64only": c / tot,
            "frac_time_memory_fp64only":  m / tot,
            "frac_time_other_fp64only":   o / tot,
            "sample_count": int(job_samples.get(j, 0)),
        })

    out = pd.DataFrame(rows).sort_values("JobID")
    out.to_parquet(OUT_LABELS, index=False)

def bin_labels(edges):
    labs = []
    for i in range(len(edges) - 1):
        a, b = int(np.floor(edges[i])), int(np.ceil(edges[i + 1]))
        labs.append(f"[{a},{b})" if i < len(edges) - 2 else f"[{a},{b}]")
    return labs

def build_and_plot(joined, job_ids, title_prefix, save_name):
    df = joined[joined["jobid"].isin(job_ids)].copy()
    x = df["x_pct"].to_numpy()
    y = df["y_pct"].to_numpy()
    z = df["mean_gpu_util"].to_numpy()

    x_edges = np.linspace(0.0, 100.0, 11)
    y_edges = np.linspace(0.0, 100.0, 11)

    hist_sum, _, _ = np.histogram2d(x, y, bins=[x_edges, y_edges], weights=z, density=False)
    counts,   _, _ = np.histogram2d(x, y, bins=[x_edges, y_edges], density=False)
    mean_grid = np.divide(hist_sum, counts, out=np.full_like(hist_sum, np.nan), where=counts > 0)

    x_labs = bin_labels(x_edges)
    y_labs = bin_labels(y_edges)
    mean_df = pd.DataFrame(mean_grid.T, index=y_labs, columns=x_labs)
    cnt_df  = pd.DataFrame(counts.T,    index=y_labs, columns=x_labs)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.width", 220)

    setup_local()

    plt.figure(figsize=(10, 8))
    plt.imshow(
        mean_grid.T, origin="lower", aspect="auto",
        extent=[0.0, 100.0, 0.0, 100.0],
        cmap="Blues", vmin=0, vmax=100
    )
    cbar = plt.colorbar(label="")
    cbar.ax.tick_params(labelsize=36)

    plt.xlabel("Mean (of FP64_ACTV) %", fontsize=38)
    plt.ylabel("Mean (of DRAM_ACTV) %", fontsize=38)
    plt.xticks([0, 20, 40, 60, 80, 100], fontsize=36)
    plt.yticks([0, 20, 40, 60, 80, 100], fontsize=36)
    plt.title(title_prefix, pad=16, fontsize=40)

    orange_cmap = ListedColormap(["#FF8C00"])
    plt.imshow(
        np.where(np.isnan(mean_grid.T), -1, np.nan),
        origin="lower", aspect="auto",
        extent=[0.0, 100.0, 0.0, 100.0],
        cmap=orange_cmap, vmin=-1, vmax=0, alpha=1
    )

    plt.tight_layout()
    plt.show()

def plot_fp64_dram_heatmaps():
    os.makedirs("pssg-plots", exist_ok=True)

    labels = pd.read_parquet(
        OUT_LABELS,
        columns=["JobID", "frac_time_compute_fp64only", "frac_time_memory_fp64only"]
    ).astype({"JobID": str})

    comp_mask = labels["frac_time_compute_fp64only"] > labels["frac_time_memory_fp64only"]
    mem_mask  = labels["frac_time_memory_fp64only"]  > labels["frac_time_compute_fp64only"]
    compute_jobs = labels.loc[comp_mask, "JobID"]
    memory_jobs  = labels.loc[mem_mask,  "JobID"]

    means = pd.read_parquet(
        OUT_MEANS,
        columns=["JobID", "nersc_ldms_dcgm_fp64_active", "nersc_ldms_dcgm_dram_active"],
    ).rename(columns={
        "JobID": "jobid",
        "nersc_ldms_dcgm_fp64_active": "fp64",
        "nersc_ldms_dcgm_dram_active": "dram",
    })
    means["jobid"] = means["jobid"].astype(str)
    means["x_pct"] = 100.0 * means["fp64"]
    means["y_pct"] = 100.0 * means["dram"]

    metrics = pd.read_parquet(
        OUT_METRICS,
        columns=["JobID", "mean_utilization"]
    ).rename(columns={"JobID": "jobid", "mean_utilization": "mean_gpu_util"})
    metrics["jobid"] = metrics["jobid"].astype(str)

    joined = means.merge(metrics, on="jobid", how="inner")

    build_and_plot(joined, compute_jobs, "Compute-bound jobs\nMean (of GPU_UTIL) %", "comp_fp64_dram_gputil.pdf")
    build_and_plot(joined, memory_jobs,  "Memory-bound jobs\nMean (of GPU_UTIL) %", "mem_fp64_dram_gputil.pdf")

build_job_means_all()
build_job_metrics_mean_util()
build_fp64only_label_fractions()
plot_fp64_dram_heatmaps()
