# Load data

In [None]:
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [None]:
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=42
)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0
)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train.values)
X_val = ss.transform(X_val.values)
X_test = ss.transform(X_test.values)

# Train model

A gradient boosting model is trained on the adult dataset. 

In [None]:
import pickle
import os.path
import lightgbm as lgb

In [None]:
if os.path.isfile("census_model.pkl"):
    print("Loading saved model")
    with open("census_model.pkl", "rb") as f:
        model = pickle.load(f)

else:
    # Setup
    params = {
        "max_bin": 512,
        "learning_rate": 0.05,
        "boosting_type": "gbdt",
        "objective": "binary",
        "metric": "binary_logloss",
        "num_leaves": 10,
        "verbose": -1,
        "min_data": 100,
        "boost_from_average": True,
    }

    # More setup
    d_train = lgb.Dataset(X_train, label=Y_train)
    d_val = lgb.Dataset(X_val, label=Y_val)

    # Train model
    model = lgb.train(
        params,
        d_train,
        10000,
        valid_sets=[d_val],
        callbacks=[lgb.early_stopping(stopping_rounds=50), lgb.log_evaluation(1000)],
    )

    # Save model
    with open("census_model.pkl", "wb") as f:
        pickle.dump(model, f)

# Surrogate

In [None]:
import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data import RandomSampler, BatchSampler
from fastshap.utils import UniformSampler, DatasetRepeat
from copy import deepcopy
from tqdm.auto import tqdm


def validate(surrogate, loss_fn, data_loader):
    """
    Calculate mean validation loss.

    Args:
      loss_fn: loss function.
      data_loader: data loader.
    """
    with torch.no_grad():
        # Setup.
        device = next(surrogate.surrogate.parameters()).device
        mean_loss = 0
        N = 0

        for x, y, S in data_loader:
            x = x.to(device)
            y = y.to(device)
            S = S.to(device)
            pred = surrogate(x, S)
            loss = loss_fn(pred, y)
            N += len(x)
            mean_loss += len(x) * (loss - mean_loss) / N

    return mean_loss


def generate_labels(dataset, model, batch_size):
    """
    Generate prediction labels for a set of inputs.

    Args:
      dataset: dataset object.
      model: predictive model.
      batch_size: minibatch size.
    """
    with torch.no_grad():
        # Setup.
        preds = []
        if isinstance(model, torch.nn.Module):
            device = next(model.parameters()).device
        else:
            device = torch.device("cpu")
        loader = DataLoader(dataset, batch_size=batch_size)

        for (x,) in loader:
            pred = model(x.to(device)).cpu()
            preds.append(pred)

    return torch.cat(preds)


