Copyright 2019 The Google Research Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Understanding Black-box Model Predictions using RL-LIM

 * Jinsung Yoon, Sercan O Arik, Tomas Pfister, "RL-LIM: Reinforcement Learning-based Locally Interpretable Modeling", arXiv preprint arXiv:1909.12367 (2019) - https://arxiv.org/abs/1909.12367
 
This notebook describes how to explain black-box models using "Reinforcement Learning based Locally Interpretable Modeling (RL-LIM)". 

RL-LIM is a state of the art locally interpretable modeling method. It is often challenging to develop a globally interpretable model that has the performance at the level of 'black-box' models. To go beyond the performance limitations, a promising direction is locally interpretable models, which explain a single prediction, instead of explaining the entire model. Methodologically, while a globally interpretable model fits a single inherently interpretable model (such as a linear model or a shallow decision tree) to the entire training set, locally interpretable models aim to fit an inherently interpretable model locally, i.e. for each instance individually, by distilling knowledge from a high performance black-box model.

Such locally interpretable models are very useful for real-world AI deployments to provide succinct and human-like explanations to users. They can be used to identify systematic failure cases (e.g. by seeking common trends in input dependence for failure cases), detect biases (e.g. by quantifying feature importance for a particular variable), and provide actionable feedback to improve a model (e.g. understand failure cases and what training data to collect).

You need:

**Training / Probe / Testing sets** 
 * If you don't have a probe set, you can construct it by splitting a small portion of training set, while keeping the rest of the training set for training purpose.
 * The training / probe / testing datasets you have should be saved under './repo/data_files/' directory, with the names: 'train.csv', 'probe.csv', and 'test.csv'.
 * In this notebook, we create 'train.csv', 'probe.csv', and 'test.csv' files from the Facebook Comment Volume dataset (https://archive.ics.uci.edu/ml/datasets/Facebook+Comment+Volume+Dataset) as an example.

##  Prerequisite

 * Download lightgbm package.
 * Clone https://github.com/google-research/google-research.git to the current directory.

In [1]:
# Installs additional packages
import pip
import IPython

def import_or_install(package):
    try:
        __import__(package)
    except ImportError:
        pip.main(['install', package])
        app = IPython.Application.instance()
        app.kernel.do_shutdown(True)  
        
import_or_install('lightgbm')

In [2]:
import os
from git import Repo

# Current working directory
repo_dir = os.getcwd() + '/repo'

if not os.path.exists(repo_dir):
    os.makedirs(repo_dir)

# Clones github repository
if not os.listdir(repo_dir):
    git_url = "https://github.com/google-research/google-research.git"
    Repo.clone_from(git_url, repo_dir)

## Necessary packages and function calls

 * ridge: Ridge regression model used as an interpretable model.
 * lightgbm: lightGBM model used as a black-box model.
 * load_facebook_data: Data loader for facebook comment volumn dataset.
 * preprocess_data: Data extraction and normalization.
 * rllim: RL-LIM class for training instance-wise weight estimator.
 * rllim_metrics: Evaluation metrics for the locally interpretable models in various metrics (overall performance and fidelity).

In [3]:
import numpy as np
import pandas as pd

from sklearn.linear_model import Ridge
import lightgbm

# Sets current directory
os.chdir(repo_dir)

from rllim.data_loading import load_facebook_data, preprocess_data
from rllim import rllim
from rllim.rllim_metrics import fidelity_metrics, overall_performance_metrics

## Data loading

 * Load training, probe and testing datasets and save those datasets as train.csv, probe.csv, test.csv in './repo/data_files/' directory.
 * If you have your own 'train.csv', 'probe.csv', and 'test.csv' files, you can skip this example dataset construction portion and just put them under  './repo/data_files/' directory.


In [3]:
# The number of training and probe samples (we use 10% of the training set as the probe set). 
# Explicit testing set exists in facebook comment volume dataset
dict_rate = dict()
dict_rate['train'] = 0.9
dict_rate['probe'] = 0.1

# Random seed
seed = 0

# Loads data
load_facebook_data(dict_rate, seed)

print('Finished data loading.')

Finished data loading.


## Data preprocessing

 * Extract features and labels from train.csv, probe.csv, test.csv in './repo/data_files/' directory.
 * Normalize the features of training, probe, and testing sets.

In [4]:
# Normalization methods: either 'minmax' or 'standard'
normalization = 'minmax' 

# Extracts features and labels, and then normalize features
x_train, y_train, x_probe, y_probe, x_test, y_test, col_names = \
preprocess_data(normalization, 'train.csv', 'probe.csv', 'test.csv')

print('Finished data preprocess.')

  return self.partial_fit(X, y)


Finished data preprocess.


## Step 0: Black-box model training

This stage is the preliminary stage for RL-LIM. We train a black-box model (in this notebook, lightGBM) using the training datasets (x_train, y_train) to make a pre-trained black-box model. If you already have a saved pre-trained black-box model, you can skip this stage and retrieve the pre-trained black-box model into bb_model. You also need to specify whether the problem is regression or classification.

 * Note that the bb_model must have fit, predict (for regression) or predict_proba (for classification) as the methods.

In [5]:
# Problem specification
problem = 'regression' # or 'classification'

# Initializes black-box model
if problem == 'regression':
    bb_model = lightgbm.LGBMRegressor()
elif problem == 'classification':
    bb_model = lightgbm.LGBMClassifier()

# Trains black-box model
bb_model = bb_model.fit(x_train, y_train)

print('Finished black-box model training.')

Finished black-box model training.


## Step 1: Auxiliary dataset construction

Using the pre-trained black-box model, we create auxiliary training (x_train, y_train_hat) and probe datasets (x_probe, y_probe_hat). These auxiliary datasets are used for instance weight estimator and locally interpretable model training.

In [6]:
# Constructs auxiliary datasets
if problem == 'regression':
    y_train_hat = bb_model.predict(x_train)
    y_probe_hat = bb_model.predict(x_probe)
elif problem == 'classification':
    y_train_hat = bb_model.predict_proba(x_train)[:, 1]
    y_probe_hat = bb_model.predict_proba(x_probe)[:, 1]
    
print('Finished auxiliary dataset construction.')

Finished auxiliary dataset construction.


## Step 2: Interpretable baseline training

To improve the stability of the instance-wise weight estimator training, a baseline model is observed to be beneficial. We use a globally interpretable model (in this notebook, we use Ridge regression) optimized to replicate the predictions of the black-box model.

1. **Input**: 
 * Locally interpretable model: ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) and predict_proba (for classification) as the subfunctions.
 
 
