In [1]:
import pandas as pd
from sklearn.utils import resample
from sklearn.model_selection import train_test_split

# --- Stratified 3-way split function ---
def stratified_split(df, test_frac=0.1, val_frac=0.1, seed=42):
    y = df[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)
    df_train, df_temp = train_test_split(df, test_size=test_frac+val_frac, stratify=y, random_state=seed)
    y_temp = df_temp[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)
    df_val, df_test = train_test_split(df_temp,
                                       test_size=val_frac/(test_frac+val_frac),
                                       stratify=y_temp,
                                       random_state=seed)
    return df_train, df_val, df_test



In [2]:
# --- Config ---
INPUT_CSV = "/home/arsalan/wsu-grid/hm_jetscapeml_source/data/" \
"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/" \
"file_labels_aggregated_g500.csv"
TRAIN_CSV = "/home/arsalan/wsu-grid/hm_jetscapeml_source/data/" \
"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/" \
"file_labels_aggregated_g500_train.csv"
VAL_CSV = "/home/arsalan/wsu-grid/hm_jetscapeml_source/data/" \
"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/" \
"file_labels_aggregated_g500_val.csv"
TEST_CSV = "/home/arsalan/wsu-grid/hm_jetscapeml_source/data/" \
"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/" \
"file_labels_aggregated_g500_test.csv"
TARGET_TOTAL = 1008
SEED = 42

# --- Load full data ---
df = pd.read_csv(INPUT_CSV)

# --- Create balanced 1000-sample dataset (equal per label combo) ---
label_cols = ['energy_loss', 'alpha', 'q0']
df['label_combo'] = df[label_cols].astype(str).agg('_'.join, axis=1)
n_classes = df['label_combo'].nunique()
samples_per_class = TARGET_TOTAL // n_classes




In [3]:
print(samples_per_class)
print(n_classes)
TARGET_TOTAL % n_classes 

84
12


0

In [4]:
if TARGET_TOTAL % n_classes != 0:
    raise ValueError(f"{TARGET_TOTAL} is not divisible by {n_classes} unique label combinations.")

In [5]:
df_balanced = (
    df.groupby('label_combo', group_keys=False)
    .apply(lambda g: resample(g, replace=True, n_samples=samples_per_class, random_state=SEED))
    .reset_index(drop=True)
)

  df.groupby('label_combo', group_keys=False)


In [6]:
df_balanced = df_balanced.drop(columns=['label_combo'])

In [7]:
# --- Stratified split into train/val/test ---
df_train, df_val, df_test = stratified_split(df_balanced, test_frac=0.1, val_frac=0.1, seed=SEED)

# --- Save to disk ---
df_train.to_csv(TRAIN_CSV, index=False)
df_val.to_csv(VAL_CSV, index=False)
df_test.to_csv(VAL_CSV, index=False)

print(f"✅ Train/val split completed and saved:")
print(f"  → Train: {len(df_train)} rows → {TRAIN_CSV}")
print(f"  → Val:   {len(df_val)} rows   → {VAL_CSV}")
print(f"  → Test:  {len(df_test)} rows   → {TEST_CSV}")

✅ Train/val split completed and saved:
  → Train: 806 rows → /home/arsalan/wsu-grid/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_g500_train.csv
  → Val:   101 rows   → /home/arsalan/wsu-grid/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_g500_val.csv
  → Test:  101 rows   → /home/arsalan/wsu-grid/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_g500_test.csv
