# Load data

In [None]:
import sage
from sklearn.model_selection import train_test_split

In [None]:
# Load data
df = sage.datasets.bank()

# Feature names and categorical columns (for CatBoost model)
feature_names = df.columns.tolist()[:-1]
categorical_cols = ['Job', 'Marital', 'Education', 'Default', 'Housing',
                    'Loan', 'Contact', 'Month', 'Prev Outcome']
categorical_inds = [feature_names.index(col) for col in categorical_cols]

In [None]:
# Split data
train, test = train_test_split(
    df.values, test_size=int(0.1 * len(df.values)), random_state=123)
train, val = train_test_split(
    train, test_size=int(0.1 * len(df.values)), random_state=123)
Y_train = train[:, -1].copy().astype(int)
Y_val = val[:, -1].copy().astype(int)
Y_test = test[:, -1].copy().astype(int)
train = train[:, :-1].copy()
val = val[:, :-1].copy()
test = test[:, :-1].copy()

# Load model

In [None]:
import pickle
from catboost import CatBoostClassifier

In [None]:
with open('trained_models/bank model.pkl', 'rb') as f:
    model = pickle.load(f)

# Perturbed data

In [None]:
import numpy as np
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt

In [None]:
# Convert duration to seconds
test_seconds = test.copy()
duration_index = feature_names.index('Duration')
test_seconds[:, duration_index] = test_seconds[:, duration_index] * 60

# Convert duration to hours
test_hours = test.copy()
test_hours[:, duration_index] = test_hours[:, duration_index] / 60

# Shift months by one
test_month = test.copy()
month_index = feature_names.index('Month')
months = ['jan', 'feb', 'mar', 'apr', 'may', 'jun',
          'jul', 'aug', 'sep', 'oct', 'nov', 'dec']
test_month[:, month_index] = list(
    map(lambda x: months[(months.index(x) + 1) % 12],
        test_month[:, month_index]))

In [None]:
# Calculate performance
p = np.array([np.sum(Y_train == i) for i in np.unique(Y_train)]) / len(Y_train)
base_ce = log_loss(Y_test, p[np.newaxis].repeat(len(test), 0))
train_ce = log_loss(Y_train, model.predict_proba(train))
val_ce = log_loss(Y_val, model.predict_proba(val))
test_ce = log_loss(Y_test, model.predict_proba(test))
seconds_ce = log_loss(Y_test, model.predict_proba(test_seconds))
hours_ce = log_loss(Y_test, model.predict_proba(test_hours))
month_ce = log_loss(Y_test, model.predict_proba(test_month))

# Plot
plt.figure(figsize=(8, 6))
plt.bar(np.arange(6), [base_ce, train_ce, val_ce, seconds_ce, hours_ce, month_ce],
        color=['tab:blue', 'tab:cyan', 'tab:purple', 'crimson', 'firebrick', 'indianred'])

ax = plt.gca()
for i, ce in enumerate([base_ce, train_ce, val_ce, seconds_ce, hours_ce, month_ce]):
    ax.text(i - 0.25, ce + 0.007, '{:.3f}'.format(ce), fontsize=16)
    
plt.ylim(0, 0.94)

plt.xticks(np.arange(6),
           ['Base Rate', 'Train', 'Validation', r'Duration $\rightarrow$ Secs',
            r'Duration $\rightarrow$ Hours', r'Month $\rightarrow$ + 1'],
           rotation=45, rotation_mode='anchor', ha='right')

plt.tick_params(labelsize=16)
plt.ylabel('Cross Entropy Loss', fontsize=18)
plt.title('Loss Comparison', fontsize=20)

plt.tight_layout()
plt.show()

# Generate explanations

In [None]:
# Setup and calculate
imputer = sage.MarginalImputer(model, train[:512])
estimator = sage.PermutationEstimator(imputer, 'cross entropy')
sage_val = estimator(val, Y_val, thresh=0.01)
sage_seconds = estimator(test_seconds, Y_test, thresh=0.01)
sage_hours = estimator(test_hours, Y_test, thresh=0.01)
sage_month = estimator(test_month, Y_test, thresh=0.01)

In [None]:
# Save
sage_val.save('results/bank_sage_val.pkl')
sage_seconds.save('results/bank_sage_seconds.pkl')
sage_hours.save('results/bank_sage_hours.pkl')
sage_month.save('results/bank_sage_month.pkl')

# Plot

