In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
import numpy as np
import torch
from matplotlib import pyplot as plt

from counterfactuals.datasets import MoonsDataset
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.cf_methods import PUMAL
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.metrics import CFMetrics

from counterfactuals.plot_utils import (
    plot_generative_model_distribution,
    plot_classifier_decision_region,
)

In [15]:
dataset = MoonsDataset("../data/moons.csv")

In [None]:
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
# disc_model.fit(
#     dataset.train_dataloader(batch_size=128, shuffle=True),
#     dataset.test_dataloader(batch_size=128, shuffle=False),
#     epochs=5000,
#     patience=100,
#     lr=1e-3,
#     checkpoint_path="moons_mlp.pt",
# )
disc_model.load("globe-ce-moons/moons_mlp.pt")
disc_model.eval()

In [None]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == np.argmax(dataset.y_test, axis=1)).mean())

In [7]:
dataset.y_train = dataset.y_transformer.transform(
    disc_model.predict(dataset.X_train).detach().numpy().reshape(-1, 1)
)
dataset.y_test = dataset.y_transformer.transform(
    disc_model.predict(dataset.X_test).detach().numpy().reshape(-1, 1)
)

In [None]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=2,
    num_layers=5,
    context_features=2,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path="moons_flow1.pth",
# )
gen_model.load("moons_flow1.pth")

In [9]:
source_class = 0
target_class = 1
X_test_origin = dataset.X_test[np.argmax(dataset.y_test, axis=1) == source_class]
y_test_origin = dataset.y_test[np.argmax(dataset.y_test, axis=1) == source_class]

In [10]:
# dataset.actionable_features = [0, 1, 2, 3, 4]
# dataset.not_actionable_features = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
cf_method = PUMAL(
    X=X_test_origin,
    cf_method_type="GCE",
    K=2,
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(eps=0.01),
    not_actionable_features=None,
    neptune_run=None,
)

train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)
log_prob_threshold = torch.quantile(
    gen_model.predict_log_prob(train_dataloader_for_log_prob),
    0.1,
)

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=4096,
    shuffle=False,
)

In [None]:
delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(
    dataloader=cf_dataloader,
    target_class=target_class,
    epochs=20000,
    lr=0.01,
    patience=500,
    alpha_dist=1e-1,
    alpha_plaus=10**2,
    alpha_class=10**5,
    alpha_s=10**3,
    alpha_k=10**1,
    log_prob_threshold=log_prob_threshold,
)

M, S, D = delta.get_matrices()
print(S.sum(axis=0))
Xs_cfs = Xs + delta().detach().numpy()

values, indexes = S.max(dim=1)

total = len(values)
i_correct = indexes[values == 1]
print(f"Correct: {len(i_correct)}/{total}")
print(len(set(i_correct.tolist())))

metrics = CFMetrics(
    X_cf=Xs_cfs,
    y_target=ys_target,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_test_origin,
    y_test=y_test_origin,
    disc_model=disc_model,
    gen_model=gen_model,
    continuous_features=list(range(dataset.X_train.shape[1])),
    categorical_features=dataset.categorical_features,
    prob_plausibility_threshold=log_prob_threshold,
)
metrics.calc_all_metrics()

In [None]:
plt.figure(figsize=(15, 5))

groups = S.argmax(dim=1)

for i in range(D.shape[0]):
    plt.subplot(1, 3, i + 1)
    plt.bar(range(2), D[i].detach().numpy())
    mean_magn = M.squeeze()[groups == i].mean(axis=0)
    std_magn = M.squeeze()[groups == i].std(axis=0)
    n_vectors = (S.argmax(axis=1) == i).sum()
    plt.title(
        f"CF {i}, # of cfs: {n_vectors},  Magnitude: {mean_magn:.2f} +- {std_magn:.2f}"
    )

In [88]:
import matplotlib

In [None]:
# fig, ax = plt.subplots(figsize=(10, 10))
fig, ax = plt.subplots(1, 1)


# Add arrows between each Xs and Xs_cfs
group_colors = [
    "red",
    "blue",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
group_cf_colors = [
    "orange",
    "purple",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
for group_i in range(S.shape[1]):
    xs_group = Xs[S.argmax(dim=1) == group_i]
    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]
    ax.scatter(
        xs_cfs_group[:, 0],
        xs_cfs_group[:, 1],
        c="orange",
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    ax.scatter(
        xs_group[:, 0],
        xs_group[:, 1],
        c=group_colors[group_i],
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    for i in range(len(xs_group)):
        ax.arrow(
            xs_group[i, 0],
            xs_group[i, 1],
            xs_cfs_group[i, 0] - xs_group[i, 0],
            xs_cfs_group[i, 1] - xs_group[i, 1],
            head_width=0.00,
            head_length=-0.05,
            fc="grey",
            ec="grey",
            # fc=group_colors[group_i],
            # ec=group_colors[group_i],
            alpha=0.5,
        )
# for i in range(len(Xs)):
#     ax.arrow(
#         Xs[i, 0],
#         Xs[i, 1],
#         Xs_cfs[i, 0] - Xs[i, 0],
#         Xs_cfs[i, 1] - Xs[i, 1],
#         head_width=0.02,
#         head_length=0.00,
#         fc="gray",
#         ec="gray",
#         alpha=0.5,
# )

plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)
plot_classifier_decision_region(ax, disc_model)
# plot_observations(ax, Xs, ys_orig, group_colors)
# plot_counterfactuals(ax, Xs_cfs)
# plot_arrows(ax, Xs, Xs_cfs)
# remove boundaries
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# remove frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
plt.tight_layout()
# plt.savefig("teaser_groupwise.pdf", dpi=300)
plt.show()

# Global

In [188]:
cf_method = PUMAL(
    X=X_test_origin,
    cf_method_type="GCE",
    K=1,
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(eps=0.01),
    not_actionable_features=None,
    neptune_run=None,
)

train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)
log_prob_threshold = torch.quantile(
    gen_model.predict_log_prob(train_dataloader_for_log_prob),
    0.1,
)

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=4096,
    shuffle=False,
)

