In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

from cel.cf_methods import WACH_OURS
from cel.datasets.file_dataset import FileDataset
from cel.losses import BinaryDiscLoss
from cel.models import MLPClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Prepare the data

dataset = FileDataset(config_path="../config/datasets/moons.yaml")
# dataset = AdultDataset()

# Get the split data that's already available
X_train = dataset.X_train
X_test = dataset.X_test
y_train = dataset.y_train
y_test = dataset.y_test

train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)
)

train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [4]:
# Train a discriminative model
num_inputs = X_train.shape[1]
num_targets = 1

discrimaiative_model = MLPClassifier(
    num_inputs=num_inputs,
    num_targets=num_targets,
    hidden_layer_sizes=[64, 32],  # Added required parameter
)
discrimaiative_model.fit(train_dataloader, test_dataloader, epochs=10000, patience=600, lr=0.01)

Epoch 1359, Train: 0.1870, test: 0.2177, patience: 600:  14%|█▎        | 1360/10000 [00:05<00:34, 249.65it/s]


In [5]:
cf_method = WACH_OURS(
    disc_model=discrimaiative_model,
    disc_model_criterion=BinaryDiscLoss(),
)
cf_method.explain_dataloader(test_dataloader, alpha=0.1)

Discriminator loss: 3.2510: 100%|██████████| 1000/1000 [00:00<00:00, 1149.73it/s]


ExplanationResult(x_cfs=array([[ 0.76920015, -0.4074115 ],
       [-0.80701715,  0.21206526],
       [ 1.5549009 , -0.28536332],
       [-0.8685572 ,  0.02867057],
       [ 1.1424952 , -0.61480695],
       [ 0.25353828, -0.34041315],
       [ 1.4335032 , -0.31750926],
       [-0.5886516 ,  0.89827585],
       [ 0.5712299 ,  0.79136056],
       [-0.7432979 ,  0.7697115 ],
       [ 0.6109952 ,  0.8194406 ],
       [ 0.8632514 ,  0.8046588 ],
       [ 0.61894965, -0.54713273],
       [ 1.7503433 , -0.12857457],
       [ 1.4746438 , -0.3150234 ],
       [ 0.3718798 , -0.2970614 ],
       [ 1.166896  , -0.30731162],
       [-0.6422947 ,  0.7452182 ],
       [ 0.23998053,  0.83485323],
       [ 0.86079216,  0.22050826],
       [ 1.6791238 , -0.16613352],
       [ 0.04834374,  0.48454222],
       [ 1.0073118 ,  0.29731613],
       [ 1.7472593 ,  0.07271036],
       [ 1.0848262 , -0.46076483],
       [ 1.7644666 , -0.1206787 ],
       [-0.05640486,  0.99886346],
       [ 1.6642902 , -0.2636177