In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

from counterfactuals.datasets import LawDataset, AdultDataset, GermanCreditDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
datasets = {
    "adult": (
        AdultDataset("../data/adult.csv"),
        "adult_disc_model.pt",
        "adult_flow.pth",
    ),
    "law": (
        LawDataset("../data/law.csv"),
        "law_disc_model.pt",
        "law_flow.pth",
        "law_flow_eval.pth",
    ),
    "german": (
        GermanCreditDataset("../data/german_credit.csv"),
        "german_disc_model.pt",
        "german_flow.pth",
    ),
}

dataset, disc_model_path, gen_model_path, gen_model_eval_path = datasets["law"]

In [31]:
from counterfactuals.datasets.utils import CustomCategoricalTransformer
from sklearn.compose import ColumnTransformer


def dequantize(dataset, data=None, transformer=None):
    """
    Apply dequantization, only affecting categorical features

    Parameters:
    -----------
    dataset : Dataset object
        Dataset containing categorical_features_lists
    data : np.ndarray, optional
        Optional external data to transform instead of dataset.X_train/X_test
    transformer : ColumnTransformer, optional
        Pre-fitted transformer to use for transformation. If None, create and fit a new one.

    Returns:
    --------
    tuple or np.ndarray
        If data is None: returns (transformer, None)
        If data is provided: returns (transformer, transformed_data)
    """
    # If no transformer is provided, create a new one
    if transformer is None:
        transformers = [
            (f"cat_group_{i}", CustomCategoricalTransformer(), group)
            for i, group in enumerate(dataset.categorical_features_lists)
        ]

        transformer = ColumnTransformer(
            transformers=transformers,
            remainder="drop",  # Drop continuous features
        )

        # Fit the transformer if it's newly created
        if data is None:
            # Handle dataset data (original functionality)
            X_train_original = dataset.X_train.copy()
            X_test_original = dataset.X_test.copy()

            # Fit on training data
            cat_transformed_train = transformer.fit_transform(dataset.X_train)
            cat_transformed_test = transformer.transform(dataset.X_test)

            dataset.X_train = X_train_original.copy()
            dataset.X_test = X_test_original.copy()

            # Replace categorical features with transformed values
            cat_idx = 0
            for group in dataset.categorical_features_lists:
                for i, feature_idx in enumerate(group):
                    dataset.X_train[:, feature_idx] = cat_transformed_train[:, cat_idx]
                    dataset.X_test[:, feature_idx] = cat_transformed_test[:, cat_idx]
                    cat_idx += 1

            return transformer, None
        else:
            # Fit on dataset but transform the external data
            transformer.fit(dataset.X_train)

    # If data is provided, transform it using the transformer (either provided or newly created)
    if data is not None:
        data_copy = data.copy()

        # Process each categorical feature group
        for i, group in enumerate(dataset.categorical_features_lists):
            transformer_name = f"cat_group_{i}"

            # Extract just the columns for this group
            group_data = data_copy[:, group]

            # Get transformed data for this group
            transformed_data = transformer.named_transformers_[
                transformer_name
            ].transform(group_data)

            # Replace only categorical columns with their transformed values
            for j, feature_idx in enumerate(group):
                data_copy[:, feature_idx] = transformed_data[:, j]

        return transformer, data_copy

    # If we reach here, it means transformer was provided but no data to transform
    return transformer, None


def inverse_dequantize(dataset, dequantizer, data=None):
    """
    Inverse the dequantization process, only affecting categorical features

    Parameters:
    -----------
    dataset : Dataset object
        Dataset containing categorical_features_lists
    dequantizer : ColumnTransformer
        The fitted dequantizer returned by dequantize()
    data : np.ndarray, optional
        Optional external data to transform instead of dataset.X_train/X_test

    Returns:
    --------
    np.ndarray or None
        Returns transformed data if data parameter provided, otherwise None
    """
    if data is not None:
        # Handle external data
        data_copy = data.copy()

        for i, group in enumerate(dataset.categorical_features_lists):
            transformer_name = f"cat_group_{i}"
            group_data = data_copy[:, group]

            transformed_data = dequantizer.named_transformers_[
                transformer_name
            ].inverse_transform(group_data)

            # Replace only categorical columns with their transformed values
            for j, feature_idx in enumerate(group):
                data_copy[:, feature_idx] = transformed_data[:, j]

        return data_copy
    else:
        # Handle dataset data
        X_train_copy = dataset.X_train.copy()
        X_test_copy = dataset.X_test.copy()

        for i, group in enumerate(dataset.categorical_features_lists):
            transformer_name = f"cat_group_{i}"

            # Process X_train
            group_train = X_train_copy[:, group]
            transformed_train = dequantizer.named_transformers_[
                transformer_name
            ].inverse_transform(group_train)
            for j, feature_idx in enumerate(group):
                X_train_copy[:, feature_idx] = transformed_train[:, j]

            # Process X_test
            group_test = X_test_copy[:, group]
            transformed_test = dequantizer.named_transformers_[
                transformer_name
            ].inverse_transform(group_test)
            for j, feature_idx in enumerate(group):
                X_test_copy[:, feature_idx] = transformed_test[:, j]

        dataset.X_train = X_train_copy
        dataset.X_test = X_test_copy
        return None

