Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Example of integration with pytorch/captum
_**This notebook showcases how create an interpret-community style explanation using captum to view it in the dashboard.**_


## Table of Contents

1. [Introduction](#Introduction)
1. [Setup](#Setup)
1. [Project](#Project)
1. [Run model explainer locally at training time](#Explain)
    1. Train a binary classification model
    1. Explain the model
        1. Generate global explanations
        1. Generate local explanations
1. [Visualize results](#Visualize)
1. [Next steps](#Next)

<a id='Introduction'></a>
## 1. Introduction

This notebook illustrates how to integrate captum explanations with intepret-community visualization.

<a id='Project'></a>       
## 2. Project

The goal of this project is to run an IntegratedGradients explainer from captum and visualize it in the ExplanationDashboard.

<a id='Setup'></a>
## 3. Setup

If you are using Jupyter notebooks, the extensions should be installed automatically with the package.
If you are using Jupyter Labs run the following command:
```
(myenv) $ jupyter labextension install @jupyter-widgets/jupyterlab-manager
```


<a id='Explain'></a>
## 4. Create a captum model (taken from their main page)

In [None]:
# Example taken from captum's main page
# https://captum.ai/
import numpy as np

import torch
import torch.nn as nn

from captum.attr import IntegratedGradients

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(3, 2)
        self.output = nn.Linear(2, 1)

        # initialize weights and biases
        self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
        self.lin1.bias = nn.Parameter(torch.zeros(1,3))
        self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
        self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        return self.output(self.lin2(self.relu(self.lin1(input))))


model = ToyModel()
model.eval()

# Fix the random seed to make computations deterministic

torch.manual_seed(123)
np.random.seed(123)

### Define input and baseline tensors

In [None]:
input = torch.rand(100, 3)
baseline = torch.zeros(100, 3)

### Run integrated gradients to get an explanation

In [None]:
ig = IntegratedGradients(model)
attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
# optionally print feature attributions
# print('IG Attributions:', attributions)

### Create an interpret-community style explanation

In [None]:
from interpret_community.captum import CaptumAdapter
adapter = CaptumAdapter(features=['A', 'B', 'C'])
global_explanation = adapter.create_global(attributions, evaluation_examples=np.array(input))

In [None]:
# Sorted SHAP values
print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))
# Corresponding feature names
print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))
# Feature ranks (based on original order of features)
print('global importance rank: {}'.format(global_explanation.global_importance_rank))

In [None]:
# Print out a dictionary that holds the sorted feature importance names and values
print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))

### Explain overall model predictions as a collection of local (instance-level) explanations

In [None]:
# feature shap values for all features and all data points in the training data
print('local importance values: {}'.format(global_explanation.local_importance_values))

<a id='Visualize'></a>
## 5. Visualize
Load the visualization dashboard

In [None]:
from interpret_community.widget import ExplanationDashboard

In [None]:
from interpret_community.common.model_wrapper import wrap_model
from interpret_community.dataset.dataset_wrapper import DatasetWrapper
import pandas as pd
wrapped_model, _ = wrap_model(model, DatasetWrapper(input), model_task='regression')
dataset = pd.DataFrame(np.array(input))

In [None]:
ExplanationDashboard(global_explanation, wrapped_model, datasetX=dataset)

<a id='Next'></a>
## 6. Next steps
Learn about other use cases of the explain package on a:
       
1. [Training time: regression problem](./explain-regression-local.ipynb)
1. [Training time: multiclass classification problem](./explain-multiclass-classification-local.ipynb)
1. Explain models with engineered features:
    1. [Simple feature transformations](./simple-feature-transformations-explain-local.ipynb)
    1. [Advanced feature transformations](./advanced-feature-transformations-explain-local.ipynb)