# Google colab link:

[https://colab.research.google.com/github/oegedijk/explainerdashboard/blob/master/custom_examples.ipynb](https://colab.research.google.com/github/oegedijk/explainerdashboard/blob/master/custom_examples.ipynb)

install explainerdashboard:

In [None]:
!pip install explainerdashboard

# Imports

In [None]:
%load_ext autoreload
%autoreload 0

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
from sklearn.ensemble import RandomForestClassifier

from jupyter_dash import JupyterDash

import dash_html_components as html
import dash_bootstrap_components as dbc


from explainerdashboard.explainers import *
from explainerdashboard.datasets import *
from explainerdashboard.dashboard_components import *
from explainerdashboard.dashboards import *

# train model and build explainer

In [None]:
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()

feature_descriptions = {
    "Sex": "Gender of passenger",
    "Deck": "The deck the passenger had their cabin on",
    "PassengerClass": "The class of the ticket: 1st, 2nd or 3rd class",
    "Fare": "The amount of money people paid", 
    "No_of_relatives_on_board": "number of siblings, spouses, parents plus children on board",
    "Embarked": "the port where the passenger boarded the Titanic. Either Southampton, Cherbourg or Queenstown",
    "Age": "Age of the passenger",
    "No_of_siblings_plus_spouses_on_board": "The sum of the number of siblings plus the number of spouses on board",
    "No_of_parents_plus_children_on_board" : "The sum of the number of parents plus the number of children on board",
}


model = RandomForestClassifier(n_estimators=50, max_depth=15)
model.fit(X_train, y_train)

explainer = RandomForestClassifierExplainer(model, X_test, y_test, 
                            cats=['Sex', 'Deck', 'Embarked'],
                            idxs=test_names, #names of passengers 
                            descriptions=feature_descriptions,
                            labels=['Not survived', 'Survived'])

explainer.plot_shap_contributions(0)

# Design Pattern I: Custom `ExplainerComponent` 

## Constructing the ExplainerComponent

By constructing a custom `ExplainerComponent` some of the boilerplate stuff gets taken care of for you. 
This includes:
 - by registering the callbacks for all your subcomponents
 - by making it easy to pre calculate all lazily calculated dependencies such as shap values up front. (otherwise these may get calculated multiple times by each component individually)
 - Makes it easy to run and start the dashboard by passing it on to an `ExplainerDashboard`
 

Important elements:
- call `super().__init__(explainer, title, name) at beginning of `__init__`
- register components with `.register_components()`.
- define your callbacks in `._register_callbacks(self, app)`. This way the ExplainerComponent automatically register all calbacks of register components first, and then call `_register_callbacks(app)`. Alternatively you just have to make sure that you register all callbacks of subcomponents in the `register_callbacks(app)` method. 

By calling `register_components()` your component callbacks are automatically registered, and you can
can calculate all dependencies of your Component at once by calling `.calculate_dependencies()`

Finally you can pass the CustomComponent to `ExplainerDashboard`, without having to write the dash boilerplate yourself.

## Building the layout

Each ExplainerComponent has a `.layout()` method that you can use inside your layout definition. 
All the configuration of the component is done when you instantiate the component. 
You can hide the different toggles(`hide_cats=True`) and titles, and set default values of 
elements (`col='Fare'`)
of the components.

For example we set 'Fare' and 'PassengerClass' as default col en color_col for the ShapDependenceComponent.


The rest of the layout is done with standard dash bootstrap components (dbc). https://dash-bootstrap-components.opensource.faculty.ai/docs/

In [None]:
class CustomDashboard(ExplainerComponent):
    def __init__(self, explainer):
        super().__init__(explainer, title="Titanic Explainer")
        
        self.precision = PrecisionComponent(explainer, 
                                hide_cutoff=True, hide_binsize=True, 
                                hide_binmethod=True, hide_multiclass=True,
                                hide_selector=True,
                                cutoff=None)
        self.shap_summary = ShapSummaryComponent(explainer, 
                                hide_title=True, hide_selector=True,
                                hide_depth=True, depth=8, 
                                hide_cats=True, cats=True)
        self.shap_dependence = ShapDependenceComponent(explainer, 
                                hide_title=True, hide_selector=True,
                                hide_cats=True, cats=True, 
                                hide_index=True,
                                col='Fare', color_col="PassengerClass")
        self.connector = ShapSummaryDependenceConnector(self.shap_summary, self.shap_dependence)
        
        self.register_components(self.precision, self.shap_summary, self.shap_dependence, self.connector)
        
    def layout(self):
        return dbc.Container([
            html.H1("Titanic Explainer"),
            dbc.Row([
                dbc.Col([
                    html.H3("Model Performance"),
                    html.Div("As you can see on the right, the model performs quite well."),
                    html.Div("The higher the predicted probability of survival predicted by"
                             "the model on the basis of learning from examples in the training set"
                             ", the higher is the actual percentage for a person surviving in "
                             "the test set"),
                ], width=4),
                dbc.Col([
                    html.H3("Model Precision Plot"),
                    self.precision.layout()
                ])
            ]),
            dbc.Row([
                dbc.Col([
                    html.H3("Feature Importances Plot"),
                    self.shap_summary.layout()
                ]),
                dbc.Col([
                    html.H3("Feature importances"),
                    html.Div("On the left you can check out for yourself which parameters were the most important."),
                    html.Div(f"{self.explainer.columns_ranked_by_shap(cats=True)[0]} was the most important"
                             f", followed by {self.explainer.columns_ranked_by_shap(cats=True)[1]}"
                             f" and {self.explainer.columns_ranked_by_shap(cats=True)[2]}."),
                    html.Div("If you select 'detailed' you can see the impact of that variable on "
                             "each individual prediction. With 'aggregate' you see the average impact size "
                             "of that variable on the finale prediction."),
                    html.Div("With the detailed view you can clearly see that the the large impact from Sex "
                            "stems both from males having a much lower chance of survival and females a much "
                            "higher chance.")
                ], width=4)
            ]),
            dbc.Row([
                dbc.Col([
                    html.H3("Relations between features and model output"),
                    html.Div("In the plot to the right you can see that the higher the priace"
                             "of the Fare that people paid, the higher the chance of survival. "
                            "Probably the people with more expensive tickets were in higher up cabins, "
                            "and were more likely to make it to a lifeboat."),
                    html.Div("When you color the impacts by the PassengerClass, you can clearly see that "
                             "the more expensive tickets were mostly 1st class, and the cheaper tickets "
                             "mostly 3rd class."),
                    html.Div("On the right you can check out for yourself how different features impact "
                            "the model output."),
                ], width=4),
                dbc.Col([
                    html.H3("Feature impact plot"),
                    self.shap_dependence.layout()
                ]),
            ])
        ])

In [None]:
ExplainerDashboard(explainer, CustomDashboard, hide_header=True).run()

# Design pattern II: building own dashboard class
Each ExplainerComponent has a `.layout()`, `.register_callbacks(app)` and `.calculate_dependencies()` method.

So you can define the dashboard any way you see fit, as long as you remember to call those.

In [None]:
class CustomDashboard():
    def __init__(self, explainer):
        self.explainer = explainer
        self.precision = PrecisionComponent(explainer, 
                                hide_cutoff=True, hide_binsize=True, hide_selector=True,
                                hide_binmethod=True, hide_multiclass=True,
                                cutoff=None)
        self.shap_summary = ShapSummaryComponent(explainer, 
                                hide_title=True, hide_selector=True,
                                hide_depth=True, depth=8, 
                                hide_cats=True, cats=True)
        self.shap_dependence = ShapDependenceComponent(explainer, 
                                hide_title=True, hide_selector=True,
                                hide_cats=True, cats=True, 
                                hide_index=True,
                                col='Fare', color_col="PassengerClass")
        self.connector = ShapSummaryDependenceConnector(self.shap_summary, self.shap_dependence)
        
    def layout(self):
        return dbc.Container([
            html.H1("Titanic Explainer"),
            dbc.Row([
                dbc.Col([
                    html.H3("Model Performance"),
                    html.Div("As you can see on the right, the model performs quite well."),
                    html.Div("The higher the predicted probability of survival predicted by"
                             "the model on the basis of learning from examples in the training set"
                             ", the higher is the actual percentage for a person surviving in "
                             "the test set"),
                ], width=4),
                dbc.Col([
                    html.H3("Model Precision Plot"),
                    self.precision.layout()
                ])
            ]),
            dbc.Row([
                dbc.Col([
                    html.H3("Feature Importances Plot"),
                    self.shap_summary.layout()
                ]),
                dbc.Col([
                    html.H3("Feature importances"),
                    html.Div("On the left you can check out for yourself which parameters were the most important."),
                    html.Div(f"{self.explainer.columns_ranked_by_shap(cats=True)[0]} was the most important"
                             f", followed by {self.explainer.columns_ranked_by_shap(cats=True)[1]}"
                             f" and {self.explainer.columns_ranked_by_shap(cats=True)[2]}."),
                    html.Div("If you select 'detailed' you can see the impact of that variable on "
                             "each individual prediction. With 'aggregate' you see the average impact size "
                             "of that variable on the finale prediction."),
                    html.Div("With the detailed view you can clearly see that the the large impact from Sex "
                            "stems both from males having a much lower chance of survival and females a much "
                            "higher chance.")
                ], width=4)
            ]),
            dbc.Row([
                
                dbc.Col([
                    html.H3("Relations between features and model output"),
                    html.Div("In the plot to the right you can see that the higher the priace"
                             "of the Fare that people paid, the higher the chance of survival. "
                            "Probably the people with more expensive tickets were in higher up cabins, "
                            "and were more likely to make it to a lifeboat."),
                    html.Div("When you color the impacts by the PassengerClass, you can clearly see that "
                             "the more expensive tickets were mostly 1st class, and the cheaper tickets "
                             "mostly 3rd class."),
                    html.Div("On the right you can check out for yourself how different features impact "
                            "the model output."),
                ], width=4),
                dbc.Col([
                    html.H3("Feature impact plot"),
                    self.shap_dependence.layout()
                ]),
            ])
        ])
    
    def register_callbacks(self, app):
        self.precision.register_callbacks(app)
        self.shap_summary.register_callbacks(app)
        self.shap_dependence.register_callbacks(app)
        self.connector.register_callbacks(app)
        

In [None]:
db = CustomDashboard(explainer)
ExplainerDashboard(explainer, db, hide_header=True).run()

# Design pattern III: traditional flat dash definition

In [None]:
precision = PrecisionComponent(explainer, 
                        hide_cutoff=True, hide_binsize=True,  hide_selector=True,
                        hide_binmethod=True, hide_multiclass=True,
                        cutoff=None)
shap_summary = ShapSummaryComponent(explainer, 
                        hide_title=True, hide_selector=True,
                        hide_depth=True, depth=8, 
                        hide_cats=True, cats=True)
shap_dependence = ShapDependenceComponent(explainer, 
                        hide_title=True, hide_selector=True,
                        hide_cats=True, cats=True, 
                        hide_index=True,
                        col='Fare', color_col="PassengerClass")

connector = ShapSummaryDependenceConnector(shap_summary, shap_dependence)
        
layout = dbc.Container([
            html.H1("Titanic Explainer"),
            dbc.Row([
                dbc.Col([
                    html.H3("Model Performance"),
                    html.Div("As you can see on the right, the model performs quite well."),
                    html.Div("The higher the predicted probability of survival predicted by"
                             "the model on the basis of learning from examples in the training set"
                             ", the higher is the actual percentage for a person surviving in "
                             "the test set"),
                ], width=4),
                dbc.Col([
                    html.H3("Model Precision Plot"),
                    precision.layout()
                ])
            ]),
            dbc.Row([
                dbc.Col([
                    html.H3("Feature Importances Plot"),
                    shap_summary.layout()
                ]),
                dbc.Col([
                    html.H3("Feature importances"),
                    html.Div("On the left you can check out for yourself which parameters were the most important."),
                    html.Div(f"{explainer.columns_ranked_by_shap(cats=True)[0]} was the most important"
                             f", followed by {explainer.columns_ranked_by_shap(cats=True)[1]}"
                             f" and {explainer.columns_ranked_by_shap(cats=True)[2]}."),
                    html.Div("If you select 'detailed' you can see the impact of that variable on "
                             "each individual prediction. With 'aggregate' you see the average impact size "
                             "of that variable on the finale prediction."),
                    html.Div("With the detailed view you can clearly see that the the large impact from Sex "
                            "stems both from males having a much lower chance of survival and females a much "
                            "higher chance.")
                ], width=4)
            ]),
            dbc.Row([
                
                dbc.Col([
                    html.H3("Relations between features and model output"),
                    html.Div("In the plot to the right you can see that the higher the priace"
                             "of the Fare that people paid, the higher the chance of survival. "
                            "Probably the people with more expensive tickets were in higher up cabins, "
                            "and were more likely to make it to a lifeboat."),
                    html.Div("When you color the impacts by the PassengerClass, you can clearly see that "
                             "the more expensive tickets were mostly 1st class, and the cheaper tickets "
                             "mostly 3rd class."),
                    html.Div("On the right you can check out for yourself how different features impact "
                            "the model output."),
                ], width=4),
                dbc.Col([
                    html.H3("Feature impact plot"),
                    shap_dependence.layout()
                ]),
            ])
        ])

In [None]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.FLATLY], assets_url_path="")
app.title = "Titanic Explainer"
app.layout = layout

precision.register_callbacks(app)
shap_summary.register_callbacks(app)
shap_dependence.register_callbacks(app)
connector.register_callbacks(app)

app.run_server(port=8053)
