In [None]:
#@title Import necessary libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_california_housing
import shap
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import shap
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

#**Predetermined explanation function**

To support you in model development, we provide a predetermined model explanation function called Model_Explainer(). This function provides you with tools to investigate the influence of variables on model predictions. You can visualize the ranking of the most important variables as well as the learned patterns.

The function allows you to create three different visualizations:

1.   The **bar plot** shows you the ranking of the most important variables. With this plot, you can see which variables contribute the most to the model's prediction

2.   The **beeswarm plot** visualizes the influence of each variable in greater detail. Each point represents an observation of the test data set. The x-axis shows the (positive or negative) influence of the variables. The coloring shows the expression of the respective variable. Thus, red points in the left value range of the x-axis represent test data points in which high values of the variable had a negative influence on the prediction. Further explanations are available [here](https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html).

3. The **interaction plot** shows how the influence of a variable depends on the expression of another variable. An exact explanation of this plot can be found [here](https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/scatter.html).



You will find the function in the **next cell**. You only have to execute the cell, then you can use the function as you wish.

The Model_Explainer() function uses three arguments:

- *model* is the trained machine learning model
- *X_train* is the training data set (excluding the target variable)
- *X_test* is the test data set (excluding the target variable), i.e. the data set with which you evaluate the performance of your model
Example application of the function: Model_Explainer(model = ML_classifier, X_train = X_train, X_test = X_train)


Example:
Model_Explainer(model = ML_classifier, X_train = X_train, X_test = X_train)

In [None]:
#@title Model explanation function classification

### Classification tasks

import shap
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML
import IPython

# Configure the Jupyter display environment for interactivity
def configure_jupyter_display():
    display(HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
            },
          });
        </script>
        '''))

# Main function to display SHAP plots with controlled replacement of old plots
def Model_Explainer(model, X_train, X_test):
    configure_jupyter_display()

    # Limit the dataset for SHAP calculation to enhance performance
    if len(X_test) > 200:
        X_test = X_test.sample(n=200, random_state=42)

    # Initialize SHAP explainer and calculate SHAP values
    try:
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_test)
        # Use TreeExplainer for tree-based models for better performance
    except Exception as e:
        print(f"Warning: TreeExplainer failed with error: {e}")
        print("Using KernelExplainer instead.")
        background_data = shap.sample(X_train, 100)  # Sample 100 instances
        explainer = shap.LinearExplainer(model, X_train)
        shap_values = explainer.shap_values(X_test)

        # Check if SHAP values have more than two dimensions and adjust
    if len(shap_values.shape) > 2:
        # Assuming the SHAP values shape is like (n_classes, n_samples, n_features)
        # and we want to take the SHAP values for the positive class which is usually at index 1
        shap_values = shap_values[:, :, 1]
        print("Adjusted SHAP values to two dimensions.")

    print("Shape of SHAP values:", shap_values.shape)

    # Set a default style (like 'ggplot') for grid-like visuals
    plt.style.use('ggplot')  # You can also use 'bmh' or 'classic'

    # Set global plot size and other parameters
    plt.rcParams.update({
        "figure.figsize": (10, 6),  # Set default figure size
        "figure.dpi": 100,          # Set default DPI for clarity
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
    })

    # Output widgets with added margin for spacing
    output_bar = widgets.Output(layout=widgets.Layout(margin='0 20px 0 0'))
    output_summary = widgets.Output(layout=widgets.Layout(margin='0 20px 0 0'))
    output_interaction = widgets.Output()

    # Helper functions to display each plot type within its Output widget
    def show_bar_plot(_=None):
        with output_bar:
            output_bar.clear_output()  # Clear previous plot
            plt.figure()
            plt.figure(figsize=(10, 6))
            shap.summary_plot(shap_values, X_test, plot_type="bar", show=True)  # Use class 1 SHAP values
            plt.show()

    def show_summary_plot(_=None):
        with output_summary:
            output_summary.clear_output()  # Clear previous plot
            plt.figure()
            plt.figure(figsize=(10, 6))
            shap.summary_plot(shap_values, X_test, show=True)  # Use class 1 SHAP values
            plt.show()

    def show_interaction_plot(_=None):
        feature_idx = feature_names.index(feature_dropdown.value)
        with output_interaction:
            output_interaction.clear_output()  # Clear previous plot
            plt.figure()
            plt.figure(figsize=(20, 6))
            shap.dependence_plot(feature_idx, shap_values, X_test, show=True)  # Use class 1 SHAP values
            plt.show()

    # Set up widgets for interactive feature selection
    feature_names = X_test.columns.tolist()
    feature_dropdown = widgets.Dropdown(options=feature_names, description="Feature:")

    # Display layout with organized widget layout
    button_bar = widgets.Button(description="Show Bar Plot")
    button_summary = widgets.Button(description="Show Summary Plot")
    button_interaction = widgets.Button(description="Show Interaction Plot")

    # Link buttons to plotting functions
    button_bar.on_click(show_bar_plot)
    button_summary.on_click(show_summary_plot)
    button_interaction.on_click(show_interaction_plot)

    # Arrange buttons and plots horizontally with spacing
    button_box = widgets.HBox([button_bar, button_summary, button_interaction, feature_dropdown])
    plot_box = widgets.HBox([output_bar, output_summary, output_interaction])  # Adjusted order here

    # Arrange layout with VBox for a structured display
    interaction_box = widgets.VBox([
        widgets.HTML("<h3>SHAP Plot Selector</h3>"),
        button_box,
        plot_box
    ])

    # Display the organized layout
    display(interaction_box)

In [11]:
#@title Load Data
# url = "https://raw.githubusercontent.com/caradamm/testCara123/main/data.csv" # Replace with your file name
url = "https://raw.githubusercontent.com/caradamm/XAI_HousePricePrediction/main/data/data.csv"
# https://github.com/caradamm/XAI_HousePricePrediction/blob/main/data/data.csv
df = pd.read_csv(url)
df.head()


Unnamed: 0,garden\tbasement\televator\tbalcony\tfloor (storey)\tnmbr of rooms\tconstruction year\tunemployment\tAnteil Gruenenwaehler\tprice
0,FALSE\tTRUE\tFALSE\tTRUE\t1\t3\t2005\t1\t2\t1
1,FALSE\tFALSE\tTRUE\tFALSE\t11\t3\t2024\t2\t2\t1
2,FALSE\tTRUE\tFALSE\tFALSE\t1\t6\t1956\t2\t1\t0
3,TRUE\tTRUE\tTRUE\tTRUE\t-1\t5\t1971\t1\t2\t1
4,FALSE\tTRUE\tTRUE\tTRUE\t2\t2\t2021\t3\t1\t1


#Your code goes here

In [None]:
#@title Call model explanation function

Model_Explainer(model = model_bin, X_train = X_train, X_test = X_test,)