Based on [this DiCE example notebook](https://github.com/interpretml/DiCE/blob/78ca5391467ba12e38730e71577fbe421d9f0ba2/docs/source/notebooks/DiCE_with_advanced_options.ipynb).

In [30]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions

# Tensorflow libraries
import tensorflow as tf
from tensorflow import keras

In [31]:
dataset = helpers.load_adult_income_dataset()

In [32]:
dataset.head()

Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,39,Government,Bachelors,Single,White-Collar,White,Male,40,0
1,50,Self-Employed,Bachelors,Married,White-Collar,White,Male,13,0
2,38,Private,HS-grad,Divorced,Blue-Collar,White,Male,40,0
3,53,Private,School,Married,Blue-Collar,Other,Male,40,0
4,28,Private,Bachelors,Married,Professional,Other,Female,40,0


In [33]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')

In [34]:
# seeding random numbers for reproducability
from numpy.random import seed
seed(1)
tf.random.set_seed(2)

In [6]:
train, _ = d.split_data(d.normalize_data(d.one_hot_encoded_data))
X_train = train.loc[:, train.columns != 'income']
y_train = train.loc[:, train.columns == 'income']

ann_model = keras.Sequential()
ann_model.add(keras.layers.Dense(20, input_shape=(X_train.shape[1],), kernel_regularizer=keras.regularizers.l1(0.001), activation=tf.nn.relu))
ann_model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid))

ann_model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.01), metrics=['accuracy'])
ann_model.fit(X_train, y_train, validation_split=0.20, epochs=100, verbose=0, class_weight={0:1,1:2})
# the training will take some time for 100 epochs.
# you can wait or set verbose=1 to see the progress of training.

<tensorflow.python.keras.callbacks.History at 0x7fbf37115160>

In [9]:
X_train.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 26048 entries, 16313 to 10863
Data columns (total 29 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   age                       26048 non-null  float64
 1   hours_per_week            26048 non-null  float64
 2   workclass_Government      26048 non-null  uint8  
 3   workclass_Other/Unknown   26048 non-null  uint8  
 4   workclass_Private         26048 non-null  uint8  
 5   workclass_Self-Employed   26048 non-null  uint8  
 6   education_Assoc           26048 non-null  uint8  
 7   education_Bachelors       26048 non-null  uint8  
 8   education_Doctorate       26048 non-null  uint8  
 9   education_HS-grad         26048 non-null  uint8  
 10  education_Masters         26048 non-null  uint8  
 11  education_Prof-school     26048 non-null  uint8  
 12  education_School          26048 non-null  uint8  
 13  education_Some-college    26048 non-null  uint8  
 14  ma

In [7]:
backend = 'TF'+tf.__version__[0] # TF2
m = dice_ml.Model(model=ann_model, backend=backend)

In [8]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age':22, 
                  'workclass':'Private', 
                  'education':'HS-grad', 
                  'marital_status':'Single', 
                  'occupation':'Service',
                  'race': 'White', 
                  'gender':'Female', 
                  'hours_per_week': 45}

In [9]:
def counterfactual_modeling(data, model, query, number_CFs, visualizing):
    exp = dice_ml.Dice(data, model)
    dice_exp = exp.generate_counterfactuals(query, total_CFs=number_CFs, desired_class="opposite")
    
    if visualizing is True:
        dice_exp.visualize_as_dataframe(show_only_changes=True)
    else:
        dice_exp.visualize_as_dataframe()

In [15]:
counterfactual_modeling(d,m,query_instance,4,True)

Diverse Counterfactuals found! total time taken: 00 min 38 sec
Query instance (original outcome : 0)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,22.0,Private,HS-grad,Single,Service,White,Female,45.0,0.009411



Diverse Counterfactual set (new outcome : 1)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,57.0,-,Doctorate,-,White-Collar,-,-,-,0.724
1,36.0,-,Prof-school,Married,-,-,-,37.0,0.869
2,-,Self-Employed,Doctorate,Married,-,-,-,-,0.755
3,43.0,-,-,Married,White-Collar,-,-,63.0,0.822


In [11]:
# get MAD
mads = d.get_mads(normalized=True)

# create feature weights
feature_weights = {}
for feature in mads:
    feature_weights[feature] = round(1/mads[feature], 2)
print(feature_weights)


{'age': 7.3, 'hours_per_week': 24.5}


