In [None]:
import os
import duckdb
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq

SET1 = "ldms_set1.parquet"
SET2 = "ldms_set2.parquet"

SACCT_CSV = "slurm.csv"
SACCT_TZ  = "US/Pacific"  

MEANS_PATH   = "job_means_all.parquet"
METRICS_PATH = "job_metrics.parquet"
ENERGY_PATH  = "energy_per_job.parquet"

OUT_DIR = "corr"
os.makedirs(OUT_DIR, exist_ok=True)

def list_counters(fname: str) -> list[str]:
    con = duckdb.connect()
    try:
        cols = con.execute(f"DESCRIBE SELECT * FROM parquet_scan('{fname}')").fetchall()
    finally:
        con.close()
    return [c[0] for c in cols if c[0].startswith("nersc_ldms_dcgm_")]

def avg_expr(cols: list[str], overlap: set[str], suffix_overlaps: bool = False) -> str:
    parts = []
    for c in cols:
        alias = f"{c}_s2" if (suffix_overlaps and c in overlap) else c
        parts.append(f"avg({c}) AS {alias}")
    return ",\n       ".join(parts)

def build_job_means_all():
    counters1 = list_counters(SET1)
    counters2 = list_counters(SET2)
    overlap = set(counters1).intersection(counters2)

    con = duckdb.connect()
    con.execute("PRAGMA memory_limit='15GB';")
    con.execute(f"PRAGMA threads={os.cpu_count() or 1};")

    con.execute(f"""
    COPY (
      WITH
      s1_gpu AS (
        SELECT
          JobID::VARCHAR AS JobID,
          hostname,
          gpu_id,
          {avg_expr(counters1, overlap, suffix_overlaps=False)}
        FROM parquet_scan('{SET1}')
        GROUP BY JobID, hostname, gpu_id
      ),
      s1_job AS (
        SELECT
          JobID,
          {avg_expr(counters1, overlap, suffix_overlaps=False)}
        FROM s1_gpu
        GROUP BY JobID
      ),
      s2_gpu AS (
        SELECT
          JobID::VARCHAR AS JobID,
          hostname,
          gpu_id,
          {avg_expr(counters2, overlap, suffix_overlaps=False)}
        FROM parquet_scan('{SET2}')
        GROUP BY JobID, hostname, gpu_id
      ),
      s2_job AS (
        SELECT
          JobID,
          {avg_expr(counters2, overlap, suffix_overlaps=True)}
        FROM s2_gpu
        GROUP BY JobID
      )
      SELECT *
      FROM s1_job AS s1
      LEFT JOIN s2_job AS s2 USING (JobID)
    )
    TO '{MEANS_PATH}' (FORMAT PARQUET, COMPRESSION 'SNAPPY');
    """)

    con.close()

def build_job_metrics():
    con = duckdb.connect()
    con.execute("PRAGMA memory_limit='15GB';")
    con.execute(f"PRAGMA threads={os.cpu_count() or 1};")

    ng = con.execute(f"""
      SELECT
        JobID::VARCHAR AS JobID,
        COUNT(*)::BIGINT AS ngpus
      FROM (
        SELECT DISTINCT JobID, hostname, gpu_id
        FROM parquet_scan('{SET1}')
      )
      GROUP BY JobID
    """).df()
    con.close()

    sacct = pd.read_csv(SACCT_CSV, usecols=[c for c in ["JobID","Start","End","npus"] if c in pd.read_csv(SACCT_CSV, nrows=0).columns], low_memory=False)
    sacct["JobID"] = sacct["JobID"].astype(str)

    start = pd.to_datetime(sacct["Start"], errors="coerce").dt.tz_localize(SACCT_TZ).dt.tz_convert("UTC")
    end   = pd.to_datetime(sacct["End"],   errors="coerce").dt.tz_localize(SACCT_TZ).dt.tz_convert("UTC")
    sacct["start_time"] = start
    sacct["end_time"]   = end
    sacct = sacct.dropna(subset=["start_time","end_time"])

    sacct = (sacct.groupby("JobID", as_index=False)
                  .agg(start_time=("start_time","min"),
                       end_time=("end_time","max"),
                       **({"npus": ("npus","max")} if "npus" in sacct.columns else {})))

    sacct["duration_hours"] = (sacct["end_time"] - sacct["start_time"]).dt.total_seconds() / 3600.0
    sacct["duration_hours"] = sacct["duration_hours"].clip(lower=0)

    out = sacct.merge(ng, on="JobID", how="inner")
    cols_out = ["JobID"]
    if "npus" in out.columns:
        cols_out.append("npus")
    cols_out += ["ngpus", "duration_hours"]

    out[cols_out].to_parquet(METRICS_PATH, index=False)

