In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
import torch
import numpy as np
import pandas as pd

from counterfactuals.datasets import LawDataset, AdultDataset, GermanCreditDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf
from counterfactuals.datasets.utils import (
    dequantize,
)

In [20]:
datasets = {
    "adult": (
        AdultDataset("../data/adult.csv"),
        "adult_disc_model.pt",
        "adult_flow.pth",
    ),
    "law": (
        LawDataset("../data/law.csv"),
        "law_disc_model.pt",
        "law_flow.pth",
    ),
    "german": (
        GermanCreditDataset("../data/german_credit.csv"),
        "german_disc_model.pt",
        "german_flow.pth",
    ),
}

dataset, disc_model_path, gen_model_path = datasets["adult"]

In [21]:
# disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [256, 256], 2)
# disc_model.fit(
#     dataset.train_dataloader(batch_size=128, shuffle=True),
#     dataset.test_dataloader(batch_size=128, shuffle=False),
#     epochs=5000,
#     patience=100,
#     lr=1e-3,
#     checkpoint_path=disc_model_path,
# )
disc_model.load(disc_model_path)

  self.load_state_dict(torch.load(path))


In [22]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.84765625


In [23]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [24]:
dequantizer, _ = dequantize(dataset)

In [25]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path=gen_model_path,
# )
gen_model.load(gen_model_path)

  self.load_state_dict(torch.load(path))


In [26]:
dataset, _, _ = datasets["adult"]

In [27]:
class DequantizerWrapper:
    def __init__(self, dequantizer):
        self.dequantizer = dequantizer

    def __call__(self, x):
        data_copy = x.copy()

        for i, group in enumerate(dataset.categorical_features_lists):
            transformer_name = f"cat_group_{i}"

            group_data = data_copy[:, group]

            transformed_data = self.dequantizer.named_transformers_[
                transformer_name
            ].transform(group_data)

            for j, feature_idx in enumerate(group):
                data_copy[:, feature_idx] = transformed_data[:, j]

        return data_copy

In [28]:
dataset = AdultDataset("../data/adult.csv")

In [29]:
from counterfactuals.datasets.torch_utils import TorchCategoricalTransformer

In [42]:
dividers = [1 for _ in range(len(dataset.numerical_features))] + [
    2 for _ in range(len(dataset.categorical_features))
]

In [53]:
dequantizer_torch = TorchCategoricalTransformer(dividers)

In [14]:
# dequantizer_wrapper = DequantizerWrapper(dequantizer)

In [63]:
pd.DataFrame(dataset.X_test)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28
0,0.273973,1.000000,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
1,0.410959,0.500000,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
2,0.465753,0.397959,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0
3,0.328767,0.346939,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
4,0.232877,0.397959,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
251,0.397260,0.397959,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
252,0.424658,0.397959,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0
253,0.191781,0.397959,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0
254,0.547945,0.295918,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0


