# Counterfactuals Training Data Extraction Experiment

In [1]:
import pandas as pd
from sklearnex import patch_sklearn
patch_sklearn()
import sklearn.ensemble as es
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
import random
import logging
import warnings
import dice_ml

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [2]:
threads = 15

logging.basicConfig()

logger = logging.getLogger('xai-privacy')

In [3]:
from experiment_setup import run_all_experiments
from experiment_setup import get_heart_disease_dataset
from experiment_setup import get_census_dataset

In [4]:
DATASET_HALF = True

data_heart_dict, data_heart_num_dict, data_heart_cat_dict = get_heart_disease_dataset(halve_dataset=DATASET_HALF)
data_census_dict, data_census_num_dict, data_census_cat_dict = get_census_dataset(halve_dataset=DATASET_HALF)

data_heart = data_heart_dict['dataset']
data_heart_num = data_heart_num_dict['dataset']
data_heart_cat = data_heart_cat_dict['dataset']
data_census = data_census_dict['dataset']
data_census_num = data_census_num_dict['dataset']
data_census_cat = data_census_cat_dict['dataset']
outcome_name_heart = data_heart_dict['outcome']
numeric_features_heart = data_heart_dict['num']

Feature Age: removed 0 rows for missing values.
Feature RestingBP: removed 59 rows for missing values.
Feature Cholesterol: removed 27 rows for missing values.
Feature FastingBS: add unknown category 2.0
Feature RestingECG: add unknown category 3.0
Feature MaxHR: removed 0 rows for missing values.
Feature Oldpeak: removed 7 rows for missing values.
Feature ST_Slope: add unknown category 4.0
Feature CA: add unknown category 4.0
Feature Thal: add unknown category 8.0
Dropped 71 of 548
Dropped 72 of 548
Dropped 71 of 548
Dropped: 2399 of 32561
census: Dropped 1256 of 15081
num: Dropped 8827 of 15081
cat: Dropped 4850 of 15081


This notebook will test whether training data extraction is possible with counterfactuals (CF) that are drawn from the training data. Training data extraction means an attacker can find out the feature values of samples from the training data without prior knowledge of them. The attacker only has access to the model's prediction function and the explanation.

This attack should be trivial because any counterfactual that is shown as an explanation was picked directly from the training data.

The idea for counterfacutal training data extraction is as follows: The attacker makes repeated queries to the model with random input values. In order to do this, the attacker knows the maximum and minimum value of each feature in the training data (or the categorical values of each feature). The returned counterfactuals are the extracted training data.

First, we implement the `train_explainer` and `training_data_extraction_model_access` functions:

In [5]:
# Attack code must be imported so that multiprocessing pool works. Check out ice_attack.py for the implementation of the attack.
from cf_attack import CounterfactualTDE

# Executing Training Data Extraction

We now generate five counterfactuals for the first sample from the training data to demonstrate counterfactual explanations in general.

In [6]:
features = data_heart.drop(outcome_name_heart, axis=1)
labels = data_heart[outcome_name_heart]

# Train a random forest on training data.
model = es.RandomForestClassifier(random_state=0)
model = model.fit(features, labels)

# Train explainer
d = dice_ml.Data(dataframe=data_heart, continuous_features=numeric_features_heart, outcome_name=outcome_name_heart)

m = dice_ml.Model(model=model, backend="sklearn", model_type='classifier')
# Generating counterfactuals from training data (kd-tree)
exp = dice_ml.Dice(d, m, method="kdtree")

In [7]:
e1 = exp.generate_counterfactuals(features[0:1], total_CFs=5, desired_class="opposite")
e1.visualize_as_dataframe(display_sparse_df=False)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.51it/s]

Query instance (original outcome : 0)





Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,CA,Thal,HeartDisease
0,62.0,1.0,2.0,140.0,271.0,0.0,0.0,152.0,0.0,1.0,1.0,4.0,8.0,0.0



