In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

from cel.cf_methods import Artelt
from cel.datasets.file_dataset import FileDataset
from cel.models import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Prepare the data

dataset = FileDataset(config_path="../config/datasets/moons.yaml")
# dataset = AdultDataset()

# Get the split data that's already available
X_train = dataset.X_train
X_test = dataset.X_test
y_train = dataset.y_train
y_test = dataset.y_test

train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)
)

train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [4]:
# Train a discriminative model
num_inputs = X_train.shape[1]
num_targets = 1

discrimaiative_model = LogisticRegression(
    num_inputs=num_inputs,
    num_targets=num_targets,
)
discrimaiative_model.fit(train_dataloader, test_dataloader, epochs=10000, patience=600, lr=0.01)

Epoch 1805, Train: 0.2590, test: 0.2770, patience: 600:  18%|█▊        | 1805/10000 [00:04<00:20, 390.42it/s]


In [5]:
cf_method = Artelt(
    disc_model=discrimaiative_model,
)
cf_method.fit_density_estimators(X_train, y_train)
cf_method.explain_dataloader(test_dataloader, alpha=0.1)

bandwidth: 0.1
n_components: 6
density_threshold: 0.6217726970580395
Plausible counterfactual generator for label 0.0 fitted
bandwidth: 0.1
n_components: 5
density_threshold: 0.5112252967515254
Plausible counterfactual generator for label 1.0 fitted


100%|██████████| 200/200 [00:03<00:00, 66.52it/s]


ExplanationResult(x_cfs=array([[ 7.69234726e-01,  3.15612214e-01],
       [-8.06974719e-01, -4.99947600e-02],
       [ 1.55494233e+00,  4.97665921e-01],
       [-8.68727508e-01, -6.43031481e-02],
       [ 1.14251579e+00,  4.02101883e-01],
       [ 2.53547559e-01,  1.96128255e-01],
       [ 1.43354457e+00,  4.69537555e-01],
       [-5.88659039e-01,  5.91309791e-04],
       [ 5.71199755e-01,  2.69333122e-01],
       [-7.43363256e-01, -3.52560625e-02],
       [ 6.10965686e-01,  2.78546899e-01],
       [ 8.63318185e-01,  3.37022588e-01],
       [ 6.18946938e-01,  2.80792444e-01],
       [ 1.75029176e+00,  5.42927938e-01],
       [ 1.47468523e+00,  4.79070000e-01],
       [ 3.71849640e-01,  2.23538995e-01],
       [ 1.16688590e+00,  4.07748126e-01],
       [-6.42329930e-01, -1.18462317e-02],
       [ 2.39997996e-01,  1.92593606e-01],
       [ 8.60791480e-01,  3.36827988e-01],
       [ 1.67930552e+00,  5.26481801e-01],
       [ 4.83437753e-02,  1.48184900e-01],
       [ 1.00731111e+00,  3.70