In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import numpy as np

import torch

from counterfactuals.datasets import MoonsDataset, WineDataset, DigitsDataset, BlobsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.discriminative_models import LogisticRegression, MultinomialLogisticRegression
from counterfactuals.generative_models import MaskedAutoregressiveFlow, KDE
from counterfactuals.losses import BinaryDiscLoss, MulticlassDiscLoss
from counterfactuals.metrics.metrics import evaluate_cf

In [3]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=16, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=16, shuffle=False)

In [4]:
disc_model = MultinomialLogisticRegression(dataset.X_test.shape[1], len(np.unique(dataset.y_train)))
disc_model.fit(train_dataloader, test_dataloader)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 15, Train Loss: 0.4201, Test Loss: 0.4183:   8%|▊         | 16/200 [00:01<00:22,  8.04it/s]


In [5]:
gen_model = KDE()
gen_model.fit(train_dataloader, test_dataloader)

Train log-likelihood: 0.9672112464904785
Test log-likelihood: 0.9333597421646118


In [6]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)

In [7]:
median_log_prob = torch.median(gen_model.predict_log_prob(test_dataloader))

In [8]:
X_cf, X_orig, y_orig, y_target, _ = cf.search_batch(test_dataloader, alpha=100, delta=median_log_prob)

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:22<00:00,  1.74s/it]


In [9]:
X_cf

array([[ 0.5901091 ,  0.5953224 ],
       [ 0.6051247 ,  0.490601  ],
       [ 0.4248429 ,  0.39826542],
       [ 0.57481855,  0.6088258 ],
       [ 0.3980022 ,  0.4366023 ],
       [ 0.5893086 ,  0.7378911 ],
       [ 0.3871987 ,  0.392809  ],
       [ 0.52437687,  0.67729896],
       [ 0.58507466,  0.7551102 ],
       [ 0.50321954,  0.30564415],
       [ 0.46704024,  0.3204397 ],
       [ 0.47479036,  0.325736  ],
       [ 0.58810055,  0.7447532 ],
       [ 0.78114533,  0.2955253 ],
       [ 0.61352205,  0.58307016],
       [ 0.46969232,  0.32222185],
       [ 0.1233225 ,  0.5491207 ],
       [ 0.24086773,  0.7040381 ],
       [ 0.69494   ,  0.23414077],
       [ 0.87063706,  0.58160764],
       [ 0.6144753 ,  0.13273111],
       [ 0.77753687,  0.35819444],
       [ 0.4802255 ,  0.37732857],
       [ 0.24351425,  0.63735414],
       [ 0.4267679 ,  0.546086  ],
       [ 0.493058  ,  0.21399653],
       [ 0.42593867,  0.5730516 ],
       [ 0.14733194,  0.4018188 ],
       [ 0.65417504,