class Surrogate:
    """
    Wrapper around surrogate model.

    Args:
      surrogate: surrogate model.
      num_features: number of features.
      groups: (optional) feature groups, represented by a list of lists.
    """

    def __init__(self, surrogate, num_features, groups=None):
        # Store surrogate model.
        self.surrogate = surrogate

        # Store feature groups.
        if groups is None:
            self.num_players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.num_players = len(groups)
            device = next(surrogate.parameters()).device
            self.groups_matrix = torch.zeros(
                len(groups), num_features, dtype=torch.float32, device=device
            )
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = 1

    def train(
        self,
        train_data,
        val_data,
        batch_size,
        max_epochs,
        loss_fn,
        validation_samples=1,
        validation_batch_size=None,
        lr=1e-3,
        min_lr=1e-5,
        lr_factor=0.5,
        lookback=5,
        training_seed=None,
        validation_seed=None,
        bar=False,
        verbose=False,
    ):
        """
        Train surrogate model.

        Args:
          train_data: training data with inputs and the original model's
            predictions (np.ndarray tuple, torch.Tensor tuple,
            torch.utils.data.Dataset).
          val_data: validation data with inputs and the original model's
            predictions (np.ndarray tuple, torch.Tensor tuple,
            torch.utils.data.Dataset).
          batch_size: minibatch size.
          max_epochs: maximum training epochs.
          loss_fn: loss function (e.g., fastshap.KLDivLoss).
          validation_samples: number of samples per validation example.
          validation_batch_size: validation minibatch size.
          lr: initial learning rate.
          min_lr: minimum learning rate.
          lr_factor: learning rate decrease factor.
          lookback: lookback window for early stopping.
          training_seed: random seed for training.
          validation_seed: random seed for generating validation data.
          verbose: verbosity.
        """
        # Set up train dataset.
        if isinstance(train_data, tuple):
            x_train, y_train = train_data
            if isinstance(x_train, np.ndarray):
                x_train = torch.tensor(x_train, dtype=torch.float32)
                y_train = torch.tensor(y_train, dtype=torch.float32)
            train_set = TensorDataset(x_train, y_train)
        elif isinstance(train_data, Dataset):
            train_set = train_data
        else:
            raise ValueError(
                "train_data must be either tuple of tensors or a " "PyTorch Dataset"
            )

        # Set up train data loader.
        random_sampler = RandomSampler(
            train_set,
            replacement=True,
            num_samples=int(np.ceil(len(train_set) / batch_size)) * batch_size,
        )
        batch_sampler = BatchSampler(
            random_sampler, batch_size=batch_size, drop_last=True
        )
        train_loader = DataLoader(train_set, batch_sampler=batch_sampler)

        # Set up validation dataset.
        sampler = UniformSampler(self.num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)
        S_val = sampler.sample(len(val_data) * validation_samples)

        if isinstance(val_data, tuple):
            x_val, y_val = val_data
            if isinstance(x_val, np.ndarray):
                x_val = torch.tensor(x_val, dtype=torch.float32)
                y_val = torch.tensor(y_val, dtype=torch.float32)
            x_val_repeat = x_val.repeat(validation_samples, 1)
            y_val_repeat = y_val.repeat(validation_samples, 1)
            val_set = TensorDataset(x_val_repeat, y_val_repeat, S_val)
        elif isinstance(val_data, Dataset):
            val_set = DatasetRepeat([val_data, TensorDataset(S_val)])
        else:
            raise ValueError(
                "val_data must be either tuple of tensors or a " "PyTorch Dataset"
            )

        if validation_batch_size is None:
            validation_batch_size = batch_size
        val_loader = DataLoader(val_set, batch_size=validation_batch_size)

        # Setup for training.
        surrogate = self.surrogate
        device = next(surrogate.parameters()).device
        optimizer = optim.Adam(surrogate.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=lr_factor,
            patience=lookback // 2,
            min_lr=min_lr,
            verbose=verbose,
        )
        best_loss = validate(self, loss_fn, val_loader).item()
        best_epoch = 0
        best_model = deepcopy(surrogate)
        loss_list = [best_loss]
        if training_seed is not None:
            torch.manual_seed(training_seed)

        for epoch in range(max_epochs):
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc="Training epoch")
            else:
                batch_iter = train_loader

            for x, y in batch_iter:
                # Prepare data.
                x = x.to(device)
                y = y.to(device)

                # Generate subsets.
                S = sampler.sample(batch_size).to(device=device)

                # Make predictions.
                pred = self.__call__(x, S)
                loss = loss_fn(pred, y)

                # Optimizer step.
                loss.backward()
                optimizer.step()
                surrogate.zero_grad()

            # Evaluate validation loss.
            self.surrogate.eval()
            val_loss = validate(self, loss_fn, val_loader).item()
            self.surrogate.train()

            # Print progress.
            if verbose:
                print("----- Epoch = {} -----".format(epoch + 1))
                print("Val loss = {:.4f}".format(val_loss))
                print("")
            scheduler.step(val_loss)
            loss_list.append(val_loss)

            # Check if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(surrogate)
                best_epoch = epoch
                if verbose:
                    print("New best epoch, loss = {:.4f}".format(val_loss))
                    print("")
            elif epoch - best_epoch == lookback:
                if verbose:
                    print("Stopping early")
                break

        # Clean up.
        for param, best_param in zip(surrogate.parameters(), best_model.parameters()):
            param.data = best_param.data
        self.loss_list = loss_list
        self.surrogate.eval()

    def train_original_model(
        self,
        train_data,
        val_data,
        original_model,
        batch_size,
        max_epochs,
        loss_fn,
        validation_samples=1,
        validation_batch_size=None,
        lr=1e-3,
        min_lr=1e-5,
        lr_factor=0.5,
        lookback=5,
        training_seed=None,
        validation_seed=None,
        bar=False,
        verbose=False,
        optimizer=None,
    ):
        """
        Train surrogate model with labels provided by the original model.

        Args:
          train_data: training data with inputs only (np.ndarray, torch.Tensor,
            torch.utils.data.Dataset).
          val_data: validation data with inputs only (np.ndarray, torch.Tensor,
            torch.utils.data.Dataset).
          original_model: original predictive model (e.g., torch.nn.Module).
          batch_size: minibatch size.
          max_epochs: maximum training epochs.
          loss_fn: loss function (e.g., fastshap.KLDivLoss).
          validation_samples: number of samples per validation example.
          validation_batch_size: validation minibatch size.
          lr: initial learning rate.
          min_lr: minimum learning rate.
          lr_factor: learning rate decrease factor.
          lookback: lookback window for early stopping.
          training_seed: random seed for training.
          validation_seed: random seed for generating validation data.
          verbose: verbosity.
        """
        if not optimizer:
            raise ValueError("optimizer must be provided")

        # Set up validation dataset.
        sampler = UniformSampler(self.num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)
        S_val = sampler.sample(len(val_data) * validation_samples)
        if validation_batch_size is None:
            validation_batch_size = batch_size

        if isinstance(val_data, np.ndarray):
            val_data = torch.tensor(val_data, dtype=torch.float32)

        if isinstance(val_data, torch.Tensor):
            # Generate validation labels.
            y_val = generate_labels(
                TensorDataset(val_data), original_model, validation_batch_size
            )
            y_val_repeat = y_val.repeat(
                validation_samples, *[1 for _ in y_val.shape[1:]]
            )

            # Create dataset.
            val_data_repeat = val_data.repeat(validation_samples, 1)
            val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val)
        elif isinstance(val_data, Dataset):
            # Generate validation labels.
            y_val = generate_labels(val_data, original_model, validation_batch_size)
            y_val_repeat = y_val.repeat(
                validation_samples, *[1 for _ in y_val.shape[1:]]
            )

            # Create dataset.
            val_set = DatasetRepeat([val_data, TensorDataset(y_val_repeat, S_val)])
        else:
            raise ValueError(
                "val_data must be either tuple of tensors or a " "PyTorch Dataset"
            )

        val_loader = DataLoader(val_set, batch_size=validation_batch_size)

        # Setup for training.
        surrogate = self.surrogate
        device = next(surrogate.parameters()).device

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=lr_factor,
            patience=lookback // 2,
            min_lr=min_lr,
            verbose=verbose,
        )
        best_loss = validate(self, loss_fn, val_loader).item()
        best_epoch = 0
        best_model = deepcopy(surrogate)
        loss_list = [best_loss]
        if training_seed is not None:
            torch.manual_seed(training_seed)

        for epoch in range(max_epochs):
            print("Epoch", epoch)
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc="Training epoch")
            else:
                batch_iter = train_loader

            for (x,) in batch_iter:
                # Prepare data.
                x = x.to(device)

                # Get original model prediction.
                with torch.no_grad():
                    y = original_model(x)

                # Generate subsets.
                S = sampler.sample(batch_size).to(device=device)

                # Make predictions.
                pred = self.__call__(x, S)
                loss = loss_fn(pred, y)

                # Optimizer step.
                loss.backward()
                optimizer.step()
                surrogate.zero_grad()

            # Evaluate validation loss.
            self.surrogate.eval()
            val_loss = validate(self, loss_fn, val_loader).item()
            self.surrogate.train()

            # Print progress.
            if verbose:
                print("----- Epoch = {} -----".format(epoch + 1))
                print("Val loss = {:.4f}".format(val_loss))
                print("")
            scheduler.step(val_loss)
            loss_list.append(val_loss)

            # Check if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(surrogate)
                best_epoch = epoch
                if verbose:
                    print("New best epoch, loss = {:.4f}".format(val_loss))
                    print("")
            elif epoch - best_epoch == lookback:
                if verbose:
                    print("Stopping early")
                break

        # Clean up.
        for param, best_param in zip(surrogate.parameters(), best_model.parameters()):
            param.data = best_param.data
        self.loss_list = loss_list
        self.surrogate.eval()

    def __call__(self, x, S):
        """
        Evaluate surrogate model.

        Args:
          x: input examples.
          S: coalitions.
        """
        if self.groups_matrix is not None:
            S = torch.mm(S, self.groups_matrix)

        return self.surrogate((x, S))

