In [1]:
import importlib
import random

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import auc, precision_recall_curve, roc_curve
from torch.utils.data import DataLoader, random_split

from suitability.datasets.wilds import get_wilds_dataset, get_wilds_model
from suitability.filter import suitability

importlib.reload(suitability)

from suitability.filter.evals import split_dataset_into_folds
from suitability.filter.suitability import SuitabilityFilter

random.seed(32)
np.random.seed(32)

In [None]:
root_dir = "/mfsnic/u/apouget/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datasets_to_test = ["fmow", "civilcomments", "rxrx1", "amazon"] # "iwildcam"
splits_to_test = ["id_", ""]
step_size = 0.01
num_random_tries = 100

df_sf = pd.DataFrame(columns=["Dataset", "Split", "ROC AUC", "PR AUC"])
df_feat = pd.DataFrame(
    columns=["Dataset", "Split", "Feature", "Type", "ROC AUC", "PR AUC"]
)


def compute_roc_pr_auc(p_vals, ground_truth):
    roc_auc = auc(*roc_curve(ground_truth.flatten(), p_vals.flatten())[:2])
    precision, recall, _ = precision_recall_curve(
        ground_truth.flatten(), p_vals.flatten()
    )
    pr_auc = auc(recall, precision)
    return roc_auc, pr_auc


for data_name in datasets_to_test:
    for split in splits_to_test:
        if data_name == "civilcomments" and split == "id_":
            continue
        print(f"Testing {data_name} for split {split}")

        # Get data
        dataset_val = get_wilds_dataset(
            data_name,
            root_dir,
            split + "val",
            batch_size=64,
            shuffle=False,
            num_workers=4,
        ).dataset
        test, regressor = random_split(dataset_val, [0.5, 0.5])
        test_data = DataLoader(test, batch_size=64, shuffle=False, num_workers=4)
        regressor_data = DataLoader(
            regressor, batch_size=64, shuffle=True, num_workers=4
        )
        user_data = get_wilds_dataset(
            data_name,
            root_dir,
            split + "test",
            batch_size=64,
            shuffle=True,
            num_workers=4,
        )

        print(
            f"Test size: {len(test_data.dataset)}, Regressor size: {len(regressor_data.dataset)}, User data size: {len(user_data.dataset)}"
        )

        # Get model
        model = get_wilds_model(data_name, root_dir, algorithm="ERM")
        model = model.to(device)
        model.eval()

        # Construct suitability filter
        suitability_filter = SuitabilityFilter(model, test_data, regressor_data, device)
        test_features, test_corr = suitability_filter.get_features(test_data)
        suitability_filter.train_regressor()
        all_user_features, all_user_corr = suitability_filter.get_features(user_data)

        # Evaluate suitability
        test_acc = np.mean(test_corr)
        target_accuracies = np.arange(
            test_acc - step_size / 2 - 4 * step_size,
            test_acc + step_size / 2 + 5 * step_size,
            step_size,
        )
        num_acc_folds = len(target_accuracies)

        corrs = np.zeros((num_random_tries, num_acc_folds))
        p_vals_sf = np.zeros((num_random_tries, num_acc_folds))
        p_vals_feat = np.zeros(
            (np.shape(all_user_features)[1], num_random_tries, num_acc_folds)
        )
        p_vals_feat_reg = np.zeros(
            (np.shape(all_user_features)[1], num_random_tries, num_acc_folds)
        )

        for j in range(num_random_tries):
            folds, actual_accuracies = split_dataset_into_folds(
                all_user_corr, target_accuracies=target_accuracies
            )
            corrs[j] = np.array(actual_accuracies)

            for i, fold_indices in enumerate(folds):
                user_features = all_user_features[fold_indices]
                p_vals_sf[j, i] = suitability_filter.suitability_test(
                    user_features=user_features, margin=0
                )["p_value"]
                feat_test = suitability_filter.suitability_test_for_individual_features(
                    user_features=user_features
                )
                for fi, test in enumerate(feat_test):
                    p_vals_feat[fi, j, i] = test["p_value"]
                    p_vals_feat_reg[fi, j, i] = (
                        suitability_filter.suitability_test_for_feature_subset(
                            feature_subset=[fi], user_features=user_features
                        )["p_value"]
                    )

        # Compute AUCs
        ground_truth = corrs >= np.mean(test_corr)

        # Compute and store suitability filter AUC values
        roc_auc_sf, pr_auc_sf = compute_roc_pr_auc(-p_vals_sf, ground_truth)
        df_sf = pd.concat(
            [
                df_sf,
                pd.DataFrame(
                    [[data_name, split, roc_auc_sf, pr_auc_sf]],
                    columns=["Dataset", "Split", "ROC AUC", "PR AUC"],
                ),
            ],
            ignore_index=True,
        )

        features = np.array(
            [
                "Conf Max",
                "Conf Std",
                "Conf Entropy",
                "Logit Mean",
                "Logit Max",
                "Logit Std",
                "Logit Top 2 Diff",
                "Loss",
                "Margin Loss",
                "Class Conf Ratio",
                "Top-k Conf Sum",
                "Energy",
            ]
        )

        # Compute AUCs for each feature
        for fi, feature in enumerate(features):
            roc_auc_feat, pr_auc_feat = compute_roc_pr_auc(
                (
                    -p_vals_feat[fi, :, :]
                    if fi in [0, 1, 4, 6, 9, 10]
                    else p_vals_feat[fi, :, :]
                ),
                ground_truth,
            )
            df_feat = pd.concat(
                [
                    df_feat,
                    pd.DataFrame(
                        [
                            [
                                data_name,
                                split,
                                feature,
                                "Feature",
                                roc_auc_feat,
                                pr_auc_feat,
                            ]
                        ],
                        columns=[
                            "Dataset",
                            "Split",
                            "Feature",
                            "Type",
                            "ROC AUC",
                            "PR AUC",
                        ],
                    ),
                ],
                ignore_index=True,
            )

            roc_auc_feat_reg, pr_auc_feat_reg = compute_roc_pr_auc(
                -p_vals_feat_reg[fi, :, :], ground_truth
            )
            df_feat = pd.concat(
                [
                    df_feat,
                    pd.DataFrame(
                        [
                            [
                                data_name,
                                split,
                                feature,
                                "Feature (Reg)",
                                roc_auc_feat_reg,
                                pr_auc_feat_reg,
                            ]
                        ],
                        columns=[
                            "Dataset",
                            "Split",
                            "Feature",
                            "Type",
                            "ROC AUC",
                            "PR AUC",
                        ],
                    ),
                ],
                ignore_index=True,
            )

# Save DataFrames to CSV (optional)
df_sf.to_csv("suitability/results/wilds_suitability_filter_auc.csv", index=False)
df_feat.to_csv("suitability/results/wilds_feature_auc.csv", index=False)

print(df_sf)
print(df_feat)

Testing fmow for split id_
Test size: 5742, Regressor size: 5741, User data size: 11327


  df_sf = pd.concat(
  df_feat = pd.concat(


Testing fmow for split 
Test size: 9958, Regressor size: 9957, User data size: 22108


