## Evaluating the suitability filter on FMoW

In [1]:
import importlib
import random
from itertools import chain, combinations

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter import suitability_efficient

importlib.reload(suitability_efficient)

from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

### Define & evaluate all possible splits

In [None]:
id_val_splits = [
    ("id_val", {"year": 2002}),
    ("id_val", {"year": 2003}),
    ("id_val", {"year": 2004}),
    ("id_val", {"year": 2005}),
    ("id_val", {"year": 2006}),
    ("id_val", {"year": 2007}),
    ("id_val", {"year": 2008}),
    ("id_val", {"year": 2009}),
    ("id_val", {"year": 2010}),
    ("id_val", {"year": 2011}),
    ("id_val", {"year": 2012}),
    ("id_val", {"region": "Asia"}),
    ("id_val", {"region": "Europe"}),
    ("id_val", {"region": "Africa"}),
    ("id_val", {"region": "Americas"}),
    ("id_val", {"region": "Oceania"}),
]

id_test_splits = [
    ("id_test", {"year": 2002}),
    ("id_test", {"year": 2003}),
    ("id_test", {"year": 2004}),
    ("id_test", {"year": 2005}),
    ("id_test", {"year": 2006}),
    ("id_test", {"year": 2007}),
    ("id_test", {"year": 2008}),
    ("id_test", {"year": 2009}),
    ("id_test", {"year": 2010}),
    ("id_test", {"year": 2011}),
    ("id_test", {"year": 2012}),
    ("id_test", {"region": "Asia"}),
    ("id_test", {"region": "Europe"}),
    ("id_test", {"region": "Africa"}),
    ("id_test", {"region": "Americas"}),
    ("id_test", {"region": "Oceania"}),
]

