In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt

from counterfactuals.datasets import MoonsDataset, LawDataset
from counterfactuals.cf_methods.regional_ppcef import RPPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import LogisticRegression
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics.metrics import evaluate_cf

In [None]:
dataset = LawDataset("../data/law.csv")
train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)

In [None]:
dataset = MoonsDataset("../data/moons.csv")
train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True)
test_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)

In [None]:
disc_model = LogisticRegression(dataset.X_test.shape[1], 1)
disc_model.load("../models/MoonsDataset/disc_model_LogisticRegression.pt")
# disc_model.fit(train_dataloader, test_dataloader, epochs=3000, lr=0.003)

In [None]:
np.mean((disc_model.predict(dataset.X_test).numpy() == dataset.y_test))

In [None]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1], hidden_features=4, context_features=1
)
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()
gen_model.load("../models/MoonsDataset/gen_model_MaskedAutoregressiveFlow.pt")
# gen_train_dataloader = dataset.train_dataloader(batch_size=1024, shuffle=True, noise_lvl=0.03)
# gen_model.fit(train_dataloader, test_dataloader, num_epochs=2000, patience=50)

In [None]:
dataset.X_test[dataset.y_test == 0].shape[0]

In [None]:
cf = RPPCEF(
    K=dataset.X_test[dataset.y_test == 0].shape[0],
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
    neptune_run=None,
)
# cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.from_numpy(dataset.X_test[dataset.y_test == 1]),
        torch.from_numpy(dataset.y_test[dataset.y_test == 1]),
    ),
    batch_size=1024,
    shuffle=False,
)
median_log_prob = torch.median(gen_model.predict_log_prob(cf_dataloader))
print(median_log_prob)
deltas, X_orig, y_orig, y_target, logs = cf.search_batch(
    cf_dataloader, alpha=100, median_log_prob=median_log_prob, epochs=10000
)

In [None]:
logs["cf_search/loss_disc"][-10:]

In [None]:
for i, (log_name, log_vals) in enumerate(logs.items()):
    plt.subplot(len(logs), 1, i + 1)
    plt.plot(log_vals, label=log_name)
    plt.legend()

In [None]:
M, S, D = deltas[0].get_matrices()

In [None]:
S.shape

In [None]:
(S.sum(axis=0) != 0).sum()

In [None]:
Xs_cfs = X_orig + deltas[0]().detach().numpy()

model_returned = np.ones(Xs_cfs.shape[0]).astype(bool)


metrics = evaluate_cf(
    gen_model=gen_model,
    disc_model=disc_model,
    X_cf=Xs_cfs,
    model_returned=model_returned,
    categorical_features=dataset.categorical_features,
    continuous_features=dataset.numerical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train.reshape(-1),
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=median_log_prob,
    S_matrix=S.detach().numpy(),
)
print(metrics)

In [None]:
X_cf = X_orig + deltas[0]().detach().numpy()

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))

i = y_orig.reshape(-1) == 1
ax.scatter(dataset.X_test[:, 0], dataset.X_test[:, 1], c=dataset.y_test, alpha=0.5)
ax.scatter(X_cf[i, 0], X_cf[i, 1], c="r")
for before, after in zip(X_orig[i], X_cf[i]):
    ax.arrow(
        before[0],
        before[1],
        after[0] - before[0],
        after[1] - before[1],
        head_width=0.0,
        head_length=0.0,
        fc="gray",
        ec="gray",
        alpha=0.5,
        width=0.0001,
    )

In [None]:
import torch

In [None]:
torch.rand(10, 2)
torch.linalg.vector_norm(torch.rand(10, 2), dim=1, ord=2)

In [None]:
column_mapping = {
    "parameters/dataset": "Dataset",
    "metrics/cf/K_vectors": "K",
    "metrics/cf/valid_cf_disc": "Validity",
    "metrics/cf/flow_prob_condition_acc": "Prob. Plaus",
    "metrics/cf/cf_belongs_to_group": "CFs assigned to Group",
    "metrics/cf/flow_log_density_cfs": "Log Dens.",
    "metrics/cf/dissimilarity_proximity_continuous_euclidean": "L2",
    "metrics/cf/dissimilarity_proximity_continuous_manhatan": "L1",
}