Diverse Counterfactual set without sparsity correction (new outcome:  1


Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,CA,Thal
439,62.0,0.0,4.0,140.0,268.0,0.0,2.0,160.0,0.0,3.6,3.0,2.0,3.0
90,57.0,1.0,2.0,124.0,261.0,0.0,0.0,141.0,0.0,0.3,1.0,0.0,7.0
166,60.0,1.0,2.0,160.0,267.0,1.0,1.0,157.0,0.0,0.5,2.0,4.0,8.0
103,66.0,1.0,4.0,112.0,261.0,0.0,0.0,140.0,0.0,1.5,1.0,4.0,8.0
370,67.0,1.0,1.0,142.0,270.0,1.0,0.0,125.0,0.0,2.5,1.0,4.0,8.0


We can see that the counterfactuals are similar to the query sample and that they have a flipped prediction. These are the two general properties of counterfactual explanations.

We will now do a small proof of concept of the experiment with logging enabled to demonstrate how it works.

In [8]:
logger.setLevel(logging.DEBUG)
logging.root.setLevel(logging.ERROR)

EXP = CounterfactualTDE(data_heart, numeric_features_heart, outcome_name_heart, random_state=0)
EXP.training_data_extraction_experiment(num_queries=12, model=es.RandomForestClassifier(random_state=0), model_access=False)

logger.setLevel(logging.INFO)

DEBUG:xai-privacy:Numeric Features: ['Age', 'RestingBP', 'Cholesterol', 'MaxHR', 'Oldpeak']
DEBUG:xai-privacy:Categorical Features: ['CA', 'ChestPainType', 'ExerciseAngina', 'FastingBS', 'RestingECG', 'ST_Slope', 'Sex', 'Thal']
100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00,  3.09it/s]
DEBUG:xai-privacy:Sample 0: Counterfactuals 
 [[65.0 '1.0' '4.0' 120.0 177.0 '0.0' '0.0' 140.0 '0.0' 0.4 '1.0' '0.0'
  '7.0']
 [50.0 '1.0' '4.0' 140.0 129.0 '0.0' '0.0' 135.0 '0.0' 0.0 '4.0' '4.0'
  '8.0']
 [55.0 '1.0' '4.0' 150.0 160.0 '0.0' '1.0' 150.0 '0.0' 0.0 '4.0' '4.0'
  '8.0']
 [41.0 '1.0' '4.0' 110.0 172.0 '0.0' '2.0' 158.0 '0.0' 0.0 '1.0' '0.0'
  '7.0']]
