In [None]:
# default_exp baseline_cfs

In [2]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [4]:
# export
from counterfactual.import_essentials import *
from counterfactual.utils import *
# from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.interface import ABCBaseModule, LocalExplainerBase, GlobalExplainerBase

from torch.nn.parameter import Parameter
from torchmetrics.functional.classification import accuracy

In [None]:
print(f"pl version: {pl.__version__}")
print(f"torch version: {torch.__version__}")

In [None]:
# export 

class Clamp(torch.autograd.Function):
    """
    Clamp parameter to [0, 1]
    code from: https://discuss.pytorch.org/t/regarding-clamped-learnable-parameter/58474/4
    """
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

# Vanilla CF

Wachter, S., Mittelstadt, B., & Russell, C. (2017). Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR. SSRN Electronic Journal. https://doi.org/10.2139/ssrn.3063289

In [None]:
# export

class VanillaCF(LocalExplainerBase):
    def __init__(self, x: torch.tensor, model: BaselineModel, n_iters: int = 1000):
        """vanilla version of counterfactual generation
            - link: https://doi.org/10.2139/ssrn.3063289

        Args:
            x (torch.tensor): input instance
            model (BaselineModel): black-box model
        """
        super().__init__(x, model)
        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)
        self.n_iters = n_iters

    def forward(self):
        cf = self.cf * 1.0
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)
        # return cf

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for col in self.model.cat_arrays:
            cat_idx_end = cat_idx + len(col)
            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y_prime = torch.ones(y_pred.shape) - y_pred

        c_y = self.model(c)
        l_1 = F.binary_cross_entropy(c_y, y_prime.float())
        l_2 = F.mse_loss(x, c)
        return l_1, l_2

    def _loss_compute(self, l_1, l_2):
        return 1.0 * l_1 + 0.5 * l_2

    def generate_cf(self, debug: bool = False):
        optim = self.configure_optimizers()
        clamp = Clamp()
        for i in range(self.n_iters):
            c = self()
            l_1, l_2 = self._loss_functions(self.x, c)
            loss = self._loss_compute(l_1, l_2)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            clamp.apply(self.cf)

        cf = self.cf * 1.0
        clamp.apply(self.cf)
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)

# Diverse CF

Mothilal, R. K., Sharma, A., & Tan, C. (2020). Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency, 607–617. https://doi.org/10.1145/3351095.3372850


In [None]:
# export

class DiverseCF(LocalExplainerBase):
    def __init__(self, x: torch.tensor, model: CounterfactualTrainingModule, n_iters = 1000):
        """diverse counterfactual explanation
            - link: https://doi.org/10.1145/3351095.3372850

        Args:
            x (torch.tensor): input instance
            model (CounterfactualTrainingModule): black-box model
        """
        self.n_cfs = 5
        super().__init__(x, model)
        # self.cf = nn.Parameter(self.x.repeat(self.n_cfs, 1), requires_grad=True)
        self.cf = nn.Parameter(torch.rand(self.n_cfs, self.x.size(1)), requires_grad=True)
        self.n_iters = n_iters

    def forward(self):
        cf = self.cf * 1.0
        return torch.clamp(cf, 0, 1)

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def _compute_dist(self, x1, x2):
        return torch.sum(torch.abs(x1 - x2), dim = 0)

    def _compute_proximity_loss(self):
        """Compute the second part (distance from x1) of the loss function."""
        proximity_loss = 0.0
        for i in range(self.n_cfs):
            proximity_loss += self.compute_dist(self.cf[i], self.x1)
        return proximity_loss/(torch.mul(len(self.minx[0]), self.total_CFs))

    def _dpp_style(self, cf):
        det_entries = torch.ones(self.n_cfs, self.n_cfs)
        for i in range(self.n_cfs):
            for j in range(self.n_cfs):
                det_entries[i, j] = self._compute_dist(cf[i], cf[j])

        # implement inverse distance
        det_entries = 1.0 / (1.0 + det_entries)
        det_entries += torch.eye(self.n_cfs) * 0.0001
        return torch.det(det_entries)

    def _compute_diverse_loss(self, c):
        return self._dpp_style(c)

    def _compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for i in range(self.n_cfs):
            for col in self.model.cat_arrays:
                cat_idx_end = cat_idx + len(col)
                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y_prime = torch.ones(y_pred.shape) - y_pred

        c_y = self.model(c)
        # yloss
        l_1 = hinge_loss(input=c_y, target=y_prime.float())
        # proximity loss
        l_2 = l1_mean(x, c)
        # diverse loss
        l_3 = self._compute_diverse_loss(c)
        # categorical penalty
        l_4 = self._compute_regularization_loss()
        return l_1, l_2, l_3, l_4

    def _compute_loss(self, *loss_f):
        return sum(loss_f)

    def generate_cf(self, debug: bool = False):
        optim = self.configure_optimizers()
        for i in range(self.n_iters):
            c = self()

            l_1, l_2, l_3, l_4 = self._loss_functions(self.x, c)
            loss = self._compute_loss(l_1, l_2, l_3, l_4)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if  debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            self.clamp.apply(self.cf)

        cf = self.cf * 1.0
        cf = torch.clamp(cf, 0, 1)
        # return cf[0]
        return cat_normalize(cf[0].view(1, -1), self.model.cat_arrays, len(self.model.continous_cols), True)

