In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from matplotlib import pyplot as plt

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.rppcef import RPPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.losses import BinaryDiscLoss

In [3]:
def plot_model_distribution(model, median_prob=None, disc_model=None):
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(20, 12)

    xline = torch.linspace(-1.5, 2.5, 200)
    yline = torch.linspace(-0.75, 1.25, 200)
    xgrid, ygrid = torch.meshgrid(xline, yline)
    xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

    with torch.no_grad():
        zgrid0 = model(xyinput, torch.zeros(40000, 1)).exp().reshape(200, 200)
        zgrid1 = model(xyinput, torch.ones(40000, 1)).exp().reshape(200, 200)

    zgrid0 = zgrid0.numpy()
    zgrid1 = zgrid1.numpy()

    if median_prob is not None:
        median_prob = np.exp(median_prob)
        # cs1_mp = ax.contourf(
        #     xgrid.numpy(),
        #     ygrid.numpy(),
        #     zgrid1,
        #     levels=[median_prob, median_prob + 10.00],
        #     alpha=0.1,
        #     colors="#DC143C",
        # )

    # cs0 = ax.contour(
    #     xgrid.numpy(),
    #     ygrid.numpy(),
    #     zgrid0,
    #     levels=10,
    #     cmap="Greys",
    #     linewidths=0.4,
    #     antialiased=True,
    # )
    # cs1 = ax.contour(
    #     xgrid.numpy(),
    #     ygrid.numpy(),
    #     zgrid1,
    #     levels=10,
    #     cmap="Oranges",
    #     linewidths=0.4,
    #     antialiased=True,
    # )
    return ax

In [4]:
# dataset = LawDataset("../data/law.csv")
# train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True)
# test_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)

In [5]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)

disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.load("../models/MoonsDataset/disc_model_LogisticRegression.pt")

# Re-labeling
# dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
# dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=4, context_features=1
)
gen_model.load("../models/MoonsDataset/gen_model_MaskedAutoregressiveFlow.pt")

In [6]:
print(
    f"Discriminative model accuracy: {np.mean((disc_model.predict(dataset.X_test).numpy() == dataset.y_test))}"
)

Discriminative model accuracy: 0.848780487804878


In [7]:
max_samples = dataset.X_test[dataset.y_test == 0].shape[0]

In [8]:
options = [
    {"K": 1, "alpha_plausability": 0, "alpha_search": 0},
    {"K": max_samples, "alpha_plausability": 0, "alpha_search": 0},
    {"K": max_samples, "alpha_plausability": 0, "alpha_search": 1000},
    {"K": 1, "alpha_plausability": 1000, "alpha_search": 0},
    {"K": max_samples, "alpha_plausability": 1000, "alpha_search": 0},
    {"K": max_samples, "alpha_plausability": 1000, "alpha_search": 1000},
]

for opt in options:
    print(opt)
    K = opt["K"]
    alpha_plausability = opt["alpha_plausability"]
    alpha_search = opt["alpha_search"]

    cf = RPPCEF(
        K=K,
        gen_model=gen_model,
        disc_model=disc_model,
        disc_model_criterion=BinaryDiscLoss(),
        neptune_run=None,
    )

    cf_dataloader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(
            torch.from_numpy(dataset.X_test[dataset.y_test == 0]),
            torch.from_numpy(dataset.y_test[dataset.y_test == 0]),
        ),
        batch_size=1024,
        shuffle=False,
    )
    median_log_prob = torch.median(gen_model.predict_log_prob(cf_dataloader))

    deltas, X_orig, y_orig, y_target, logs = cf.search_batch(
        cf_dataloader,
        alpha=1000,
        alpha_plausability=alpha_plausability,
        alpha_search=alpha_search,
        median_log_prob=median_log_prob,
        epochs=10000,
    )

    M, S, D = deltas[0].get_matrices()
    print(f"Number of vectors: {(S.sum(axis=0) != 0).sum()}")

    X_cf = X_orig + deltas[0]().detach().numpy()

    ## Distribution Plot
    ax = plot_model_distribution(cf.gen_model, median_log_prob, cf.disc_model)

    ## Classifier Line
    w1, w2 = list(disc_model.parameters())[0].detach().cpu().numpy()[0]
    b = list(disc_model.parameters())[1].detach().cpu().numpy().item()
    c = -b / w2
    m = -w1 / w2
    xmin, xmax = -1.5, 2.5
    ymin, ymax = -1.5, 2.5
    xd = np.array([xmin, xmax])
    yd = m * xd + c
    plt.plot(xd, yd, "#ADD8E6", lw=2.0, ls="dashed")
    # plt.axis("off")

    ## Original points
    ax.scatter(X_orig[:, 0], X_orig[:, 1], alpha=0.5)

    ## Counterfactuals
    ax.scatter(X_cf[:, 0], X_cf[:, 1], c="r")

    for before, after in zip(X_orig, X_cf):
        ax.arrow(
            before[0],
            before[1],
            after[0] - before[0],
            after[1] - before[1],
            # head_width=0.0,
            # head_length=0.0,
            fc="gray",
            ec="gray",
            alpha=0.5,
            width=0.0001,
        )

    plt.xlim([-0.10, 1.05])
    plt.ylim([-0.10, 1.10])
    plt.tight_layout()

    ## Save Figure
    plt.savefig(f"K_{K}_Search_{alpha_search}_Plausability_{alpha_plausability}.pdf")
    plt.close()

{'K': 1, 'alpha_plausability': 0, 'alpha_search': 0}


  0%|                                                                                                         …

Number of vectors: 1


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


{'K': 103, 'alpha_plausability': 0, 'alpha_search': 0}


  0%|                                                                                                         …

Number of vectors: 48
{'K': 103, 'alpha_plausability': 0, 'alpha_search': 1000}


  0%|                                                                                                         …

Number of vectors: 5
{'K': 1, 'alpha_plausability': 1000, 'alpha_search': 0}


  0%|                                                                                                         …

Number of vectors: 1
{'K': 103, 'alpha_plausability': 1000, 'alpha_search': 0}


  0%|                                                                                                         …

Number of vectors: 40
{'K': 103, 'alpha_plausability': 1000, 'alpha_search': 1000}


  0%|                                                                                                         …

Number of vectors: 9