In [65]:
pd.DataFrame(dequantizer_torch(torch.tensor(dataset.X_test)).numpy())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28
0,0.273973,1.000000,-0.170297,-0.934777,2.933995,-1.650979,-1.178828,-0.997925,-1.261674,1.993535,-0.235556,-1.585705,-1.628724,-1.306919,-1.329526,0.724896,-0.757682,-0.926177,-0.583801,1.171040,-1.387497,-0.786720,-0.538330,-0.737916,-0.633819,-1.420151,0.978214,-1.325772,1.222409
1,0.410959,0.500000,-0.377397,-0.467473,0.746473,-0.236984,-2.189113,-0.884691,-0.636885,-0.797540,-1.235138,-1.827757,-1.209382,0.567342,-0.525872,1.941185,-1.616271,-1.423881,-0.351181,-2.321211,-1.461822,2.235583,-1.843709,-1.147979,-1.838447,-0.965943,0.214690,-1.151894,1.699486
2,0.465753,0.397959,-0.603044,-0.976545,1.336800,-0.679214,-0.574513,-1.911274,-0.833291,2.319466,-1.701604,-0.939376,-0.637641,-1.224242,-0.530315,1.156975,-0.796727,-2.093104,-0.721547,1.479964,-1.263314,-0.545035,-0.570605,-2.470812,-1.421225,2.302171,-0.317172,-0.278790,0.955953
3,0.328767,0.346939,-0.998668,-1.247520,0.250421,-1.270245,-1.082763,-2.700841,-0.839767,-0.801358,-0.365925,-1.636058,1.087114,-2.504772,-0.564647,1.490057,-0.132079,-0.510187,-1.022088,2.006005,-1.560177,-1.263640,-1.248062,-1.794741,-1.012490,-0.381033,2.396269,-0.587415,0.563770
4,0.232877,0.397959,-1.287165,-2.231602,1.477300,-1.394033,-1.265493,0.805548,-0.807351,-1.035606,-1.060286,-1.544641,-1.030638,-0.932146,-2.007001,-1.980569,-0.459191,1.466244,-0.752555,-0.984483,-1.491629,2.228874,-0.567328,-1.670483,-1.199868,-1.240672,0.976455,0.334471,-0.916839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
251,0.397260,0.397959,-0.756036,-0.620906,2.224864,-1.007287,-2.346214,-1.485558,-0.257130,-2.582365,-1.208548,-1.525567,-1.414763,1.286874,0.737389,-0.623332,-1.563213,-1.078997,-0.815224,-0.788494,-1.694611,-0.815186,1.422865,-0.455965,-0.266515,-0.385061,1.704159,1.833749,-3.746720
252,0.424658,0.397959,1.349905,-1.141819,-0.658627,-0.675131,-1.672927,-1.097559,-1.309924,-1.674365,-2.471585,-1.089029,-1.788346,1.547317,0.732031,-2.279099,-0.463367,-0.699261,-0.580376,-1.012158,-0.948822,-2.855185,-2.677552,-1.016929,2.083496,-1.026978,0.264562,-1.064726,2.402987
253,0.191781,0.397959,-1.195338,-0.136758,0.431342,-0.623203,1.891596,-1.117687,-0.613693,-1.798489,-2.156335,-1.430445,-1.012069,-0.885025,-2.026349,0.856738,-2.385995,-2.587605,-1.620425,-1.002967,-0.632535,-2.075387,-2.376551,-1.349528,0.627372,-0.848665,0.747060,-1.486922,2.259174
254,0.547945,0.295918,-2.307264,-1.696024,1.017839,-0.895737,-1.050861,-2.239344,-1.319688,-0.342777,-2.027091,-1.738401,1.400112,-0.681375,-1.154985,-0.675662,-2.626636,-1.658227,1.273172,-0.481925,-1.179945,-1.781787,-0.449309,1.614726,-1.808484,1.279507,-2.168596,1.262677,-0.769372


In [66]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)

target_class = 0
X_test_origin = dataset.X_test[dataset.y_test != target_class]
y_test_origin = dataset.y_test[dataset.y_test != target_class]

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=1024,
    shuffle=False,
)


log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=10000,
    lr=0.001,
    categorical_intervals=dataset.categorical_features_lists,
    dequantizer=dequantizer_torch,
)
log_prob_threshold

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [62, 27]], which is output 0 of AsStridedBackward0, is at version 29; expected version 28 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [13]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()

for interval in dataset.categorical_features_lists:
    max_indices = np.argmax(X_cf_cat[:, interval], axis=1)
    X_cf_cat[:, interval] = np.eye(X_cf_cat[:, interval].shape[1])[max_indices]

In [13]:
# for categorical_features, transform in zip(
#         dataset.categorical_features_lists, dequantizer.named_transformers_
#     ):

#     X_cf[:, categorical_features] = dequantizer.named_transformers_[
#         transform
#     ].inverse_transform(X_cf[:, list(range(len(categorical_features)))])

In [15]:
# X_cf_deq = inverse_dequantize(dataset, dequantizer, X_cf_cat)
_, X_cf_q = dequantize(dataset, X_cf_cat, dequantizer)

In [16]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf_q,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

2025-04-19 14:10:16,301 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 14:10:16,301 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 14:10:16,302 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 14:10:16,304 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 14:10:16,305 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 14:10:16,306 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 14:10:16,307 - counterfactuals.metrics.distances - INFO - Calculating combined distance
2025-04-19 14:10:16,307 - counterfactuals.metrics.distances - INFO - Calculating continuous distance
2025-04-19 14:10:16,308 - counterfactuals.metrics.distances - INFO - Calculating categorical distance
2025-04-19 14:10:16,308 - counterfactuals.metrics.distances - INFO - Calculating combined dist

{'coverage': 1.0,
 'validity': 0.3090909090909091,
 'actionability': 0.0,
 'sparsity': 1.0,
 'proximity_categorical_hamming': 0.9749532019011216,
 'proximity_categorical_jaccard': 0.9749532019011216,
 'proximity_continuous_manhattan': 0.9884099700085773,
 'proximity_continuous_euclidean': 0.9749532019011216,
 'proximity_continuous_mad': 1.861275817027682,
 'proximity_l2_jaccard': 0.9749532019011216,
 'proximity_mad_hamming': 1.861275817027682,
 'prob_plausibility': 0.18181818181818182,
 'log_density_cf': -45.080257,
 'log_density_test': -82.857475,
 'lof_scores_cf': 1.0623248,
 'lof_scores_test': 1.0580301,
 'isolation_forest_scores_cf': 0.06439388425658807,
 'isolation_forest_scores_test': 0.07902376467016735}