## Evaluating the suitability filter on FMoW

In [1]:
import importlib
import random

import numpy as np
import pandas as pd
import torch
from suitability.filter import suitability_efficient

importlib.reload(suitability_efficient)

from suitability.filter.suitability_efficient import SuitabilityFilter

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

## Evaluate suitability filter

In [1]:
import numpy as np

def calculate_ece_and_bias(probs, correct, n_bins=10):
    """
    Calculate the Expected Calibration Error (ECE) and Calibration Bias (CB).
    
    Args:
        probs (np.ndarray): Array of predicted probabilities for the positive class, shape (n_samples,).
        correct (np.ndarray): Array of correct binary labels (0 or 1), shape (n_samples,).
        n_bins (int): Number of bins to use for calibration calculation.
        
    Returns:
        tuple: (ECE, CB), where:
            - ECE (float): Expected Calibration Error.
            - CB (float): Calibration Bias (positive = overestimation, negative = underestimation).
    """
    # Define bin edges and initialize variables
    bins = np.linspace(0, 1, n_bins + 1)
    ece = 0
    cb = 0
    
    # Assign probabilities to bins
    bin_indices = np.digitize(probs, bins) - 1  # Map probabilities to bin indices (0 to n_bins-1)
    
    # Calculate ECE and CB
    for i in range(n_bins):
        # Mask for the current bin
        bin_mask = bin_indices == i
        if np.sum(bin_mask) == 0:  # Skip empty bins
            continue
        
        # Bin accuracy and confidence
        bin_accuracy = np.mean(correct[bin_mask])
        bin_confidence = np.mean(probs[bin_mask])
        
        # Bin weight
        bin_weight = np.sum(bin_mask) / len(correct)
        
        # Update ECE and CB
        ece += bin_weight * np.abs(bin_accuracy - bin_confidence)
        # cb += bin_weight * (bin_confidence - bin_accuracy)

    cb = np.mean(probs) - np.mean(correct)
    
    return ece, cb

# ID SPLIT SUBSET EVALS

In [2]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:
    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)
    id_feature_dict = {}
    id_feature_dict["id_val"] = full_feature_dict["id_val"]
    id_feature_dict["id_test"] = full_feature_dict["id_test"]

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"
    with open(split_cache_file, "rb") as f:
        id_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05]
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        # [0],
        # [1],
        # [2],
        # [3],
        # [4],
        # [5],
        # [6],
        # [7],
        # [8],
        # [9],
        # [10],
        # [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    # Main loop
    for user_split_name, user_filter in tqdm(id_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = id_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = id_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        # Re-partition remaining data into folds
        remaining_indices = np.setdiff1d(np.arange(len(all_corr)), user_split_indices)
        remaining_features = all_features[remaining_indices]
        remaining_corr = all_corr[remaining_indices]
        if user_split_name == "id_val":
            other_split_name = "id_test"
        elif user_split_name == "id_test":
            other_split_name = "id_val"
        else:
            raise ValueError(f"Invalid split name: {user_split_name}")
        additional_features, additional_corr = id_feature_dict[other_split_name]
        source_features = np.concatenate([remaining_features, additional_features], axis=0)
        source_corr = np.concatenate([remaining_corr, additional_corr], axis=0)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin, return_predictions=True
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                pred_user = sf_test["user_predictions"]
                                pred_test = sf_test["test_predictions"]

                                # Calculate ECE and CB
                                ece_user, cb_user = calculate_ece_and_bias(
                                    pred_user, user_corr
                                )
                                ece_test, cb_test = calculate_ece_and_bias(
                                    pred_test, test_corr
                                )

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                        "ece_user": ece_user,
                                        "cb_user": cb_user,
                                        "ece_test": ece_test,
                                        "cb_test": cb_test,
                                        "mean_pred_user": np.mean(pred_user),
                                        "mean_pred_test": np.mean(pred_test),
                                        "std_pred_user": np.std(pred_user),
                                        "std_pred_test": np.std(pred_test),
                                    }
                                )

                                # Run non-inferiority test on features directly
                                # if (
                                #     len(feature_subset) == 1
                                #     and margin == 0
                                #     and classifier == "logistic_regression"
                                #     and (j == 0 or (i == 0 and j == 1))
                                # ):
                                #     test_feature_subset = test_features[:, feature_subset].flatten()
                                #     user_feature_subset = user_features[:, feature_subset].flatten()
                                #     test_1 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=True,
                                #     )
                                #     test_2 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=False,
                                #     )
                                #     direct_testing_results.append(
                                #         {
                                #             "data_name": data_name,
                                #             "algorithm": algorithm,
                                #             "seed": seed,
                                #             "model_type": model_type,
                                #             "test_fold": j,
                                #             "test_size": test_size,
                                #             "test_acc": test_acc,
                                #             "user_split": user_split_name,
                                #             "user_filter": user_filter,
                                #             "user_size": user_size,
                                #             "user_acc": user_acc,
                                #             "p_value_increase_good": test_1["p_value"],
                                #             "p_value_decrease_good": test_2["p_value"],
                                #             "ground_truth": ground_truth,
                                #             "feature_subset": feature_subset,
                                #             "acc_diff": user_acc - test_acc,
                                #         }
                                #     )


    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/irm/fmow_sf_results_id_calibration_{algorithm}_{model_type}_{seed}_FINAL.csv",
        index=False,
    )
    # direct_testing_evals = pd.DataFrame(direct_testing_results)
    # direct_testing_evals.to_csv(
    #     f"suitability/results/sf_evals/irm/fmow_direct_testing_results_id_calibration_{algorithm}_{model_type}_{seed}_NEW.csv",
    #     index=False,
    # )


  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:12<03:08, 12.57s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:25<02:56, 12.60s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:38<02:45, 12.73s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:51<02:36, 13.07s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:04<02:23, 13.06s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:17<02:10, 13.02s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:31<01:58, 13.14s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:44<01:44, 13.09s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:57<01:32, 13.16s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:10<01:18, 13.02s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:22<01:04, 12.99s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:35<00:51, 12.96s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:48<00:38, 12.98s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:02<00:26, 13.09s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:15<00:13, 13.15s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:28<00:00, 13.03s/it]
  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:12<03:13, 12.90s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:25<02:59, 12.83s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:39<02:52, 13.29s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:52<02:37, 13.15s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:05<02:24, 13.11s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:18<02:10, 13.06s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:32<02:00, 13.42s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:45<01:46, 13.28s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:58<01:31, 13.12s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:11<01:18, 13.05s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:24<01:05, 13.03s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:37<00:52, 13.24s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:51<00:39, 13.21s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:03<00:26, 13.11s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:17<00:13, 13.19s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:30<00:00, 13.15s/it]
  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:13<03:19, 13.32s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:26<03:01, 12.95s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:39<02:49, 13.03s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:52<02:36, 13.03s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:05<02:23, 13.07s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:18<02:11, 13.17s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:32<01:59, 13.26s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:45<01:45, 13.19s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:57<01:31, 13.03s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:10<01:17, 12.94s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:23<01:05, 13.01s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:36<00:52, 13.01s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:49<00:39, 13.01s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:02<00:26, 13.01s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:16<00:13, 13.27s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:29<00:00, 13.10s/it]