In [None]:
# Load
sage_val = sage.load('results/bank_sage_val.pkl')
sage_seconds = sage.load('results/bank_sage_seconds.pkl')
sage_hours = sage.load('results/bank_sage_hours.pkl')
sage_month = sage.load('results/bank_sage_month.pkl')

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(16, 10))

# Performance comparison
plt.sca(axarr[0, 0])
plt.bar(np.arange(6), [base_ce, train_ce, val_ce, seconds_ce, hours_ce, month_ce],
        color=['tab:blue', 'tab:cyan', 'tab:purple', 'crimson', 'firebrick', 'indianred'])
ax = plt.gca()
for i, ce in enumerate([base_ce, train_ce, val_ce, seconds_ce, hours_ce, month_ce]):
    ax.text(i - 0.25, ce + 0.007, '{:.3f}'.format(ce), fontsize=16)
plt.ylim(0, 0.97)
plt.xticks(np.arange(6),
           ['Base Rate', 'Train', 'Validation', r'Duration$\rightarrow$Secs',
            r'Duration$\rightarrow$Hours', r'Month$\rightarrow$+1'],
           rotation=35, rotation_mode='anchor', ha='right')
plt.tick_params(labelsize=16)
plt.ylabel('Cross Entropy Loss', fontsize=18)
plt.title('Loss Comparison', fontsize=20)

# Order
order = np.argsort(sage_val.values)[::-1]
values = sage_val.values[order]
std = sage_val.std[order]
width = 0.4

# Month
ax = axarr[0, 1]
plt.sca(ax)
comp_values = sage_month.values[order]
comp_std = sage_month.std[order]
enumeration = enumerate(zip(
    (values, comp_values),
    (std, comp_std),
    ('Validation', r'Month$\rightarrow$+1'),
    ('tab:purple', 'indianred')))
for i, (sage_values, stddev, name, color) in enumeration:
    pos = - 0.4 + width / 2 + width * i
    ax.bar(np.arange(len(feature_names)) + pos,
           sage_values, width=width, color=color, yerr=1.96*stddev,
           capsize=4, label=name)
plt.legend(loc='lower right', fontsize=18)
plt.tick_params('y', labelsize=16)
plt.ylabel('SAGE Values', fontsize=18)
plt.xticks(np.arange(len(values)), np.array(feature_names)[order], rotation=45,
           ha='right', rotation_mode='anchor', fontsize=16)
plt.title('Detecting Corrupted Months', fontsize=20)

# Seconds
ax = axarr[1, 0]
plt.sca(ax)
comp_values = sage_seconds.values[order]
comp_std = sage_seconds.std[order]
enumeration = enumerate(zip(
    (values, comp_values),
    (std, comp_std),
    ('Validation', r'Duration$\rightarrow$Secs'),
    ('tab:purple', 'crimson')))
for i, (sage_values, stddev, name, color) in enumeration:
    pos = - 0.4 + width / 2 + width * i
    ax.bar(np.arange(len(feature_names)) + pos,
           sage_values, width=width, color=color, yerr=1.96*stddev,
           capsize=4, label=name)
plt.legend(loc='lower right', fontsize=18)
plt.tick_params('y', labelsize=16)
plt.ylabel('SAGE Values', fontsize=18)
plt.xticks(np.arange(len(values)), np.array(feature_names)[order], rotation=45,
           ha='right', rotation_mode='anchor', fontsize=16)
plt.title('Detecting Corrupted Duration (Seconds)', fontsize=20)


# Hours
ax = axarr[1, 1]
plt.sca(ax)
comp_values = sage_hours.values[order]
comp_std = sage_hours.std[order]
enumeration = enumerate(zip(
    (values, comp_values),
    (std, comp_std),
    ('Validation', r'Duration$\rightarrow$Hours'),
    ('tab:purple', 'firebrick')))
for i, (sage_values, stddev, name, color) in enumeration:
    pos = - 0.4 + width / 2 + width * i
    ax.bar(np.arange(len(feature_names)) + pos,
           sage_values, width=width, color=color, yerr=1.96*stddev,
           capsize=4, label=name)
plt.legend(loc='lower right', fontsize=18)
plt.tick_params('y', labelsize=16)
plt.ylabel('SAGE Values', fontsize=18)
plt.xticks(np.arange(len(values)), np.array(feature_names)[order], rotation=45,
           ha='right', rotation_mode='anchor', fontsize=16)
plt.title('Detecting Corrupted Duration (Hours)', fontsize=20)

plt.tight_layout()
# plt.show()
plt.savefig('figures/model_monitoring.pdf')