In [None]:
def setup_data(train_data, batch_size):
    # Set up train dataset.
    if isinstance(train_data, np.ndarray):
        train_data = torch.tensor(train_data, dtype=torch.float32)

    if isinstance(train_data, torch.Tensor):
        train_set = TensorDataset(train_data)
    elif isinstance(train_data, Dataset):
        train_set = train_data
    else:
        raise ValueError("train_data must be either tensor or a " "PyTorch Dataset")

    # Set up train data loader.
    random_sampler = RandomSampler(
        train_set,
        replacement=True,
        num_samples=int(np.ceil(len(train_set) / batch_size)) * batch_size,
    )
    batch_sampler = BatchSampler(random_sampler, batch_size=batch_size, drop_last=True)
    train_loader = DataLoader(train_set, batch_sampler=batch_sampler)
    return train_loader, random_sampler, batch_sampler

# Train surrogate

In [None]:
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import KLDivLoss

In [None]:
from opacus.validators import ModuleValidator
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager

In [None]:
# Select device
device = torch.device("cuda")

In [None]:
# Check for model
if os.path.isfile("census_surrogate.pt"):
    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    print("Loading saved surrogate model")
    surr = torch.load("census_surrogate.pt").to(device)
    surrogate = Surrogate(surr, num_features)
    surrogate.train_original_model(
        X_train,  # We pass the training dataset of the black box to the surrogate object
        X_val,  # We pass the validation dataset of the black box to the surrogate object
        original_model,  # black box we want to explain
        batch_size=64,
        max_epochs=100,
        loss_fn=KLDivLoss(),
        validation_samples=10,  # number of samples per validation example
        validation_batch_size=10000,  # size of the mini batch
        verbose=True,
    )

