In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from tqdm import tqdm

from counterfactuals.datasets.moons import MoonsDataset
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.cf_methods.regional_ppcef import RPPCEF, GCE

In [None]:
def plot_model_distribution(model, median_prob=None, disc_model=None):
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(20, 12)

    xline = torch.linspace(-1.5, 2.5, 200)
    yline = torch.linspace(-0.75, 1.25, 200)
    xgrid, ygrid = torch.meshgrid(xline, yline)
    xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

    with torch.no_grad():
        zgrid0 = model(xyinput, torch.zeros(40000, 1)).exp().reshape(200, 200)
        zgrid1 = model(xyinput, torch.ones(40000, 1)).exp().reshape(200, 200)

    zgrid0 = zgrid0.numpy()
    zgrid1 = zgrid1.numpy()

    if median_prob is not None:
        median_prob = np.exp(median_prob)
        cs1_mp = ax.contourf(  # noqa: F841
            xgrid.numpy(),
            ygrid.numpy(),
            zgrid1,
            levels=[median_prob, median_prob + 10.00],
            alpha=0.1,
            colors="#DC143C",
        )

    cs0 = ax.contour(  # noqa: F841
        xgrid.numpy(),
        ygrid.numpy(),
        zgrid0,
        levels=10,
        cmap="Greys",
        linewidths=0.4,
        antialiased=True,
    )
    cs1 = ax.contour(  # noqa: F841
        xgrid.numpy(),
        ygrid.numpy(),
        zgrid1,
        levels=10,
        cmap="Oranges",
        linewidths=0.4,
        antialiased=True,
    )
    return ax

In [None]:
dataset = MoonsDataset(file_path="../data/moons.csv")

origin_class = 0
target_class = np.abs(1 - origin_class)
X_test_origin = dataset.X_test[dataset.y_test == origin_class]
y_test_origin = dataset.y_test[dataset.y_test == origin_class]
X_test_target = dataset.X_test[dataset.y_test == target_class]
y_test_target = dataset.y_test[dataset.y_test == target_class]
# if cf_method in ["ARES", "GLOBAL_CE"]:
#     K = 1
# elif cfg.counterfactuals.K is not None:
#     K = cfg.counterfactuals.K
# else:
#     K = X_test_origin.shape[0]

In [None]:
disc_model = LogisticRegression(input_size=2, target_size=1)
disc_model.load("../models/MoonsDataset/disc_model_LogisticRegression.pt")

flow = MaskedAutoregressiveFlow(features=2, hidden_features=4, context_features=1)
flow.load("../models/MoonsDataset/gen_model_MaskedAutoregressiveFlow.pt")

delta = GCE(
    N=X_test_origin.shape[0], D=X_test_origin.shape[1], K=X_test_origin.shape[0]
)

cf = RPPCEF(
    delta=delta,
    gen_model=flow,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)

In [None]:
log_prob_threshold = torch.quantile(
    flow.predict_log_prob(dataset.train_dataloader(1024, shuffle=False)), 0.25
)
log_prob_threshold

In [None]:
cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=4096,
    shuffle=False,
)

deltas, Xs, ys_orig, ys_target, _ = cf.search_batch(
    dataloader=cf_dataloader,
    epochs=20_000,
    lr=3e-3,
    patience=500,
    alpha=1000,
    alpha_s=1000,
    alpha_k=1000,
    median_log_prob=log_prob_threshold,
)

In [None]:
def entropy_loss(prob_dist):
    prob_dist = torch.clamp(prob_dist, min=1e-9)
    row_wise_entropy = -torch.sum(prob_dist * torch.log(prob_dist), dim=1)
    return row_wise_entropy


x_origin = torch.tensor([[-0.6, 0.8], [-0.6, 1.0], [0.4, 0.4]], requires_grad=False)


N = x_origin.shape[0]
D = x_origin.shape[1]
K = 3


context_origin = torch.tile(torch.Tensor([0]), dims=(N, 1))
context_target = torch.tile(torch.Tensor([1]), dims=(N, 1))


blcf = GCE(N, D, K)

optimizer = torch.optim.Adam(blcf.parameters(), lr=0.005)
min_loss = np.inf
no_improve = 0

num_iterations = 2000
patience = 100
alpha = 10
# lamda = 1

p_history = []
p_history.append(x_origin.detach().numpy().copy())

pbar = tqdm(range(num_iterations))

for i in pbar:
    optimizer.zero_grad()
    x_prim = blcf(x_origin)

    loss_components = cf.search_step(
        x_prim,
        x_origin,
        context_origin,
        context_target,
        alpha=alpha,
        median_log_prob=log_prob_threshold,
    )
    mean_loss = (
        loss_components["dist"].reshape(-1, 1) + alpha * loss_components["loss_disc"]
    ).mean()
    # mean_loss = loss_components["loss"].mean()
    mean_loss += (
        100 * entropy_loss(blcf.sparsemax(blcf.s)).mean()
    )  # Entropy Loss to enforce class assignment!
    pbar.set_description(f"Loss: {mean_loss}")

    mean_loss.backward()
    optimizer.step()

    if mean_loss.item() < min_loss:
        min_loss = mean_loss.item()
    else:
        no_improve += 1
    if no_improve > patience:
        p_history.append(x_prim.detach().numpy().copy())
        break

    if i % 10 == 0:
        p_history.append(x_prim.detach().numpy().copy())

p_history = np.concatenate(p_history)

In [None]:
# entropy_loss(blcf.sparsemax(blcf.s).sum(axis=0) / blcf.sparsemax(blcf.s).sum())
# Entropy Loss to enforce the smallest possible amount of components.

In [None]:
loss_components

In [None]:
blcf.get_matrices()

In [None]:
disc_model(x_origin), disc_model(x_prim)

In [None]:
## Distribution Plot
ax = plot_model_distribution(cf.gen_model, log_prob_threshold, cf.disc_model)

## Classifier Line
w1, w2 = list(disc_model.parameters())[0].detach().cpu().numpy()[0]
b = list(disc_model.parameters())[1].detach().cpu().numpy().item()
c = -b / w2
m = -w1 / w2
xmin, xmax = -1.5, 2.5
ymin, ymax = -1.5, 2.5
xd = np.array([xmin, xmax])
yd = m * xd + c
plt.plot(xd, yd, "#ADD8E6", lw=2.0, ls="dashed")
plt.axis("off")

for n in range(N):
    p_hist = p_history[n::N]
    ## Arrows
    for i in range(p_hist.shape[0] - 1):
        ax.arrow(
            p_hist[i, 0],
            p_hist[i, 1],
            p_hist[i + 1, 0] - p_hist[i, 0],
            p_hist[i + 1, 1] - p_hist[i, 1],
            width=0.005,
            lw=0.001,
            length_includes_head=True,
            alpha=0.4,
            color="orange",
        )

    ax.scatter(p_hist[0:1, 0], p_hist[0:1, 1], c="k", s=100, alpha=0.8)
    ax.scatter(p_hist[1:-1, 0], p_hist[1:-1, 1], c="orange", s=20, alpha=0.8)
    ax.scatter(p_hist[-1:, 0], p_hist[-1:, 1], c="#DC143C", s=100, alpha=0.8)
    ax.arrow(
        p_hist[0, 0],
        p_hist[0, 1],
        p_hist[-1, 0] - p_hist[0, 0],
        p_hist[-1, 1] - p_hist[0, 1],
        width=0.01,
        lw=0.001,
        length_includes_head=True,
        alpha=0.5,
        color="k",
    )
    _ = ax.axis("off")

## Save Figure
plt.tight_layout()
plt.savefig("Ours_Dist.pdf")