# ProtoCF

In [None]:
# export net

class AE(BaseModule):
    def __init__(self, configs, encoded_size=5):
        super().__init__(configs)
        input_dim = configs['encoder_dims'][0]
        self.encoder_model = MultilayerPerception([input_dim, 20, 16, 14, 12, encoded_size])
        self.decoder_model = MultilayerPerception([encoded_size, 12, 14, 16, 20, input_dim])

    def forward(self, x):
        z = self.encoded(x)
        x_prime = self.decoder_model(z)
        return x_prime

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def encoded(self, x):
        return self.encoder_model(x)

    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('val/val_loss', loss)

        return loss

In [None]:
# export

class ProtoCF(LocalExplainerBase):
    def __init__(self, x: torch.tensor, model: pl.LightningModule, train_loader: DataLoader, ae: AE, n_iters: int = 1000):
        """vanilla version of counterfactual generation
            - link: https://doi.org/10.2139/ssrn.3063289

        Args:
            x (torch.tensor): input instance
            model (pl.LightningModule): black-box model
        """
        super().__init__(x, model)
        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)
        self.sampled_data, _ = next(iter(train_loader))
        self.sampled_label = self.model.predict(self.sampled_data)
        self.ae = ae
        self.ae.freeze()
        self.n_iters = n_iters

    def forward(self):
        cf = self.cf * 1.0
        # return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)
        return cf

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for col in self.model.cat_arrays:
            cat_idx_end = cat_idx + len(col)
            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def proto(self, data):
        return self.ae.encoded(data).mean(axis=0).view(1, -1)

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y = torch.ones(y_pred.shape) - y_pred

        data = self.sampled_data[self.sampled_label == y]

        l_1 = F.binary_cross_entropy(self.model(c), y)
        l_2 = 0.1 * F.l1_loss(x, c) + F.mse_loss(x, c)
        l_3 = F.mse_loss(self.ae.encoded(c), self.proto(data))

        return l_1, l_2, l_3

    def _loss_compute(self, l_1, l_2, l_3):
        return l_1 + l_2 + l_3 #+ self.compute_regularization_loss()

    def generate_cf(self, debug: bool = False):
        optim = self.configure_optimizers()
        for i in range(self.n_iters):
            c = self()

            l_1, l_2, l_3 = self._loss_functions(self.x, c)
            loss = self._loss_compute(l_1, l_2, l_3)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            self.clamp.apply(self.cf)

        cf = self.cf * 1.0
        self.clamp.apply(self.cf)
        # return cf
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)

# VAE-CF