2. **Output**:
 * Trained interpretable baseline model: function that tries to replicate the predictions of the black-box model using globally interpretable model.

In [7]:
# Define interpretable baseline model
baseline = Ridge(alpha=1)

# Trains interpretable baseline model
baseline.fit(x_train, y_train_hat)

print('Finished interpretable baseline training.')

Finished interpretable baseline training.


## Step 3: Train instance-wise weight estimator

We train an instance-wise weight estimator using the auxiliary training (x_train, y_train_hat) and probe datasets (x_probe, y_probe_hat) using reinforcement learning.

1. **Input**: 
 * Network parameters: Set network parameters of instance-wise weight estimator.
 * Locally interpretable model: Ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) or predict_proba (for classification) as the methods.
 
 
2. **Output**:
 * Instancewise weight estimator: Function that uses auxiliary training set and a testing sample as inputs to estimate weights for each training sample to construct locally interpretable model for the testing sample.

In [8]:
# Instance-wise weight estimator network parameters
parameters = dict()
parameters['hidden_dim'] = 100
parameters['iterations'] = 2000
parameters['num_layers'] = 5
parameters['batch_size'] = 5000
parameters['batch_size_inner'] = 10
parameters['lambda'] = 1.0

# Defines locally interpretable model
interp_model = Ridge(alpha = 1)

# Checkpoint file name
checkpoint_file_name = './tmp/model.ckpt'

# Initializes RL-LIM
rllim_class = rllim.Rllim(x_train, y_train_hat, x_probe, y_probe_hat, parameters, 
                          interp_model, baseline, checkpoint_file_name)

# Trains RL-LIM
rllim_class.rllim_train()

print('Finished instance-wise weight estimator training.')

## Output functions
# Instance-wise weight estimation for x_test[0, :]
dve_out = rllim_class.instancewise_weight_estimator(x_train, y_train_hat, x_test[0, :])

# Interpretable predictions (test_y_fit) and instance-wise explanations (test_coef) for x_test[:0, :]
test_y_fit, test_coef = rllim_class.rllim_interpreter(x_train, y_train_hat, x_test[0, :], interp_model)

print('Finished instance-wise weight estimations, instance-wise predictions, and local explanations.')


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Colocations handled automatically by placer.


100%|██████████| 2000/2000 [36:02<00:00,  1.02it/s] 


Finished instance-wise weight estimator training.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt


100%|██████████| 1/1 [00:01<00:00,  1.11s/it]

Finished instance-wise weight estimations, instance-wise predictions, and local explanations.





## Step 4: Interpretable inference

Unlike Step 3 (training instance-wise weight estimator), we use a fixed instance-wise weight estimator (without the sampler and interpretable baseline) and merely fit the locally interpretable model at inference. Given the test instance, we obtain the selection probabilities from the instance-wise weight estimator, and using these as the weights, we fit the locally interpretable model via weighted optimization. 

