In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

from counterfactuals.datasets import LawDataset, AdultDataset, GermanCreditDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf
from counterfactuals.datasets.utils import (
    dequantize,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
datasets = {
    "adult": (
        AdultDataset("../data/adult.csv"),
        "adult_disc_model.pt",
        "adult_flow.pth",
    ),
    "law": (
        LawDataset("../data/law.csv"),
        "law_disc_model.pt",
        "law_flow.pth",
    ),
    "german": (
        GermanCreditDataset("../data/german_credit.csv"),
        "german_disc_model.pt",
        "german_flow.pth",
    ),
}

dataset, disc_model_path, gen_model_path = datasets["adult"]

In [4]:
# disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [256, 256], 2)
# disc_model.fit(
#     dataset.train_dataloader(batch_size=128, shuffle=True),
#     dataset.test_dataloader(batch_size=128, shuffle=False),
#     epochs=5000,
#     patience=100,
#     lr=1e-3,
#     checkpoint_path=disc_model_path,
# )
disc_model.load(disc_model_path)

  self.load_state_dict(torch.load(path))


In [5]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.8209734377399048


In [6]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [7]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path=gen_model_path,
# )
gen_model.load(gen_model_path)

  self.load_state_dict(torch.load(path))


In [8]:
dequantizer, _ = dequantize(dataset)

In [9]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)

target_class = 0
X_test_origin = dataset.X_test[dataset.y_test != target_class]
y_test_origin = dataset.y_test[dataset.y_test != target_class]

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=1024,
    shuffle=False,
)


log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
dataset = AdultDataset("../data/adult.csv")
X_test_origin = dataset.X_test[dataset.y_test != target_class]
y_test_origin = dataset.y_test[dataset.y_test != target_class]

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=1024,
    shuffle=False,
)


deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=10000,
    lr=0.001,
    categorical_intervals=dataset.categorical_features_lists,
)
log_prob_threshold

Discriminator loss: 0.0509, Prob loss: 690070.7500: 100%|██████████| 10000/10000 [02:12<00:00, 75.37it/s]    


tensor(-15.4933)

In [10]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()

for interval in dataset.categorical_features_lists:
    max_indices = np.argmax(X_cf_cat[:, interval], axis=1)
    X_cf_cat[:, interval] = np.eye(X_cf_cat[:, interval].shape[1])[max_indices]

In [11]:
# for categorical_features, transform in zip(
#         dataset.categorical_features_lists, dequantizer.named_transformers_
#     ):

#     X_cf[:, categorical_features] = dequantizer.named_transformers_[
#         transform
#     ].inverse_transform(X_cf[:, list(range(len(categorical_features)))])

In [12]:
# X_cf_deq = inverse_dequantize(dataset, dequantizer, X_cf_cat)
dequantizer, _ = dequantize(dataset)
_, X_cf_q = dequantize(dataset, X_cf_cat, dequantizer)
dataset = AdultDataset("../data/adult.csv")

In [13]:
import torch.nn as nn


class DequantizingFlow(nn.Module):
    def __init__(self, gen_model, dequantizer, dataset):
        super().__init__()
        self.gen_model = gen_model
        self.dequantizer = dequantizer
        self.dequantize = dequantize
        self.dataset = dataset

    def forward(self, X, y):
        if isinstance(X, torch.Tensor):
            X = X.numpy()
        _, X = self.dequantize(self.dataset, X, self.dequantizer)
        X = torch.from_numpy(X)
        log_probs = self.gen_model(X, y)
        return log_probs


dequantizing_flow = DequantizingFlow(gen_model, dequantizer, dataset)

In [14]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=dequantizing_flow,
    X_cf=X_cf_cat,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

2025-04-21 15:23:32,805 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-21 15:23:32,806 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-21 15:23:32,806 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-21 15:23:32,806 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-21 15:23:32,807 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-21 15:23:32,807 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-21 15:23:32,807 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-21 15:23:32,808 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-21 15:23:32,808 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-21 15:23:32,808 - counterfactuals.metrics.distances - INFO - Calculating combined dist

{'coverage': 1.0,
 'validity': 0.4032258064516129,
 'actionability': 0.0,
 'sparsity': 0.07675194660734148,
 'proximity_categorical_hamming': 0.031169021221318545,
 'proximity_categorical_jaccard': 0.06003601629521017,
 'proximity_continuous_manhattan': 0.06750814141311082,
 'proximity_continuous_euclidean': 0.06003601629521017,
 'proximity_continuous_mad': 0.6443739354551719,
 'proximity_l2_jaccard': 0.06003601629521017,
 'proximity_mad_hamming': 0.6155069403812803,
 'prob_plausibility': 0.3870967741935484,
 'log_density_cf': -116.41296,
 'log_density_test': -102.39931,
 'lof_scores_cf': 1.2957588,
 'lof_scores_test': 1.196811,
 'isolation_forest_scores_cf': 0.025560530669840622,
 'isolation_forest_scores_test': 0.04085840215068121}