In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

from counterfactuals.datasets import LawDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = LawDataset("../data/law.csv")

In [4]:
dataset.feature_columns

['lsat', 'gpa', 'zfygpa', 'race']

In [5]:
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
# disc_model.fit(
#     dataset.train_dataloader(batch_size=128, shuffle=True),
#     dataset.test_dataloader(batch_size=128, shuffle=False),
#     epochs=5000,
#     patience=100,
#     lr=1e-3,
#     checkpoint_path="law_disc_model.pt",
# )
disc_model.load("law_disc_model.pt")

  self.load_state_dict(torch.load(path))


In [6]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.7680180180180181


In [7]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [8]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path="law_flow.pth"
# )
gen_model.load("law_flow.pth")

  self.load_state_dict(torch.load(path))


In [9]:
# torch.nn.functional.softmax(torch.rand(3, 4), dim=1)

In [10]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader, alpha=100, log_prob_threshold=log_prob_threshold, epochs=20000
)

Discriminator loss: 0.0148, Prob loss: 6.8239: 100%|██████████| 20000/20000 [03:18<00:00, 100.57it/s] 


In [11]:
X_cf = X_orig + deltas

evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    model_returned=np.ones(X_cf.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

{'coverage': 1.0,
 'validity': 0.49099099099099097,
 'actionability': 0.0,
 'sparsity': 1.0,
 'proximity_categorical_hamming': 0.8814228418272965,
 'proximity_categorical_jaccard': 0.8814228418272965,
 'proximity_continuous_manhattan': 0.9667337269361462,
 'proximity_continuous_euclidean': 0.8814228418272965,
 'proximity_continuous_mad': 2.693339854304844,
 'proximity_l2_jaccard': 0.8814228418272965,
 'proximity_mad_hamming': 2.693339854304844,
 'prob_plausibility': 0.0,
 'log_density_cf': -inf,
 'log_density_test': 16.452646,
 'lof_scores_cf': 20.599035,
 'lof_scores_test': 1.0911144,
 'isolation_forest_scores_cf': -0.049082294113076706,
 'isolation_forest_scores_test': 0.07767746732699544}

In [12]:
# torch.nn.functional.gumbel_softmax(torch.rand(4, 3), tau=0.1, dim=1)

In [44]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()


# X_cf_cat = torch.from_numpy(X_cf_cat)
# X_cf_cat[:, 3:] = torch.nn.functional.gumbel_softmax(X_cf_cat[:, 3:], tau=0.1, dim=1)
# X_cf_cat = X_cf_cat.numpy()
max_indices = np.argmax(X_cf_cat[:, 3:], axis=1)
X_cf_cat[:, 3:] = np.eye(X_cf_cat[:, 3:].shape[1])[max_indices]

In [45]:
X_cf.shape

(444, 11)

In [46]:
np.sum(np.abs(X_cf_cat[:, 3:] - X_orig[:, 3:]))

160.0

In [47]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf_cat,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

{'coverage': 1.0,
 'validity': 1.0,
 'actionability': 0.0,
 'sparsity': 0.30548730548730546,
 'proximity_categorical_hamming': 0.17837772099491042,
 'proximity_categorical_jaccard': 0.2766578192750087,
 'proximity_continuous_manhattan': 0.3592241762799381,
 'proximity_continuous_euclidean': 0.2766578192750087,
 'proximity_continuous_mad': 2.014457961167594,
 'proximity_l2_jaccard': 0.2766578192750087,
 'proximity_mad_hamming': 1.9161778628874957,
 'prob_plausibility': 0.7027027027027027,
 'log_density_cf': 20.070389,
 'log_density_test': 16.452646,
 'lof_scores_cf': 1.3323857,
 'lof_scores_test': 1.0911144,
 'isolation_forest_scores_cf': 0.036989625748853006,
 'isolation_forest_scores_test': 0.07767746732699544}