In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import numpy as np
from counterfactuals.datasets import MoonsDataset
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.cf_methods import PUMAL
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.metrics import CFMetrics

In [3]:
dataset = MoonsDataset("../data/moons.csv")

In [4]:
disc_model = MultilayerPerceptron(
    dataset.X_test.shape[1], [512, 512], dataset.y_test.shape[1]
)
disc_model.fit(
    dataset.train_dataloader(batch_size=128, shuffle=True),
    dataset.test_dataloader(batch_size=128, shuffle=False),
    epochs=5000,
    patience=100,
    lr=1e-3,
)
disc_model.eval()

Epoch 357, Train: 0.0027, test: 0.0330, patience: 100:   7%|▋         | 358/5000 [00:07<01:36, 48.01it/s]
  self.load_state_dict(torch.load(path))


MultilayerPerceptron(
  (layers): ModuleList(
    (0): Linear(in_features=2, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=2, bias=True)
  )
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (final_activation): Softmax(dim=1)
  (criterion): CrossEntropyLoss()
)

In [7]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == np.argmax(dataset.y_test, axis=1)).mean())

Test accuracy: 1.0


In [8]:
dataset.y_train = dataset.y_transformer.transform(
    disc_model.predict(dataset.X_train).detach().numpy().reshape(-1, 1)
)
dataset.y_test = dataset.y_transformer.transform(
    disc_model.predict(dataset.X_test).detach().numpy().reshape(-1, 1)
)

In [9]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=2,
    num_layers=5,
    context_features=2,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

gen_model.fit(
    train_dataloader,
    train_dataloader,
    learning_rate=1e-3,
    patience=100,
    num_epochs=500,
    checkpoint_path="moons_flow1.pth",
)

Epoch 216, Train: -0.9051, test: -0.9427, patience: 100:  43%|████▎     | 216/500 [00:09<00:12, 22.02it/s]
  self.load_state_dict(torch.load(path))


In [10]:
source_class = 0
target_class = 1
X_test_origin = dataset.X_test[np.argmax(dataset.y_test, axis=1) == source_class]
y_test_origin = dataset.y_test[np.argmax(dataset.y_test, axis=1) == source_class]

In [14]:
cf_method = PUMAL(
    X=X_test_origin,
    cf_method_type="GCE",
    K=6,
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(eps=0.01),
    not_actionable_features=None,
    neptune_run=None,
)

train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)
log_prob_threshold = torch.quantile(
    gen_model.predict_log_prob(train_dataloader_for_log_prob),
    0.25,
)

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=4096,
    shuffle=False,
)

In [15]:
delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(
    dataloader=cf_dataloader,
    target_class=target_class,
    epochs=20000,
    lr=0.01,
    patience=500,
    alpha_dist=1e-1,
    alpha_plaus=10**4,
    alpha_class=10**5,
    alpha_s=10**4,
    alpha_k=10**3,
    alpha_d=10**2,
    log_prob_threshold=log_prob_threshold,
)
Xs_cfs = Xs + delta().detach().numpy()

metrics = CFMetrics(
    X_cf=Xs_cfs,
    y_target=ys_target,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_test_origin,
    y_test=y_test_origin,
    disc_model=disc_model,
    gen_model=gen_model,
    continuous_features=list(range(dataset.X_train.shape[1])),
    categorical_features=dataset.categorical_features,
    prob_plausibility_threshold=log_prob_threshold,
)
metrics.calc_all_metrics()

loss: 5811.2456, dist: 0.0612, max_inner: 466.0587, loss_disc: 0.0000, delta_loss: 5345.1807:  90%|█████████ | 18019/20000 [01:49<00:12, 164.27it/s]       


{'coverage': 1.0,
 'validity': 1.0,
 'actionability': 0.0,
 'sparsity': 1.0,
 'proximity_categorical_hamming': nan,
 'proximity_categorical_jaccard': 0.4575844545879927,
 'proximity_continuous_manhattan': 0.6115565376505436,
 'proximity_continuous_euclidean': 0.4575844545879927,
 'proximity_continuous_mad': 3.1594508689584084,
 'proximity_l2_jaccard': 0.4575844545879927,
 'proximity_mad_hamming': nan,
 'prob_plausibility': 0.8737864077669902,
 'log_density_cf': 1.5956664,
 'log_density_test': 1.3673558,
 'lof_scores_cf': 1.0299549,
 'lof_scores_test': 1.0530212,
 'isolation_forest_scores_cf': 0.006346049550843414,
 'isolation_forest_scores_test': 0.003182315741079119}