In [None]:
import torch
import numpy as np

from counterfactuals.datasets import PolishBankDataset, WineDataset
from counterfactuals.discriminative_models import MultilayerPerceptron

# from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.metrics.metrics import evaluate_cf

from sklearn.metrics import classification_report

import matplotlib.pyplot as plt
from matplotlib import cm

In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
device = "cpu"

In [None]:
# dataset = BlobsDataset("../data/blobs.csv")
dataset = PolishBankDataset("../data/polish_bankruptcy.csv")
# dataset = MoonsDataset("../data/moons.csv")
dataset = WineDataset("../data/wine.csv")
train_dataloader = dataset.train_dataloader(
    batch_size=128,
    shuffle=True,
    noise_lvl=0,
)
test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)

In [None]:
disc_model = MultilayerPerceptron(
    input_size=dataset.X_test.shape[1], target_size=3, hidden_layer_sizes=[256, 128]
)
disc_model.fit(
    dataset.train_dataloader(128, True),
    dataset.test_dataloader(128, False),
    epochs=500,
    lr=1e-5,
)
preds = disc_model.predict(dataset.X_test)
print(classification_report(dataset.y_test.flatten(), preds.numpy()))

In [None]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_test.shape[1],
    hidden_features=8,
    num_layers=2,
    context_features=1,
    device=device,
)
gen_model.fit(train_dataloader, test_dataloader, num_epochs=2000)

In [None]:
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(8, 5)

xline = torch.linspace(-0.5, 1.5, 200)
yline = torch.linspace(-0.5, 1.5, 200)
xgrid, ygrid = torch.meshgrid(xline, yline)
xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

context_zero = torch.full((xyinput.shape[0], 1), 0, dtype=torch.float32)
context_one = torch.full((xyinput.shape[0], 1), 1, dtype=torch.float32)
context_two = torch.full((xyinput.shape[0], 1), 2, dtype=torch.float32)

with torch.no_grad():
    zgrid = gen_model(xyinput, context_one)
    zgrid = zgrid.reshape(200, 200).exp().numpy()

cs = ax.contourf(xgrid.numpy(), ygrid.numpy(), zgrid, levels=100, cmap=cm.PuBu_r)
cbar = fig.colorbar(cs)  # noqa: F841

In [None]:
y_preds = []
with torch.no_grad():
    for x, y in zip(dataset.X_test, dataset.y_test):
        x = torch.from_numpy(x).view(1, -1)
        y_zero = torch.Tensor([0])
        y_one = torch.Tensor([1])
        y_two = torch.Tensor([2])
        y_pred = np.argmax(
            [
                gen_model(x, y_zero).item(),
                gen_model(x, y_one).item(),
                gen_model(x, y_two).item(),
            ]
        )
        y_preds.append(y_pred)
print(classification_report(dataset.y_test.flatten(), y_preds))

In [28]:
# class MulticlassDiscLoss(torch.nn.modules.loss._Loss):
#     def __init__(
#         self, size_average=None, reduce=None, reduction: str = "mean", eps=0.02
#     ) -> None:
#         super().__init__(size_average, reduce, reduction)
#         self.eps = eps

#     def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
#         return torch.nn.functional.cross_entropy(input, target.view(-1))


class MulticlassDiscLoss(torch.nn.modules.loss._Loss):
    def __init__(
        self, size_average=None, reduce=None, reduction: str = "mean", eps=0.02
    ) -> None:
        super().__init__(size_average, reduce, reduction)
        self.eps = eps

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        one_hot = torch.eye(3)[target][:, 0, :]
        dot_product = torch.sum(input * one_hot, dim=1)
        return torch.norm(dot_product - torch.max(input, dim=1).values)

In [29]:
loss_fn = MulticlassDiscLoss()

In [30]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=loss_fn,
    neptune_run=None,
    device="cpu",
)

In [None]:
delta = torch.median(gen_model.predict_log_prob(train_dataloader))
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
Xs_cfs, Xs, ys_orig, ys_target, loss_components = cf.search_batch(
    cf_dataloader,
    epochs=5000,
    patience=200,
    lr=0.005,
    alpha=100,
    delta=delta,
)

In [None]:
loss_components.keys()

In [None]:
for i, (k, v) in enumerate(loss_components.items(), 1):
    plt.subplot(4, 1, i)
    plt.plot(v, label=k)
    plt.legend()

In [34]:
metrics = evaluate_cf(
    gen_model=gen_model,
    disc_model=disc_model,
    y_target=ys_target,
    X_cf=Xs_cfs,
    model_returned=np.ones(Xs_cfs.shape[0], dtype=bool),
    categorical_features=dataset.categorical_features,
    continuous_features=dataset.numerical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train.reshape(-1),
    X_test=dataset.X_test,
    y_test=dataset.y_test,
    delta=delta.numpy(),
)