else:
    batch_size = 64
    train_loader, random_sampler, batch_sampler = setup_data(
        train_data=X_train, batch_size=batch_size
    )
    print("Setup data")

    # Create surrogate model
    surr = nn.Sequential(
        MaskLayer1d(value=0, append=True),
        nn.Linear(2 * num_features, 128),
        # nn.ELU(inplace=True),
        # nn.Linear(128, 128),
        # nn.ELU(inplace=True),
        nn.Linear(128, 2),
    )
    surr = ModuleValidator.fix(surr)
    ModuleValidator.validate(surr, strict=False)
    print("Fixing model")

    MAX_GRAD_NORM = 1.2
    EPSILON = 50.0
    DELTA = 1e-5
    lr = 1e-3
    EPOCHS = 100

    privacy_engine = PrivacyEngine()

    optimizer = optim.Adam(surr.parameters(), lr=lr)
    print("Created optimizer")

    surr, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
        module=surr,
        optimizer=optimizer,
        data_loader=train_loader,
        epochs=EPOCHS,
        target_epsilon=EPSILON,
        target_delta=DELTA,
        max_grad_norm=MAX_GRAD_NORM,
    )

    print("Made model private")

    surr = surr.to(device)

    # Set up surrogate object: we pass the model that we have defined as
    # surrogate and the number of input features of the training dataset
    # we used to train the black box.
    surrogate = Surrogate(surr, num_features)
    print("Created surrogate")

    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Train
    # What happens inside the train_original_model
    # - A UniformSampler is created to sample the data from the validation set
    # - Given the validation set, this is multiplied by validation_samples and
    # the UniformSampler will then create a matrix of size (len(validation_set) * validation_samples, num_features)
    # Inside this matrix there will be 1 or 0, the value is based on a random threshold. This is to mask some of the features
    # - For each sample in this "augmented" matrix we need to compute the corresponding
    # prediction of the black box model. This is done by calling the original_model using the
    # validation set and then augmenting the prediction to match the size of the augmented groups_matrix
    # - Then a validation_set is created with the repeated samples of the original validation set
    # the corresponding repeated predictions and the masked features matrix (S_val)
    # It is important to notice that the validation set will have the following shape:
    # [repeated_val_data, repeated_predictions, masked_features]
    # - Then we set the optimizer. Note that this is another hyperparameter but it is set as Adam
    # directly from the code.
    # - The training loop starts:
    #    - We iterate over the batches of the training data
    #    - We compute the prediction of the original model for each of the batches
    #    - We compute the prediction of the surrogate model on the batch masked using a sampling mask (S)
    #    - We compute the loss using the prediction of the surrogate model and the prediction of the original model
    #    - We compute the gradients and update the surrogate model
    #    - After each batch, we evaluate the surrogate model on the validation set

    print("Starting training original model")
    surrogate.train_original_model(
        X_train,  # We pass the training dataset of the black box to the surrogate object
        X_val,  # We pass the validation dataset of the black box to the surrogate object
        original_model,  # black box we want to explain
        batch_size=batch_size,
        max_epochs=EPOCHS,
        loss_fn=KLDivLoss(),
        validation_samples=10,  # this number is multiplied with the length of the validation dataset
        validation_batch_size=10000,  # size of the mini batch
        verbose=True,
        lr=lr,
        optimizer=optimizer,
        train_loader=train_loader,
        random_sampler=random_sampler,
        batch_sampler=batch_sampler,
        bar=True,
    )

    # Save surrogate
    # surr.cpu()
    # torch.save(surr, 'census_surrogate.pt')
    # surr.to(device)

