In [7]:
# %% [markdown]
# ### Render publication-style summary table to PNG (like process.py) with model sub-columns

# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from pathlib import Path

# ------------------------------------------
# 0) Load inputs (same as earlier)
# ------------------------------------------
output_name = "all_"
deepseek = pd.read_csv("data/experiments/deepseekR1/results/summary_metrics.csv")
realsafe = pd.read_csv("data/experiments/realsafeR1/results/summary_metrics.csv")
gpt4o    = pd.read_csv("data/experiments/gpt-4o/results/summary_metrics.csv")
claude35 = pd.read_csv("data/experiments/claude-3.5-sonnet/results/summary_metrics.csv")
gemini25 = pd.read_csv("data/experiments/gemini-2.5-flash/results/summary_metrics.csv")

MODEL_FRAMES = {
    "DeepSeek R1": deepseek,
    "RealSafe R1": realsafe,
    "GPT-4o": gpt4o,
    "Claude 3.5 Sonnet": claude35,
    "Gemini 2.5 Flash": gemini25,
}

# Metrics you showed in process.py
REFUSAL_MEANS = ["or_bench_eval_mean", "or_bench_hard_mean", "or_bench_toxic_mean"]
REFUSAL_STDS  = ["or_bench_eval_std",  "or_bench_hard_std",  "or_bench_toxic_std"]
REFUSAL_LABELS = ["or-bench (↓)", "or-bench-hard (↓)", "or-bench-toxic (↑)"]

CAP_MEAN = "capability_mean"
CAP_STD  = "capability_std"
CAP_TOP  = "Capability (↑)"
CAP_MID  = "Accuracy"

# ------------------------------------------
# 1) Tidy long df: (spec_id, model, metrics)
# ------------------------------------------
rows = []
for model_name, df in MODEL_FRAMES.items():
    if df.empty: 
        continue
    usecols = ["spec_id"] + REFUSAL_MEANS + REFUSAL_STDS + [CAP_MEAN, CAP_STD]
    df2 = df[[c for c in usecols if c in df.columns]].copy()
    df2["model"] = model_name
    rows.append(df2)
long_df = pd.concat(rows, ignore_index=True)

spec_ids  = list(sorted(long_df["spec_id"].unique()))
models    = list(MODEL_FRAMES.keys())

def fmt(m, s):
    if pd.isna(m): return "—"
    mp = 100*m
    sp = 100*(0.0 if pd.isna(s) else s)
    return f"{mp:.1f}% ± {sp:.1f}%"

# ------------------------------------------
# 2) Build a 3-level header description
#    Top row:   ["", "Refusal", "Refusal", "Refusal", "Capability (↑)"]
#    Mid row:   ["System Prompt", metric labels..., "Accuracy"]
#    Low row:   under each metric, per-model names
# ------------------------------------------
top_headers = []
mid_headers = []
low_headers = []

# First column is System Prompt (spans 3 header rows)
top_headers.append("")
mid_headers.append("System Prompt")
low_headers.append("")

# For each refusal metric: add one block per metric, with a sub-column for each model
for mid in REFUSAL_LABELS:
    top_headers.extend(["Refusal"] * len(models))
    mid_headers.extend([mid] * len(models))
    low_headers.extend(models)

# Capability: one block too (Accuracy), per model
top_headers.extend([CAP_TOP] * len(models))
mid_headers.extend([CAP_MID] * len(models))
low_headers.extend(models)

# ------------------------------------------
# 3) Assemble the table cell matrix (strings)
# ------------------------------------------
data_rows = []
for sid in spec_ids:
    g = long_df[long_df["spec_id"] == sid]
    row = [sid]  # first column
    # refusal metrics per model
    for mcol, scol in zip(REFUSAL_MEANS, REFUSAL_STDS):
        for model in models:
            m = g.loc[g["model"] == model, mcol]
            s = g.loc[g["model"] == model, scol]
            m = m.iloc[0] if len(m) else np.nan
            s = s.iloc[0] if len(s) else np.nan
            row.append(fmt(m, s))
    # capability per model
    for model in models:
        m = g.loc[g["model"] == model, CAP_MEAN]
        s = g.loc[g["model"] == model, CAP_STD]
        m = m.iloc[0] if len(m) else np.nan
        s = s.iloc[0] if len(s) else np.nan
        row.append(fmt(m, s))

    data_rows.append(row)

Heatmap

In [None]:
# %% [markdown]
# ### Publication-style heatmap PNG (same layout as the table: 3 header rows, per-model subcolumns)

# %%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.cm import ScalarMappable
from matplotlib import colors as mpl_colors
from pathlib import Path