In [5]:
# disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [256, 256], 2)
# disc_model.fit(
#     dataset.train_dataloader(batch_size=128, shuffle=True),
#     dataset.test_dataloader(batch_size=128, shuffle=False),
#     epochs=5000,
#     patience=100,
#     lr=1e-3,
#     checkpoint_path=disc_model_path,
# )
disc_model.load(disc_model_path)
# disc_model.load("german_disc_model_onehot.pt")

  self.load_state_dict(torch.load(path))


In [6]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.7454954954954955


In [7]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [8]:
gen_model_eval = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model_eval.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path=gen_model_eval_path,
# )
gen_model_eval.load(gen_model_eval_path)

  self.load_state_dict(torch.load(path))


In [9]:
dequantizer = dequantize(dataset)

In [13]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path=gen_model_path,
# )
gen_model.load(gen_model_path)

Epoch 499, Train: 2.7664, test: 2.3346, patience: 39: 100%|██████████| 500/500 [01:19<00:00,  6.28it/s]  
  self.load_state_dict(torch.load(path))


In [14]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)

target_class = 0
X_test_origin = dataset.X_test[dataset.y_test != target_class]
y_test_origin = dataset.y_test[dataset.y_test != target_class]

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=1024,
    shuffle=False,
)


log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=4000,
    lr=0.001,
    categorical_intervals=dataset.categorical_features_lists,
)
log_prob_threshold

Discriminator loss: 0.0000, Prob loss: 0.0144: 100%|██████████| 4000/4000 [00:33<00:00, 120.08it/s]


tensor(-5.7548)

In [46]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()

for interval in dataset.categorical_features_lists:
    max_indices = np.argmax(X_cf_cat[:, interval], axis=1)
    X_cf_cat[:, interval] = np.eye(X_cf_cat[:, interval].shape[1])[max_indices]

In [155]:
# for categorical_features, transform in zip(
#         dataset.categorical_features_lists, dequantizer.named_transformers_
#     ):

#     X_cf[:, categorical_features] = dequantizer.named_transformers_[
#         transform
#     ].inverse_transform(X_cf[:, list(range(len(categorical_features)))])

In [47]:
# X_cf_deq = inverse_dequantize(dataset, dequantizer, X_cf_cat)
_, X_cf_q = dequantize(dataset, X_cf_cat, dequantizer)

In [48]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf_q,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

2025-04-19 11:51:24,347 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 11:51:24,348 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 11:51:24,348 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 11:51:24,349 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 11:51:24,349 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 11:51:24,349 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 11:51:24,350 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 11:51:24,350 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 11:51:24,351 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 11:51:24,352 - counterfactuals.metrics.distances - INFO - Calculating combined dist

{'coverage': 1.0,
 'validity': 0.9151785714285714,
 'actionability': 0.0,
 'sparsity': 0.9962225274725275,
 'proximity_categorical_hamming': 0.8312443619784518,
 'proximity_categorical_jaccard': 0.8312443619784518,
 'proximity_continuous_manhattan': 0.8737806974220007,
 'proximity_continuous_euclidean': 0.8312443619784518,
 'proximity_continuous_mad': 1.6685796865189577,
 'proximity_l2_jaccard': 0.8312443619784518,
 'proximity_mad_hamming': 1.6685796865189577,
 'prob_plausibility': 0.7901785714285714,
 'log_density_cf': -3.6842823,
 'log_density_test': -10.425357,
 'lof_scores_cf': 1.0816156,
 'lof_scores_test': 1.0943108,
 'isolation_forest_scores_cf': 0.07907068300191541,
 'isolation_forest_scores_test': 0.07398354269553695}