DEBUG:xai-privacy:Sample 1: Counterfactuals 
 [[56.0 '1.0' '2.0' 126.0 166.0 '0.0' '1.0' 140.0 '0.0' 0.0 '4.0' '4.0'
  '8.0']
 [50.0 '1.0' '4.0' 140.0 129.0 '0.0' '0.0' 135.0 '0.0' 0.0 '4.0' '4.0'
  '8.0']
 [35.0 '0.0' '1.0' 120.0 160.0 '0.0' '1.0' 185.0 '0.0' 0.0 '4.0' '4.0'
  '8.0']

DEBUG:xai-privacy:Extracted sample: [ 55.   1.   4. 150. 160.   0.   1. 150.   0.   0.   4.   4.   8.]
DEBUG:xai-privacy:Appears in training data at indices [262]
DEBUG:xai-privacy:Extracted sample: [ 56.   1.   2. 126. 166.   0.   1. 140.   0.   0.   4.   4.   8.]
DEBUG:xai-privacy:Appears in training data at indices [271]
DEBUG:xai-privacy:Extracted sample: [ 57.    0.    4.  120.  354.    0.    0.  163.    1.    0.6   1.    0.
   3. ]
DEBUG:xai-privacy:Appears in training data at indices [285]
DEBUG:xai-privacy:Extracted sample: [ 57.    1.    3.  150.  126.    1.    0.  173.    0.    0.2   1.    1.
   7. ]
DEBUG:xai-privacy:Appears in training data at indices [293]
DEBUG:xai-privacy:Extracted sample: [ 57.   1.   4. 130. 311.   2.   1. 148.   1.   2.   2.   4.   8.]
DEBUG:xai-privacy:Appears in training data at indices [300]
DEBUG:xai-privacy:Extracted sample: [ 58.   0.   2. 136. 319.   1.   2. 152.   0.   0.   1.   2.   3.]
DEBUG:xai-privacy:Appears in training data at indices [3

Total time: 4.13s (training model: 0.03s, training explainer: 0.01s, experiment: 4.09s)
Number of extracted samples: 34
Number of accurate extracted samples: 34
Precision: 1.0, recall: 2.8333333333333335


The proof of concept should show that each extracted sample is an actual training sample (precision of 100%). Recall is above 100% because this method can extract multiple samples per query (multiple counterfactuals are returned). Recall will reach a reasonable value if the experiment is executed for the full training data. In this case, the attack cannot return more samples than the number of queries because the attack is limited by the number of training samples.

Now we begin executing the actual experiment. We begin by defining the table that will hold the results for all our different experiment variations. Then we execute all variations of the experiment for this dataset. We vary the model between a decision tree, a random forest and a neural network. Each model uses the default configuration of scikit-learn.

In [9]:
results_ = {'dataset': [], 'model': [], 'precision': [], 'recall': []}

results = pd.DataFrame(data = results_)

In [10]:
dataset_dicts = [data_heart_dict, data_heart_num_dict, data_heart_cat_dict, data_census_dict, data_census_num_dict, data_census_cat_dict]

dt_dict = {'name': 'decision tree', 'model': DecisionTreeClassifier}
rf_dict = {'name': 'random forest', 'model': es.RandomForestClassifier}
nn_dict = {'name': 'neural network', 'model': MLPClassifier}

model_dicts = [dt_dict, rf_dict, nn_dict]

# We set the number of extractions to the length of the dataset
num_queries_dict = { 'heart': len(data_heart), 'heart numeric': len(data_heart_num), 'heart categorical': len(data_heart_cat), 'census': len(data_census), 'census numeric': len(data_census_num), 'census categorical': len(data_census_cat)}

In [11]:
# remove pandas warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

In [12]:
# This will run the experiment for each dataset and model combination

results = run_all_experiments(CounterfactualTDE, dataset_dicts, model_dicts, random_state=0, num_queries=num_queries_dict, model_access=False, threads=threads, results_table=results, is_mem_inf=False, convert_cat_to_str=True)

dataset: heart, model: decision tree
Total time: 37.46s (training model: 0.01s, training explainer: 0.01s, experiment: 37.43s)
Number of extracted samples: 343
Number of accurate extracted samples: 343
Precision: 1.0, recall: 0.7190775681341719
dataset: heart, model: random forest
Total time: 38.22s (training model: 0.02s, training explainer: 0.01s, experiment: 38.18s)
Number of extracted samples: 334
Number of accurate extracted samples: 334
Precision: 1.0, recall: 0.70020964360587
dataset: heart, model: neural network




Total time: 38.66s (training model: 1.06s, training explainer: 0.01s, experiment: 37.59s)
Number of extracted samples: 311
Number of accurate extracted samples: 311
Precision: 1.0, recall: 0.6519916142557652
dataset: heart numeric, model: decision tree
Total time: 18.07s (training model: 0.01s, training explainer: 0.00s, experiment: 18.06s)
Number of extracted samples: 376
Number of accurate extracted samples: 376
Precision: 1.0, recall: 0.7899159663865546
dataset: heart numeric, model: random forest
Total time: 18.60s (training model: 0.03s, training explainer: 0.05s, experiment: 18.52s)
Number of extracted samples: 334
Number of accurate extracted samples: 334
Precision: 1.0, recall: 0.7016806722689075
dataset: heart numeric, model: neural network




Total time: 18.06s (training model: 0.67s, training explainer: 0.02s, experiment: 17.38s)
Number of extracted samples: 353
Number of accurate extracted samples: 353
Precision: 1.0, recall: 0.7415966386554622
dataset: heart categorical, model: decision tree
Total time: 14.09s (training model: 0.02s, training explainer: 0.02s, experiment: 14.06s)
Number of extracted samples: 414
Number of accurate extracted samples: 414
Precision: 1.0, recall: 0.8679245283018868
dataset: heart categorical, model: random forest


  warn(


Total time: 21.45s (training model: 0.26s, training explainer: 0.02s, experiment: 21.17s)
Number of extracted samples: 382
Number of accurate extracted samples: 382
Precision: 1.0, recall: 0.80083857442348
dataset: heart categorical, model: neural network




Total time: 15.11s (training model: 1.11s, training explainer: 0.02s, experiment: 13.99s)
Number of extracted samples: 366
Number of accurate extracted samples: 366
Precision: 1.0, recall: 0.7672955974842768
dataset: census, model: decision tree
Total time: 617.93s (training model: 0.46s, training explainer: 0.05s, experiment: 617.42s)
Number of extracted samples: 1172
Number of accurate extracted samples: 1172
Precision: 1.0, recall: 0.08477396021699819
dataset: census, model: random forest


  warn(


Total time: 1165.49s (training model: 7.86s, training explainer: 0.06s, experiment: 1157.57s)
Number of extracted samples: 880
Number of accurate extracted samples: 880
Precision: 1.0, recall: 0.06365280289330923
dataset: census, model: neural network




Total time: 680.72s (training model: 33.28s, training explainer: 0.05s, experiment: 647.39s)
Number of extracted samples: 1190
Number of accurate extracted samples: 1190
Precision: 1.0, recall: 0.08607594936708861
dataset: census numeric, model: decision tree
Total time: 68.57s (training model: 0.02s, training explainer: 0.00s, experiment: 68.56s)
Number of extracted samples: 866
Number of accurate extracted samples: 866
Precision: 1.0, recall: 0.13847137831787656
dataset: census numeric, model: random forest
Total time: 109.89s (training model: 0.11s, training explainer: 0.00s, experiment: 109.78s)
Number of extracted samples: 658
Number of accurate extracted samples: 658
Precision: 1.0, recall: 0.10521266389510713
dataset: census numeric, model: neural network




Total time: 95.08s (training model: 7.37s, training explainer: 0.00s, experiment: 87.72s)
Number of extracted samples: 734
Number of accurate extracted samples: 734
Precision: 1.0, recall: 0.1173648864726575
dataset: census categorical, model: decision tree
Total time: 459.49s (training model: 0.28s, training explainer: 0.08s, experiment: 459.12s)
Number of extracted samples: 8239
Number of accurate extracted samples: 8239
Precision: 1.0, recall: 0.8052976248656045
dataset: census categorical, model: random forest


  warn(


Total time: 767.68s (training model: 4.73s, training explainer: 0.09s, experiment: 762.85s)
Number of extracted samples: 8032
Number of accurate extracted samples: 8032
Precision: 1.0, recall: 0.7850649985338677
dataset: census categorical, model: neural network




Total time: 464.50s (training model: 25.21s, training explainer: 0.08s, experiment: 439.22s)
Number of extracted samples: 7073
Number of accurate extracted samples: 7073
Precision: 1.0, recall: 0.6913302707457727


# Results

Precision is the percentage of extracted samples that is actually from the training data. 

Recall is the ratio of the number extracted training samples to all training samples.

In [13]:
results

Unnamed: 0,dataset,model,precision,recall
0,heart,decision tree,1.0,0.719078
1,heart,random forest,1.0,0.70021
2,heart,neural network,1.0,0.651992
3,heart numeric,decision tree,1.0,0.789916
4,heart numeric,random forest,1.0,0.701681
5,heart numeric,neural network,1.0,0.741597
6,heart categorical,decision tree,1.0,0.867925
7,heart categorical,random forest,1.0,0.800839
8,heart categorical,neural network,1.0,0.767296
9,census,decision tree,1.0,0.084774


In [14]:
file_name = 'results/2-1-cf-training-data-extraction-results'
if DATASET_HALF:
    file_name += '_dataset_size_halved'
results.to_csv(file_name + '.csv', index=False, na_rep='NaN', float_format='%.3f')

# Discussion

In our experiments, training data extraction with counterfactuals drawn from the training data has a recall between 45% and 67% for numeric data and 30% to 64% for categorical data. Since the attack cannot produce any false positive samples, precision is always 100%.