In [204]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..'))


import numpy as np
import torch
from torch import nn


import dice_ml


from data import data_loader
from models import model_loader, model_constants

In [91]:
dataset, dataset_info = data_loader.load_data(data_loader.DatasetName.CREDIT_CARD_DEFAULT)
model = model_loader.load_model(model_constants.ModelType.LOGISTIC_REGRESSION, data_loader.DatasetName.CREDIT_CARD_DEFAULT)
adapter = model.adapter

### Get the weights and bias from the logistic regression model

In [123]:
w = model.model.coef_
b = model.model.intercept_

[[ 0.17913481 -0.08972634 -0.67712912 -0.05839876 -0.12636706 -0.08014985
  -0.10174463 -0.09461647  0.05835037 -0.07734015 -0.07670007 -0.13692418
   0.05415821  0.14608283  0.20286413  0.16240486  0.04927758  0.01076174
   0.00396594  0.04528845]]
[0.18713934]


## Create a simple LR model and verify that it returns the same values as the original model

In [206]:
def sigmoid(z):
    return 1 / (1 + np.exp(-z))


def fake_lr(coef, intercept, data):
    return sigmoid(data @ coef[0] + intercept)

data = adapter.transform(dataset).iloc[:10].drop("Y", axis=1)
print("Original model predictions:")
print(model.predict_pos_proba(dataset.iloc[:10]).to_numpy())

print("New model predictions:")
print(fake_lr(w, b, data.to_numpy()))

Original model predictions:
[0.17976778 0.54215535 0.62426223 0.59346296 0.61332876 0.58984144
 0.8786564  0.6438568  0.56752424 0.59791674]
New model predictions:
[0.17976778 0.54215535 0.62426223 0.59346296 0.61332876 0.58984144
 0.8786564  0.6438568  0.56752424 0.59791674]


# Implement the model in Pytorch

In [207]:
class TorchLR(nn.Module):
    def __init__(self, w, b):
        super().__init__()
        self.layer = nn.Linear(w.shape[0], 1, dtype=torch.float32)
        with torch.no_grad():
            self.layer.weight.copy_(torch.tensor(w))
            self.layer.bias.copy_(torch.tensor(b))
        self.sigmoid = nn.Sigmoid()

    def forward(self, data):
        return self.sigmoid(self.layer(data))


### Test the PyTorch model with DICE

In [209]:
clf = TorchLR(w[0], b)
data = adapter.transform(dataset)
m = dice_ml.Model(model=clf, backend="PYT")
d = dice_ml.Data(dataframe=data, continuous_features=dataset_info.continuous_features, outcome_name=adapter.label_column)
dice = dice_ml.Dice(data_interface=d, model_interface=m, method="gradient")

In [168]:
POIs = data[data.Y == -1].iloc[[1039, 192, 5, 242, 23]].drop("Y", axis=1)

Running the recourse with the same arguments produces the same results.

In [180]:
result = dice.generate_counterfactuals(query_instances=POIs.iloc[:1], total_CFs=2, desired_class=1)
adapter.inverse_transform(result.cf_examples_list[0].final_cfs_df)

100%|██████████| 1/1 [02:19<00:00, 139.30s/it]

Diverse Counterfactuals found! total time taken: 00 min 16 sec





The recourse also ignores the various `weight` arguments

In [199]:
result3 = dice.generate_counterfactuals(query_instances=POIs.iloc[:1], total_CFs=2, desired_class=1, sparsity_weight=0, diversity_weight=0, proximity_weight=5)
adapter.inverse_transform(result3.cf_examples_list[0].final_cfs_df)

100%|██████████| 1/1 [03:10<00:00, 190.40s/it]

Diverse Counterfactuals found! total time taken: 00 min 20 sec





In [203]:
result4 = dice.generate_counterfactuals(query_instances=POIs.iloc[:1], total_CFs=2, desired_class=1, sparsity_weight=0.2, diversity_weight=0.2, proximity_weight=0.5)
adapter.inverse_transform(result4.cf_examples_list[0].final_cfs_df)

100%|██████████| 1/1 [02:20<00:00, 140.22s/it]

Diverse Counterfactuals found! total time taken: 00 min 21 sec





Unnamed: 0,LIMIT_BAL,AGE,PAY_1,PAY_2,PAY_3,PAY_4,PAY_5,PAY_6,BILL_AMT1,BILL_AMT2,...,BILL_AMT4,BILL_AMT5,BILL_AMT6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,Y
0,297753.9375,40.8102,0.344048,0.308905,0.29519,0.254095,0.21481,0.223286,124630.851562,120197.101562,...,107569.539062,100945.46875,98215.96875,5625.100586,12886.78418,8929.982422,8806.604492,9832.210938,10164.326172,1
1,297753.9375,41.138721,0.344048,0.308905,0.29519,0.254095,0.21481,0.223286,124630.851562,120197.101562,...,107569.539062,100945.46875,98215.96875,5625.100586,12360.105469,8627.44043,9302.266602,10270.052734,10793.949219,1
