In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.losses import BinaryDiscLoss

In [None]:
dataset = MoonsDataset("data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=16, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=16, shuffle=False)

disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.fit(train_dataloader, test_dataloader)

gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=8, context_features=1
)
gen_model.fit(train_dataloader, test_dataloader)

cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
median_log_prob = torch.median(gen_model.predict_log_prob(test_dataloader))
X_cf, X_orig, y_orig, y_target, _ = cf.search_batch(
    test_dataloader, alpha=100, delta=median_log_prob
)
print(X_cf)

In [None]:
X_cf