In [1]:
import numpy as np
import torch

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics import evaluate_cf


dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)

disc_model = MultilayerPerceptron(
    input_size=2, hidden_layer_sizes=[256, 256], target_size=1, dropout=0.2
)
disc_model.fit(
    train_dataloader,
    test_dataloader,
    epochs=5000,
    patience=300,
    lr=1e-3,
)

gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=8, context_features=1
)
gen_model.fit(train_dataloader, test_dataloader, num_epochs=1000)

cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader, alpha=100, log_prob_threshold=log_prob_threshold, epochs=4000
)
X_cf = X_orig + deltas
print(X_cf)
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    model_returned=np.ones(X_cf.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

  from .autonotebook import tqdm as notebook_tqdm
  self.load_state_dict(torch.load(path))
Epoch 1733, Train: 0.0298, test: 0.0366, patience: 300:  35%|███▍      | 1734/5000 [00:28<00:53, 60.98it/s]
Epoch 403, Train: -1.5278, test: -1.4541, patience: 20:  40%|████      | 403/1000 [00:15<00:22, 26.37it/s]
  self.load_state_dict(torch.load(path))
Discriminator loss: 0.0000, Prob loss: 0.0000:  94%|█████████▍| 3774/4000 [00:14<00:00, 255.28it/s]


[[0.6461668  0.48673075]
 [0.64491254 0.48646665]
 [0.42095536 0.48698193]
 [0.64061767 0.4877382 ]
 [0.35129285 0.60633886]
 [0.65817857 0.50366366]
 [0.347574   0.6028753 ]
 [0.59215605 0.55519074]
 [0.6580796  0.5030301 ]
 [0.5004674  0.28209668]
 [0.4318893  0.41733497]
 [0.4616859  0.3444159 ]
 [0.6580493  0.50309277]
 [0.69330925 0.19144821]
 [0.646605   0.48684222]
 [0.43635416 0.40590066]
 [0.34836894 0.6049624 ]
 [0.3508606  0.6061292 ]
 [0.64535636 0.48630798]
 [0.65846413 0.53983504]
 [0.64379805 0.48631877]
 [0.6466717  0.48672044]
 [0.63260174 0.49570203]
 [0.34080726 0.5844625 ]
 [0.5344524  0.6598058 ]
 [0.6455662  0.4862625 ]
 [0.5154522  0.6973327 ]
 [0.3442777  0.5960312 ]
 [0.6449921  0.4863658 ]
 [0.34223115 0.5909497 ]
 [0.3483961  0.6049528 ]
 [0.43713617 0.40359363]
 [0.41584408 0.50511765]
 [0.64554864 0.48649818]
 [0.65792835 0.5027333 ]
 [0.50775015 0.7047355 ]
 [0.64404005 0.48653632]
 [0.36927402 0.60536087]
 [0.47682446 0.32357457]
 [0.34825334 0.6049093 ]


{'coverage': 1.0,
 'validity': 1.0,
 'actionability': 0.0,
 'sparsity': 1.0,
 'proximity_categorical_hamming': nan,
 'proximity_categorical_jaccard': 0.27648498851905734,
 'proximity_continuous_manhattan': 0.3554813292331812,
 'proximity_continuous_euclidean': 0.27648498851905734,
 'proximity_continuous_mad': 1.8272110244122948,
 'proximity_l2_jaccard': 0.27648498851905734,
 'proximity_mad_hamming': nan,
 'prob_plausibility': 0.9951219512195122,
 'log_density_cf': 1.367367,
 'log_density_test': -36.42818,
 'lof_scores_cf': 1.04192,
 'lof_scores_test': 1.0409402,
 'isolation_forest_scores_cf': 0.027092389010673755,
 'isolation_forest_scores_test': 0.0041604418163049064}