## Evaluating the suitability filter on FMoW

In [1]:
import importlib
import random
from itertools import chain, combinations

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter import suitability_efficient

importlib.reload(suitability_efficient)

from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

### Define & evaluate all possible splits

In [None]:
id_val_splits = [
    ("id_val", {"year": 2002}),
    ("id_val", {"year": 2003}),
    ("id_val", {"year": 2004}),
    ("id_val", {"year": 2005}),
    ("id_val", {"year": 2006}),
    ("id_val", {"year": 2007}),
    ("id_val", {"year": 2008}),
    ("id_val", {"year": 2009}),
    ("id_val", {"year": 2010}),
    ("id_val", {"year": 2011}),
    ("id_val", {"year": 2012}),
    ("id_val", {"region": "Asia"}),
    ("id_val", {"region": "Europe"}),
    ("id_val", {"region": "Africa"}),
    ("id_val", {"region": "Americas"}),
    ("id_val", {"region": "Oceania"}),
]

id_test_splits = [
    ("id_test", {"year": 2002}),
    ("id_test", {"year": 2003}),
    ("id_test", {"year": 2004}),
    ("id_test", {"year": 2005}),
    ("id_test", {"year": 2006}),
    ("id_test", {"year": 2007}),
    ("id_test", {"year": 2008}),
    ("id_test", {"year": 2009}),
    ("id_test", {"year": 2010}),
    ("id_test", {"year": 2011}),
    ("id_test", {"year": 2012}),
    ("id_test", {"region": "Asia"}),
    ("id_test", {"region": "Europe"}),
    ("id_test", {"region": "Africa"}),
    ("id_test", {"region": "Americas"}),
    ("id_test", {"region": "Oceania"}),
]

