In [6]:
import numpy as np
import torch
from counterfactuals.datasets import WineDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics import evaluate_cf

dataset = WineDataset("../data/wine.csv")
train_dataloader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)

disc_model = MultilayerPerceptron(
    input_size=dataset.X_train.shape[1], hidden_layer_sizes=[256, 256], target_size=1, dropout=0.2
)
disc_model.fit(
    train_dataloader,
    test_dataloader,
    epochs=5000,
    patience=300,
    lr=1e-3,
)

gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=8, context_features=1
)
gen_model.fit(train_dataloader, test_dataloader, num_epochs=1000)

cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader, alpha=100, log_prob_threshold=log_prob_threshold, epochs=4000
)
X_cf = X_orig + deltas
print(X_cf)
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    model_returned=np.ones(X_cf.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

  from .autonotebook import tqdm as notebook_tqdm
  self.load_state_dict(torch.load(path))
Epoch 1226, Train: -853.5295, test: -986.7457, patience: 300:  25%|██▍       | 1227/5000 [00:06<00:18, 199.71it/s]  
Epoch 237, Train: -7.7728, test: -5.8891, patience: 20:  24%|██▎       | 237/1000 [00:03<00:10, 74.58it/s]
  self.load_state_dict(torch.load(path))
Discriminator loss: 0.0000, Prob loss: 0.0428: 100%|██████████| 4000/4000 [00:11<00:00, 355.63it/s] 


[[ 0.59217846  0.37478423  0.62249774  0.6134726   0.42478135  0.29927152
   0.2043211   0.52385867  0.34732658  0.5297014   0.3271137   0.12562147
   0.41266748]
 [ 0.80285114  0.5892716   0.5858578   0.47232234  0.4124051   0.60578495
   0.43588093  0.31217784  0.37242544  0.325127    0.36111823  0.72017735
   0.5587177 ]
 [ 0.49347317  0.26686007  0.6637576   0.6881596   0.44715855  0.47811982
   0.26924318  0.3243528   0.43642783  0.41863766  0.19248445  0.12993595
   0.18402112]
 [ 0.383875    0.03666166  0.41637328  0.39590898  0.1559631   0.85276824
   0.6979956   0.21022251  0.6201558   0.2791566   0.45568448  0.56031305
   0.28721526]
 [ 0.48052353  0.4455378   0.3628083   0.3688509   0.27116096  0.69860935
   0.5996438   0.16172627  0.75230217  0.16186948  0.33010444  0.639772
   0.34667066]
 [ 0.5959895   0.24659315  0.47873083  0.11883532  0.33925673  0.5118383
   0.4342995   0.3102938   0.31505787  0.32459283  0.40167475  0.7090547
   0.48506808]
 [ 0.43993777  0.27928025 

{'coverage': 1.0,
 'validity': 1.0,
 'actionability': 0.0,
 'sparsity': 0.9978632478632479,
 'proximity_categorical_hamming': nan,
 'proximity_categorical_jaccard': 0.3567284325657089,
 'proximity_continuous_manhattan': 1.115057898982842,
 'proximity_continuous_euclidean': 0.3567284325657089,
 'proximity_continuous_mad': 8.275870707299974,
 'proximity_l2_jaccard': 0.3567284325657089,
 'proximity_mad_hamming': nan,
 'prob_plausibility': 0.9444444444444444,
 'log_density_cf': 4.728581,
 'log_density_test': -4.348091,
 'lof_scores_cf': 1.0247798,
 'lof_scores_test': 1.083412,
 'isolation_forest_scores_cf': 0.06842495049559408,
 'isolation_forest_scores_test': 0.03483643362366132}