In [1]:
import torch


from counterfactuals.datasets import AdultDataset
from counterfactuals.dequantization import DequantizationOriginal

In [2]:
dataset = AdultDataset("../data/adult.csv")

In [3]:
X_test = dataset.X_test

In [8]:
torch.from_numpy(X_test[0:2])

tensor([[0.4110, 0.3980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.1096, 0.3980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000,
         0.0000, 0.0000, 0.0000]])

In [5]:
deq = DequantizationOriginal(dataset.categorical_features_lists)

In [6]:
x_deq = deq(torch.from_numpy(X_test[0:2]))
x_deq

tensor([[ 0.4110,  0.3980,  3.3353, -1.0966, -2.6848],
        [ 0.1096,  0.3980,  3.4401,  0.4775, -0.2634]])

In [9]:
x_out = deq(x_deq, reverse=True)

In [10]:
torch.from_numpy(X_test[0:2]) == x_out

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True]])

In [51]:
import torch
import torch.nn as nn


class DequantizationOriginal(nn.Module):
    def __init__(self, categorical_features_lists, alpha=1e-5):
        """
        Args:
            alpha: small constant that is used to scale the original input.
                    Prevents dealing with values very close to 0 and 1 when inverting the sigmoid
            quants: Number of possible discrete values (usually 256 for 8-bit image)
        """
        super().__init__()
        self.alpha = alpha
        self.categorical_features_lists = categorical_features_lists
        self.numerical_features = list(range(0, categorical_features_lists[0][0]))
        self.quants = torch.Tensor([len(x) for x in categorical_features_lists])

    def forward(self, z, ldj=None, reverse=False):
        num_feat = z[:, self.numerical_features]
        if not reverse:
            z = self.from_one_hot(z)
            z, ldj = self.dequant(z, ldj)
            z, ldj = self.sigmoid(z, ldj, reverse=True)
        else:
            z = z[:, len(self.numerical_features) :]
            z, ldj = self.sigmoid(z, ldj, reverse=False)
            z = z * self.quants
            # ldj += np.log(self.quants) * np.prod(z.shape[1:])
            for i in range(len(self.quants)):
                z[:, i] = (
                    torch.floor(z[:, i])
                    .clamp(min=0, max=self.quants[i] - 1)
                    .to(torch.int32)
                )
        return torch.hstack([num_feat, z])  # , ldj

    def from_one_hot(self, z):
        cat_feat = torch.zeros((z.shape[0], len(self.categorical_features_lists)))
        for i, i_cat in enumerate(self.categorical_features_lists):
            cat_feat[:, i] = torch.argmax(z[:, i_cat], dim=1)
        return cat_feat

    def sigmoid(self, z, ldj=None, reverse=False):
        # Applies an invertible sigmoid transformation
        if not reverse:
            # ldj += (-z - 2 * F.softplus(-z)).sum()
            z = torch.sigmoid(z)
        else:
            z = (
                z * (1 - self.alpha) + 0.5 * self.alpha
            )  # Scale to prevent boundaries 0 and 1
            # ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])
            # ldj += (-torch.log(z) - torch.log(1 - z)).sum()
            z = torch.log(z) - torch.log(1 - z)
        return z, ldj

    def dequant(self, z, ldj=None):
        # Transform discrete values to continuous volumes
        z = z.to(torch.float32)
        z = z + torch.rand_like(z).detach()
        z = z / self.quants
        # ldj -= np.log(8) * np.prod(z.shape[1:])
        return z, ldj

In [57]:
deq = DequantizationOriginal(dataset.categorical_features_lists)

In [58]:
x = torch.from_numpy(X_test[:5]).float()

In [102]:
x

tensor([[0.4110, 0.3980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.1096, 0.3980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000,
         0.0000, 0.0000, 0.0000],
        [0.6301, 0.3980, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 0.0000],
        [0.3425, 0.5510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000],
        [0.3288, 0.6531, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,
         0.0000, 0.0000, 0.0000]])

In [103]:
deq_x = deq(x, reverse=False)
deq_x

tensor([[ 0.4110,  0.3980,  3.0588, -0.8408, -2.1959],
        [ 0.1096,  0.3980,  2.1952,  0.8099, -0.4018],
        [ 0.6301,  0.3980, -1.2952, -1.1783,  0.4731],
        [ 0.3425,  0.5510,  3.7515, -0.8798,  6.3128],
        [ 0.3288,  0.6531,  0.5378, -0.9485, -0.5795]])

In [104]:
deq(deq_x, reverse=True)

tensor([[0.4110, 0.3980, 7.0000, 1.0000, 0.0000],
        [0.1096, 0.3980, 7.0000, 3.0000, 2.0000],
        [0.6301, 0.3980, 1.0000, 1.0000, 3.0000],
        [0.3425, 0.5510, 7.0000, 1.0000, 5.0000],
        [0.3288, 0.6531, 5.0000, 1.0000, 2.0000]])