## Evaluating the suitability filter on FMoW

In [16]:
import importlib
import random

import numpy as np
import pandas as pd
import torch

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter import suitability

importlib.reload(suitability)

from suitability.filter.suitability import SuitabilityFilter

random.seed(32)
np.random.seed(32)

### Define & evaluate all possible splits

In [None]:
id_val_splits = [
    ("id_val", {"year": 2002}),
    ("id_val", {"year": 2003}),
    ("id_val", {"year": 2004}),
    ("id_val", {"year": 2005}),
    ("id_val", {"year": 2006}),
    ("id_val", {"year": 2007}),
    ("id_val", {"year": 2008}),
    ("id_val", {"year": 2009}),
    ("id_val", {"year": 2010}),
    ("id_val", {"year": 2011}),
    ("id_val", {"year": 2012}),
    ("id_val", {"region": "Asia"}),
    ("id_val", {"region": "Europe"}),
    ("id_val", {"region": "Africa"}),
    ("id_val", {"region": "Americas"}),
    ("id_val", {"region": "Oceania"}),
]

id_test_splits = [
    ("id_test", {"year": 2002}),
    ("id_test", {"year": 2003}),
    ("id_test", {"year": 2004}),
    ("id_test", {"year": 2005}),
    ("id_test", {"year": 2006}),
    ("id_test", {"year": 2007}),
    ("id_test", {"year": 2008}),
    ("id_test", {"year": 2009}),
    ("id_test", {"year": 2010}),
    ("id_test", {"year": 2011}),
    ("id_test", {"year": 2012}),
    ("id_test", {"region": "Asia"}),
    ("id_test", {"region": "Europe"}),
    ("id_test", {"region": "Africa"}),
    ("id_test", {"region": "Americas"}),
    ("id_test", {"region": "Oceania"}),
]

