In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from matplotlib import pyplot as plt

from counterfactuals.datasets import MoonsDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.losses import BinaryDiscLoss

In [None]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)

In [None]:
disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.fit(train_dataloader, test_dataloader, epochs=1000, lr=0.01)

In [None]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=8, context_features=1
)
gen_train_dataloader = dataset.train_dataloader(
    batch_size=1024, shuffle=True, noise_lvl=0.03
)
gen_model.fit(train_dataloader, test_dataloader, num_epochs=1000)

In [None]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
median_log_prob = torch.median(gen_model.predict_log_prob(cf_dataloader))
print(median_log_prob)
deltas, X_orig, y_orig, y_target, logs = cf.search_batch(
    cf_dataloader, alpha=100, median_log_prob=median_log_prob, epochs=4000
)

In [None]:
logs["cf_search/loss_disc"][-10:]

In [None]:
for i, (log_name, log_vals) in enumerate(logs.items()):
    plt.subplot(len(logs), 1, i + 1)
    plt.plot(log_vals, label=log_name)
    plt.legend()

In [None]:
X_cf = X_orig + deltas

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))

i = y_orig.reshape(-1) == 1
ax.scatter(X_orig[:, 0], X_orig[:, 1], c=y_orig)
ax.scatter(X_orig[i, 0] + deltas[i, 0], X_orig[i, 1] + deltas[i, 1], c="r")
for before, after in zip(X_orig[i], X_cf[i]):
    ax.arrow(
        before[0],
        before[1],
        after[0] - before[0],
        after[1] - before[1],
        head_width=0.0,
        head_length=0.0,
        fc="gray",
        ec="gray",
        alpha=0.5,
        width=0.0001,
    )

In [None]:
X_cf