In [1]:
import pandas as pd
import numpy as np
from typing import Iterable, Union, Sequence, Optional

def summarize_best_superparam(
    task_array_path: str,
    res_path: str,
    setting_identifiers: Iterable[str],
    sup_hyperparam: Union[str, Sequence[str]],
    *,
    sup_name: str = "super_param",          # name of the coalesced column when a list is supplied
    dropna_in_groups: bool = False,
    enforce_mutual_exclusive: bool = False,
) -> pd.DataFrame:
    """
    Average subgroup_utility within `setting_identifiers` (keeping a unified super-parameter),
    then for each combo of the *other* identifiers pick the super-parameter value that maximizes
    the mean. The final table uses `sup_name` as the super-parameter column if a list is given;
    otherwise it keeps the single column name you passed.

    Returns columns:
      <setting_identifiers (excluding any sup columns)>,
      <super-parameter column>,            # named `sup_name` if list, else the single column name
      max_subgroup_utility_mean
    """
    task_df = pd.read_csv(task_array_path, sep="\t", index_col=None)
    res_df  = pd.read_csv(res_path,       sep="\t", index_col=None)
    merged  = res_df.merge(task_df, how="left", on="task_id")

    if "subgroup_utility" not in merged.columns:
        raise KeyError("Column 'subgroup_utility' not found in merged data.")

    # Build unified super-parameter column
    if isinstance(sup_hyperparam, (list, tuple)):
        sup_cols = list(dict.fromkeys(sup_hyperparam))
        missing = [c for c in sup_cols if c not in merged.columns]
        if missing:
            raise KeyError(f"Missing sup_hyperparam columns: {missing}")
        if enforce_mutual_exclusive:
            nn = merged[sup_cols].notna().sum(axis=1)
            if (nn > 1).any():
                raise ValueError("More than one of the provided sup_hyperparam columns is non-NaN on some rows.")
        # Coalesce left->right to a single working column named sup_name
        merged[sup_name] = merged[sup_cols].bfill(axis=1).iloc[:, 0]
        work_sup_col = sup_name
        # Remove sup cols from identifiers to avoid duplicate grouping keys
        setting_identifiers = [k for k in setting_identifiers if k not in sup_cols]
        output_value_col = sup_name
    else:
        if sup_hyperparam not in merged.columns:
            raise KeyError(f"'{sup_hyperparam}' not found in merged columns.")
        work_sup_col = str(sup_hyperparam)
        setting_identifiers = [k for k in setting_identifiers if k != work_sup_col]
        output_value_col = work_sup_col

    setting_identifiers = list(dict.fromkeys(setting_identifiers))
    group_keys = setting_identifiers + [work_sup_col]

    # Mean aggregation
    agg = (
        merged
        .groupby(group_keys, dropna=dropna_in_groups)["subgroup_utility"]
        .mean()
        .reset_index()
        .rename(columns={"subgroup_utility": "subgroup_utility_mean"})
    )

    # Argmax over the super-parameter for each combo of the other identifiers
    other_keys = [k for k in group_keys if k != work_sup_col]
    if len(other_keys) == 0:
        idx = agg["subgroup_utility_mean"].idxmax()
        best = agg.loc[[idx]].copy().reset_index(drop=True)
    else:
        idx = agg.groupby(other_keys, dropna=dropna_in_groups)["subgroup_utility_mean"].idxmax()
        best = agg.loc[idx].reset_index(drop=True)

    # Final names and ordering
    best = best.rename(columns={"subgroup_utility_mean": "max_subgroup_utility_mean"})
    best = best[setting_identifiers + [work_sup_col, "max_subgroup_utility_mean"]]
    best = best.rename(columns={work_sup_col: output_value_col})

    return best


