<a href="https://colab.research.google.com/github/danielsaggau/DICE/blob/master/Cali_DICE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dice-ml
import numpy as np
import timeit
import random
import dice_ml
import pandas as pd
from dice_ml.utils import helpers  # helper functions
from dice_ml import Dice
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import fetch_california_housing

In [6]:
housing = fetch_california_housing()
df_housing = pd.DataFrame(housing.data, columns=housing.feature_names)
df_housing['outcome_name'] = pd.Series(housing.target)
df_housing.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,outcome_name
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [9]:
continuous_features_cali = df_housing.drop('outcome_name', axis=1).columns.tolist()
target = df_housing['outcome_name']

In [12]:
datasetX = df_housing.drop('outcome_name', axis=1)
x_train, x_test, y_train, y_test = train_test_split(datasetX,
                                                    target,
                                                    test_size=0.2,
                                                    random_state=0)
categorical_features = x_train.columns.difference(continuous_features_cali)

numeric_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())])

categorical_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])

transformations = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, continuous_features_cali),
        ('cat', categorical_transformer, categorical_features)])

regr_cali = Pipeline(steps=[('preprocessor', transformations),
                              ('regressor', RandomForestRegressor())])
model_cali = regr_cali.fit(x_train, y_train)

In [13]:
d_cali = dice_ml.Data(dataframe=df_housing, continuous_features=continuous_features_cali, outcome_name='outcome_name')
m_cali = dice_ml.Model(model=model_cali, backend="sklearn", model_type='regressor')
exp_genetic_cali = Dice(d_cali, m_cali, method="genetic")

In [16]:
query_instances_cali = x_train[17:19]
genetic_cali = exp_genetic_cali.generate_counterfactuals(query_instances_cali, 
                                                             total_CFs=5, 
                                                             desired_range=[3-8, 4])
genetic_cali.visualize_as_dataframe(show_only_changes=True)

100%|██████████| 2/2 [00:01<00:00,  1.19it/s]

Query instance (original outcome : 3)





Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,outcome_name
0,6.0145,14.0,6.156749,1.082729,1858.0,2.696662,34.23,-118.87,3.08981



Diverse Counterfactual set (new outcome: [-5, 4])


Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,outcome_name
0,-,-,6.2,1.1,-,2.7,-,-,-
0,6.8645,23.0,7.3,1.0,-,3.5,33.72,-117.96,3.1908903000000013
0,3.9926,-,5.6,1.0,1867.0,3.1,36.32,-119.77,0.9211099999999991
0,3.0543,23.0,4.6,1.2,-,4.1,36.33,-121.23,1.0773100000000015
0,2.3879,-,4.0,1.0,1867.0,4.0,36.5,-121.43,1.258450000000001


Query instance (original outcome : 2)


Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,outcome_name
0,3.6413,40.0,4.604736,1.041894,1318.0,2.400729,34.27,-119.26,2.28974



Diverse Counterfactual set (new outcome: [-5, 4])


Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,outcome_name
0,-,-,4.6,1.0,-,2.4,-,-,-
0,2.625,42.0,4.8,1.0,-,3.5,33.91,-118.27,1.0567099999999996
0,4.3269,-,5.5,1.1,1322.0,5.0,33.8,-117.85,1.9069100000000014
0,1.9688,-,3.9,1.1,1322.0,4.7,34.04,-118.19,1.4529699999999994
0,1.6786,33.0,0.8,1.1,-,3.2,33.9,-118.32,1.6627800999999989