In [None]:
delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(
    dataloader=cf_dataloader,
    target_class=target_class,
    epochs=100,
    lr=0.001,
    patience=500,
    alpha_dist=1e-1,
    alpha_plaus=10**3,
    alpha_class=10**8,
    alpha_s=0,
    alpha_k=0,
    log_prob_threshold=log_prob_threshold,
)

M, S, D = delta.get_matrices()
print(S.sum(axis=0))
Xs_cfs = Xs + delta().detach().numpy()

values, indexes = S.max(dim=1)

total = len(values)
i_correct = indexes[values == 1]
print(f"Correct: {len(i_correct)}/{total}")
print(len(set(i_correct.tolist())))

metrics = CFMetrics(
    X_cf=Xs_cfs,
    y_target=ys_target,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_test_origin,
    y_test=y_test_origin,
    disc_model=disc_model,
    gen_model=gen_model,
    continuous_features=list(range(dataset.X_train.shape[1])),
    categorical_features=dataset.categorical_features,
    prob_plausibility_threshold=log_prob_threshold,
)
metrics.calc_all_metrics()

In [None]:
# fig, ax = plt.subplots(figsize=(10, 10))
fig, ax = plt.subplots(1, 1)
# Add arrows between each Xs and Xs_cfs
group_colors = [
    "blue",
    "red",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
group_cf_colors = [
    "orange",
    "purple",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
for group_i in range(S.shape[1]):
    xs_group = Xs[S.argmax(dim=1) == group_i]
    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]
    ax.scatter(
        xs_cfs_group[:, 0],
        xs_cfs_group[:, 1],
        c="orange",
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    ax.scatter(
        xs_group[:, 0],
        xs_group[:, 1],
        c=group_colors[group_i],
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    for i in range(len(xs_group)):
        ax.arrow(
            xs_group[i, 0],
            xs_group[i, 1],
            xs_cfs_group[i, 0] - xs_group[i, 0],
            xs_cfs_group[i, 1] - xs_group[i, 1],
            head_width=0.00,
            head_length=-0.05,
            fc="grey",
            ec="grey",
            alpha=0.5,
        )

plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)
plot_classifier_decision_region(ax, disc_model)
# remove boundaries
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# remove frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
plt.tight_layout()
plt.savefig("teaser_global.pdf", dpi=300)
plt.show()

# LOCAL

In [224]:
cf_method = PUMAL(
    X=X_test_origin,
    cf_method_type="GCE",
    K=None,
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(eps=0.01),
    not_actionable_features=None,
    neptune_run=None,
)

train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)
log_prob_threshold = torch.quantile(
    gen_model.predict_log_prob(train_dataloader_for_log_prob),
    0.1,
)

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=4096,
    shuffle=False,
)

In [None]:
delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(
    dataloader=cf_dataloader,
    target_class=target_class,
    epochs=20000,
    lr=0.01,
    patience=500,
    alpha_dist=1e1,
    alpha_plaus=10**3,
    alpha_class=10**8,
    alpha_s=0,
    alpha_d=0,
    alpha_k=0,
    log_prob_threshold=log_prob_threshold,
)

M, S, D = delta.get_matrices()
print(S.sum(axis=0))
Xs_cfs = Xs + delta().detach().numpy()

values, indexes = S.max(dim=1)

total = len(values)
i_correct = indexes[values == 1]
print(f"Correct: {len(i_correct)}/{total}")
print(len(set(i_correct.tolist())))

metrics = CFMetrics(
    X_cf=Xs_cfs,
    y_target=ys_target,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_test_origin,
    y_test=y_test_origin,
    disc_model=disc_model,
    gen_model=gen_model,
    continuous_features=list(range(dataset.X_train.shape[1])),
    categorical_features=dataset.categorical_features,
    prob_plausibility_threshold=log_prob_threshold,
)
metrics.calc_all_metrics()

In [None]:
# fig, ax = plt.subplots(figsize=(10, 10))
fig, ax = plt.subplots(1, 1)
# Add arrows between each Xs and Xs_cfs
group_colors = [
    "blue",
    "red",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
group_cf_colors = [
    "orange",
    "purple",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]
for group_i in range(S.shape[1]):
    xs_group = Xs[S.argmax(dim=1) == group_i]
    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]
    ax.scatter(
        xs_cfs_group[:, 0],
        xs_cfs_group[:, 1],
        c="orange",
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    ax.scatter(
        xs_group[:, 0],
        xs_group[:, 1],
        c=group_colors[0],
        cmap=matplotlib.colormaps["tab10"],
        s=40,
        alpha=0.6,
    )
    for i in range(len(xs_group)):
        ax.arrow(
            xs_group[i, 0],
            xs_group[i, 1],
            xs_cfs_group[i, 0] - xs_group[i, 0],
            xs_cfs_group[i, 1] - xs_group[i, 1],
            head_width=0.00,
            head_length=-0.05,
            fc="grey",
            ec="grey",
            alpha=0.5,
        )

plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)
plot_classifier_decision_region(ax, disc_model)
# remove boundaries
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# remove frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
plt.tight_layout()
plt.savefig("teaser_local.pdf", dpi=300)
plt.show()