# Train FastSHAP

In [None]:
from fastshap import FastSHAP

In [None]:
# Check for model
if os.path.isfile("census_explainer.pt"):
    print("Loading saved explainer model")
    explainer = torch.load("census_explainer.pt").to(device)
    fastshap = FastSHAP(
        explainer, surrogate, normalization="additive", link=nn.Softmax(dim=-1)
    )

else:
    # Create explainer model
    explainer = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 2 * num_features),
    ).to(device)

    # Set up FastSHAP object
    fastshap = FastSHAP(
        explainer, surrogate, normalization="additive", link=nn.Softmax(dim=-1)
    )

    # Train
    fastshap.train(
        X_train,
        X_val[:100],
        batch_size=32,
        num_samples=32,
        max_epochs=200,
        validation_samples=128,
        verbose=True,
    )

    # Save explainer
    explainer.cpu()
    torch.save(explainer, "census_explainer.pt")
    explainer.to(device)

# Compare with KernelSHAP

In [None]:
import matplotlib.pyplot as plt

In [None]:
# Setup for KernelSHAP
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S).softmax(dim=-1)
    return pred.cpu().data.numpy()

In [None]:
# Select example
ind = np.random.choice(len(X_test))
x = X_test[ind : ind + 1]
y = int(Y_test[ind])

# Run FastSHAP
fastshap_values = fastshap.shap_values(x)[0]

# Run KernelSHAP to convergence
game = shapreg.games.PredictionGame(imputer, x)
shap_values, all_results = shapreg.shapley.ShapleyRegression(
    game,
    batch_size=32,
    paired_sampling=False,
    detect_convergence=True,
    bar=True,
    return_all=True,
)

In [None]:
# Create figure
plt.figure(figsize=(9, 5.5))

# Bar chart
width = 0.75
kernelshap_iters = 128
plt.bar(
    np.arange(num_features) - width / 3,
    shap_values.values[:, y],
    width / 3,
    label="True SHAP values",
    color="tab:gray",
)
plt.bar(
    np.arange(num_features),
    fastshap_values[:, y],
    width / 3,
    label="FastSHAP",
    color="tab:green",
)
plt.bar(
    np.arange(num_features) + width / 3,
    all_results["values"][list(all_results["iters"]).index(kernelshap_iters)][:, y],
    width / 3,
    label="KernelSHAP @ {}".format(kernelshap_iters),
    color="tab:red",
)

# Annotations
plt.legend(fontsize=16)
plt.tick_params(labelsize=14)
plt.ylabel("SHAP Values", fontsize=16)
plt.title("Census Explanation Example", fontsize=18)
plt.xticks(
    np.arange(num_features),
    feature_names,
    rotation=35,
    rotation_mode="anchor",
    ha="right",
)

plt.tight_layout()
plt.show()