def build_energy_per_job(set2_path: str, out_path: str, batch_size: int = 1_000_000):
    COL_JOB = "JobID"
    COL_HOST = "hostname"
    COL_GPU = "gpu_id"
    COL_TS = "ts_ns"
    COL_ENE = "nersc_ldms_dcgm_total_energy_consumption"

    dataset = ds.dataset(set2_path, format="parquet")
    scanner = dataset.scanner(
        columns=[COL_JOB, COL_HOST, COL_GPU, COL_TS, COL_ENE],
        batch_size=batch_size,
        use_threads=True,
    )

    state = {}
    for batch in scanner.to_batches():
        pdf = pa.Table.from_batches([batch]).to_pandas(types_mapper=pd.ArrowDtype)

        pdf[COL_JOB]  = pdf[COL_JOB].astype("string")
        pdf[COL_HOST] = pdf[COL_HOST].astype("string")
        pdf[COL_GPU]  = pd.to_numeric(pdf[COL_GPU], errors="coerce").astype("Int64")
        pdf[COL_TS]   = pd.to_numeric(pdf[COL_TS],  errors="coerce").astype("Int64")
        pdf[COL_ENE]  = pd.to_numeric(pdf[COL_ENE], errors="coerce").astype("Float64")

        pdf = pdf.dropna(subset=[COL_TS, COL_ENE, COL_GPU, COL_JOB, COL_HOST])
        if pdf.empty:
            continue

        grp = pdf.groupby([COL_JOB, COL_HOST, COL_GPU], sort=False, observed=True)
        counts = grp.size().reset_index(name="n").set_index([COL_JOB, COL_HOST, COL_GPU])

        first_rows = (
            pdf.drop_duplicates(subset=[COL_JOB, COL_HOST, COL_GPU], keep="first")
               .set_index([COL_JOB, COL_HOST, COL_GPU])[[COL_ENE]]
               .rename(columns={COL_ENE: "first_e"})
        )
        last_rows = (
            pdf.drop_duplicates(subset=[COL_JOB, COL_HOST, COL_GPU], keep="last")
               .set_index([COL_JOB, COL_HOST, COL_GPU])[[COL_ENE]]
               .rename(columns={COL_ENE: "last_e"})
        )

        batch_keys = set((str(i[0]), str(i[1]), int(i[2])) for i in counts.index)

        for (j, h, g) in batch_keys:
            fe = float(first_rows.loc[(j, h, g), "first_e"])
            le = float(last_rows.loc[(j, h, g), "last_e"])
            key = (j, h, g)
            if key not in state:
                state[key] = {"first_e": fe, "last_e": le}
            else:
                state[key]["last_e"] = le

    rows_gpu = []
    for (j, h, g), st in state.items():
        rows_gpu.append((j, st["last_e"] - st["first_e"]))

    df_gpu = pd.DataFrame(rows_gpu, columns=["JobID", "energy_delta"])
    df_job = df_gpu.groupby("JobID", as_index=False).agg(total_energy_joules=("energy_delta", "sum"))
    df_job.to_parquet(out_path, index=False)

