In [73]:
import sys
sys.path.append('../src')
# from causal_shapley import causal_shapley, predict_proba
from egtoolkit import *
%load_ext autoreload
%autoreload 2

from egtoolkit import *

import random
import warnings
import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
import itertools
from itertools import combinations, permutations, chain

import shap

warnings.filterwarnings("ignore")
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

from imblearn.under_sampling import RandomUnderSampler

from pgmpy.models import BayesianNetwork
from pgmpy.inference import VariableElimination
from pgmpy.factors.discrete.CPD import TabularCPD

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Import dataset

In [74]:
categorical_cols = ['global_active_power', 'global_reactive_power', 'voltage',
                    'global_intensity', 'kitchen', 'laundry', 'climate_control', 'other',
                    'weekend', 'month_name', 'season_name', 'day_name']  # list all columns that are categorical

In [75]:
dataset = pd.read_csv('./datasets/2.0-discretized-v2-3-peak-label-encoded.csv')

In [76]:
dataset[['peak_warning']].values

array([[False],
       [False],
       [False],
       ...,
       [False],
       [False],
       [False]])

In [77]:
targets = ['peak_label_pred', 'peak_warning', 'no_significant_change', 'lower_than_usual']
targets_only = ['peak_label_pred', 'no_significant_change', 'lower_than_usual']

In [91]:
dataset.columns