In [None]:
# export net
class VAE(pl.LightningModule):
    def __init__(self, input_dims, encoded_size=5):
        super().__init__()
        self.encoder_mean = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.encoder_var = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.decoder_mean = MultilayerPerception([encoded_size + 1, 12, 14, 16, 20, input_dims])

    def encoder(self, x):
        mean = self.encoder_mean(x)
        logvar = 0.5+ self.encoder_var(x)
        return mean, logvar

    def decoder(self, z):
        mean = self.decoder_mean(z)
        return mean

    def sample_latent_code(self, mean, logvar):
        eps = torch.randn_like(logvar)
        return mean + torch.sqrt(logvar) * eps

    def normal_likelihood(self, x, mean, logvar, raxis=1):
        return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1)

    def forward(self, x, c):
        """
        x: input instance
        c: target y
        """
        c = c.view(c.shape[0], 1)
        c = torch.tensor(c).float()
        res = {}
        mc_samples = 50
        em, ev = self.encoder(torch.cat((x, c), 1))
        res['em'] = em
        res['ev'] = ev
        res['z'] = []
        res['x_pred'] = []
        res['mc_samples'] = mc_samples
        for i in range(mc_samples):
            z = self.sample_latent_code(em, ev)
            x_pred = self.decoder(torch.cat((z, c), 1))
            res['z'].append(z)
            res['x_pred'].append(x_pred)
        return res

    def compute_elbo(self, x, c, model):
        c= c.clone().detach().float()
        c=c.view(c.shape[0], 1)
        em, ev = self.encoder(torch.cat((x,c),1))
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        z = self.sample_latent_code(em, ev)
        dm= self.decoder( torch.cat((z,c),1) )
        log_px_z = torch.tensor(0.0)

        x_pred= dm
        return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, model.predict(x_pred)

In [None]:
# export
class VAE_CF(CounterfactualTrainingModule):
    def __init__(self, config: Dict, model: pl.LightningModule):
        """
        config: basic configs
        model: the black-box model to be explained
        """
        super().__init__(config)
        self.model = model
        self.model.freeze()
        self.vae = VAE(input_dims=self.enc_dims[0])
        # validity_reg set to 42.0
        # according to https://interpret.ml/DiCE/notebooks/DiCE_getting_started_feasible.html#Generate-counterfactuals-using-a-VAE-model
        self.validity_reg = config['validity_reg'] if 'validity_reg' in config.keys() else 1.0

    def model_forward(self, x):
        """lazy implementation since this method is actually not needed"""
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1 - self.model.predict(x), self.model)
        # return y, c
        return cf_label, x_pred

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def predict(self, x):
        return self.model.predict(x)

    def compute_loss(self, out, x, y):
        em = out['em']
        ev = out['ev']
        z = out['z']
        dm = out['x_pred']
        mc_samples = out['mc_samples']
        #KL Divergence
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        #Reconstruction Term
        #Proximity: L1 Loss
        x_pred = dm[0]
        cat_idx = len(self.continous_cols)
        # recon_err = - \
        #     torch.sum(torch.abs(x[:, cat_idx:-1] -
        #                         x_pred[:, cat_idx:-1]), axis=1)
        recon_err = - torch.sum(torch.abs(x - x_pred), axis=1)

        # Sum to 1 over the categorical indexes of a feature
        for col in self.cat_arrays:
            cat_end_idx = cat_idx + len(col)
            temp = - \
                torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
            recon_err += temp

        #Validity
        c_y = self.model(x_pred)
        validity_loss = torch.zeros(1, device=self.device)
        validity_loss += hinge_loss(input=c_y, target=y.float())

        for i in range(1, mc_samples):
            x_pred = dm[i]

            # recon_err += - \
            #     torch.sum(torch.abs(x[:, cat_idx:-1] -
            #                         x_pred[:, cat_idx:-1]), axis=1)
            recon_err += - torch.sum(torch.abs(x - x_pred), axis=1)

            # Sum to 1 over the categorical indexes of a feature
            for col in self.cat_arrays:
                cat_end_idx = cat_idx + len(col)
                temp = - \
                    torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
                recon_err += temp

            #Validity
            c_y = self.model(x_pred)
            validity_loss += hinge_loss(c_y, y.float())

        recon_err = recon_err / mc_samples
        validity_loss = -1 * self.validity_reg * validity_loss / mc_samples

        return -torch.mean(recon_err - kl_divergence) - validity_loss


    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        _, _, _, x_pred, cf_label = self.vae.compute_elbo(x, y, self.model)

        cf_proximity = torch.abs(x - x_pred).sum(dim=1).mean()
        cf_accuracy = accuracy(cf_label, y)

        self.log('val/val_loss', loss)
        self.log('val/proximity', cf_proximity)
        self.log('val/cf_accuracy', cf_accuracy)

        return loss

    def validation_epoch_end(self, val_outs):
        return

    def generate_cf(self, x):
        self.vae.freeze()
        y_hat = self.model.predict(x)
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1.-y_hat, self.model)
        return self.model.cat_normalize(x_pred, hard=True)

