In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

In [3]:
from nflows.flows import MaskedAutoregressiveFlow

from counterfactuals.datasets.heloc import HelocDataset
from counterfactuals.datasets.moons import MoonsDataset
from counterfactuals.datasets.law import LawDataset
from counterfactuals.datasets.compas import CompasDataset

from counterfactuals.optimizers.approach_gen_disc import ApproachGenDisc

from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report

from counterfactuals.metrics.metrics import (
    perc_valid_cf,
    perc_valid_actionable_cf,
    continuous_distance,
    categorical_distance,
    distance_l2_jaccard,
    distance_mad_hamming,
    plausibility,
    delta_proba
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = CompasDataset(file_path="../data/origin/compas_two_years.csv")
train_dataloader = dataset.train_dataloader(batch_size=64, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=64, shuffle=False)

In [5]:
# clf = LogisticRegression()
clf = MLPClassifier((128, 64), max_iter=100)
clf.fit(dataset.X_train, dataset.y_train)
y_pred_train = clf.predict(dataset.X_train)
y_pred_test = clf.predict(dataset.X_test)
print(classification_report(dataset.y_test, y_pred_test, output_dict=False))

              precision    recall  f1-score   support

         0.0       0.79      0.77      0.78       423
         1.0       0.78      0.79      0.78       419

    accuracy                           0.78       842
   macro avg       0.78      0.78      0.78       842
weighted avg       0.78      0.78      0.78       842





In [6]:
flow = MaskedAutoregressiveFlow(features=dataset.X_test.shape[1], hidden_features=4, context_features=None)
cf = ApproachGenDisc(gen_model=flow, disc_model=clf)

In [7]:
cf.train_model(
    train_loader=train_dataloader,
    test_loader=test_dataloader,
    epochs=200,
    verbose=True
)

Epochs:   2%|▏         | 3/200 [00:00<00:17, 11.02it/s]

Epoch 0, Train: 43.587529828471524, test: 40.9719249180385


Epochs:   6%|▋         | 13/200 [00:01<00:15, 11.78it/s]

Epoch 10, Train: -0.2981909133734242, test: -0.4081937564270837


Epochs:  12%|█▏        | 23/200 [00:01<00:15, 11.68it/s]

Epoch 20, Train: -1.7054029251298597, test: -1.7455938820328032


Epochs:  16%|█▋        | 33/200 [00:02<00:13, 12.17it/s]

Epoch 30, Train: -2.7367403891778763, test: -2.654904763613428


Epochs:  22%|██▏       | 43/200 [00:03<00:12, 12.20it/s]

Epoch 40, Train: -5.598197437101795, test: -5.734824333872114


Epochs:  26%|██▋       | 53/200 [00:04<00:12, 12.17it/s]

Epoch 50, Train: -8.505825988708004, test: -8.692775419780187


Epochs:  32%|███▏      | 63/200 [00:05<00:11, 12.22it/s]

Epoch 60, Train: -11.140110646524738, test: -11.858657564435687


Epochs:  36%|███▋      | 73/200 [00:06<00:10, 12.21it/s]

Epoch 70, Train: -14.252436945515294, test: -14.949844496590751


Epochs:  42%|████▏     | 83/200 [00:06<00:09, 12.23it/s]

Epoch 80, Train: -16.196782081357895, test: -16.682552610124862


Epochs:  46%|████▋     | 93/200 [00:07<00:08, 12.30it/s]

Epoch 90, Train: -17.096112558918616, test: -17.896042619432723


Epochs:  52%|█████▏    | 103/200 [00:08<00:07, 12.17it/s]

Epoch 100, Train: -17.7854548423521, test: -18.370494161333358


Epochs:  56%|█████▋    | 113/200 [00:09<00:07, 12.22it/s]

Epoch 110, Train: -19.10747460396059, test: -19.635623659406388


Epochs:  62%|██████▏   | 123/200 [00:10<00:06, 12.18it/s]

Epoch 120, Train: -18.915714756135017, test: -19.579544884817942


Epochs:  66%|██████▋   | 133/200 [00:10<00:05, 12.23it/s]

Epoch 130, Train: -19.159069645789362, test: -19.48715795789446


Epochs:  72%|███████▏  | 143/200 [00:11<00:04, 12.23it/s]

Epoch 140, Train: -19.953757532181278, test: -19.762725830078125


Epochs:  76%|███████▋  | 153/200 [00:12<00:03, 12.20it/s]

Epoch 150, Train: -19.91091315977035, test: -20.95369257245745


Epochs:  82%|████████▏ | 163/200 [00:13<00:03, 12.15it/s]

Epoch 160, Train: -19.46411508129489, test: -19.560198102678573


Epochs:  86%|████████▋ | 173/200 [00:14<00:02, 12.17it/s]

Epoch 170, Train: -20.995983246834047, test: -21.2900573185512


Epochs:  92%|█████████▏| 183/200 [00:15<00:01, 12.19it/s]

Epoch 180, Train: -20.055115607477003, test: -21.225250925336564


Epochs:  96%|█████████▋| 193/200 [00:15<00:00, 12.17it/s]

Epoch 190, Train: -21.634810909148186, test: -21.794276646205358


Epochs: 100%|██████████| 200/200 [00:16<00:00, 12.11it/s]


In [8]:
cf.test_model(test_loader=test_dataloader)

              precision    recall  f1-score   support

         0.0       0.50      1.00      0.67       423
         1.0       0.00      0.00      0.00       419

    accuracy                           0.50       842
   macro avg       0.25      0.50      0.33       842
weighted avg       0.25      0.50      0.34       842



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [9]:
Xs_cfs = cf.generate_counterfactuals(Xs=dataset.X_test[:100], ys=dataset.y_test[:100], num_epochs=100, lr=0.005, alpha=20, beta=0.01)
Xs_cfs = torch.concat(Xs_cfs).detach()

0it [00:00, ?it/s]

100it [00:25,  3.85it/s]


In [11]:
y_cfs_pred_mlpc = clf.predict(Xs_cfs)
y_orig_pred_mlpc = clf.predict(dataset.X_test[:100])

In [12]:
X = dataset.X_test[:100]
ys_cfs_pred = clf.predict(Xs_cfs)
ys_orig_pred = clf.predict(dataset.X_test[:100])
ys_orig = dataset.y_test[:100].flatten()

{
    "valid_cf": perc_valid_cf(ys_orig_pred, y_cf=ys_cfs_pred),
    "valid_cf_for_orig_data": perc_valid_cf(ys_orig, y_cf=ys_cfs_pred),
    # "perc_valid_actionable_cf": perc_valid_actionable_cf(X=dataset.X_test[:100], X_cf=Xs_cfs, y=ys_orig_pred, y_cf=ys_cfs_pred, actionable_features=[1,2]),
    "continuous_distance": continuous_distance(X=X, X_cf=Xs_cfs, continuous_features=dataset.continuous_columns, metric='mad', X_all=dataset.X_test),
    "categorical_distance": categorical_distance(X=X, X_cf=Xs_cfs, categorical_features=dataset.categorical_columns, metric='jaccard', agg='mean'),
    "distance_l2_jaccard": distance_l2_jaccard(X=X, X_cf=Xs_cfs, continuous_features=dataset.continuous_columns, categorical_features=dataset.categorical_columns),
    "distance_mad_hamming": distance_mad_hamming(X=X, X_cf=Xs_cfs,
                                                 continuous_features=dataset.continuous_columns, categorical_features=dataset.categorical_columns, X_all=X, agg='mean'),
    "plausibility": plausibility(
        X, Xs_cfs, ys_orig,
        continuous_features_all=dataset.continuous_columns,
        categorical_features_all=dataset.categorical_columns,
        X_train=dataset.X_train,
        ratio_cont=None
    ),
    "delta_probability": delta_proba(X, Xs_cfs, classifier=clf),
    "log_density": np.mean(cf.predict_model(test_dataloader))
}

{'valid_cf': 0.02,
 'valid_cf_for_orig_data': 0.22,
 'continuous_distance': 10.4251913975653,
 'categorical_distance': 0.48,
 'distance_l2_jaccard': 0.27139890054267696,
 'distance_mad_hamming': 5.112228220476975,
 'plausibility': 7.521560593275668,
 'delta_probability': 0.020853292,
 'log_density': 21.179031}