Index(['global_active_power', 'global_reactive_power', 'voltage',
       'global_intensity', 'kitchen', 'laundry', 'climate_control', 'other',
       'weekend', 'month_name', 'season_name', 'day_name', 'peak_label_pred',
      dtype='object')

In [92]:
dataset['peak_warning'].values

array([False, False, False, ..., False, False, False])

In [93]:
dataset.shape

(41673, 16)

In [94]:
sample = 100
background = dataset.drop(columns=targets).sample(sample, random_state=0)
indexes = background.index.tolist()

In [95]:
dataset.iloc[indexes]

Unnamed: 0,global_active_power,global_reactive_power,voltage,global_intensity,kitchen,laundry,climate_control,other,weekend,month_name,season_name,day_name,peak_label_pred,peak_warning,no_significant_change,lower_than_usual
35784,Low,Medium,High,Low,Medium,High,Low,Medium,False,november,autumn,thursday,lower_than_usual,False,False,True
8217,Low,Low,High,Low,High,Medium,Medium,Low,True,october,autumn,saturday,no_significant_change,False,True,False
13032,Medium,Low,Very High,Medium,Medium,Medium,Low,High,False,december,winter,thursday,no_significant_change,False,True,False
6450,High,High,Very High,High,Low,High,Very High,High,True,april,spring,sunday,no_significant_change,False,True,False
36805,High,Low,Very High,High,Medium,Medium,High,Low,True,january,winter,saturday,lower_than_usual,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15711,Medium,Medium,Medium,Medium,Medium,Low,High,Medium,True,july,summer,saturday,peak_warning,True,False,False
20349,Very High,High,Low,Very High,Very High,Very High,High,Very High,False,february,winter,monday,peak_warning,True,False,False
21439,High,High,Medium,High,Very High,Very High,Very High,Low,False,july,summer,tuesday,peak_warning,True,False,False
7083,High,Low,High,High,High,High,Very High,High,True,november,autumn,saturday,no_significant_change,False,True,False


In [98]:
def custom_predict(data):
    # Debugging to check input data shape
    print('=================')
    print("Input shape to custom_predict:", data.shape)
    # if data.shape[0] == 621600:
    #     print(data)

    # Assuming boolean to integer conversion as before
    encode = {True: 1, False: 0}
    if isinstance(data, pd.DataFrame):
        print('dataframe instance')
        predictions = data['peak_warning'].apply(lambda x: encode[x]).values
    else:
        print('non dataframe instance')
        predictions = np.array([encode[bool(p)] for p in data[:, -1]])
    
    print('data', data)
    # Ensure the output is correctly shaped
    predictions = predictions.reshape(-1, 1)
    print("Output shape from custom_predict:", predictions.shape)
    print('')
    return predictions
explainer = shap.KernelExplainer(custom_predict, background)

Input shape to custom_predict: (100, 12)
non dataframe instance
data [['Low' 'Medium' 'High' ... 'november' 'autumn' 'thursday']
 ['Low' 'Low' 'High' ... 'october' 'autumn' 'saturday']
 ['Medium' 'Low' 'Very High' ... 'december' 'winter' 'thursday']
 ...
 ['High' 'High' 'Medium' ... 'july' 'summer' 'tuesday']
 ['High' 'Low' 'High' ... 'november' 'autumn' 'saturday']
 ['Low' 'Medium' 'Medium' ... 'may' 'spring' 'tuesday']]
Output shape from custom_predict: (100, 1)



In [99]:
observation = dataset.drop(columns=targets).iloc[0:1]
print("Observation shape before SHAP:", observation.shape)

try:
    shap_values = explainer.shap_values(observation)
except Exception as e:
    print("Error during SHAP computation:", e)


Observation shape before SHAP: (1, 12)


  0%|          | 0/1 [00:00<?, ?it/s]INFO:shap:num_full_subsets = 3
INFO:shap:remaining_weight_vector = array([0.35673839, 0.32616082, 0.31710079])
INFO:shap:num_paired_subset_sizes = 5
INFO:shap:weight_left = 0.2929005745959313


Input shape to custom_predict: (1, 12)
non dataframe instance
data [['Very High' 'Low' 'Low' 'Very High' 'High' 'Medium' 'High' 'Very High'
  True 'march' 'spring' 'saturday']]
Output shape from custom_predict: (1, 1)

Input shape to custom_predict: (207200, 12)
non dataframe instance
data [['Very High' 'Medium' 'High' ... 'november' 'autumn' 'thursday']
 ['Very High' 'Low' 'High' ... 'october' 'autumn' 'saturday']
 ['Very High' 'Low' 'Very High' ... 'december' 'winter' 'thursday']
 ...
 ['Very High' 'High' 'Medium' ... 'july' 'summer' 'saturday']
 ['Very High' 'Low' 'High' ... 'november' 'autumn' 'saturday']
 ['Very High' 'Medium' 'Medium' ... 'may' 'spring' 'saturday']]
Output shape from custom_predict: (207200, 1)



INFO:shap:phi = array([1.29611511e-17, 1.23360402e-17, 8.78090263e-18, 1.76465671e-17,
       9.17646157e-18, 7.65517750e-18, 5.09888927e-18, 9.43483593e-18,
       9.55179021e-18, 1.69228112e-18, 9.71089857e-18, 6.97730719e-18])
100%|██████████| 1/1 [00:00<00:00,  1.64it/s]


In [109]:
data = background.values

In [152]:
data = dataset.drop(columns=targets_only)
data

Unnamed: 0,global_active_power,global_reactive_power,voltage,global_intensity,kitchen,laundry,climate_control,other,weekend,month_name,season_name,day_name,peak_warning
0,Very High,Low,Low,Very High,High,Medium,High,Very High,True,march,spring,saturday,False
1,High,Low,High,High,Low,Low,High,High,False,january,winter,thursday,False
2,Very High,High,High,Very High,Medium,Medium,Very High,Very High,False,january,winter,friday,False
3,Low,Low,Low,Low,Medium,Low,Low,Medium,False,february,winter,friday,False
4,Very High,Low,Very High,Very High,Low,Low,Low,Very High,True,december,winter,saturday,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
41668,Low,Low,Very High,Low,Low,Low,Low,Medium,True,january,winter,saturday,False
41669,Medium,Medium,Very High,Medium,High,Medium,Medium,High,False,february,winter,tuesday,False
41670,Medium,Very High,High,Medium,Low,High,Low,High,False,april,spring,tuesday,False
41671,Low,Low,Very High,Low,High,High,Medium,Low,True,january,winter,saturday,False


In [200]:
rand_state = 10
observation = data.sample(5, random_state=rand_state).iloc[:,:-1].values
real_target = data.sample(5, random_state=rand_state).iloc[:,-1].values
real_target

array([False, False, False,  True, False])

In [202]:
def find_outcome(data, observation):
    # Iterate through each record in the data
    # print(len(data.values))
    for record in data.values:
        # print('obs', observation)
        if (record[:-1] == observation).all():
            print('Found!')
            print('Record: ', record)
            print('Observation: ', observation)
            return record[-1]  # Return the last item (True/False) if there's a match
    return None  # Return None if  no match is found

# for obs in observation:
# print(observation)
for obs in observation:
    outcome = find_outcome(data, obs)
    print('outcome', outcome)
    # print("The outcome for observation", obs, "is:", outcome)


Found!
Record:  ['Low' 'Medium' 'Very High' 'Low' 'Very High' 'High' 'High' 'Medium' False
 'june' 'summer' 'monday' False]
Observation:  ['Low' 'Medium' 'Very High' 'Low' 'Very High' 'High' 'High' 'Medium' False
 'june' 'summer' 'monday']
outcome False
Found!
Record:  ['Very High' 'Very High' 'High' 'Very High' 'Medium' 'Very High' 'Low'
 'Medium' False 'march' 'spring' 'monday' False]
Observation:  ['Very High' 'Very High' 'High' 'Very High' 'Medium' 'Very High' 'Low'
 'Medium' False 'march' 'spring' 'monday']
outcome False
Found!
Record:  ['Medium' 'High' 'High' 'Medium' 'Low' 'Very High' 'Low' 'Medium' False
 'august' 'summer' 'friday' False]
Observation:  ['Medium' 'High' 'High' 'Medium' 'Low' 'Very High' 'Low' 'Medium' False
 'august' 'summer' 'friday']
outcome False
Found!
Record:  ['Very High' 'High' 'Low' 'Very High' 'Very High' 'High' 'Very High'
 'Very High' False 'october' 'autumn' 'tuesday' True]
Observation:  ['Very High' 'High' 'Low' 'Very High' 'Very High' 'High' 'Very 

In [148]:
def find_outcome(data, observation):
    # Iterate through each record in the data
    for record in data:
        if (record[:-1] == observation.values[0][:-1]).all():
            return record[-1]  # Return the last item (True/False) if there's a match
    return None  # Return None if  no match is found

outcome = find_outcome(data, observation)
print("The outcome for observation", observation, "is:", outcome)


The outcome for observation   global_active_power global_reactive_power voltage global_intensity kitchen  \
0           Very High                   Low     Low        Very High    High   

  laundry climate_control      other  weekend month_name season_name  day_name  
0  Medium            High  Very High     True      march      spring  saturday   is: None