def compare_chiseling(
    df: pd.DataFrame,
    setting_col: str,                          # the column X that identifies settings (varies by table)
    *,
    strategy_col: str = "strategy",
    metric_col: str = "max_subgroup_utility_mean",
    p_col: str = "p",
    chiseling_label: str = "Chiseling",
    dropna_setting: bool = False,              # if True, keep NaN settings as their own group
) -> pd.DataFrame:
    """
    For each level of `setting_col`, compute:
      - ratio = (Chiseling's `metric_col`) / (max `metric_col` among non-Chiseling strategies)
      - chiseling_p = Chiseling's `p_col`
      - best_other_strategy = argmax non-Chiseling strategy on `metric_col`
      - best_other_p = that strategy's `p_col`

    Assumes `df` has columns: [strategy_col, setting_col, p_col, metric_col].
    If there are multiple rows per (setting, strategy), the row with the *largest* `metric_col`
    is selected for that (setting, strategy).
    """
    required = {strategy_col, setting_col, p_col, metric_col}
    missing = required - set(df.columns)
    if missing:
        raise KeyError(f"DataFrame is missing required columns: {sorted(missing)}")

    # Keep only the columns we need
    work = df[[strategy_col, setting_col, p_col, metric_col]].copy()

    # Reduce to one row per (setting, strategy): keep the row with the largest metric
    # (handles accidental duplicates gracefully)
    idx_best_per_strat = (
        work.groupby([setting_col, strategy_col], dropna=dropna_setting)[metric_col]
        .idxmax()
        .dropna()
        .astype(int)
    )
    work = work.loc[idx_best_per_strat].reset_index(drop=True)

    # Split into chiseling vs others
    chis = work[work[strategy_col] == chiseling_label].copy()
    oth  = work[work[strategy_col] != chiseling_label].copy()

    # Best non-Chiseling per setting
    if len(oth) == 0:
        # No competitors: create empty result with expected columns
        result = chis[[setting_col]].copy()
        result["chiseling_metric"] = chis[metric_col]
        result["chiseling_p"] = chis[p_col]
        result["best_other_strategy"] = pd.NA
        result["best_other_metric"] = pd.NA
        result["best_other_p"] = pd.NA
        result["chiseling_to_best_ratio"] = pd.NA
        return result.sort_values(setting_col, na_position="first").reset_index(drop=True)

    idx_best_other = (
        oth.groupby(setting_col, dropna=dropna_setting)[metric_col]
        .idxmax()
        .dropna()
        .astype(int)
    )
    best_other = oth.loc[idx_best_other].rename(
        columns={
            strategy_col: "best_other_strategy",
            metric_col: "best_other_metric",
            p_col: "best_other_p",
        }
    )[[setting_col, "best_other_strategy", "best_other_metric", "best_other_p"]]

    # Take the (unique) Chiseling row per setting (largest metric if duplicates)
    if len(chis):
        idx_best_chis = (
            chis.groupby(setting_col, dropna=dropna_setting)[metric_col]
            .idxmax()
            .dropna()
            .astype(int)
        )
        chis = chis.loc[idx_best_chis]
    # Prepare Chiseling columns
    chis = chis.rename(columns={metric_col: "chiseling_metric", p_col: "chiseling_p"})[
        [setting_col, "chiseling_metric", "chiseling_p"]
    ]

    # Merge and compute ratio
    out = chis.merge(best_other, on=setting_col, how="outer", validate="one_to_one")

    # Safe ratio: NaN if denominator is 0 or missing
    denom = out["best_other_metric"]
    num   = out["chiseling_metric"]
    out["chiseling_to_best_ratio"] = np.where(
        (denom.notna()) & (denom != 0),
        num / denom,
        np.nan
    )

    # Final ordering
    out = out[[setting_col,
               "chiseling_metric", "chiseling_p",
               "best_other_strategy", "best_other_metric", "best_other_p",
               "chiseling_to_best_ratio"]]

    return out.sort_values(setting_col, na_position="first").reset_index(drop=True)


## Process

In [2]:
res_data = [["naive_chiseling_vs_naive_data_splitting", ["n", "strategy", "n_burn_in", "train_ratio"], "n"],
            ["naive_chiseling_vs_naive_data_splitting_kang_schafer", ["n", "strategy", "n_burn_in", "train_ratio"], "n"],
            ["binary_regression", ["n", "theta", "tau", "subgroup_size", "strategy", "margin_width", "n_burn_in", "train_ratio", "bonf_strategy"], "subgroup_size"],
            ["heterogeneous_linear_rct", ["n", "theta", "tau", "subgroup_size", "strategy", "margin_width", "n_burn_in", "train_ratio", "bonf_strategy"], "subgroup_size"],
            ["kang_schafer", ["n", "tau", "subgroup_size", "strategy", "margin_width", "n_burn_in", "train_ratio", "bonf_strategy"], "subgroup_size"],
            ["bart_analysis", ["test_thresh", "strategy", "n_burn_in", "train_ratio", "bonf_strategy", "margin_width"], "test_thresh"],
           ]