ood_val_splits = [
    ("val", {"year": 2013}),
    ("val", {"year": 2014}),
    ("val", {"year": 2015}),
    ("val", {"region": "Asia"}),
    ("val", {"region": "Europe"}),
    ("val", {"region": "Africa"}),
    ("val", {"region": "Americas"}),
    ("val", {"region": "Oceania"}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
]

ood_test_splits = [
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

In [None]:
data_name = "fmow"
root_dir = "/mfsnic/u/apouget/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_wilds_model(data_name, root_dir, algorithm="ERM")
model = model.to(device)
model.eval()

all_splits = id_val_splits + id_test_splits + ood_val_splits + ood_test_splits

results = pd.DataFrame(columns=["split", "year", "region", "num_samples", "accuracy"])

for split, pre_filter in all_splits:
    data = get_wilds_dataset(
        data_name,
        root_dir,
        split,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=pre_filter,
    )
    suitability_filter = SuitabilityFilter(model, data, data, device)
    corr = suitability_filter.get_correct(data)

    num_samples = len(data.dataset)
    accuracy = np.mean(corr)
    year = pre_filter.get("year", "ALL")
    region = pre_filter.get("region", "ALL")
    results = results._append(
        {
            "split": split,
            "year": year,
            "region": region,
            "num_samples": num_samples,
            "accuracy": accuracy,
        },
        ignore_index=True,
    )

results.to_csv("suitability/results/data_splits/fmow_ERM_0_last.csv", index=False)

### Evaluate suitability filter

In [None]:
valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

print(
    f"Number of valid id splits: {len(valid_id_splits)}, number of valid ood splits: {len(valid_ood_splits)}"
)

Number of valid id splits: 16, number of valid ood splits: 30


## Evaluate suitability filter

In [76]:
import os
import pickle
import random

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

cache_file_id_combined = "suitability/results/features/fmow_ERM_last_0_combined_id.pkl"
if os.path.exists(cache_file_id_combined):
    with open(cache_file_id_combined, "rb") as f:
        features_id_combined = pickle.load(f)

cache_file_id_individual = "suitability/results/features/fmow_ERM_last_0_id.pkl"
if os.path.exists(cache_file_id_individual):
    with open(cache_file_id_individual, "rb") as f:
        features_id_individual = pickle.load(f)

In [None]:
features_id_individual[("id_test", "{'region': ['Asia']}")]["features"][1]

array([ 9.99883533e-01,  1.26985088e-01,  1.36794173e-03, -1.35100746e+01,
        4.84576941e+00,  4.60921860e+00,  1.01284866e+01,  1.16460695e-04,
       -1.01284847e+01,  2.50463672e+04,  9.99990404e-01, -4.84588575e+00],
      dtype=float32)

In [None]:
features_id_individual[("id_test", "{'region': ['Asia']}")]["indices"][1]

120

In [None]:
features_id_combined["features"][120]

array([  0.94658554,   0.1202942 ,   0.21097937, -14.761689  ,
         1.6514555 ,   4.5276346 ,   2.8811173 ,   0.05489393,
        -2.8811173 ,  17.834188  ,   0.9999731 ,  -1.7063494 ],
      dtype=float32)

In [None]:
import os
import pickle
import random

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seed = 0
model = get_wilds_model(
    data_name, root_dir, algorithm=algorithm, seed=seed, model_type=model_type
)
model = model.to(device)
model.eval()
print(f"Model loaded to device: {device}")

# Initialize results DataFrame
features_cache_file = (
    f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
)
valid_splits = ["id_val", "id_test", "val", "test"]
splits_features_cache = {}

# Precompute all data features
for split_name in valid_splits:
    print(f"Computing features for split: {split_name}")
    dataset = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
    )
    splits_features_cache[split_name] = get_sf_features(dataset, model, device)
print("ID splits features computed")

# Save feature cache
with open(features_cache_file, "wb") as f:
    pickle.dump(splits_features_cache, f)

# Precompute all id split indices
id_cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"

valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

id_splits_indices_cache = {}
for split_name, split_filter in valid_id_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    id_splits_indices_cache[(split_name, str(split_filter))] = indices

with open(id_cache_file, "wb") as f:
    pickle.dump(id_splits_indices_cache, f)

# Precompute all ood split indices
ood_cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

ood_splits_indices_cache = {}

for split_name, split_filter in valid_ood_splits:
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    ood_splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(ood_cache_file, "wb") as f:
    pickle.dump(ood_cache_file, f)
print("Features saved")

imported
loading model




Model loaded to device: cuda
Computing features for split: id_val
Computing features for split: id_test
Computing features for split: val
Computing features for split: test
ID splits features computed
Features saved


In [None]:
import pickle

x = pickle.load(
    open("/h/321/apouget/suitability/results/split_indices/fmow_id.pkl", "rb")
)
print(x.keys())

dict_keys([('id_val', "{'year': [2002, 2003, 2004, 2005, 2006]}"), ('id_val', "{'year': [2007, 2008, 2009]}"), ('id_val', "{'year': [2010]}"), ('id_val', "{'year': [2011]}"), ('id_val', "{'year': [2012]}"), ('id_val', "{'region': ['Asia']}"), ('id_val', "{'region': ['Europe']}"), ('id_val', "{'region': ['Americas']}"), ('id_test', "{'year': [2002, 2003, 2004, 2005, 2006]}"), ('id_test', "{'year': [2007, 2008, 2009]}"), ('id_test', "{'year': [2010]}"), ('id_test', "{'year': [2011]}"), ('id_test', "{'year': [2012]}"), ('id_test', "{'region': ['Asia']}"), ('id_test', "{'region': ['Europe']}"), ('id_test', "{'region': ['Americas']}")])


In [22]:
for a, b in x.keys():
    indices = x[(a, b)]
    break
print(len(indices))

1420


In [None]:
import pickle

y = pickle.load(
    open("/h/321/apouget/suitability/results/features/fmow_ERM_last_0.pkl", "rb")
)
print(y["id_val"][0][indices].shape)

(1420, 12)


In [99]:
# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize results DataFrame
cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"

valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

splits_indices_cache = {}


# Precompute all data features
for split_name, split_filter in valid_id_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(cache_file, "wb") as f:
    pickle.dump(splits_indices_cache, f)
print("Features saved")

Computing indices for split: id_val with filter: {'year': [2002, 2003, 2004, 2005, 2006]}
Computing indices for split: id_val with filter: {'year': [2007, 2008, 2009]}
Computing indices for split: id_val with filter: {'year': [2010]}
Computing indices for split: id_val with filter: {'year': [2011]}
Computing indices for split: id_val with filter: {'year': [2012]}
Computing indices for split: id_val with filter: {'region': ['Asia']}
Computing indices for split: id_val with filter: {'region': ['Europe']}
Computing indices for split: id_val with filter: {'region': ['Americas']}
Computing indices for split: id_test with filter: {'year': [2002, 2003, 2004, 2005, 2006]}
Computing indices for split: id_test with filter: {'year': [2007, 2008, 2009]}
Computing indices for split: id_test with filter: {'year': [2010]}
Computing indices for split: id_test with filter: {'year': [2011]}
Computing indices for split: id_test with filter: {'year': [2012]}
Computing indices for split: id_test with filte

In [100]:
# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize results DataFrame
cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

splits_indices_cache = {}


# Precompute all data features
for split_name, split_filter in valid_ood_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(cache_file, "wb") as f:
    pickle.dump(splits_indices_cache, f)
print("Features saved")

Computing indices for split: val with filter: {'year': [2013]}
Computing indices for split: val with filter: {'year': [2014]}
Computing indices for split: val with filter: {'year': [2015]}
Computing indices for split: val with filter: {'region': ['Asia']}
Computing indices for split: val with filter: {'region': ['Europe']}
Computing indices for split: val with filter: {'region': ['Africa']}
Computing indices for split: val with filter: {'region': ['Americas']}
Computing indices for split: val with filter: {'region': ['Oceania']}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2013}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2014}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2015}
Computing indices for split: val with filter: {'region': 'Asia', 'year': 2013}
Computing indices for split: val with filter: {'region': 'Asia', 'year': 2014}
Computing indices for split: val with filter: {'region': 'Asia',

In [1]:
import numpy as np

def calculate_ece_and_bias(probs, correct, n_bins=10):
    """
    Calculate the Expected Calibration Error (ECE) and Calibration Bias (CB).
    
    Args:
        probs (np.ndarray): Array of predicted probabilities for the positive class, shape (n_samples,).
        correct (np.ndarray): Array of correct binary labels (0 or 1), shape (n_samples,).
        n_bins (int): Number of bins to use for calibration calculation.
        
    Returns:
        tuple: (ECE, CB), where:
            - ECE (float): Expected Calibration Error.
            - CB (float): Calibration Bias (positive = overestimation, negative = underestimation).
    """
    # Define bin edges and initialize variables
    bins = np.linspace(0, 1, n_bins + 1)
    ece = 0
    cb = 0
    
    # Assign probabilities to bins
    bin_indices = np.digitize(probs, bins) - 1  # Map probabilities to bin indices (0 to n_bins-1)
    
    # Calculate ECE and CB
    for i in range(n_bins):
        # Mask for the current bin
        bin_mask = bin_indices == i
        if np.sum(bin_mask) == 0:  # Skip empty bins
            continue
        
        # Bin accuracy and confidence
        bin_accuracy = np.mean(correct[bin_mask])
        bin_confidence = np.mean(probs[bin_mask])
        
        # Bin weight
        bin_weight = np.sum(bin_mask) / len(correct)
        
        # Update ECE and CB
        ece += bin_weight * np.abs(bin_accuracy - bin_confidence)
        # cb += bin_weight * (bin_confidence - bin_accuracy)

    cb = np.mean(probs) - np.mean(correct)
    
    return ece, cb

# ID SPLIT SUBSET EVALS

In [2]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:
    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)
    id_feature_dict = {}
    id_feature_dict["id_val"] = full_feature_dict["id_val"]
    id_feature_dict["id_test"] = full_feature_dict["id_test"]

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"
    with open(split_cache_file, "rb") as f:
        id_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05]
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        # [0],
        # [1],
        # [2],
        # [3],
        # [4],
        # [5],
        # [6],
        # [7],
        # [8],
        # [9],
        # [10],
        # [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    # Main loop
    for user_split_name, user_filter in tqdm(id_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = id_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = id_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        # Re-partition remaining data into folds
        remaining_indices = np.setdiff1d(np.arange(len(all_corr)), user_split_indices)
        remaining_features = all_features[remaining_indices]
        remaining_corr = all_corr[remaining_indices]
        if user_split_name == "id_val":
            other_split_name = "id_test"
        elif user_split_name == "id_test":
            other_split_name = "id_val"
        else:
            raise ValueError(f"Invalid split name: {user_split_name}")
        additional_features, additional_corr = id_feature_dict[other_split_name]
        source_features = np.concatenate([remaining_features, additional_features], axis=0)
        source_corr = np.concatenate([remaining_corr, additional_corr], axis=0)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin, return_predictions=True
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                pred_user = sf_test["user_predictions"]
                                pred_test = sf_test["test_predictions"]

                                # Calculate ECE and CB
                                ece_user, cb_user = calculate_ece_and_bias(
                                    pred_user, user_corr
                                )
                                ece_test, cb_test = calculate_ece_and_bias(
                                    pred_test, test_corr
                                )

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                        "ece_user": ece_user,
                                        "cb_user": cb_user,
                                        "ece_test": ece_test,
                                        "cb_test": cb_test,
                                        "mean_pred_user": np.mean(pred_user),
                                        "mean_pred_test": np.mean(pred_test),
                                        "std_pred_user": np.std(pred_user),
                                        "std_pred_test": np.std(pred_test),
                                    }
                                )

                                # Run non-inferiority test on features directly
                                # if (
                                #     len(feature_subset) == 1
                                #     and margin == 0
                                #     and classifier == "logistic_regression"
                                #     and (j == 0 or (i == 0 and j == 1))
                                # ):
                                #     test_feature_subset = test_features[:, feature_subset].flatten()
                                #     user_feature_subset = user_features[:, feature_subset].flatten()
                                #     test_1 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=True,
                                #     )
                                #     test_2 = non_inferiority_ttest(
                                #         test_feature_subset,
                                #         user_feature_subset,
                                #         increase_good=False,
                                #     )
                                #     direct_testing_results.append(
                                #         {
                                #             "data_name": data_name,
                                #             "algorithm": algorithm,
                                #             "seed": seed,
                                #             "model_type": model_type,
                                #             "test_fold": j,
                                #             "test_size": test_size,
                                #             "test_acc": test_acc,
                                #             "user_split": user_split_name,
                                #             "user_filter": user_filter,
                                #             "user_size": user_size,
                                #             "user_acc": user_acc,
                                #             "p_value_increase_good": test_1["p_value"],
                                #             "p_value_decrease_good": test_2["p_value"],
                                #             "ground_truth": ground_truth,
                                #             "feature_subset": feature_subset,
                                #             "acc_diff": user_acc - test_acc,
                                #         }
                                #     )


    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/irm/fmow_sf_results_id_calibration_{algorithm}_{model_type}_{seed}_FINAL.csv",
        index=False,
    )
    # direct_testing_evals = pd.DataFrame(direct_testing_results)
    # direct_testing_evals.to_csv(
    #     f"suitability/results/sf_evals/irm/fmow_direct_testing_results_id_calibration_{algorithm}_{model_type}_{seed}_NEW.csv",
    #     index=False,
    # )


  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:12<03:08, 12.57s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:25<02:56, 12.60s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:38<02:45, 12.73s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:51<02:36, 13.07s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:04<02:23, 13.06s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:17<02:10, 13.02s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:31<01:58, 13.14s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:44<01:44, 13.09s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:57<01:32, 13.16s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:10<01:18, 13.02s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:22<01:04, 12.99s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:35<00:51, 12.96s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:48<00:38, 12.98s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:02<00:26, 13.09s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:15<00:13, 13.15s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:28<00:00, 13.03s/it]
  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:12<03:13, 12.90s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:25<02:59, 12.83s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:39<02:52, 13.29s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:52<02:37, 13.15s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:05<02:24, 13.11s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:18<02:10, 13.06s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:32<02:00, 13.42s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:45<01:46, 13.28s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:58<01:31, 13.12s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:11<01:18, 13.05s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:24<01:05, 13.03s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:37<00:52, 13.24s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:51<00:39, 13.21s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:03<00:26, 13.11s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:17<00:13, 13.19s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:30<00:00, 13.15s/it]
  0%|                                                                                                                                     | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                     | 1/16 [00:13<03:19, 13.32s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▋                                                                                                             | 2/16 [00:26<03:01, 12.95s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▍                                                                                                     | 3/16 [00:39<02:49, 13.03s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████▎                                                                                             | 4/16 [00:52<02:36, 13.03s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|███████████████████████████████████████                                                                                      | 5/16 [01:05<02:23, 13.07s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▉                                                                              | 6/16 [01:18<02:11, 13.17s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▋                                                                      | 7/16 [01:32<01:59, 13.26s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████▌                                                              | 8/16 [01:45<01:45, 13.19s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|██████████████████████████████████████████████████████████████████████▎                                                      | 9/16 [01:57<01:31, 13.03s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|█████████████████████████████████████████████████████████████████████████████▌                                              | 10/16 [02:10<01:17, 12.94s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|█████████████████████████████████████████████████████████████████████████████████████▎                                      | 11/16 [02:23<01:05, 13.01s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|█████████████████████████████████████████████████████████████████████████████████████████████                               | 12/16 [02:36<00:52, 13.01s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                       | 13/16 [02:49<00:39, 13.01s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 14/16 [03:02<00:26, 13.01s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 15/16 [03:16<00:13, 13.27s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [03:29<00:00, 13.10s/it]


# OOD SPLIT SUBSET EVALS

In [None]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:
    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"
    with open(split_cache_file, "rb") as f:
        ood_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05] #  0.005, 0.01, 0.05
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        # [0],
        # [1],
        # [2],
        # [3],
        # [4],
        # [5],
        # [6],
        # [7],
        # [8],
        # [9],
        # [10],
        # [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    id_features_val, id_corr_val = full_feature_dict["id_val"]
    id_features_test, id_corr_test = full_feature_dict["id_test"]
    source_features = np.concatenate([id_features_val, id_features_test], axis=0)
    source_corr = np.concatenate([id_corr_val, id_corr_test], axis=0)


    # Main loop
    for user_split_name, user_filter in tqdm(ood_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = ood_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = full_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin, return_predictions=True
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                pred_user = sf_test["user_predictions"]
                                pred_test = sf_test["test_predictions"]


                                # MARGIN ADJUSTMENT PART
                                n_cal_samples = 50 # Number of samples for calibration error estimation

                                # Ensure n_cal_samples is not larger than available data
                                n_cal_test = min(n_cal_samples, test_size)
                                n_cal_user = min(n_cal_samples, user_size)

                                if n_cal_test > 0 and n_cal_user > 0:
                                    # Sample indices for calibration estimation
                                    cal_test_indices = np.random.choice(test_size, n_cal_test, replace=False)
                                    cal_user_indices = np.random.choice(user_size, n_cal_user, replace=False)

                                    # Get corresponding predictions and ground truth correctness
                                    pred_test_cal = pred_test[cal_test_indices]
                                    test_corr_cal = test_corr[cal_test_indices] # Index into original test_corr using sampled indices from test_indices

                                    pred_user_cal = pred_user[cal_user_indices]
                                    user_corr_cal = user_corr[cal_user_indices] # Index into original user_corr

                                    # Calculate calibration errors (Delta)
                                    delta_test = np.mean(pred_test_cal) - np.mean(test_corr_cal)
                                    delta_u = np.mean(pred_user_cal) - np.mean(user_corr_cal)

                                    # Calculate adjusted margin
                                    m_prime = margin + delta_test - delta_u

                                sf_test_adjusted = suitability_filter.suitability_test(
                                    user_features=user_features, margin=m_prime, return_predictions=True
                                )

                                p_value_ma = sf_test_adjusted["p_value"]

                                # END MARGIN ADJUSTMENT PART

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                        "mean_pred_user": np.mean(pred_user),
                                        "mean_pred_test": np.mean(pred_test),
                                        "std_pred_user": np.std(pred_user),
                                        "std_pred_test": np.std(pred_test),
                                        "p_value_ma": p_value_ma,
                                    }
                                )


    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/erm/fmow_sf_results_ood_calibration_{algorithm}_{model_type}_{seed}_REBUTTAL.csv",
        index=False,
    )


  0%|                                                                                                                                     | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}