def get_text_color(rgba_color):
    """Determine text color based on background luminance."""
    r, g, b = rgba_color[:3]
    luminance = 0.299 * r + 0.587 * g + 0.114 * b
    return "white" if luminance < 0.4 else "black"

# ------------------------------------------
# 1) Build heatmap matrix in the same column order as the table headers
#    Columns = [System Prompt | (for each REFUSAL metric → per-model) | (Capability → per-model)]
#    We skip the first "System Prompt" column in the matrix; it's only a label column.
# ------------------------------------------
heat_cols = []  # list of (group, metric_label, model, mean_col)
for mean_col, metric_label in zip(REFUSAL_MEANS, REFUSAL_LABELS):
    for model in models:
        heat_cols.append(("Refusal", metric_label, model, mean_col))
for model in models:
    heat_cols.append((CAP_TOP, CAP_MID, model, CAP_MEAN))

# Build matrix (rows = spec_ids, cols = heat_cols)
mat = np.full((len(spec_ids), len(heat_cols)), np.nan, dtype=float)
for i, sid in enumerate(spec_ids):
    g = long_df[long_df["spec_id"] == sid]
    for j, (_, _, model, mean_col) in enumerate(heat_cols):
        val = g.loc[g["model"] == model, mean_col]
        mat[i, j] = val.iloc[0] if len(val) else np.nan

# Colormap range and mappable
vmin, vmax = 0.0, 1.0
cmap = plt.get_cmap("RdYlBu_r")
norm = mpl_colors.Normalize(vmin=vmin, vmax=vmax)
mappable = ScalarMappable(norm=norm, cmap=cmap)

# %% [markdown]
# ### Publication-style heatmap with mean ± sd in each cell

# %%
# Build matrix of stds in the same order as mat
mat_std = np.full((len(spec_ids), len(heat_cols)), np.nan, dtype=float)
for i, sid in enumerate(spec_ids):
    g = long_df[long_df["spec_id"] == sid]
    for j, (_, _, model, mean_col) in enumerate(heat_cols):
        std_col = mean_col.replace("_mean", "_std")
        val = g.loc[g["model"] == model, std_col]
        mat_std[i, j] = val.iloc[0] if len(val) else np.nan

def render_heatmap_png_with_sd(
    top_headers, mid_headers, low_headers,
    spec_ids, mat, mat_std, out_path: str,
    col_width_first=3.0, col_width_model=1.9, row_height=1.0, font_size=10,
    cmap=cmap, norm=norm, mappable=mappable,
    draw_colorbar=True
):
    n_header_rows = 3
    n_rows = len(spec_ids)
    n_cols_total = len(top_headers)
    n_cols_matrix = n_cols_total - 1

    col_widths = [col_width_first] + [col_width_model] * n_cols_matrix
    total_width = sum(col_widths)
    total_rows = n_header_rows + n_rows

    fig_w = 1.1 * total_width + (0.8 if draw_colorbar else 0.0)
    fig_h = 0.6 * total_rows + 0.8
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    ax.set_xlim(0, total_width)
    ax.set_ylim(0, total_rows)
    ax.invert_yaxis()

    def cell(x, y, w, h, text=None, bold=False, facecolor=None):
        rect = Rectangle((x, y), w, h, fill=facecolor is not None, facecolor=facecolor,
                         linewidth=1, edgecolor="black")
        ax.add_patch(rect)
        if text:
            ax.text(
                x + w/2, y + h/2, text,
                ha="center", va="center",
                fontsize=font_size,
                fontweight=("bold" if bold else "normal"),
                color=("white" if (facecolor is not None and np.mean(facecolor[:3]) < 0.5) else "black"),
            )

    # --- Headers (same as before) ---
    def span_iter(headers):
        start = 1
        while start < len(headers):
            label = headers[start]
            end = start
            while end < len(headers) and headers[end] == label:
                end += 1
            yield label, start, end
            start = end

    y = 0; x = 0
    cell(x, y, col_widths[0], row_height, "", bold=True)
    for label, start, end in span_iter(top_headers):
        x0 = sum(col_widths[:start])
        w = sum(col_widths[start:end])
        cell(x0, y, w, row_height, label, bold=True)

    y = 1; x = 0
    for j in range(n_cols_total):
        cell(x, y, col_widths[j], row_height, mid_headers[j], bold=True)
        x += col_widths[j]

    y = 2; x = 0
    for j in range(n_cols_total):
        label = "" if j == 0 else low_headers[j]
        cell(x, y, col_widths[j], row_height, label, bold=True)
        x += col_widths[j]

    # --- Data rows with mean ± sd ---
    for i, sid in enumerate(spec_ids, start=3):
        cell(0, i, col_widths[0], row_height, sid)
        x = col_widths[0]
        for j in range(n_cols_matrix):
            m = mat[i-3, j]; s = mat_std[i-3, j]
            if np.isnan(m):
                cell(x, i, col_widths[j+1], row_height, "")
            else:
                fc = mappable.to_rgba(m)
                txt = f"{100*m:.1f}%\n±{100*s:.1f}%" if not np.isnan(s) else f"{100*m:.1f}%"
                rect = Rectangle((x, i), col_widths[j+1], row_height,
                            linewidth=1, edgecolor="black", facecolor=fc)
                ax.add_patch(rect)
                
                # Use improved text color selection
                text_color = get_text_color(fc)
                ax.text(x+col_widths[j+1]/2, i+row_height/2, txt,
                        ha="center", va="center", fontsize=font_size,
                        color=text_color, weight='bold')  # Added bold for better visibility
            x += col_widths[j+1]

    if draw_colorbar:
        cax = fig.add_axes([0.92, 0.25, 0.02, 0.5])
        cb = plt.colorbar(mappable, cax=cax)
        cb.set_label("Percent")
        cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
        cb.set_ticklabels([f"{t*100:.0f}" for t in [0, .25, .5, .75, 1]])

    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