ood_val_splits = [
    ("val", {"year": 2013}),
    ("val", {"year": 2014}),
    ("val", {"year": 2015}),
    ("val", {"region": "Asia"}),
    ("val", {"region": "Europe"}),
    ("val", {"region": "Africa"}),
    ("val", {"region": "Americas"}),
    ("val", {"region": "Oceania"}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
]

ood_test_splits = [
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

In [None]:
data_name = "fmow"
root_dir = "/mfsnic/u/apouget/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_wilds_model(data_name, root_dir, algorithm="ERM")
model = model.to(device)
model.eval()

all_splits = id_val_splits + id_test_splits + ood_val_splits + ood_test_splits

results = pd.DataFrame(columns=["split", "year", "region", "num_samples", "accuracy"])

for split, pre_filter in all_splits:
    data = get_wilds_dataset(
        data_name,
        root_dir,
        split,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pre_filter=pre_filter,
    )
    suitability_filter = SuitabilityFilter(model, data, data, device)
    corr = suitability_filter.get_correct(data)

    num_samples = len(data.dataset)
    accuracy = np.mean(corr)
    year = pre_filter.get("year", "ALL")
    region = pre_filter.get("region", "ALL")

    results = results._append(
        {
            "split": split,
            "year": year,
            "region": region,
            "num_samples": num_samples,
            "accuracy": accuracy,
        },
        ignore_index=True,
    )

results.to_csv("suitability/results/data_splits/fmow_ERM_0_last.csv", index=False)

### Evaluate suitability filter

In [None]:
valid_id_splits = [
    ("id_val", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_val", {"year": [2007, 2008, 2009]}),
    ("id_val", {"year": [2010]}),
    ("id_val", {"year": [2011]}),
    ("id_val", {"year": [2012]}),
    ("id_val", {"region": ["Asia"]}),
    ("id_val", {"region": ["Europe"]}),
    ("id_val", {"region": ["Americas"]}),
    ("id_test", {"year": [2002, 2003, 2004, 2005, 2006]}),
    ("id_test", {"year": [2007, 2008, 2009]}),
    ("id_test", {"year": [2010]}),
    ("id_test", {"year": [2011]}),
    ("id_test", {"year": [2012]}),
    ("id_test", {"region": ["Asia"]}),
    ("id_test", {"region": ["Europe"]}),
    ("id_test", {"region": ["Americas"]}),
]

valid_ood_splits = [
    ("val", {"year": [2013]}),
    ("val", {"year": [2014]}),
    ("val", {"year": [2015]}),
    ("val", {"region": ["Asia"]}),
    ("val", {"region": ["Europe"]}),
    ("val", {"region": ["Africa"]}),
    ("val", {"region": ["Americas"]}),
    ("val", {"region": ["Oceania"]}),
    ("val", {"region": "Europe", "year": 2013}),
    ("val", {"region": "Europe", "year": 2014}),
    ("val", {"region": "Europe", "year": 2015}),
    ("val", {"region": "Asia", "year": 2013}),
    ("val", {"region": "Asia", "year": 2014}),
    ("val", {"region": "Asia", "year": 2015}),
    ("val", {"region": "Americas", "year": 2013}),
    ("val", {"region": "Americas", "year": 2014}),
    ("val", {"region": "Americas", "year": 2015}),
    ("test", {"year": 2016}),
    ("test", {"year": 2017}),
    ("test", {"region": "Asia"}),
    ("test", {"region": "Europe"}),
    ("test", {"region": "Africa"}),
    ("test", {"region": "Americas"}),
    ("test", {"region": "Oceania"}),
    ("test", {"region": "Europe", "year": 2016}),
    ("test", {"region": "Europe", "year": 2017}),
    ("test", {"region": "Asia", "year": 2016}),
    ("test", {"region": "Asia", "year": 2017}),
    ("test", {"region": "Americas", "year": 2016}),
    ("test", {"region": "Americas", "year": 2017}),
]

print(
    f"Number of valid id splits: {len(valid_id_splits)}, number of valid ood splits: {len(valid_ood_splits)}"
)

Number of valid id splits: 16, number of valid ood splits: 30


In [14]:
data_name = "fmow"
root_dir = "/mfsnic/u/apouget/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

algorithm = "ERM"
model_type = "last"
seed = 0
model = get_wilds_model(data_name, root_dir, algorithm=algorithm, seed=seed, model_type=model_type)
model = model.to(device)
model.eval()

normalize = True
calibrated = True
margins = [0, 0.001, 0.005, 0.01, 0.05, 0.1]

sf_evals = pd.DataFrame(columns=["data_name", "algorithm", "seed", "model_type", "normalize", "calibrated", "margin", "reg_split", "reg_filter", "reg_size", "reg_acc", "test_split", "test_filter", "test_size", "test_acc", "user_split", "user_filter", "user_size", "user_acc", "p_value", "ground_truth"])

for i, reg in enumerate(valid_id_splits[:2]):
    reg_data = get_wilds_dataset(
        data_name,
        root_dir,
        reg[0],
        batch_size=64,
        shuffle=True,
        num_workers=4,
        pre_filter=reg[1],
    )
    reg_size = len(reg_data.dataset)

    for j, test in enumerate(valid_id_splits[:2]):
        if i == j:
            continue

        test_data = get_wilds_dataset(
            data_name,
            root_dir,
            test[0],
            batch_size=64,
            shuffle=False,
            num_workers=4,
            pre_filter=test[1],
        )
        test_size = len(test_data.dataset)

        suitability_filter = SuitabilityFilter(
            model, reg_data, test_data, device, normalize=normalize
        )
        suitability_filter.train_regressor(calibrated=calibrated)
        reg_acc = np.mean(suitability_filter.regressor_correct)

        for user in valid_ood_splits[:2]:
            user_data = get_wilds_dataset(
                data_name,
                root_dir,
                user[0],
                batch_size=64,
                shuffle=False,
                num_workers=4,
                pre_filter=user[1],
            )
            user_size = len(user_data.dataset)
            user_features, user_corr = suitability_filter.get_features(user_data)

            user_acc = np.mean(user_corr)

            for margin in margins:
                sf_test = suitability_filter.suitability_test(
                    user_features=user_features, margin=margin
                )
                test_acc = np.mean(suitability_filter.test_correct)
                p_value = sf_test["p_value"]
                ground_truth = user_acc >= test_acc - margin

                sf_evals = sf_evals._append(
                    {
                        "data_name": data_name,
                        "algorithm": algorithm,
                        "seed": seed,
                        "model_type": model_type,
                        "normalize": normalize,
                        "calibrated": calibrated,
                        "margin": margin,
                        "reg_split": reg[0],
                        "reg_filter": reg[1],
                        "reg_size": reg_size,
                        "reg_acc": reg_acc,
                        "test_split": test[0],
                        "test_filter": test[1],
                        "test_size": test_size,
                        "test_acc": test_acc,
                        "user_split": user[0],
                        "user_filter": user[1],
                        "user_size": user_size,
                        "user_acc": user_acc,
                        "p_value": p_value,
                        "ground_truth": ground_truth,
                    },
                    ignore_index=True,
                )

sf_evals.to_csv("suitability/results/sf_evals/fmow_sf_results_1.csv", index=False)

  sf_evals = sf_evals._append(