# Test

## Configs

In [None]:
m_configs = {
    'data_dir': 'data/s_adult.csv',
    'lr':0.01, 
    'batch_size': 2048,
    'lambda_1': 1.,
    'lambda_2': .01,
    'lambda_3': 1.,
    'threshold': 1., 
    'continous_cols': ['age', 'hours_per_week'],
    'discret_cols': ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender'],
    'encoder_dims': [29, 50, 10],
    'decoder_dims': [10, 10],
    'explainer_dims': [10, 50],
    'loss_1': 'mse',
    'loss_2': 'mse',
    'loss_3': 'mse'
}
# trainer configs
t_configs = {
    'max_epochs': 100,
#     'deterministic': True,
#     'gradient_clip_val': 0.5,
    'num_sanity_val_steps': 0,
#     'callbacks': [early_stopping],
    'accelerator': 'ddp',
    'gpus': 1,
#     debug
#     'weights_summary': 'full',
#     'fast_dev_run': True,
    'track_grad_norm':2
}

## Quickly init a model

In [None]:
model = load_model('../saved_weights/adult/baseline/epoch=55-step=10695.ckpt', 56)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores
x_cont: (32561, 2), x_cat: (32561, 27)
(32561, 29)

  | Name  | Type       | Params | In sizes | Out sizes
------------------------------------------------------------
0 | model | Sequential | 2.3 K  | [1, 29]  | [1, 1]   
------------------------------------------------------------
2.3 K     Trainable params
0         Non-trainable params
2.3 K     Total params


Training: 0it [00:00, ?it/s]




In [None]:
x = torch.rand(1, 29)
x

tensor([[8.7551e-01, 8.3667e-01, 8.5154e-02, 6.7861e-01, 4.8256e-01, 1.4477e-01,
         3.7129e-01, 8.2560e-01, 6.8110e-01, 1.0998e-01, 6.4115e-01, 4.2497e-01,
         3.1698e-01, 6.1735e-01, 5.8713e-01, 8.0798e-01, 7.3314e-06, 5.0367e-01,
         7.4309e-01, 2.9842e-01, 9.4241e-01, 3.8378e-01, 6.7887e-01, 3.9197e-01,
         6.5418e-01, 6.4981e-01, 2.8277e-01, 1.5848e-01, 2.7166e-01]])

## VanillaCF

In [None]:
%%time
cf = VanillaCF(x, model)
cf.generate_cf(1000)

CPU times: user 5.1 s, sys: 0 ns, total: 5.1 s
Wall time: 3.42 s


tensor([[0.3636, 1.0516, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,
         1.0000, 0.0000]], grad_fn=<CopySlices>)

## DiverseCF

In [None]:
%%time
cf = DiverseCF(x, model)
cf.generate_cf(1000)

  return F.l1_loss(x, c, reduction='mean') / x.abs().mean() # MAD


CPU times: user 16.7 s, sys: 10.8 ms, total: 16.7 s
Wall time: 9.01 s


tensor([0.8001, 0.5417, 0.2331, 0.7017, 0.0000, 0.1092, 0.0928, 0.0000, 0.0618,
        0.0000, 0.5129, 0.4254, 0.4727, 0.6172, 0.5870, 0.5621, 0.3842, 0.5036,
        0.7432, 0.2988, 0.9423, 0.3837, 0.6788, 0.5285, 0.6543, 0.6498, 0.3424,
        0.3928, 0.2716], grad_fn=<SelectBackward>)

## ProtoCF

In [None]:
result = train(AE(m_configs), t_configs)
ae = result['module']

