In [37]:
import numpy 
import pandas as pd
import dice_ml
import xgboost as xgb
import warnings
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer

warnings.filterwarnings("ignore") 
pd.options.display.max_rows = 500

In [38]:
TARGET = 'TSV'

In [39]:
data = pd.read_csv('../data/TotalClothingValue.csv', index_col=0)
data = data.dropna()

In [40]:
mapping = {-2: 0, -1: 1, 0: 2, 1: 3, 2: 4}
data[TARGET] = data[TARGET].replace(mapping)

In [51]:
columns = ['TSV', 'TPV', 'TCV', 'TSL']
columns.remove(TARGET)

In [54]:
# Working only on TSV now
data = data.drop(columns=columns)

In [55]:
features = data.drop(TARGET, axis=1).columns.tolist()
target = data[TARGET]

In [56]:
datasetX = data.drop(TARGET, axis=1)

In [57]:
x_train, x_test, y_train, y_test = train_test_split(datasetX,
                                                    target,
                                                    test_size=0.2,
                                                    random_state=0)

In [58]:
categorical_features = x_train.columns.difference(features)

# We create the preprocessing pipelines for both numeric and categorical data.
numeric_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())])

categorical_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])

transformations = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, features),
        ('cat', categorical_transformer, categorical_features)])

regr = Pipeline(steps=[('preprocessor', transformations),
                        ('classifier', xgb.XGBClassifier())])
model = regr.fit(x_train, y_train)

In [59]:
y_pred = model.predict(x_test)

In [60]:
from sklearn.metrics import accuracy_score
print(accuracy_score(y_test, y_pred))

0.6519607843137255


In [8]:
import pickle
#pickle.dump(model, open('../models/tsv+3_full.pkl', 'wb'))
#model = pickle.load(open('../models/tsv+3_full.pkl', 'rb'))

In [61]:
d = dice_ml.Data(dataframe=data, continuous_features=features, outcome_name=TARGET)
m = dice_ml.Model(model=model, backend='sklearn', model_type='classifier')

In [62]:
exp = dice_ml.Dice(d, m, method='random')

In [76]:
always_immutable = ['DAY', 'AvgMaxDailyTemp','AvgMinDailyTemp','School','StartTime', 'SchoolType']
freezed = always_immutable + []

features_to_vary = data.columns.difference(freezed).to_list()

In [77]:
features_to_vary.remove(TARGET)

In [65]:
query_instances = x_test[:]


In [None]:
query_instances