sup_hyperparam = ["n_burn_in", "train_ratio"]
sup_name = "p"

task_array_path_format = "../../task_arrays/{}.tasks.tsv"
res_path_format = "../../../output/{}/{}.combined.tsv"

In [3]:
all_res_df = {}
for rd in res_data:
    res_df = summarize_best_superparam(task_array_path_format.format(rd[0]),
                                       res_path_format.format(rd[0], rd[0]),
                                       setting_identifiers=rd[1],
                                       sup_hyperparam=sup_hyperparam,
                                       sup_name=sup_name)
    # Reduce res_df. We only need chiseling, data split, simul data split, t-test
    # And we only need margin_width = 1
    if "strategy" in res_df.columns:
        res_df = res_df[res_df.strategy.isin(["Chiseling",
                                              "DataSplittingStrategy",
                                              "SimulDataSplittingStrategy",
                                              "TTestStrategy"])]
    if "margin_width" in res_df.columns:
        res_df = res_df[(res_df.strategy != "Chiseling") | (res_df.margin_width == 1)]
    # Process results
    res_df = compare_chiseling(res_df, rd[2])
    all_res_df[rd[0]] = res_df

## Inspect results

In [4]:
all_res_df["naive_chiseling_vs_naive_data_splitting"]

Unnamed: 0,n,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,500,0.061304,0.35,DataSplittingStrategy,0.050072,0.55,1.224334
1,1500,0.162342,0.4,DataSplittingStrategy,0.149946,0.5,1.08267
2,4000,0.189333,0.7,DataSplittingStrategy,0.187578,0.7,1.009355


In [5]:
all_res_df["naive_chiseling_vs_naive_data_splitting_kang_schafer"]

Unnamed: 0,n,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,500,2.014604,0.1,DataSplittingStrategy,1.639226,0.25,1.228997
1,1500,5.422758,0.05,DataSplittingStrategy,3.933755,0.25,1.378519
2,4000,8.530198,0.05,DataSplittingStrategy,7.027009,0.4,1.213916


In [6]:
all_res_df["binary_regression"]

Unnamed: 0,subgroup_size,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,0.01,0.000193,0.1,SimulDataSplittingStrategy,9.5e-05,0.4,2.031648
1,0.05,0.000849,0.1,SimulDataSplittingStrategy,0.000271,0.7,3.129337
2,0.1,0.003253,0.1,SimulDataSplittingStrategy,0.001301,0.6,2.500047
3,0.25,0.011189,0.5,SimulDataSplittingStrategy,0.007383,0.7,1.515555
4,0.5,0.023211,0.6,SimulDataSplittingStrategy,0.016522,0.7,1.404835


In [7]:
all_res_df["heterogeneous_linear_rct"]

Unnamed: 0,subgroup_size,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,0.01,0.002063,0.1,SimulDataSplittingStrategy,0.001584,0.4,1.301929
1,0.05,0.006774,0.1,SimulDataSplittingStrategy,0.004673,0.5,1.449608
2,0.1,0.011577,0.2,SimulDataSplittingStrategy,0.007277,0.6,1.590994
3,0.25,0.049362,0.3,SimulDataSplittingStrategy,0.035459,0.6,1.392102
4,0.5,0.141658,0.6,SimulDataSplittingStrategy,0.11794,0.6,1.201105


In [8]:
all_res_df["kang_schafer"]

Unnamed: 0,subgroup_size,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,0.25,2.438922,0.1,SimulDataSplittingStrategy,1.746012,0.3,1.396853
1,0.5,8.798228,0.1,SimulDataSplittingStrategy,6.51138,0.4,1.351208
2,0.75,18.905165,0.1,TTestStrategy,16.783246,,1.126431


In [9]:
all_res_df["bart_analysis"]

Unnamed: 0,test_thresh,chiseling_metric,chiseling_p,best_other_strategy,best_other_metric,best_other_p,chiseling_to_best_ratio
0,0.3,0.013916,0.1,DataSplittingStrategy,0.010542,0.3,1.32003
1,0.35,0.008426,0.1,SimulDataSplittingStrategy,0.004474,0.4,1.883165
2,0.4,0.004683,0.1,SimulDataSplittingStrategy,0.00194,0.3,2.414034
3,0.45,0.000311,0.1,SimulDataSplittingStrategy,9.1e-05,0.1,3.418043