GPU available: False, used: False
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
TPU available: None, using: 0 TPU cores
hyper parameters: "batch_size":     128
"continous_cols": ['age', 'hours_per_week']
"data_dir":       ../data/s_adult.csv
"decoder_dims":   [10, 10]
"discret_cols":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']
"encoder_dims":   [29, 50, 10]
"explainer_dims": [10, 50]
"lambda_1":       1.0
"lambda_2":       0.01
"lambda_3":       1.0
"loss_1":         mse
"loss_2":         mse
"loss_3":         mse
"lr":             0.01
"threshold":      1.0
hyper parameters: "batch_size":     128
"continous_cols": ['age', 'hours_per_week']
"data_dir":       ../data/s_adult.csv
"decoder_dims":   [10, 10]
"discret_cols":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']
"encoder_dims":   [29, 50, 10]
"explainer_dims": [10, 50]
"lambda_1":       1.0
"lambda_2":       0.01
"lambda_3":       1.0
"lo

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [None]:
cf = ProtoCF(x=x, model=model, train_loader=ae.train_dataloader(), ae=ae)
cf.generate_cf(1000, debug=True)

ProtoCF initialized.
iter: 0, loss: 6.1465535163879395
l_1: 0.0, l_2: 0.0, l_3: 0.14655336737632751
iter: 100, loss: 6.0259175300598145
l_1: 0.015172960236668587, l_2: 0.0018211350543424487, l_3: 0.00983387790620327
iter: 200, loss: 6.016486167907715
l_1: 0.01386270672082901, l_2: 0.002581034554168582, l_3: 0.001332893269136548
iter: 300, loss: 6.016288757324219
l_1: 0.014074395410716534, l_2: 0.002700776094570756, l_3: 0.0008641568128950894
iter: 400, loss: 6.016404628753662
l_1: 0.01418527215719223, l_2: 0.0027024508453905582, l_3: 0.0008683655178174376
iter: 500, loss: 6.016188144683838
l_1: 0.013960369862616062, l_2: 0.0027024406008422375, l_3: 0.0008763322839513421
iter: 600, loss: 6.016372203826904
l_1: 0.014170968905091286, l_2: 0.002703545382246375, l_3: 0.000849399424623698
iter: 700, loss: 6.016396999359131
l_1: 0.014149093069136143, l_2: 0.002701388904824853, l_3: 0.0008973746444098651
iter: 800, loss: 6.016327381134033
l_1: 0.014108755625784397, l_2: 0.002702898345887661, l

tensor([[0.9893, 0.9910, 0.8431, 0.9772, 0.4053, 0.0504, 0.6392, 0.8146, 0.4124,
         0.5149, 0.6525, 0.9059, 0.5502, 0.0373, 0.3220, 0.4081, 0.0209, 0.7489,
         0.6025, 0.2613, 0.9205, 0.7139, 0.0665, 0.4821, 0.3095, 0.4590, 0.4964,
         0.9253, 0.7672]], grad_fn=<MulBackward0>)

## VAE-CF

In [None]:
cf = VAE_CF(m_configs, model=model)
result = train(
    cf, 
    t_configs,
    logger=pl_loggers.TestTubeLogger(Path('../log/'), name="adult/vae")
)

x = torch.rand(100, 29)
result['module'].generate_cf(x)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
hyper parameters: "batch_size":     2048
"continous_cols": ['age', 'hours_per_week']
"data_dir":       data/s_adult.csv
"decoder_dims":   [10, 10]
"discret_cols":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']
"encoder_dims":   [29, 50, 10]
"explainer_dims": [10, 50]
"lambda_1":       1.0
"lambda_2":       0.01
"lambda_3":       1.0
"loss_1":         mse
"loss_2":         mse
"loss_3":         mse
"lr":             0.01
"threshold":      1.0
x_cont: (32561, 2), x_cat: (32561, 27)
(32561, 29)

  | Name  | Type          | Params | In sizes | Out sizes
---------------------------------------------------------------
0 | model | BaselineModel | 2.3 K  | [1, 29]  | [1]      
1 | vae   | VAE           | 4.8 K  | ?        | ?        
--------------------------------------------------------------

Epoch 0:   0%|          | 0/16 [00:00<?, ?it/s] 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd0bc471830>
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd0bc471830>
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
        self._shutdown_workers()

# Misc

In [None]:
dummy = pd.read_csv('../data/dummy_data.csv')
dummy[:5]

Unnamed: 0,x1,x2,x3,y
0,28.869472,75.537317,13.732009,0.0
1,56.541628,51.057476,14.176149,0.0
2,54.259902,46.058342,12.92356,1.0
3,43.165512,56.31358,12.536208,1.0
4,24.003729,26.398063,10.360779,1.0


