In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.losses import BinaryDiscLoss

In [3]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=16, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=16, shuffle=False)

In [4]:
disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.fit(train_dataloader)

Epoch 199, Loss: 0.2639: 100%|██████████| 200/200 [00:02<00:00, 78.10it/s]


In [5]:
gen_model = MaskedAutoregressiveFlow(dataset.X_test.shape[1], 4, 1)
gen_model.fit(train_dataloader, test_dataloader)

Epoch 99, Train: 0.5980, test: 0.6135: 100%|██████████| 100/100 [00:22<00:00,  4.40it/s]


In [11]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)

In [12]:
median_log_prob = torch.median(gen_model.predict_log_prob(test_dataloader))

In [13]:
X_cf, X_orig, y_orig = cf.search_batch(test_dataloader, alpha=100, delta=median_log_prob)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:01<00:00,  4.25it/s]