# Render
render_heatmap_png_with_sd(
    top_headers, mid_headers, low_headers,
    spec_ids, mat, mat_std,
    out_path=f"results/summary_heatmap_{output_name}models_sd.png",
)
print(f"Saved to results/summary_heatmap_{output_name}models_sd.png")


Saved to results/summary_heatmap_all_models_sd.png


# Old

In [None]:
# ------------------------------------------
# 4) Draw like process.py (rectangles + text)
# ------------------------------------------
def render_table_png(
    top_headers, mid_headers, low_headers, data_rows, out_path: str,
    col_width_first=3.0, col_width_model=1.9, row_height=1.0, font_size=11
):
    n_header_rows = 3
    n_data_rows = len(data_rows)
    n_cols = len(top_headers)
    assert len(mid_headers) == n_cols and len(low_headers) == n_cols

    # Column widths: first column wide; every model cell narrower
    col_widths = [col_width_first] + [col_width_model]*(n_cols-1)
    total_width = sum(col_widths)
    total_rows = n_header_rows + n_data_rows

    fig_w = 1.1*total_width
    fig_h = 0.6*total_rows + 0.8
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    ax.set_xlim(0, total_width)
    ax.set_ylim(0, total_rows)
    ax.invert_yaxis()

    def cell(x, y, w, h, text, bold=False):
        ax.add_patch(Rectangle((x, y), w, h, fill=False, linewidth=1, edgecolor="black"))
        if text:
            ax.text(
                x + w/2, y + h/2, text, ha="center", va="center",
                fontsize=font_size, fontweight=("bold" if bold else "normal"), wrap=True
            )

    # --- Header rows ---
    # Row 0 (top): merged blocks (Refusal, Capability)
    y = 0
    x = 0
    # first cell is blank spanning only itself
    cell(x, y, col_widths[0], row_height, "", bold=True)
    x += col_widths[0]

    # compute spans for each contiguous block in top_headers (excluding first col)
    def span_iter(headers):
        start = 1
        while start < len(headers):
            label = headers[start]
            end = start
            while end < len(headers) and headers[end] == label:
                end += 1
            yield label, start, end
            start = end

    for label, start, end in span_iter(top_headers):
        width = sum(col_widths[start:end])
        cell(sum(col_widths[:start]), y, width, row_height, label, bold=True)

    # Row 1 (middle): metric labels
    y = 1
    x = 0
    for j in range(n_cols):
        w = col_widths[j]
        label = mid_headers[j]
        cell(x, y, w, row_height, label, bold=True)
        x += w

    # Row 2 (bottom header): model names (blank for the first column)
    y = 2
    x = 0
    for j in range(n_cols):
        w = col_widths[j]
        label = low_headers[j] if j != 0 else ""
        cell(x, y, w, row_height, label, bold=True)
        x += w

    # --- Data rows ---
    for i, row_vals in enumerate(data_rows, start=3):
        y = i
        x = 0
        for j, val in enumerate(row_vals):
            w = col_widths[j]
            cell(x, y, w, row_height, val, bold=False)
            x += w

    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

# Render and save
render_table_png(
    top_headers, mid_headers, low_headers, data_rows,
    out_path="results/summary_table_cs_models.png",
)
print("Saved to results/summary_table_cs_models.png")


In [None]:

# ------------------------------------------
# 2) Render heatmap "like the table": rectangles + fills + headers + annotations
# ------------------------------------------
def render_heatmap_png(
    top_headers, mid_headers, low_headers,  # include the "System Prompt" position for convenience
    spec_ids, mat, out_path: str,
    col_width_first=3.0, col_width_model=1.9, row_height=1.0, font_size=11,
    cmap=cmap, norm=norm, mappable=mappable,
    draw_colorbar=True
):
    n_header_rows = 3
    n_rows = len(spec_ids)
    # The headers include the first label column; matrix does NOT.
    n_cols_total = len(top_headers)
    assert n_cols_total == len(mid_headers) == len(low_headers)
    n_cols_matrix = n_cols_total - 1  # excluding the "System Prompt" column

    # Column widths
    col_widths = [col_width_first] + [col_width_model] * n_cols_matrix
    total_width = sum(col_widths)
    total_rows = n_header_rows + n_rows

    # Figure size
    fig_w = 1.1 * total_width + (0.8 if draw_colorbar else 0.0)
    fig_h = 0.6 * total_rows + 0.8
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    ax.set_xlim(0, total_width)
    ax.set_ylim(0, total_rows)
    ax.invert_yaxis()

    def cell(x, y, w, h, text=None, bold=False, facecolor=None, edgecolor="black"):
        rect = Rectangle((x, y), w, h, fill=facecolor is not None, facecolor=facecolor,
                         linewidth=1, edgecolor=edgecolor)
        ax.add_patch(rect)
        if text is not None and text != "":
            ax.text(
                x + w/2, y + h/2, text,
                ha="center", va="center",
                fontsize=font_size,
                fontweight=("bold" if bold else "normal"),
                color=("white" if (isinstance(facecolor, (tuple, list)) and
                                   np.mean(facecolor[:3]) < 0.5) else "black"),
            )

    # Helper to draw merged blocks for the top row
    def span_iter(headers):
        start = 1  # start after "System Prompt"
        while start < len(headers):
            label = headers[start]
            end = start
            while end < len(headers) and headers[end] == label:
                end += 1
            yield label, start, end
            start = end

    # --- Header rows ---
    # Row 0: top merged blocks
    y = 0
    x = 0
    # First col is empty (no merged block)
    cell(x, y, col_widths[0], row_height, "", bold=True)
    for label, start, end in span_iter(top_headers):
        x0 = sum(col_widths[:start])
        w = sum(col_widths[start:end])
        cell(x0, y, w, row_height, label, bold=True)

    # Row 1: metric labels
    y = 1
    x = 0
    for j in range(n_cols_total):
        label = mid_headers[j]
        cell(x, y, col_widths[j], row_height, label, bold=True)
        x += col_widths[j]

    # Row 2: model labels (blank for first column)
    y = 2
    x = 0
    for j in range(n_cols_total):
        label = "" if j == 0 else low_headers[j]
        cell(x, y, col_widths[j], row_height, label, bold=True)
        x += col_widths[j]

    # --- Data rows: colored cells with percent annotations ---
    for i, sid in enumerate(spec_ids, start=3):
        # First column: spec_id label (no fill)
        cell(0, i, col_widths[0], row_height, sid, bold=False)
        # Heat cells
        x = col_widths[0]
        for j in range(n_cols_matrix):
            v = mat[i-3, j]  # align with matrix row
            if np.isnan(v):
                cell(x, i, col_widths[j+1], row_height, "", bold=False, facecolor=None)
            else:
                fc = mappable.to_rgba(v)
                # Text color based on value contrast (simple threshold around midtone)
                txt = f"{100*v:.1f}"
                # Draw filled rect
                rect = Rectangle((x, i), col_widths[j+1], row_height, linewidth=1,
                                 edgecolor="black", facecolor=fc)
                ax.add_patch(rect)
                ax.text(
                    x + col_widths[j+1]/2, i + row_height/2, txt,
                    ha="center", va="center", fontsize=font_size,
                    color=("black" if v < 0.5*(norm.vmax - norm.vmin) + norm.vmin else "white")
                )
            x += col_widths[j+1]

    # Optional colorbar
    if draw_colorbar:
        # Add a small colorbar on the right
        cax = fig.add_axes([0.92, 0.25, 0.02, 0.5])  # [left, bottom, width, height] in figure coords
        cb = plt.colorbar(mappable, cax=cax)
        cb.set_label("Percent")
        cb.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
        cb.set_ticklabels([f"{t*100:.0f}" for t in [0.0, 0.25, 0.5, 0.75, 1.0]])

    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

# Render and save
render_heatmap_png(
    top_headers=top_headers,
    mid_headers=mid_headers,
    low_headers=low_headers,
    spec_ids=spec_ids,
    mat=mat,
    out_path="results/summary_heatmap_cs_models.png",
)
print("Saved to results/summary_heatmap_cs_models.png")


Saved to results/summary_heatmap_cs_models.png