In [None]:
input_instance = dummy[7500:]
x = model.transform(input_instance)

cf = VanillaCF(x=x, model=model)
r = cf.generate_cf(10000).detach()


iter: 0, loss: 1.7241566181182861
iter: 100, loss: 0.14336611330509186
iter: 200, loss: 0.10569803416728973


KeyboardInterrupt: 

In [None]:
model.freeze()
r = torch.tensor(r)
model.check_cont_robustness(x, r, model.predict(r))

  r = torch.tensor(r)


(tensor(0), 0)

In [None]:
def proximity(x, c):
    return torch.abs(x - c).sum(dim=1).mean()

In [None]:
def cf_accuracy(c_y, y_hat):
    return accuracy(c_y > .5, y_hat < .5)

In [None]:
proximity(x, r)

tensor(0.6029)

In [None]:
y, c = model(x)
c_y, _ = model(r)
cf_accuracy(c_y, y)

tensor(1.)

In [None]:
y, c = model(x)
proximity(x, c)

tensor(0.2822, grad_fn=<MeanBackward0>)

In [None]:
c_y, _ = model(c)
cf_accuracy(c_y, y)

tensor(0.9548)

In [None]:
# model configs
m_configs = {
    'data_dir': '../data/dummy_data.csv',
    'lr':3e-4, 
    'batch_size': 128,
    'lambda_1': 1.,
    'lambda_2': 0.5,
    'lambda_3': 1.,
    'threshold': 1, 
    'continous_cols': ['x1', 'x2', 'x3',],
    'discret_cols': [], 
    'encoder_dims': [3, 100, 10],
    'decoder_dims': [10, 10],
    'explainer_dims': [10, 10]
}
# trainer configs
t_configs = {
    'max_epochs': 100,
#     'checkpoint_callback': checkpoint_callback,
#     'callbacks': [early_stopping]
#     'gpus': 1,
#     debug
#     'weights_summary': 'full',
#     'fast_dev_run': True,
    'track_grad_norm':2
}

In [None]:
result = train(CounterfactualModel(m_configs), t_configs,logger_name = "debug")
model = result['module']

GPU available: False, used: False
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
TPU available: None, using: 0 TPU cores
hyper parameters: "batch_size":     128
"continous_cols": ['x1', 'x2', 'x3']
"data_dir":       ../data/dummy_data.csv
"decoder_dims":   [10, 10]
"discret_cols":   []
"encoder_dims":   [3, 100, 10]
"explainer_dims": [10, 10]
"lambda_1":       1.0
"lambda_2":       0.5
"lambda_3":       1.0
"lr":             0.0003
"threshold":      1
hyper parameters: "batch_size":     128
"continous_cols": ['x1', 'x2', 'x3']
"data_dir":       ../data/dummy_data.csv
"decoder_dims":   [10, 10]
"discret_cols":   []
"encoder_dims":   [3, 100, 10]
"explainer_dims": [10, 10]
"lambda_1":       1.0
"lambda_2":       0.5
"lambda_3":       1.0
"lr":             0.0003
"threshold":      1
x_cont: (10000, 3), x_cat: (10000, 0)
x_cont: (10000, 3), x_cat: (10000, 0)
(10000, 3)
(10000, 3)

  | Name          | Type                 | Params | In sizes | Out sizes
----------

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






In [None]:
model = CounterfactualModel(m_configs)
trainer = pl.Trainer(max_epochs=63, resume_from_checkpoint="../log/debug/version_0/checkpoints/epoch=62-step=3716.ckpt")
trainer.fit(model)

GPU available: False, used: False
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
TPU available: None, using: 0 TPU cores
x_cont: (10000, 3), x_cat: (10000, 0)
x_cont: (10000, 3), x_cat: (10000, 0)
(10000, 3)
(10000, 3)

  | Name          | Type                 | Params | In sizes | Out sizes
------------------------------------------------------------------------------
0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10]  
1 | predictor     | Sequential           | 141    | [1, 10]  | [1, 1]   
2 | explainer     | Sequential           | 163    | [1, 10]  | [1, 3]   
------------------------------------------------------------------------------
1.9 K     Trainable params
0         Non-trainable params
1.9 K     Total params

  | Name          | Type                 | Params | In sizes | Out sizes
------------------------------------------------------------------------------
0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1