In [1]:
import numpy as np
import torch
from counterfactuals.datasets import LawDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics import evaluate_cf

dataset = LawDataset("../data/law.csv")
train_dataloader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)

disc_model = MultilayerPerceptron(
    input_size=dataset.X_train.shape[1], hidden_layer_sizes=[256, 256], target_size=1, dropout=0.2
)
disc_model.fit(
    train_dataloader,
    test_dataloader,
    epochs=5000,
    patience=300,
    lr=1e-3,
)

gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=8, context_features=1
)
gen_model.fit(train_dataloader, test_dataloader, num_epochs=1000)

cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader, alpha=100, log_prob_threshold=log_prob_threshold, epochs=4000
)
X_cf = X_orig + deltas
print(X_cf)
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    model_returned=np.ones(X_cf.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

  from .autonotebook import tqdm as notebook_tqdm
  self.load_state_dict(torch.load(path))
Epoch 329, Train: 0.4698, test: 0.5159, patience: 300:   7%|▋         | 330/5000 [00:10<02:22, 32.85it/s]
Epoch 111, Train: -1.2250, test: -1.3029, patience: 20:  11%|█         | 111/1000 [00:07<01:01, 14.56it/s]
  self.load_state_dict(torch.load(path))
Discriminator loss: 0.0000, Prob loss: 0.0000:  26%|██▌       | 1032/4000 [00:05<00:14, 201.61it/s]


[[0.83324206 0.44325075 0.5241975 ]
 [0.7297297  0.6666667  0.37310925]
 [0.6756757  0.42857143 0.5546219 ]
 ...
 [0.6286705  0.5275927  0.5796142 ]
 [0.46296933 0.6939928  0.6594068 ]
 [0.59345514 0.510941   0.58612925]]


{'coverage': 1.0,
 'validity': 0.6216216216216216,
 'actionability': 0.20045045045045046,
 'sparsity': 0.7995495495495496,
 'proximity_categorical_hamming': nan,
 'proximity_categorical_jaccard': 0.1747434908992034,
 'proximity_continuous_manhattan': 0.2953791396179493,
 'proximity_continuous_euclidean': 0.1747434908992034,
 'proximity_continuous_mad': 2.469260804409134,
 'proximity_l2_jaccard': 0.1747434908992034,
 'proximity_mad_hamming': nan,
 'prob_plausibility': 1.0,
 'log_density_cf': 1.6123302,
 'log_density_test': 0.0727248,
 'lof_scores_cf': 1.0429159,
 'lof_scores_test': 1.0629689,
 'isolation_forest_scores_cf': 0.0517592142653124,
 'isolation_forest_scores_test': 0.029737427250345106}