In [4]:
# Auto-reload extension for development (optional)
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import warnings

from utils import get_model_values_df, DATASET_INFO, sample_data
from baselines import labeled_data_alone

# Suppress unnecessary warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
pd.options.mode.chained_assignment = None

# ---------------- Configuration ----------------
seed = 42
n_labeled = 20
n_unlabeled = 1000
dataset = 'CivilComments'
subgroups = None  # Optionally specify comma-separated subgroup names as a string

np.random.seed(seed)

n_classes = DATASET_INFO[dataset]['n_classes']
model_names = DATASET_INFO[dataset]['model_names']
subgroup_list = subgroups.split(',') if subgroups is not None else None

# ---------------- Data Loading ----------------
# Load DataFrame with model predictions
dataset_df = get_model_values_df(dataset, model_names)

# Split data into train and test sets
train_dataset_df = dataset_df.sample(frac=0.5, random_state=seed)
test_dataset_df = dataset_df.drop(train_dataset_df.index)

# ---------------- Data Sampling ----------------
# Sample labeled and unlabeled data from train set
sampled_data, sampled_labels, sampled_true_labels, sampled_data_df = sample_data(
    train_dataset_df, n_labeled, n_unlabeled, model_names, seed, n_classes
)

labeled_idxs = np.where(sampled_labels != -1)[0]
unlabeled_idxs = np.where(sampled_labels == -1)[0]

# ---------------- Group Metadata ----------------
# Start with 'global' group for all examples
sampled_groups = [np.full(n_labeled + n_unlabeled, 'global', dtype=object)]
test_groups = [np.full(len(test_dataset_df), 'global', dtype=object)]

if subgroup_list:
    for subgroup in subgroup_list:
        sampled_groups.append(sampled_data[subgroup].values.astype(str))
        test_groups.append(test_dataset_df[subgroup].values.astype(str))

assert len(labeled_idxs) == n_labeled, "Mismatch in number of labeled examples."
assert len(unlabeled_idxs) == n_unlabeled, "Mismatch in number of unlabeled examples."

print(f"Loaded {len(train_dataset_df)} train samples and {len(test_dataset_df)} test samples")
print(f"Sampled {n_labeled} labeled and {n_unlabeled} unlabeled examples")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Loaded 66891 train samples and 66891 test samples
Sampled 20 labeled and 1000 unlabeled examples


In [5]:
estimation_labeled_data = (
    sampled_data[labeled_idxs],
    [d[labeled_idxs] for d in sampled_groups],
    sampled_labels[labeled_idxs]
)

estimation_unlabeled_data = (
    sampled_data[unlabeled_idxs],
    [d[unlabeled_idxs] for d in sampled_groups],
    sampled_labels[unlabeled_idxs]
)

test_data = (
    test_dataset_df[model_names].values,
    test_groups,
    test_dataset_df['label'].values
)

### Labeled data alone

In [6]:
# Compute metrics using labeled baseline
method_config = {'dataset': dataset, 'subgroups': subgroups}
metrics_df = labeled_data_alone(estimation_labeled_data, method_config)

# Add model names to the results
metrics_df['model'] = metrics_df['model_idx'].apply(lambda x: model_names[x])

# Display key metrics
print("\n=== Estimated Performance Metrics ===")
print(metrics_df[['model', 'acc', 'ece', 'auc', 'auprc']].to_string(index=False))


=== Estimated Performance Metrics ===
        model  acc      ece      auc    auprc
    alg_CORAL 0.75 0.148426 0.781250 0.429861
      alg_ERM 0.90 0.102048 1.000000 1.000000
      alg_IRM 0.80 0.211246 0.937500 0.679167
alg_ERM_seed1 0.90 0.102431 0.968750 0.916667
alg_ERM_seed2 0.90 0.106434 1.000000 1.000000
alg_IRM_seed1 0.85 0.158549 0.921875 0.645833
alg_IRM_seed2 0.80 0.219561 0.921875 0.645833


### SSME

In [10]:
from model import SSME_KDE 

# Compute metrics using labeled baseline
method_config = {'dataset': dataset, 'subgroups': subgroups}
ssme_metrics_df = SSME_KDE(estimation_labeled_data, estimation_unlabeled_data, method_config)

# Add model names to the results
ssme_metrics_df['model'] = ssme_metrics_df['model_idx'].apply(lambda x: model_names[x])

# Display key metrics
print("\n=== Estimated Performance Metrics ===")
print(ssme_metrics_df[['model', 'ece', 'auc', 'auprc', 'acc']].to_string(index=False))

 20%|██        | 4/20 [00:02<00:09,  1.62it/s]

100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


Estimated priors:  [np.float64(0.857), np.float64(0.143)]

=== Estimated Performance Metrics ===
        model      ece      auc    auprc      acc
    alg_CORAL 0.081280 0.930762 0.640757 0.862745
      alg_ERM 0.043703 0.962893 0.802357 0.934314
      alg_IRM 0.063187 0.962031 0.764166 0.923529
alg_ERM_seed1 0.050335 0.959139 0.772711 0.929412
alg_ERM_seed2 0.032324 0.968018 0.808547 0.940196
alg_IRM_seed1 0.060823 0.943967 0.784146 0.926471
alg_IRM_seed2 0.071643 0.962270 0.745536 0.917647


### Ground truth

In [11]:
gt_metrics_df = labeled_data_alone(test_data, method_config)
gt_metrics_df['model'] = gt_metrics_df['model_idx'].apply(lambda x: model_names[x])

print("\n=== Ground Truth Performance Metrics ===")
print(gt_metrics_df[['model', 'ece', 'auc', 'auprc', 'acc']].to_string(index=False))


=== Ground Truth Performance Metrics ===
        model      ece      auc    auprc      acc
    alg_CORAL 0.059626 0.863927 0.399186 0.882914
      alg_ERM 0.060149 0.939604 0.725134 0.923263
      alg_IRM 0.105903 0.916018 0.660259 0.881957
alg_ERM_seed1 0.061202 0.938858 0.726129 0.923039
alg_ERM_seed2 0.049073 0.942639 0.737038 0.922516
alg_IRM_seed1 0.096068 0.911152 0.667215 0.892258
alg_IRM_seed2 0.101714 0.920637 0.656295 0.887399
