# Load data

In [1]:
import sage
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
# Load data
df = sage.datasets.bank()

# Convert binary features to 0/1
binary_cols = ['Default', 'Housing', 'Loan']
for col in binary_cols:
    df[col] = (df[col] == 'yes').astype(float)
    
# Convert education to numerical
df['Education'].replace(
    {'unknown': 0, 'primary': 1, 'secondary': 2, 'tertiary': 3},
    inplace=True)

# Convert month to numerical
df['Month'].replace(
    {'jan': 0, 'feb': 1, 'mar': 2, 'apr': 3, 'may': 4, 'jun': 5,
     'jul': 6, 'aug': 7, 'sep': 8, 'oct': 9, 'nov': 10, 'dec': 11},
    inplace=True)

# Convert marital to one-hot
for value in np.unique(df['Marital'].values):
    df['Marital-{}'.format(value)] = (df['Marital'] == value).astype(float)
df.drop(columns='Marital', inplace=True)

# Convert contact to one-hot
for value in np.unique(df['Contact'].values):
    df['Contact-{}'.format(value)] = (df['Contact'] == value).astype(float)
df.drop(columns='Contact', inplace=True)

# Convert prev outcome to one-hot
for value in np.unique(df['Prev Outcome'].values):
    df['Prev Outcome-{}'.format(value)] = (df['Prev Outcome'] == value).astype(float)
df.drop(columns='Prev Outcome', inplace=True)

# Convert job to one-hot
for value in np.unique(df['Job'].values):
    df['Job-{}'.format(value)] = (df['Job'] == value).astype(float)
df.drop(columns='Job', inplace=True)

# Split into X, Y
values = df.values.astype(float)
X_cols = np.array(df.columns) != 'Success'
X, Y = values[:, X_cols], values[:, ~X_cols]

# Get feature names, groups
feature_names = np.array(df.columns)[X_cols]
prefixes = np.array([name.split('-')[0] for name in feature_names])
groups = []
group_names = []
for prefix in np.unique(prefixes):
    groups.append(np.where(prefixes == prefix)[0])
    group_names.append(prefix)

# Train/val/test split
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.1, random_state=123)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.1, random_state=123)

In [3]:
# Standardize continuous columns
feature_names = list(feature_names)
num_features = len(feature_names)
continuous_cols = ['Age', 'Balance', 'Day', 'Duration', 'Campaign',
                   'Month', 'Prev Days', 'Prev Contacts']
continuous_inds = [feature_names.index(col) for col in continuous_cols]
ss = StandardScaler()
ss.fit(X_train[:, continuous_inds])
X_train[:, continuous_inds] = ss.transform(X_train[:, continuous_inds])
X_val[:, continuous_inds] = ss.transform(X_val[:, continuous_inds])
X_test[:, continuous_inds] = ss.transform(X_test[:, continuous_inds])

# Set up imputer

In [4]:
import torch
import torch.nn as nn
from fastshap_torch.utils import MarginalImputer

In [5]:
device = torch.device('cuda', 7)
model = torch.load('../models/bank_model.pt').eval().to(device=device)

In [6]:
# Set up background samples
np.random.seed(0)
inds = np.random.choice(len(X_train), size=128, replace=False)
background = X_train[inds]

# Set up imputer
imputer = MarginalImputer(model, background, groups=groups, link=nn.Sigmoid())

# FastSHAP

In [7]:
import torch
import torch.nn as nn
from fastshap_torch import FastSHAP

# Test samples number

In [8]:
for n_samples in (1, 4, 16, 32, 48, 64, 96):
    # Set up explainer model
    explainer = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, len(groups))).to(device)

    # Set up FastSHAP wrapper
    fastshap = FastSHAP(explainer, imputer, normalization='additive')

    # Train
    fastshap.train(
        X_train,
        X_val[:500],
        batch_size=32,
        num_samples=n_samples,
        max_epochs=200,
        eff_lambda=0,
        paired_sampling=False,
        validation_samples=128,
        validation_seed=123,
        verbose=False)

    # Print performance
    print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))

    # Save model
    modifier = 'samples={}'.format(n_samples)
    explainer.cpu()
    torch.save(explainer, '../models/bank_marginal_explainer {} nopenalty.pt'.format(modifier))

Best val loss = 0.00267827
Best val loss = 0.00039676
Best val loss = 0.00004859
Best val loss = 0.00004841
Best val loss = 0.00004826
Best val loss = 0.00004794
Best val loss = 0.00004797


In [9]:
for n_samples in (4, 16, 32, 48, 64, 96):
    # Set up explainer model
    explainer = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, len(groups))).to(device)

    # Set up FastSHAP wrapper
    fastshap = FastSHAP(explainer, imputer, normalization='additive')

    # Train
    fastshap.train(
        X_train,
        X_val[:500],
        batch_size=32,
        num_samples=n_samples,
        max_epochs=200,
        eff_lambda=0,
        paired_sampling=True,
        validation_samples=128,
        validation_seed=123,
        verbose=False)

    # Print performance
    print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))

    # Save model
    modifier = 'paired_samples={}'.format(n_samples)
    explainer.cpu()
    torch.save(explainer, '../models/bank_marginal_explainer {} nopenalty.pt'.format(modifier))

Best val loss = 0.00039314
Best val loss = 0.00004798
Best val loss = 0.00004804
Best val loss = 0.00004823
Best val loss = 0.00004780
Best val loss = 0.00004768


# Test other parameters

In [10]:
# Set up explainer model
explainer = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, len(groups))).to(device)

# Set up FastSHAP wrapper
fastshap = FastSHAP(explainer, imputer, normalization='additive')

# Train
fastshap.train(
    X_train,
    X_val[:500],
    batch_size=32,
    num_samples=32,
    max_epochs=200,
    eff_lambda=0,
    paired_sampling=True,
    validation_samples=128,
    validation_seed=123,
    verbose=False)

# Print performance
print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))

# Save model
modifier = 'nopenalty'
explainer.cpu()
torch.save(explainer, '../models/bank_marginal_explainer {}.pt'.format(modifier))

Best val loss = 0.00004812


In [11]:
# Set up explainer model
explainer = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, len(groups))).to(device)

# Set up FastSHAP wrapper
fastshap = FastSHAP(explainer, imputer, normalization=None)

# Train
fastshap.train(
    X_train,
    X_val[:500],
    batch_size=32,
    num_samples=32,
    max_epochs=200,
    eff_lambda=0.1,
    paired_sampling=True,
    validation_samples=128,
    validation_seed=123,
    verbose=False)

# Print performance
print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))

# Save model
modifier = 'nonormalization'
explainer.cpu()
torch.save(explainer, '../models/bank_marginal_explainer {}.pt'.format(modifier))

Best val loss = 0.00003690


In [12]:
# Set up explainer model
explainer = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, len(groups))).to(device)

# Set up FastSHAP wrapper
fastshap = FastSHAP(explainer, imputer, normalization=None)

# Train
fastshap.train(
    X_train,
    X_val[:500],
    batch_size=32,
    num_samples=32,
    max_epochs=200,
    eff_lambda=0,
    paired_sampling=True,
    validation_samples=128,
    validation_seed=123,
    verbose=False)

# Print performance
print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))

# Save model
modifier = 'nopenalty nonormalization'
explainer.cpu()
torch.save(explainer, '../models/bank_marginal_explainer {}.pt'.format(modifier))

Best val loss = 0.00003619