In [78]:
cf = exp.generate_counterfactuals(query_instances=query_instances, total_CFs=4, desired_class=3, features_to_vary=features_to_vary)

  0%|          | 0/408 [00:00<?, ?it/s]

  2%|▏         | 10/408 [00:06<04:10,  1.59it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


  3%|▎         | 11/408 [00:06<04:47,  1.38it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


  4%|▍         | 16/408 [00:10<05:18,  1.23it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


  5%|▌         | 22/408 [00:15<05:22,  1.20it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


  7%|▋         | 28/408 [00:19<04:19,  1.46it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


  7%|▋         | 29/408 [00:20<05:08,  1.23it/s]

Only 2 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 01 sec


  9%|▊         | 35/408 [00:24<04:23,  1.41it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 11%|█         | 43/408 [00:28<03:49,  1.59it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 15%|█▌        | 63/408 [00:39<03:29,  1.65it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 17%|█▋        | 68/408 [00:42<03:25,  1.65it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 17%|█▋        | 70/408 [00:43<03:42,  1.52it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 18%|█▊        | 74/408 [00:46<04:06,  1.35it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 19%|█▉        | 78/408 [00:48<03:59,  1.38it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 01 sec


 21%|██        | 85/408 [00:53<03:56,  1.37it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 26%|██▌       | 105/408 [01:04<03:32,  1.43it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 27%|██▋       | 109/408 [01:07<04:10,  1.19it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 01 sec


 27%|██▋       | 112/408 [01:10<03:52,  1.27it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 01 sec


 41%|████      | 167/408 [01:38<02:58,  1.35it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 44%|████▍     | 180/408 [01:45<02:08,  1.78it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 46%|████▌     | 188/408 [01:50<02:24,  1.52it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 49%|████▊     | 198/408 [01:55<02:04,  1.69it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 50%|█████     | 206/408 [02:00<02:21,  1.43it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 52%|█████▏    | 213/408 [02:04<02:09,  1.51it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 54%|█████▍    | 220/408 [02:08<01:44,  1.81it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 54%|█████▍    | 221/408 [02:09<02:06,  1.48it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 56%|█████▌    | 228/408 [02:13<01:58,  1.52it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 60%|█████▉    | 243/408 [02:21<01:39,  1.65it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 63%|██████▎   | 258/408 [02:28<01:16,  1.95it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 67%|██████▋   | 275/408 [02:37<01:24,  1.58it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 68%|██████▊   | 279/408 [02:39<01:21,  1.58it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 73%|███████▎  | 296/408 [02:48<01:01,  1.81it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 73%|███████▎  | 298/408 [02:49<01:09,  1.59it/s]

Only 3 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 76%|███████▋  | 312/408 [02:57<00:57,  1.66it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 78%|███████▊  | 317/408 [03:00<00:56,  1.61it/s]

Only 1 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 79%|███████▉  | 324/408 [03:03<00:53,  1.58it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 83%|████████▎ | 337/408 [03:12<01:00,  1.18it/s]

Only 2 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 01 sec


 84%|████████▍ | 343/408 [03:17<00:46,  1.41it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 91%|█████████ | 372/408 [03:33<00:26,  1.34it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 92%|█████████▏| 375/408 [03:35<00:24,  1.37it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 94%|█████████▍| 384/408 [03:40<00:14,  1.68it/s]

Only 2 (required 4)  Diverse Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


 99%|█████████▊| 402/408 [03:52<00:04,  1.38it/s]

No Counterfactuals found for the given configuration, perhaps try with different parameters... ; total time taken: 00 min 00 sec


100%|██████████| 408/408 [03:56<00:00,  1.72it/s]


In [None]:
cf.visualize_as_dataframe(show_only_changes=True)

In [79]:
r = []
for i in range(len(x_test)):
    r.append(x_test[i:i+1])
    if cf.cf_examples_list[i].final_cfs_df is not None:
        r.append(cf.cf_examples_list[i].final_cfs_df)

r2 = pd.concat(r)

In [80]:
mapping = {0: -2, 1: -1, 2: 0, 3: 1, 4: 2}
r2[TARGET] = r2[TARGET].replace(mapping)

In [81]:
r2

Unnamed: 0,DAY,School,SchoolType,StartTime,AvgMaxDailyTemp,AvgMinDailyTemp,AvgIndoorRelativeHumidity,IndoorTempDuringSurvey,Grade,Age,Gender,FormalClothing,TotalCLOwithChair,SwC,MC,TSV
939,1.0,3.0,1.0,3.0,22.4,10.1,46.3,17.0,5.0,10.0,0.0,1.0,1.16,1.0,4.0,
0,1.0,3.0,1.0,3.0,22.4,10.1,71.6,13.3,5.0,10.0,0.0,0.0,1.58,1.0,4.0,1.0
1,1.0,3.0,1.0,3.0,22.4,10.1,74.7,12.8,3.0,8.0,0.0,1.0,1.56,1.0,4.0,1.0
2,1.0,3.0,1.0,3.0,22.4,10.1,46.3,14.0,3.0,10.0,1.0,1.0,1.58,1.0,3.0,1.0
3,1.0,3.0,1.0,3.0,22.4,10.1,69.5,14.0,5.0,10.0,0.0,1.0,1.56,1.0,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1475,2.0,4.0,1.0,1.0,19.6,6.6,59.6,13.1,3.0,7.0,1.0,1.0,1.33,2.0,4.0,
0,2.0,4.0,1.0,1.0,19.6,6.6,59.6,13.1,3.0,7.0,0.0,0.0,1.58,1.0,4.0,1.0
1,2.0,4.0,1.0,1.0,19.6,6.6,53.4,13.1,4.0,7.0,1.0,0.0,1.56,1.0,4.0,1.0
2,2.0,4.0,1.0,1.0,19.6,6.6,45.5,13.1,5.0,7.0,1.0,1.0,1.56,1.0,4.0,1.0


In [87]:
r2.to_csv(f'../results/{TARGET}/Total_all.csv')