In [9]:
%load_ext autoreload
%autoreload 2 
import numpy as np
import pandas as pd
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

from utils import get_model_values_df, DATASET_INFO, sample_data
from baselines import labeled_data_alone

pd.options.mode.chained_assignment = None

# --- Configuration ---
seed = 42
nl = 20
nu = 1000
sim = True
np.random.seed(seed)

dataset = 'CivilComments'
subgroups = None          
n_classes = DATASET_INFO[dataset]['n_classes']
model_names = DATASET_INFO[dataset]['model_names'] 
subgroup_list = subgroups.split(',') if subgroups is not None else None

# --- Load DataFrame with model predictions ---
dataset_df = get_model_values_df(dataset, model_names)
train_dataset_df = dataset_df.sample(frac=0.5, random_state=seed)
test_dataset_df = dataset_df[~dataset_df.index.isin(train_dataset_df.index)]

# --- Sample labeled and unlabeled data from train set ---
sampled_data, sampled_labels, sampled_true_labels, sampled_data_df = sample_data(
    train_dataset_df, nl, nu, model_names, seed, n_classes
)

labeled_idxs = np.where(sampled_labels != -1)[0]
unlabeled_idxs = np.where(sampled_labels == -1)[0]

# Assign group memberships
sampled_groups = [np.array(['global'] * (nl + nu))]
test_groups = [np.array(['global'] * len(test_dataset_df))]

if subgroup_list:
    for subgroup in subgroup_list:
        sampled_groups.append(np.array(sampled_data[subgroup].values))
        test_groups.append(np.array(test_dataset_df[subgroup].values))

# Sanity checks
assert len(labeled_idxs) == nl, "Number of labeled examples does not match N_LABELED."
assert len(unlabeled_idxs) == nu, "Number of unlabeled examples does not match N_UNLABELED."

print(f"Loaded {len(train_dataset_df)} train samples and {len(test_dataset_df)} test samples")
print(f"Sampled {nl} labeled and {nu} unlabeled examples")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Loaded 66891 train samples and 66891 test samples
Sampled 20 labeled and 1000 unlabeled examples


In [None]:
estimation_labeled_data = (
    sampled_data[labeled_idxs],
    [d[labeled_idxs] for d in sampled_groups],
    sampled_labels[labeled_idxs]
)

estimation_unlabeled_data = (
    sampled_data[unlabeled_idxs],
    [d[unlabeled_idxs] for d in sampled_groups],
    sampled_labels[unlabeled_idxs]
)


test_data = (
    test_dataset_df[model_names].values,
    test_groups,
    test_dataset_df['label'].values
)

### Labeled data alone

In [11]:
# Compute metrics using labeled baseline
method_config = {'dataset': dataset, 'subgroups': subgroups}
metrics_df = labeled_data_alone(estimation_labeled_data, method_config)

# Add model names to the results
metrics_df['model'] = metrics_df['model_idx'].apply(lambda x: model_names[x])

# Display key metrics
print("\n=== Estimated Performance Metrics ===")
print(metrics_df[['model', 'acc', 'ece', 'auc', 'auprc']].to_string(index=False))


=== Estimated Performance Metrics ===
        model  acc      ece      auc    auprc
    alg_CORAL 0.75 0.148426 0.781250 0.429861
      alg_ERM 0.90 0.102048 1.000000 1.000000
      alg_IRM 0.80 0.211246 0.937500 0.679167
alg_ERM_seed1 0.90 0.102431 0.968750 0.916667
alg_ERM_seed2 0.90 0.106434 1.000000 1.000000
alg_IRM_seed1 0.85 0.158549 0.921875 0.645833
alg_IRM_seed2 0.80 0.219561 0.921875 0.645833


### SSME

In [8]:
from model import SSME_KDE 

# Compute metrics using labeled baseline
method_config = {'dataset': dataset, 'subgroups': subgroups}
ssme_metrics_df = SSME_KDE(estimation_labeled_data, estimation_unlabeled_data, method_config)

# Add model names to the results
ssme_metrics_df['model'] = ssme_metrics_df['model_idx'].apply(lambda x: model_names[x])

# Display key metrics
print("\n=== Estimated Performance Metrics ===")
print(ssme_metrics_df[['model', 'ece', 'auc', 'auprc', 'acc']].to_string(index=False))

X_preds shape:  (1020, 7)


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:09<00:00,  2.03it/s]

Estimated priors:  [0.866, 0.134]

=== Estimated Performance Metrics ===
        model      ece      auc    auprc      acc
    alg_CORAL 0.075502 0.929264 0.587280 0.863725
      alg_ERM 0.040767 0.969579 0.794922 0.941176
      alg_IRM 0.069773 0.960916 0.745116 0.916667
alg_ERM_seed1 0.047401 0.970522 0.783655 0.940196
alg_ERM_seed2 0.024967 0.973233 0.817928 0.945098
alg_IRM_seed1 0.073047 0.942631 0.723688 0.915686
alg_IRM_seed2 0.080486 0.956676 0.728252 0.908824





### Ground truth

In [None]:
gt_metrics_df = labeled_data_alone(test_data, method_config)
gt_metrics_df['model'] = gt_metrics_df['model_idx'].apply(lambda x: model_names[x])

print("\n=== Ground Truth Performance Metrics ===")
print(gt_metrics_df[['model', 'ece', 'auc', 'auprc', 'acc']].to_string(index=False))


=== Ground Truth Performance Metrics ===
        model      ece      auc    auprc      acc
    alg_CORAL 0.059625 0.863927 0.399186 0.882914
      alg_ERM 0.060140 0.939604 0.725134 0.923263
      alg_IRM 0.105904 0.916018 0.660259 0.881957
alg_ERM_seed1 0.061192 0.938858 0.726129 0.923039
alg_ERM_seed2 0.049065 0.942639 0.737038 0.922516
alg_IRM_seed1 0.096066 0.911152 0.667215 0.892258
alg_IRM_seed2 0.101711 0.920637 0.656295 0.887399