# OOD SPLIT SUBSET EVALS

In [3]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:
    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"
    with open(split_cache_file, "rb") as f:
        ood_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05]
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        # [0],
        # [1],
        # [2],
        # [3],
        # [4],
        # [5],
        # [6],
        # [7],
        # [8],
        # [9],
        # [10],
        # [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    id_features_val, id_corr_val = full_feature_dict["id_val"]
    id_features_test, id_corr_test = full_feature_dict["id_test"]
    source_features = np.concatenate([id_features_val, id_features_test], axis=0)
    source_corr = np.concatenate([id_corr_val, id_corr_test], axis=0)


    # Main loop
    for user_split_name, user_filter in tqdm(ood_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = ood_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = full_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin, return_predictions=True
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                pred_user = sf_test["user_predictions"]
                                pred_test = sf_test["test_predictions"]

                                # Calculate ECE and CB
                                ece_user, cb_user = calculate_ece_and_bias(
                                    pred_user, user_corr
                                )
                                ece_test, cb_test = calculate_ece_and_bias(
                                    pred_test, test_corr
                                )

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                        "ece_user": ece_user,
                                        "cb_user": cb_user,
                                        "ece_test": ece_test,
                                        "cb_test": cb_test,
                                        "mean_pred_user": np.mean(pred_user),
                                        "mean_pred_test": np.mean(pred_test),
                                        "std_pred_user": np.std(pred_user),
                                        "std_pred_test": np.std(pred_test),
                                    }
                                )

                                # Run non-inferiority test on features directly
                                # if (
                                #     len(feature_subset) == 1
                                #     and margin == 0
                                #     and classifier == "logistic_regression"
                                #     and (j == 0 or (i == 0 and j == 1))
                                # ):
                                #     test_feature_subset = test_features[:, feature_subset].flatten()
                                #     user_feature_subset = user_features[:, feature_subset].flatten()
                                #     test_1 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=True,
                                #     )
                                #     test_2 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=False,
                                #     )
                                #     direct_testing_results.append(
                                #         {
                                #             "data_name": data_name,
                                #             "algorithm": algorithm,
                                #             "seed": seed,
                                #             "model_type": model_type,
                                #             "test_fold": j,
                                #             "test_size": test_size,
                                #             "test_acc": test_acc,
                                #             "user_split": user_split_name,
                                #             "user_filter": user_filter,
                                #             "user_size": user_size,
                                #             "user_acc": user_acc,
                                #             "p_value_increase_good": test_1["p_value"],
                                #             "p_value_decrease_good": test_2["p_value"],
                                #             "ground_truth": ground_truth,
                                #             "feature_subset": feature_subset,
                                #             "acc_diff": user_acc - test_acc,
                                #         }
                                #     )

    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/erm/fmow_sf_results_ood_calibration_{algorithm}_{model_type}_{seed}_FINAL.csv",
        index=False,
    )
    # direct_testing_evals = pd.DataFrame(direct_testing_results)
    # direct_testing_evals.to_csv(
    #     f"suitability/results/sf_evals/erm/fmow_direct_testing_results_ood_subsets_{algorithm}_{model_type}_{seed}_NEW.csv",
    #     index=False,
    # )


  0%|                                                                                                                                     | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|████▏                                                                                                                        | 1/30 [00:13<06:34, 13.60s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|████████▎                                                                                                                    | 2/30 [00:27<06:33, 14.04s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|████████████▌                                                                                                                | 3/30 [00:43<06:36, 14.68s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|████████████████▋                                                                                                            | 4/30 [00:57<06:19, 14.58s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|████████████████████▊                                                                                                        | 5/30 [01:12<06:05, 14.62s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|█████████████████████████                                                                                                    | 6/30 [01:25<05:33, 13.90s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|█████████████████████████████▏                                                                                               | 7/30 [01:39<05:23, 14.05s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|█████████████████████████████████▎                                                                                           | 8/30 [01:52<05:04, 13.85s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████████████████▌                                                                                       | 9/30 [02:05<04:44, 13.54s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|█████████████████████████████████████████▎                                                                                  | 10/30 [02:18<04:28, 13.42s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|█████████████████████████████████████████████▍                                                                              | 11/30 [02:32<04:15, 13.45s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|█████████████████████████████████████████████████▌                                                                          | 12/30 [02:44<03:57, 13.20s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|█████████████████████████████████████████████████████▋                                                                      | 13/30 [02:58<03:45, 13.29s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|█████████████████████████████████████████████████████████▊                                                                  | 14/30 [03:11<03:31, 13.20s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|██████████████████████████████████████████████████████████████                                                              | 15/30 [03:24<03:15, 13.06s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|██████████████████████████████████████████████████████████████████▏                                                         | 16/30 [03:37<03:02, 13.07s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|██████████████████████████████████████████████████████████████████████▎                                                     | 17/30 [03:50<02:51, 13.18s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|██████████████████████████████████████████████████████████████████████████▍                                                 | 18/30 [04:08<02:53, 14.45s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|██████████████████████████████████████████████████████████████████████████████▌                                             | 19/30 [04:22<02:38, 14.37s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|██████████████████████████████████████████████████████████████████████████████████▋                                         | 20/30 [04:36<02:21, 14.18s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|██████████████████████████████████████████████████████████████████████████████████████▊                                     | 21/30 [04:50<02:07, 14.17s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|██████████████████████████████████████████████████████████████████████████████████████████▉                                 | 22/30 [05:03<01:52, 14.02s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|███████████████████████████████████████████████████████████████████████████████████████████████                             | 23/30 [05:18<01:39, 14.26s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 24/30 [05:31<01:22, 13.75s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎                    | 25/30 [05:44<01:08, 13.73s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▍                | 26/30 [05:57<00:53, 13.41s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▌            | 27/30 [06:11<00:40, 13.54s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋        | 28/30 [06:24<00:26, 13.34s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊    | 29/30 [06:38<00:13, 13.59s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [06:51<00:00, 13.71s/it]
  0%|                                                                                                                                     | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|████▏                                                                                                                        | 1/30 [00:14<06:57, 14.41s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|████████▎                                                                                                                    | 2/30 [00:28<06:41, 14.32s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|████████████▌                                                                                                                | 3/30 [00:44<06:40, 14.82s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|████████████████▋                                                                                                            | 4/30 [00:57<06:14, 14.42s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|████████████████████▊                                                                                                        | 5/30 [01:13<06:07, 14.70s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|█████████████████████████                                                                                                    | 6/30 [01:25<05:36, 14.02s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|█████████████████████████████▏                                                                                               | 7/30 [01:40<05:25, 14.16s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|█████████████████████████████████▎                                                                                           | 8/30 [01:52<05:00, 13.67s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████████████████▌                                                                                       | 9/30 [02:05<04:43, 13.48s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|█████████████████████████████████████████▎                                                                                  | 10/30 [02:19<04:33, 13.65s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|█████████████████████████████████████████████▍                                                                              | 11/30 [02:33<04:19, 13.63s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|█████████████████████████████████████████████████▌                                                                          | 12/30 [02:46<04:00, 13.36s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|█████████████████████████████████████████████████████▋                                                                      | 13/30 [02:59<03:44, 13.21s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|█████████████████████████████████████████████████████████▊                                                                  | 14/30 [03:12<03:31, 13.23s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|██████████████████████████████████████████████████████████████                                                              | 15/30 [03:25<03:19, 13.33s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|██████████████████████████████████████████████████████████████████▏                                                         | 16/30 [03:38<03:04, 13.20s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|██████████████████████████████████████████████████████████████████████▎                                                     | 17/30 [03:52<02:52, 13.25s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|██████████████████████████████████████████████████████████████████████████▍                                                 | 18/30 [04:09<02:52, 14.35s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|██████████████████████████████████████████████████████████████████████████████▌                                             | 19/30 [04:24<02:39, 14.50s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|██████████████████████████████████████████████████████████████████████████████████▋                                         | 20/30 [04:37<02:23, 14.33s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|██████████████████████████████████████████████████████████████████████████████████████▊                                     | 21/30 [04:52<02:08, 14.27s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|██████████████████████████████████████████████████████████████████████████████████████████▉                                 | 22/30 [05:05<01:51, 13.92s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|███████████████████████████████████████████████████████████████████████████████████████████████                             | 23/30 [05:19<01:39, 14.17s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 24/30 [05:33<01:23, 13.86s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎                    | 25/30 [05:47<01:09, 13.91s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▍                | 26/30 [05:59<00:54, 13.55s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▌            | 27/30 [06:13<00:40, 13.49s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋        | 28/30 [06:26<00:26, 13.45s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊    | 29/30 [06:40<00:13, 13.66s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [06:53<00:00, 13.79s/it]
  0%|                                                                                                                                     | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|████▏                                                                                                                        | 1/30 [00:13<06:36, 13.68s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|████████▎                                                                                                                    | 2/30 [00:28<06:37, 14.19s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|████████████▌                                                                                                                | 3/30 [00:44<06:46, 15.05s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|████████████████▋                                                                                                            | 4/30 [00:58<06:17, 14.53s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|████████████████████▊                                                                                                        | 5/30 [01:12<06:05, 14.62s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|█████████████████████████                                                                                                    | 6/30 [01:25<05:37, 14.05s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|█████████████████████████████▏                                                                                               | 7/30 [01:40<05:29, 14.34s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|█████████████████████████████████▎                                                                                           | 8/30 [01:53<05:03, 13.80s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████████████████▌                                                                                       | 9/30 [02:06<04:44, 13.56s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|█████████████████████████████████████████▎                                                                                  | 10/30 [02:19<04:29, 13.48s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|█████████████████████████████████████████████▍                                                                              | 11/30 [02:33<04:17, 13.53s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|█████████████████████████████████████████████████▌                                                                          | 12/30 [02:46<04:01, 13.44s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|█████████████████████████████████████████████████████▋                                                                      | 13/30 [02:59<03:45, 13.27s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|█████████████████████████████████████████████████████████▊                                                                  | 14/30 [03:12<03:31, 13.24s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|██████████████████████████████████████████████████████████████                                                              | 15/30 [03:25<03:16, 13.10s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|██████████████████████████████████████████████████████████████████▏                                                         | 16/30 [03:38<03:03, 13.08s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|██████████████████████████████████████████████████████████████████████▎                                                     | 17/30 [03:52<02:54, 13.41s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|██████████████████████████████████████████████████████████████████████████▍                                                 | 18/30 [04:09<02:53, 14.42s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|██████████████████████████████████████████████████████████████████████████████▌                                             | 19/30 [04:23<02:38, 14.37s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|██████████████████████████████████████████████████████████████████████████████████▋                                         | 20/30 [04:37<02:21, 14.19s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|██████████████████████████████████████████████████████████████████████████████████████▊                                     | 21/30 [04:52<02:09, 14.37s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|██████████████████████████████████████████████████████████████████████████████████████████▉                                 | 22/30 [05:05<01:52, 14.02s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|███████████████████████████████████████████████████████████████████████████████████████████████                             | 23/30 [05:20<01:39, 14.28s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 24/30 [05:32<01:22, 13.75s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎                    | 25/30 [05:46<01:09, 13.86s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▍                | 26/30 [05:59<00:54, 13.56s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▌            | 27/30 [06:13<00:40, 13.55s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋        | 28/30 [06:26<00:26, 13.40s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊    | 29/30 [06:40<00:13, 13.63s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [06:54<00:00, 13.80s/it]
