In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.generative_models import MaskedAutoregressiveFlow, KDE
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics.metrics import evaluate_cf

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=16, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=16, shuffle=False)

In [4]:
disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.fit(train_dataloader)

Epoch 199, Loss: 0.2630: 100%|██████████| 200/200 [00:02<00:00, 81.12it/s]


In [5]:
gen_model = KDE()
gen_model.fit(train_dataloader, test_dataloader)

Train log-likelihood: -0.42380446195602417
Test log-likelihood: -0.4324100911617279


In [6]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)

In [7]:
median_log_prob = torch.median(gen_model.predict_log_prob(test_dataloader))

In [8]:
X_cf, X_orig, y_orig = cf.search_batch(test_dataloader, alpha=100, delta=median_log_prob)

100%|██████████| 7/7 [00:02<00:00,  2.49it/s]