In [12]:
import ipywidgets as widgets
from IPython.display import display
x = False
slider = widgets.Checkbox(description="hello")
slider.value = x
def on_change(v):
    if v['new'] :
        print(3)
    else:
        print(4)
slider.observe(on_change, names='value')
display(slider)

Checkbox(value=False, description='hello')

3
4
3
4
3
4


In [105]:
from ipywidgets import interact
from IPython.display import display, clear_output

column_names=[]
for i in list(dataset.columns):
    column_names.append(i)

print('choose the output:')

dropdown = widgets.Dropdown(options=column_names)

output_var=[]
def get_and_plot(b):
    #clear_output()
    out_var=dropdown.value
    #print(out_var)
    output_var.append(out_var)
    

display(dropdown)
dropdown.observe(get_and_plot, names='value')

choose the output:


Dropdown(options=('age', 'workclass', 'education', 'marital_status', 'occupation', 'race', 'gender', 'hours_pe…

In [106]:
feature_names=[]
for i in column_names:
    if i!=output_var[0]:
        feature_names.append(i)

#data
print('continous feature?')
for i in feature_names:
    cont_feat=display(Checkbox(description=i))

continous feature?


Checkbox(value=False, description='age')

Checkbox(value=False, description='workclass')

Checkbox(value=False, description='education')

Checkbox(value=False, description='marital_status')

Checkbox(value=False, description='occupation')

Checkbox(value=False, description='race')

Checkbox(value=False, description='gender')

Checkbox(value=False, description='hours_per_week')

In [102]:
#query
print('Input Query:')
for i in feature_names:
    input_query=widgets.Text(
    value='',
    description=i,
    disabled=False)
    display(input_query)

Input Query:


Text(value='', description='age')

Text(value='', description='workclass')

Text(value='', description='education')

Text(value='', description='marital_status')

Text(value='', description='occupation')

Text(value='', description='race')

Text(value='', description='gender')

Text(value='', description='hours_per_week')

In [108]:
num_exp = widgets.IntSlider(description='Number of Explanations', min=1, max=5)
display(num_exp)

IntSlider(value=1, description='Number of Explanations', max=5, min=1)

In [123]:
from ipywidgets import Checkbox, VBox

children1=[]
for i in feature_names:
    new_check=Checkbox(description=i)
    children1.append(new_check)

vb = VBox(children = children1)
top_toggle = Checkbox(description='Assign Feature Weights?')

def add_checks(button):
    if button['new']:
        vb.children = children1
        display(vb)
    else:
        vb.children=[]


top_toggle.observe(add_checks, names='value')
display(top_toggle)

Checkbox(value=False, description='Assign Feature Weights?')

VBox(children=(Checkbox(value=False, description='age'), Checkbox(value=False, description='workclass'), Check…

In [122]:
from ipywidgets import interactive
from IPython.display import display
from ipywidgets import Checkbox, VBox

cb1 = widgets.FloatSlider(description='Proximity Weight', max=10)
cb2 = widgets.FloatSlider(description='Diversity Weight', max=10)


#interact(f, x=widgets.IntSlider(min=-10, max=30, step=1, value=10));

vb = VBox(children = [cb1, cb2])
top_toggle = Checkbox(description='Tune proximity/diversity?')

def add_2(button):
    if button['new']:
        vb.children = [cb1, cb2]
        display(vb)
    else:
        vb.children=[]

top_toggle.observe(add_2, names='value')
display(top_toggle)

Checkbox(value=False, description='Tune proximity/diversity?')

VBox(children=(FloatSlider(value=0.0, description='Proximity Weight', max=10.0), FloatSlider(value=0.0, descri…

VBox(children=(FloatSlider(value=0.0, description='Proximity Weight', max=10.0), FloatSlider(value=0.0, descri…

In [94]:
new_widg=widgets.Checkbox(
    #value=False,
    description='Only highlight changes in explanations',
    disabled=False,
    indent=False
)
display(new_widg)

def variable(button):
    if button['new']:
        x=1
    else:
        x=2
        
        

Checkbox(value=False, description='Only highlight changes in explanations', indent=False)

In [125]:
button = widgets.Button(
    description='Generate!',
)
button

@button.on_click
def plot_on_click(b):
    print('hi', 3)
    
display(button)

Button(description='Generate!', style=ButtonStyle())

hi 3