ood_val_splits = [
    ("val", {"year": 2013}),
    ("val", {"year": 2014}),
    ("val", {"year": 2015}),
    ("val", {"region": "Asia"}),
    ("val", {"region": "Europe"}),
    ("val", {"region": "Africa"}),
    ("val", {"region": "Americas"}),
    ("val", {"region": "Oceania"}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
]

ood_test_splits = [
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

In [None]:
data_name = "fmow"
root_dir = "/mfsnic/u/apouget/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_wilds_model(data_name, root_dir, algorithm="ERM")
model = model.to(device)
model.eval()

all_splits = id_val_splits + id_test_splits + ood_val_splits + ood_test_splits

results = pd.DataFrame(columns=["split", "year", "region", "num_samples", "accuracy"])

for split, pre_filter in all_splits:
    data = get_wilds_dataset(
        data_name,
        root_dir,
        split,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=pre_filter,
    )
    suitability_filter = SuitabilityFilter(model, data, data, device)
    corr = suitability_filter.get_correct(data)

    num_samples = len(data.dataset)
    accuracy = np.mean(corr)
    year = pre_filter.get("year", "ALL")
    region = pre_filter.get("region", "ALL")
    results = results._append(
        {
            "split": split,
            "year": year,
            "region": region,
            "num_samples": num_samples,
            "accuracy": accuracy,
        },
        ignore_index=True,
    )

results.to_csv("suitability/results/data_splits/fmow_ERM_0_last.csv", index=False)

### Evaluate suitability filter

In [None]:
valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

print(
    f"Number of valid id splits: {len(valid_id_splits)}, number of valid ood splits: {len(valid_ood_splits)}"
)

Number of valid id splits: 16, number of valid ood splits: 30


## Evaluate suitability filter

In [76]:
import os
import pickle
import random

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

cache_file_id_combined = "suitability/results/features/fmow_ERM_last_0_combined_id.pkl"
if os.path.exists(cache_file_id_combined):
    with open(cache_file_id_combined, "rb") as f:
        features_id_combined = pickle.load(f)

cache_file_id_individual = "suitability/results/features/fmow_ERM_last_0_id.pkl"
if os.path.exists(cache_file_id_individual):
    with open(cache_file_id_individual, "rb") as f:
        features_id_individual = pickle.load(f)

In [None]:
features_id_individual[("id_test", "{'region': ['Asia']}")]["features"][1]

array([ 9.99883533e-01,  1.26985088e-01,  1.36794173e-03, -1.35100746e+01,
        4.84576941e+00,  4.60921860e+00,  1.01284866e+01,  1.16460695e-04,
       -1.01284847e+01,  2.50463672e+04,  9.99990404e-01, -4.84588575e+00],
      dtype=float32)

In [None]:
features_id_individual[("id_test", "{'region': ['Asia']}")]["indices"][1]

120

In [None]:
features_id_combined["features"][120]

array([  0.94658554,   0.1202942 ,   0.21097937, -14.761689  ,
         1.6514555 ,   4.5276346 ,   2.8811173 ,   0.05489393,
        -2.8811173 ,  17.834188  ,   0.9999731 ,  -1.7063494 ],
      dtype=float32)

In [None]:
import os
import pickle
import random

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter.suitability_efficient import SuitabilityFilter, get_sf_features

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seed = 0
model = get_wilds_model(
    data_name, root_dir, algorithm=algorithm, seed=seed, model_type=model_type
)
model = model.to(device)
model.eval()
print(f"Model loaded to device: {device}")

# Initialize results DataFrame
features_cache_file = (
    f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
)
valid_splits = ["id_val", "id_test", "val", "test"]
splits_features_cache = {}

# Precompute all data features
for split_name in valid_splits:
    print(f"Computing features for split: {split_name}")
    dataset = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
    )
    splits_features_cache[split_name] = get_sf_features(dataset, model, device)
print("ID splits features computed")

# Save feature cache
with open(features_cache_file, "wb") as f:
    pickle.dump(splits_features_cache, f)

# Precompute all id split indices
id_cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"

valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

id_splits_indices_cache = {}
for split_name, split_filter in valid_id_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    id_splits_indices_cache[(split_name, str(split_filter))] = indices

with open(id_cache_file, "wb") as f:
    pickle.dump(id_splits_indices_cache, f)

# Precompute all ood split indices
ood_cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

ood_splits_indices_cache = {}

for split_name, split_filter in valid_ood_splits:
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    ood_splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(ood_cache_file, "wb") as f:
    pickle.dump(ood_cache_file, f)
print("Features saved")

imported
loading model




Model loaded to device: cuda
Computing features for split: id_val
Computing features for split: id_test
Computing features for split: val
Computing features for split: test
ID splits features computed
Features saved


In [None]:
import pickle

x = pickle.load(
    open("/h/321/apouget/suitability/results/split_indices/fmow_id.pkl", "rb")
)
print(x.keys())

dict_keys([('id_val', "{'year': [2002, 2003, 2004, 2005, 2006]}"), ('id_val', "{'year': [2007, 2008, 2009]}"), ('id_val', "{'year': [2010]}"), ('id_val', "{'year': [2011]}"), ('id_val', "{'year': [2012]}"), ('id_val', "{'region': ['Asia']}"), ('id_val', "{'region': ['Europe']}"), ('id_val', "{'region': ['Americas']}"), ('id_test', "{'year': [2002, 2003, 2004, 2005, 2006]}"), ('id_test', "{'year': [2007, 2008, 2009]}"), ('id_test', "{'year': [2010]}"), ('id_test', "{'year': [2011]}"), ('id_test', "{'year': [2012]}"), ('id_test', "{'region': ['Asia']}"), ('id_test', "{'region': ['Europe']}"), ('id_test', "{'region': ['Americas']}")])


In [22]:
for a, b in x.keys():
    indices = x[(a, b)]
    break
print(len(indices))

1420


In [None]:
import pickle

y = pickle.load(
    open("/h/321/apouget/suitability/results/features/fmow_ERM_last_0.pkl", "rb")
)
print(y["id_val"][0][indices].shape)

(1420, 12)


In [99]:
# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize results DataFrame
cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"

valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

splits_indices_cache = {}


# Precompute all data features
for split_name, split_filter in valid_id_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(cache_file, "wb") as f:
    pickle.dump(splits_indices_cache, f)
print("Features saved")

Computing indices for split: id_val with filter: {'year': [2002, 2003, 2004, 2005, 2006]}
Computing indices for split: id_val with filter: {'year': [2007, 2008, 2009]}
Computing indices for split: id_val with filter: {'year': [2010]}
Computing indices for split: id_val with filter: {'year': [2011]}
Computing indices for split: id_val with filter: {'year': [2012]}
Computing indices for split: id_val with filter: {'region': ['Asia']}
Computing indices for split: id_val with filter: {'region': ['Europe']}
Computing indices for split: id_val with filter: {'region': ['Americas']}
Computing indices for split: id_test with filter: {'year': [2002, 2003, 2004, 2005, 2006]}
Computing indices for split: id_test with filter: {'year': [2007, 2008, 2009]}
Computing indices for split: id_test with filter: {'year': [2010]}
Computing indices for split: id_test with filter: {'year': [2011]}
Computing indices for split: id_test with filter: {'year': [2012]}
Computing indices for split: id_test with filte

In [100]:
# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize results DataFrame
cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

splits_indices_cache = {}


# Precompute all data features
for split_name, split_filter in valid_ood_splits:
    print(f"Computing indices for split: {split_name} with filter: {split_filter}")
    dataset, indices = get_wilds_dataset(
        data_name,
        root_dir,
        split_name,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=split_filter,
        return_indices=True,
    )
    splits_indices_cache[(split_name, str(split_filter))] = indices

# Save cache
with open(cache_file, "wb") as f:
    pickle.dump(splits_indices_cache, f)
print("Features saved")

Computing indices for split: val with filter: {'year': [2013]}
Computing indices for split: val with filter: {'year': [2014]}
Computing indices for split: val with filter: {'year': [2015]}
Computing indices for split: val with filter: {'region': ['Asia']}
Computing indices for split: val with filter: {'region': ['Europe']}
Computing indices for split: val with filter: {'region': ['Africa']}
Computing indices for split: val with filter: {'region': ['Americas']}
Computing indices for split: val with filter: {'region': ['Oceania']}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2013}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2014}
Computing indices for split: val with filter: {'region': 'Europe', 'year': 2015}
Computing indices for split: val with filter: {'region': 'Asia', 'year': 2013}
Computing indices for split: val with filter: {'region': 'Asia', 'year': 2014}
Computing indices for split: val with filter: {'region': 'Asia',

# ID SPLIT SUBSET EVALS

In [2]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:

    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)
    id_feature_dict = {}
    id_feature_dict["id_val"] = full_feature_dict["id_val"]
    id_feature_dict["id_test"] = full_feature_dict["id_test"]

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_id.pkl"
    with open(split_cache_file, "rb") as f:
        id_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05]
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        [0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9],
        [10],
        [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    # Main loop
    for user_split_name, user_filter in tqdm(id_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = id_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = id_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        # Re-partition remaining data into folds
        remaining_indices = np.setdiff1d(np.arange(len(all_corr)), user_split_indices)
        remaining_features = all_features[remaining_indices]
        remaining_corr = all_corr[remaining_indices]
        if user_split_name == "id_val":
            other_split_name = "id_test"
        elif user_split_name == "id_test":
            other_split_name = "id_val"
        else:
            raise ValueError(f"Invalid split name: {user_split_name}")
        additional_features, additional_corr = id_feature_dict[other_split_name]
        source_features = np.concatenate([remaining_features, additional_features], axis=0)
        source_corr = np.concatenate([remaining_corr, additional_corr], axis=0)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                    }
                                )

                                # Run non-inferiority test on features directly
                                if (
                                    len(feature_subset) == 1
                                    and margin == 0
                                    and classifier == "logistic_regression"
                                    and (j == 0 or (i == 0 and j == 1))
                                ):
                                    test_feature_subset = test_features[:, feature_subset].flatten()
                                    user_feature_subset = user_features[:, feature_subset].flatten()
                                    test_1 = non_inferiority_ttest(
                                        test_feature_subset,
                                        user_feature_subset,
                                        increase_good=True,
                                    )
                                    test_2 = non_inferiority_ttest(
                                        test_feature_subset,
                                        user_feature_subset,
                                        increase_good=False,
                                    )
                                    direct_testing_results.append(
                                        {
                                            "data_name": data_name,
                                            "algorithm": algorithm,
                                            "seed": seed,
                                            "model_type": model_type,
                                            "test_fold": j,
                                            "test_size": test_size,
                                            "test_acc": test_acc,
                                            "user_split": user_split_name,
                                            "user_filter": user_filter,
                                            "user_size": user_size,
                                            "user_acc": user_acc,
                                            "p_value_increase_good": test_1["p_value"],
                                            "p_value_decrease_good": test_2["p_value"],
                                            "ground_truth": ground_truth,
                                            "feature_subset": feature_subset,
                                            "acc_diff": user_acc - test_acc,
                                        }
                                    )


    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/irm/fmow_sf_results_id_subsets_{algorithm}_{model_type}_{seed}_NEW.csv",
        index=False,
    )
    direct_testing_evals = pd.DataFrame(direct_testing_results)
    direct_testing_evals.to_csv(
        f"suitability/results/sf_evals/irm/fmow_direct_testing_results_id_subsets_{algorithm}_{model_type}_{seed}_NEW.csv",
        index=False,
    )


  0%|                                                                                                                            | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▎                                                                                                            | 1/16 [01:36<24:06, 96.45s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|██████████████▌                                                                                                     | 2/16 [03:13<22:35, 96.84s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|█████████████████████▊                                                                                              | 3/16 [04:51<21:02, 97.12s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|█████████████████████████████                                                                                       | 4/16 [06:28<19:28, 97.34s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|████████████████████████████████████▎                                                                               | 5/16 [08:06<17:52, 97.50s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|███████████████████████████████████████████▌                                                                        | 6/16 [09:44<16:16, 97.61s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████▊                                                                 | 7/16 [11:23<14:42, 98.10s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████                                                          | 8/16 [13:01<13:03, 97.99s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|█████████████████████████████████████████████████████████████████▎                                                  | 9/16 [14:38<11:24, 97.85s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|███████████████████████████████████████████████████████████████████████▉                                           | 10/16 [16:15<09:45, 97.59s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|███████████████████████████████████████████████████████████████████████████████                                    | 11/16 [17:53<08:08, 97.69s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|██████████████████████████████████████████████████████████████████████████████████████▎                            | 12/16 [19:31<06:30, 97.70s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|█████████████████████████████████████████████████████████████████████████████████████████████▍                     | 13/16 [21:09<04:53, 97.83s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████▋              | 14/16 [22:47<03:15, 97.82s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 15/16 [24:26<01:38, 98.16s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [26:04<00:00, 97.78s/it]
  0%|                                                                                                                            | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▎                                                                                                            | 1/16 [01:37<24:16, 97.09s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|██████████████▌                                                                                                     | 2/16 [03:13<22:37, 96.94s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|█████████████████████▊                                                                                              | 3/16 [04:51<21:04, 97.29s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|█████████████████████████████                                                                                       | 4/16 [06:29<19:28, 97.35s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|████████████████████████████████████▎                                                                               | 5/16 [08:06<17:53, 97.55s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|███████████████████████████████████████████▌                                                                        | 6/16 [09:44<16:14, 97.47s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████▊                                                                 | 7/16 [11:23<14:41, 97.93s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████                                                          | 8/16 [13:01<13:03, 97.91s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|█████████████████████████████████████████████████████████████████▎                                                  | 9/16 [14:37<11:23, 97.60s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|███████████████████████████████████████████████████████████████████████▉                                           | 10/16 [16:15<09:45, 97.52s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|███████████████████████████████████████████████████████████████████████████████                                    | 11/16 [17:52<08:07, 97.49s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|██████████████████████████████████████████████████████████████████████████████████████▎                            | 12/16 [19:30<06:30, 97.56s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|█████████████████████████████████████████████████████████████████████████████████████████████▍                     | 13/16 [21:08<04:52, 97.66s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████▋              | 14/16 [22:45<03:15, 97.56s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 15/16 [24:24<01:38, 98.05s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [26:02<00:00, 97.67s/it]
  0%|                                                                                                                                    | 0/16 [00:00<?, ?it/s]

Evaluating user split: id_val with filter {'year': [2002, 2003, 2004, 2005, 2006]}


  6%|███████▊                                                                                                                    | 1/16 [01:39<24:45, 99.04s/it]

Evaluating user split: id_val with filter {'year': [2007, 2008, 2009]}


 12%|███████████████▌                                                                                                            | 2/16 [03:16<22:54, 98.16s/it]

Evaluating user split: id_val with filter {'year': [2010]}


 19%|███████████████████████▎                                                                                                    | 3/16 [04:54<21:14, 98.06s/it]

Evaluating user split: id_val with filter {'year': [2011]}


 25%|███████████████████████████████                                                                                             | 4/16 [06:32<19:36, 98.01s/it]

Evaluating user split: id_val with filter {'year': [2012]}


 31%|██████████████████████████████████████▊                                                                                     | 5/16 [08:10<17:58, 98.01s/it]

Evaluating user split: id_val with filter {'region': ['Asia']}


 38%|██████████████████████████████████████████████▌                                                                             | 6/16 [09:48<16:19, 97.99s/it]

Evaluating user split: id_val with filter {'region': ['Europe']}


 44%|██████████████████████████████████████████████████████▎                                                                     | 7/16 [11:27<14:45, 98.36s/it]

Evaluating user split: id_val with filter {'region': ['Americas']}


 50%|██████████████████████████████████████████████████████████████                                                              | 8/16 [13:05<13:06, 98.36s/it]

Evaluating user split: id_test with filter {'year': [2002, 2003, 2004, 2005, 2006]}


 56%|█████████████████████████████████████████████████████████████████████▊                                                      | 9/16 [14:42<11:25, 97.93s/it]

Evaluating user split: id_test with filter {'year': [2007, 2008, 2009]}


 62%|████████████████████████████████████████████████████████████████████████████▉                                              | 10/16 [16:19<09:45, 97.61s/it]

Evaluating user split: id_test with filter {'year': [2010]}


 69%|████████████████████████████████████████████████████████████████████████████████████▌                                      | 11/16 [17:57<08:08, 97.67s/it]

Evaluating user split: id_test with filter {'year': [2011]}


 75%|████████████████████████████████████████████████████████████████████████████████████████████▎                              | 12/16 [19:35<06:30, 97.73s/it]

Evaluating user split: id_test with filter {'year': [2012]}


 81%|███████████████████████████████████████████████████████████████████████████████████████████████████▉                       | 13/16 [21:17<04:57, 99.02s/it]

Evaluating user split: id_test with filter {'region': ['Asia']}


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▊               | 14/16 [23:02<03:21, 100.79s/it]

Evaluating user split: id_test with filter {'region': ['Europe']}


 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 15/16 [24:47<01:42, 102.24s/it]

Evaluating user split: id_test with filter {'region': ['Americas']}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [26:32<00:00, 99.55s/it]


# OOD SPLIT SUBSET EVALS

In [1]:
import pickle
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from suitability.filter.suitability_efficient import SuitabilityFilter
from suitability.filter.tests import non_inferiority_ttest

# Set seeds for reproducibility
random.seed(32)
np.random.seed(32)

# Configuration
data_name = "fmow"
root_dir = "/mfsnic/projects/suitability/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
algorithm = "ERM"
model_type = "last"
seeds = [0, 1, 2]

for seed in seeds:
    # Load the features
    feature_cache_file = (
        f"suitability/results/features/{data_name}_{algorithm}_{model_type}_{seed}.pkl"
    )
    with open(feature_cache_file, "rb") as f:
        full_feature_dict = pickle.load(f)

    # Load the split indices
    split_cache_file = f"suitability/results/split_indices/{data_name}_ood.pkl"
    with open(split_cache_file, "rb") as f:
        ood_split_dict = pickle.load(f)

    # Define suitability filter and experiment parameters
    classifiers = [
        "logistic_regression"
    ]  # "logistic_regression", "svm", "random_forest", "gradient_boosting", "mlp", "decision_tree"]
    margins = [0, 0.005, 0.01, 0.05]
    normalize = True
    calibrated = True
    sf_results = []
    direct_testing_results = []
    feature_subsets = [
        [0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9],
        [10],
        [11],
        # [4, 11],
        # [4, 11, 8],
        # [4, 11, 8, 6],
        # [4, 11, 8, 6, 2],
        # [4, 11, 8, 6, 2, 1],
        # [4, 11, 8, 6, 2, 1, 0],
        # [4, 11, 8, 6, 2, 1, 0, 7],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10],
        # [4, 11, 8, 6, 2, 1, 0, 7, 3, 10, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    ]
    num_fold_arr = [15]

    id_features_val, id_corr_val = full_feature_dict["id_val"]
    id_features_test, id_corr_test = full_feature_dict["id_test"]
    source_features = np.concatenate([id_features_val, id_features_test], axis=0)
    source_corr = np.concatenate([id_corr_val, id_corr_test], axis=0)


    # Main loop
    for user_split_name, user_filter in tqdm(ood_split_dict.keys()):
        print(f"Evaluating user split: {user_split_name} with filter {user_filter}")

        # Get user split indices
        user_split_indices = ood_split_dict[(user_split_name, user_filter)]

        # Get user split features and correctness
        all_features, all_corr = full_feature_dict[user_split_name]
        user_features = all_features[user_split_indices]
        user_corr = all_corr[user_split_indices]
        user_size = len(user_corr)
        user_acc = np.mean(user_corr)

        for num_folds in num_fold_arr:
            source_fold_size = len(source_corr) // num_folds
            indices = np.arange(len(source_corr))
            np.random.shuffle(indices)
            fold_indices = [
                indices[i * source_fold_size : (i + 1) * source_fold_size]
                for i in range(num_folds)
            ]

            for i, reg_indices in enumerate(fold_indices):
                reg_features = source_features[reg_indices]
                reg_corr = source_corr[reg_indices]
                reg_size = len(reg_corr)
                reg_acc = np.mean(reg_corr)

                for j, test_indices in enumerate(fold_indices):
                    if i == j:
                        continue
                    test_features = source_features[test_indices]
                    test_corr = source_corr[test_indices]
                    test_size = len(test_corr)
                    test_acc = np.mean(test_corr)

                    for classifier in classifiers:
                        for feature_subset in feature_subsets:
                            suitability_filter = SuitabilityFilter(
                                test_features,
                                test_corr,
                                reg_features,
                                reg_corr,
                                device,
                                normalize=normalize,
                                feature_subset=feature_subset,
                            )
                            suitability_filter.train_classifier(
                                calibrated=calibrated, classifier=classifier
                            )

                            for margin in margins:
                                # Test suitability filter
                                sf_test = suitability_filter.suitability_test(
                                    user_features=user_features, margin=margin
                                )
                                p_value = sf_test["p_value"]
                                ground_truth = user_acc >= test_acc - margin

                                sf_results.append(
                                    {
                                        "data_name": data_name,
                                        "algorithm": algorithm,
                                        "seed": seed,
                                        "model_type": model_type,
                                        "normalize": normalize,
                                        "calibrated": calibrated,
                                        "margin": margin,
                                        "reg_fold": i,
                                        "reg_size": reg_size,
                                        "reg_acc": reg_acc,
                                        "test_fold": j,
                                        "test_size": test_size,
                                        "test_acc": test_acc,
                                        "user_split": user_split_name,
                                        "user_filter": user_filter,
                                        "user_size": user_size,
                                        "user_acc": user_acc,
                                        "p_value": p_value,
                                        "ground_truth": ground_truth,
                                        "classifier": classifier,
                                        "feature_subset": feature_subset,
                                        "acc_diff": user_acc - test_acc,
                                        "acc_diff_adjusted": user_acc + margin - test_acc,
                                    }
                                )

                                # Run non-inferiority test on features directly
                                if (
                                    len(feature_subset) == 1
                                    and margin == 0
                                    and classifier == "logistic_regression"
                                    and (j == 0 or (i == 0 and j == 1))
                                ):
                                    test_feature_subset = test_features[:, feature_subset].flatten()
                                    user_feature_subset = user_features[:, feature_subset].flatten()
                                    test_1 = non_inferiority_ttest(
                                        test_feature_subset,
                                        user_feature_subset,
                                        increase_good=True,
                                    )
                                    test_2 = non_inferiority_ttest(
                                        test_feature_subset,
                                        user_feature_subset,
                                        increase_good=False,
                                    )
                                    direct_testing_results.append(
                                        {
                                            "data_name": data_name,
                                            "algorithm": algorithm,
                                            "seed": seed,
                                            "model_type": model_type,
                                            "test_fold": j,
                                            "test_size": test_size,
                                            "test_acc": test_acc,
                                            "user_split": user_split_name,
                                            "user_filter": user_filter,
                                            "user_size": user_size,
                                            "user_acc": user_acc,
                                            "p_value_increase_good": test_1["p_value"],
                                            "p_value_decrease_good": test_2["p_value"],
                                            "ground_truth": ground_truth,
                                            "feature_subset": feature_subset,
                                            "acc_diff": user_acc - test_acc,
                                        }
                                    )

    # Save results
    sf_evals = pd.DataFrame(sf_results)
    sf_evals.to_csv(
        f"suitability/results/sf_evals/erm/fmow_sf_results_ood_subsets_{algorithm}_{model_type}_{seed}_NEW.csv",
        index=False,
    )
    direct_testing_evals = pd.DataFrame(direct_testing_results)
    direct_testing_evals.to_csv(
        f"suitability/results/sf_evals/erm/fmow_direct_testing_results_ood_subsets_{algorithm}_{model_type}_{seed}_NEW.csv",
        index=False,
    )


  0%|                                                                                                         | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|███▏                                                                                             | 1/30 [01:36<46:30, 96.21s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|██████▍                                                                                          | 2/30 [03:14<45:31, 97.54s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|█████████▋                                                                                       | 3/30 [04:56<44:49, 99.60s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|████████████▉                                                                                    | 4/30 [06:33<42:41, 98.52s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|████████████████▏                                                                                | 5/30 [08:13<41:19, 99.16s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|███████████████████▍                                                                             | 6/30 [09:46<38:49, 97.05s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|██████████████████████▋                                                                          | 7/30 [11:26<37:29, 97.79s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|█████████████████████████▊                                                                       | 8/30 [12:59<35:20, 96.38s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████████                                                                    | 9/30 [14:34<33:33, 95.87s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|████████████████████████████████                                                                | 10/30 [16:09<31:52, 95.62s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|███████████████████████████████████▏                                                            | 11/30 [17:45<30:22, 95.91s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|██████████████████████████████████████▍                                                         | 12/30 [19:19<28:33, 95.20s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|█████████████████████████████████████████▌                                                      | 13/30 [20:54<26:55, 95.03s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|████████████████████████████████████████████▊                                                   | 14/30 [22:28<25:18, 94.91s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|████████████████████████████████████████████████                                                | 15/30 [24:02<23:37, 94.52s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|███████████████████████████████████████████████████▏                                            | 16/30 [25:37<22:06, 94.74s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|██████████████████████████████████████████████████████▍                                         | 17/30 [27:13<20:36, 95.14s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|█████████████████████████████████████████████████████████▌                                      | 18/30 [29:01<19:47, 98.93s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|████████████████████████████████████████████████████████████▊                                   | 19/30 [30:40<18:07, 98.89s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|████████████████████████████████████████████████████████████████                                | 20/30 [32:17<16:25, 98.51s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|███████████████████████████████████████████████████████████████████▏                            | 21/30 [33:56<14:46, 98.49s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|██████████████████████████████████████████████████████████████████████▍                         | 22/30 [35:31<13:00, 97.60s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|█████████████████████████████████████████████████████████████████████████▌                      | 23/30 [37:12<11:29, 98.51s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|████████████████████████████████████████████████████████████████████████████▊                   | 24/30 [38:46<09:42, 97.09s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|████████████████████████████████████████████████████████████████████████████████                | 25/30 [40:24<08:06, 97.29s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|███████████████████████████████████████████████████████████████████████████████████▏            | 26/30 [41:57<06:25, 96.29s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|██████████████████████████████████████████████████████████████████████████████████████▍         | 27/30 [43:34<04:49, 96.34s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|█████████████████████████████████████████████████████████████████████████████████████████▌      | 28/30 [45:09<03:11, 95.88s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|████████████████████████████████████████████████████████████████████████████████████████████▊   | 29/30 [46:47<01:36, 96.65s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [48:22<00:00, 96.76s/it]
  0%|                                                                                                         | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|███▏                                                                                             | 1/30 [01:38<47:35, 98.48s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|██████▍                                                                                          | 2/30 [03:17<46:07, 98.85s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|█████████▌                                                                                      | 3/30 [04:59<45:11, 100.43s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|████████████▉                                                                                    | 4/30 [06:37<42:57, 99.13s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|████████████████▏                                                                                | 5/30 [08:17<41:28, 99.55s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|███████████████████▍                                                                             | 6/30 [09:51<39:02, 97.62s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|██████████████████████▋                                                                          | 7/30 [11:30<37:37, 98.14s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|█████████████████████████▊                                                                       | 8/30 [13:04<35:29, 96.81s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████████                                                                    | 9/30 [14:39<33:39, 96.18s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|████████████████████████████████                                                                | 10/30 [16:14<31:58, 95.94s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|███████████████████████████████████▏                                                            | 11/30 [17:51<30:28, 96.22s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|██████████████████████████████████████▍                                                         | 12/30 [19:24<28:34, 95.26s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|█████████████████████████████████████████▌                                                      | 13/30 [20:59<26:55, 95.06s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|████████████████████████████████████████████▊                                                   | 14/30 [22:33<25:18, 94.92s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|████████████████████████████████████████████████                                                | 15/30 [24:08<23:41, 94.74s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|███████████████████████████████████████████████████▏                                            | 16/30 [25:42<22:05, 94.71s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|██████████████████████████████████████████████████████▍                                         | 17/30 [27:18<20:34, 95.00s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|█████████████████████████████████████████████████████████▌                                      | 18/30 [29:05<19:45, 98.77s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|████████████████████████████████████████████████████████████▊                                   | 19/30 [30:44<18:06, 98.76s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|████████████████████████████████████████████████████████████████                                | 20/30 [32:22<16:25, 98.55s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|███████████████████████████████████████████████████████████████████▏                            | 21/30 [34:01<14:47, 98.56s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|██████████████████████████████████████████████████████████████████████▍                         | 22/30 [35:36<13:00, 97.52s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|█████████████████████████████████████████████████████████████████████████▌                      | 23/30 [37:16<11:28, 98.39s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|████████████████████████████████████████████████████████████████████████████▊                   | 24/30 [38:50<09:41, 96.97s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|████████████████████████████████████████████████████████████████████████████████                | 25/30 [40:27<08:05, 97.14s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|███████████████████████████████████████████████████████████████████████████████████▏            | 26/30 [42:01<06:24, 96.16s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|██████████████████████████████████████████████████████████████████████████████████████▍         | 27/30 [43:38<04:48, 96.26s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|█████████████████████████████████████████████████████████████████████████████████████████▌      | 28/30 [45:13<03:11, 95.85s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|████████████████████████████████████████████████████████████████████████████████████████████▊   | 29/30 [46:51<01:36, 96.68s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [48:26<00:00, 96.90s/it]
  0%|                                                                                             | 0/30 [00:00<?, ?it/s]

Evaluating user split: val with filter {'year': [2013]}


  3%|██▊                                                                                  | 1/30 [01:37<47:03, 97.36s/it]

Evaluating user split: val with filter {'year': [2014]}


  7%|█████▋                                                                               | 2/30 [03:16<46:01, 98.61s/it]

Evaluating user split: val with filter {'year': [2015]}


 10%|████████▍                                                                           | 3/30 [04:59<45:15, 100.56s/it]

Evaluating user split: val with filter {'region': ['Asia']}


 13%|███████████▎                                                                         | 4/30 [06:36<43:00, 99.26s/it]

Evaluating user split: val with filter {'region': ['Europe']}


 17%|██████████████▏                                                                      | 5/30 [08:17<41:32, 99.68s/it]

Evaluating user split: val with filter {'region': ['Africa']}


 20%|█████████████████                                                                    | 6/30 [09:51<39:06, 97.79s/it]

Evaluating user split: val with filter {'region': ['Americas']}


 23%|███████████████████▊                                                                 | 7/30 [11:31<37:42, 98.36s/it]

Evaluating user split: val with filter {'region': ['Oceania']}


 27%|██████████████████████▋                                                              | 8/30 [13:04<35:31, 96.88s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2013}


 30%|█████████████████████████▌                                                           | 9/30 [14:39<33:39, 96.17s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2014}


 33%|████████████████████████████                                                        | 10/30 [16:15<32:00, 96.01s/it]

Evaluating user split: val with filter {'region': 'Europe', 'year': 2015}


 37%|██████████████████████████████▊                                                     | 11/30 [17:51<30:25, 96.06s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2013}


 40%|█████████████████████████████████▌                                                  | 12/30 [19:25<28:37, 95.41s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2014}


 43%|████████████████████████████████████▍                                               | 13/30 [20:59<26:57, 95.13s/it]

Evaluating user split: val with filter {'region': 'Asia', 'year': 2015}


 47%|███████████████████████████████████████▏                                            | 14/30 [22:34<25:22, 95.16s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2013}


 50%|██████████████████████████████████████████                                          | 15/30 [24:08<23:42, 94.84s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2014}


 53%|████████████████████████████████████████████▊                                       | 16/30 [25:43<22:07, 94.81s/it]

Evaluating user split: val with filter {'region': 'Americas', 'year': 2015}


 57%|███████████████████████████████████████████████▌                                    | 17/30 [27:20<20:40, 95.39s/it]

Evaluating user split: test with filter {'year': 2016}


 60%|██████████████████████████████████████████████████▍                                 | 18/30 [29:07<19:47, 98.99s/it]

Evaluating user split: test with filter {'year': 2017}


 63%|█████████████████████████████████████████████████████▏                              | 19/30 [30:47<18:09, 99.06s/it]

Evaluating user split: test with filter {'region': 'Asia'}


 67%|████████████████████████████████████████████████████████                            | 20/30 [32:24<16:26, 98.66s/it]

Evaluating user split: test with filter {'region': 'Europe'}


 70%|██████████████████████████████████████████████████████████▊                         | 21/30 [34:03<14:47, 98.63s/it]

Evaluating user split: test with filter {'region': 'Africa'}


 73%|█████████████████████████████████████████████████████████████▌                      | 22/30 [35:39<13:02, 97.78s/it]

Evaluating user split: test with filter {'region': 'Americas'}


 77%|████████████████████████████████████████████████████████████████▍                   | 23/30 [37:19<11:30, 98.60s/it]

Evaluating user split: test with filter {'region': 'Oceania'}


 80%|███████████████████████████████████████████████████████████████████▏                | 24/30 [38:53<09:42, 97.10s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2016}


 83%|██████████████████████████████████████████████████████████████████████              | 25/30 [40:30<08:06, 97.29s/it]

Evaluating user split: test with filter {'region': 'Europe', 'year': 2017}


 87%|████████████████████████████████████████████████████████████████████████▊           | 26/30 [42:05<06:25, 96.38s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2016}


 90%|███████████████████████████████████████████████████████████████████████████▌        | 27/30 [43:41<04:48, 96.24s/it]

Evaluating user split: test with filter {'region': 'Asia', 'year': 2017}


 93%|██████████████████████████████████████████████████████████████████████████████▍     | 28/30 [45:16<03:11, 95.84s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2016}


 97%|█████████████████████████████████████████████████████████████████████████████████▏  | 29/30 [46:55<01:36, 96.93s/it]

Evaluating user split: test with filter {'region': 'Americas', 'year': 2017}


100%|████████████████████████████████████████████████████████████████████████████████████| 30/30 [48:30<00:00, 97.01s/it]