1. **Input**: 
 * Locally interpretable model: Ridge regression (we can switch this to shallow tree). The model must have fit, predict (for regression) and predict_proba (for classification) as the subfunctions.
 
 
2. **Output**:
 * Instance-wise explanations (test_coef): Estimated local dynamics for testing samples using trained locally interpretable model.
 * Interpretable predictions (test_y_fit): Local predictions for testing samples using trained locally interpretable model.

In [9]:
# Train locally interpretable models and output instance-wise explanations (test_coef) and
# interpretable predictions (test_y_fit) 
test_y_fit, test_coef = rllim_class.rllim_interpreter(x_train, y_train_hat, x_test, interp_model)

print('Finished instance-wise predictions and local explanations.')

INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt


100%|██████████| 900/900 [15:26<00:00,  1.03s/it]

Finished instance-wise predictions and local explanations.





## Evaluation

We use two quantitative metrics (overall performance and fidelity) and one qualitative metric (instance-wise explanations) to evaluate the locally interpretable models.

 * Overall_performance: Difference between ground truth labels (y_test) and interpretable predictions (test_y_fit). 
 * Fidelity: Difference between black-box model predictions (y_test_hat) and interpretable predictions (test_y_fit).
 * Instance-wise explanations: Qualitatively show the examples of instance-wise explanations.

### 1. Overall performance

 * We use Mean Absolute Error (MAE) as the metric for the overall performance. However, users can replace MAE to RMSE or others.

In [10]:
# Overall performance
mae = overall_performance_metrics (y_test, test_y_fit, metric='mae')
print('Overall performance of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))

Overall performance of RL-LIM in terms of MAE: 24.7097


### 2. Fidelity

 * We use R2 score and Mean Absolute Error (MAE) as the metrics for the fidelity. 

In [11]:
# Black-box model predictions
y_test_hat = bb_model.predict(x_test)

# Fidelity in terms of MAE
mae = fidelity_metrics (y_test_hat, test_y_fit, metric='mae')
print('Fidelity of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))

# Fidelity in terms of R2 Score
r2 = fidelity_metrics (y_test_hat, test_y_fit, metric='r2')
print('Fidelity of RL-LIM in terms of R2 Score: ' + str(np.round(r2, 4)))

Fidelity of RL-LIM in terms of MAE: 20.5316
Fidelity of RL-LIM in terms of R2 Score: 0.433


### 3. Instance-wise explanations

 * We qualitatively demonstrate the local explanations of 5 testing samples using the fitted coefficients of locally interpretable model (Ridge regression).
 * To run this cell, the interpretable model must have intercept_ and coef_ as the subfunctions. Here, intercept and coef represent the fitted locally interpretable model's intercept and coefficients.

In [12]:
# Local explanations of n samples
n = 5
local_explanations = test_coef[:n, :]

# Make pandas dataframe
final_col_names = np.concatenate((np.asarray(['intercept']), col_names), axis = 0) 
pd.DataFrame(data=local_explanations, index=range(n), columns=final_col_names)

Unnamed: 0,intercept,Page Popularity/likes,Page Check,Page talk,Page Category,min # of comments,min # of comments in last 24 hours,min # of comments in last 48 hours,min # of comments in the first 24 hours,min # of comments in last 48 to last 24 hours,...,post was published on Thursday,post was published on Friday,post was published on Saturday,basetime (Sunday),basetime (Monday),basetime (Tuesday),basetime (Wednesday),basetime (Thursday),basetime (Friday),basetime (Saturday)
0,-103.330512,4.422585,-15.001978,4.496328,1.229217,4.497695,18.962609,32.218586,29.317612,25.75556,...,0.295788,-0.005714,-0.559625,0.078723,0.61073,0.452924,0.481491,-0.878218,-0.243942,-0.501708
1,-104.206197,4.25365,-14.915397,3.943046,1.137022,4.459325,18.666929,32.59171,29.570496,26.151958,...,0.130149,0.01597,-0.533755,0.159478,0.58526,0.235655,0.421461,-0.735738,-0.339109,-0.327006
2,-103.286616,4.186166,-14.795392,3.534093,1.126795,4.319873,18.569088,32.592334,29.231428,26.181185,...,0.176272,-0.096792,-0.56279,0.202239,0.594071,0.24888,0.31107,-0.700757,-0.387459,-0.268045
3,-101.693317,4.348262,-14.905574,5.691516,1.223091,4.993201,18.508501,31.850354,29.68688,23.989133,...,0.204816,-0.003839,-0.53076,0.177499,0.65965,0.291278,0.422315,-0.789438,-0.333941,-0.427364
4,-101.229288,4.31802,-14.496846,5.55079,1.20983,4.939441,18.472585,31.814553,29.207907,24.08257,...,0.230303,-0.065082,-0.528969,0.089551,0.525554,0.416373,0.367239,-0.741127,-0.35491,-0.302679