def correlation_heatmap():
    PREFIX = "nersc_ldms_dcgm_"
    THR    = 0.5

    pf_means   = pq.ParquetFile(MEANS_PATH)
    pf_metrics = pq.ParquetFile(METRICS_PATH)

    means_cols_all   = pf_means.schema_arrow.names
    metrics_cols_all = pf_metrics.schema_arrow.names

    counter_cols = [c for c in means_cols_all if c.startswith(PREFIX)]
    need_means   = ["JobID"] + counter_cols

    means = pd.read_parquet(MEANS_PATH, columns=need_means).copy()
    means["JobID"] = means["JobID"].astype(str)

    rename_map = {c: c[len(PREFIX):] for c in counter_cols}
    means = means.rename(columns=rename_map)

    drop_means = {
        "fb_free", "gr_engine_active", "sm_occupancy", "tensor_hmma_active",
        "memory_clock", "memory_temp", "power_usage", "mem_copy_utilization"
    }
    means = means.drop(columns=[c for c in drop_means if c in means.columns], errors="ignore")

    pretty_map = {
        "dram_active":        "DRAM_ACTV",
        "fb_used":            "HBM_USED",
        "fp16_active":        "FP16_ACTV",
        "fp32_active":        "FP32_ACTV",
        "fp64_active":        "FP64_ACTV",
        "sm_active":          "SM_ACTV",
        "gpu_utilization":    "GPU_UTIL",
        "nvlink_rx_bytes":    "NVLINK_RX",
        "nvlink_tx_bytes":    "NVLINK_TX",
        "pcie_rx_bytes":      "PCIE_RX",
        "pcie_tx_bytes":      "PCIE_TX",
        "gpu_temp":           "GPU_TEMP",
        "tensor_active":      "TNSR_ACTV",
    }
    means = means.rename(columns=pretty_map)

    extras_candidates = ["JobID", "npus", "ngpus", "duration_hours"]
    extras_present    = [c for c in extras_candidates if c in metrics_cols_all]
    metrics = pd.read_parquet(METRICS_PATH, columns=extras_present).copy()
    if "JobID" in metrics.columns:
        metrics["JobID"] = metrics["JobID"].astype(str)

    energy = pd.read_parquet(ENERGY_PATH, columns=["JobID", "total_energy_joules"]).copy()
    energy["JobID"] = energy["JobID"].astype(str)

    df = means.merge(metrics, on="JobID", how="left").merge(energy, on="JobID", how="left")

    for c in ["total_energy_joules", "ngpus", "duration_hours"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    den = df.get("ngpus", np.nan) * df.get("duration_hours", np.nan) * 3600.0
    df["power_per_gpu"] = df["total_energy_joules"] / den
    df = df.rename(columns={"power_per_gpu": "GPU_POWER"})

    drop_from_num = ["JobID", "total_energy_consumption", "total_energy_joules", "ngpus", "duration_hours"]
    num = df.drop(columns=[c for c in drop_from_num if c in df.columns], errors="ignore")

    for c in num.columns:
        num[c] = pd.to_numeric(num[c], errors="coerce")

    std = num.std(numeric_only=True, ddof=0)
    keep_cols = std[std > 0].index.tolist()
    num = num[keep_cols]

    corr = num.corr(method="spearman")

    sel = corr.abs() >= THR
    ann = corr.where(sel).round(2).astype(object)
    ann[~sel] = ""

    order = corr["GPU_UTIL"].abs().sort_values(ascending=False).index.tolist()
    corr = corr.loc[order, order]
    ann  = ann.loc[order,  order]

    n_cols = corr.shape[1]
    fig_w = max(14, min(0.8 * n_cols, 60))
    fig_h = max(12, min(0.8 * n_cols, 60))

    plt.figure(figsize=(fig_w, fig_h))
    ax = sns.heatmap(
        corr,
        annot=ann, fmt="",
        cmap="coolwarm",
        vmin=-1, vmax=1,
        linewidths=2,
        linecolor="black",
        cbar=True,
        annot_kws={"size": 26, "color": "white"},
    )

    plt.xticks(rotation=90, fontsize=28)
    plt.yticks(rotation=0, fontsize=28)
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=28)

    plt.title("Correlation Heatmap of Mean Counter Values", fontsize=30, pad=12)
    plt.tight_layout()
    plt.show()
    return corr

build_job_means_all()
build_job_metrics()
build_energy_per_job(SET2, ENERGY_PATH)
corr = correlation_heatmap()
corr
