In [1]:
# hide
# skip
! [ -e /content ] && pip install -Uqq model_inspector nbdev
# For colab. Restart the runtime after running this cell!

In [2]:
# default_exp inspect/any_model

# Any Model

> Inspector functionality for any model

In [3]:
# export
from typing import Optional

import pandas as pd
import sklearn.inspection
from fastcore.basics import basic_repr, store_attr
from matplotlib.axes import Axes
from model_inspector.delegate import delegates
from model_inspector.explore import plot_column_clusters, show_correlation
from sklearn.base import BaseEstimator
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import check_is_fitted

In [9]:
# export
class _Inspector:
    """Model inspector base class

    Users should use `get_inspector` to generate appropriate
    `_Inspector` objects rather than instantiating this class or its
    subclasses directly.
    """

    def __init__(self, model: BaseEstimator, X: pd.DataFrame, y: pd.Series):
        check_is_fitted(model)
        check_X_y(X, y)
        if not isinstance(model, (DummyClassifier, DummyRegressor)):
            model._check_n_features(X, reset=False)

        store_attr()

    __repr__ = basic_repr(["model"])

    @delegates(sklearn.inspection.PartialDependenceDisplay.from_estimator)
    def plot_dependence(self, *args, **kwargs) -> Axes:
        """Plot partial dependence"""
        return sklearn.inspection.PartialDependenceDisplay.from_estimator(
            estimator=self.model, X=self.X, *args, **kwargs
        ).axes_

    @delegates(sklearn.inspection.permutation_importance)
    def permutation_importance(
        self,
        sort: bool = True,
        **kwargs,
    ) -> pd.Series:
        """Calculate permutation importance

        - `sort`: Sort features by decreasing importance
        """
        if kwargs is None:
            kwargs = {}
        kwargs = {**{"n_jobs": -1}, **kwargs}

        importances = pd.Series(
            permutation_importance(self.model, self.X, self.y, **kwargs)[
                "importances_mean"
            ],
            index=self.X.columns,
        )
        if sort:
            importances = importances.sort_values(ascending=False)
        return importances

    def plot_permutation_importance(
        self,
        ax: Optional[Axes] = None,
        importance_kwargs: Optional[dict] = None,
        plot_kwargs: Optional[dict] = None,
    ) -> Axes:
        """Plot a correlation matrix for `self.X` and `self.y`

        Parameters:
        - `ax`: Matplotlib `Axes` object. Plot will be added to this object
        if provided; otherwise a new `Axes` object will be generated.
        - `importance_kwargs`: kwargs to pass to
        `sklearn.inspection.permutation_importance`
        - `plot_kwargs`: kwargs to pass to `pd.Series.plot.barh`
        """
        if importance_kwargs is None:
            importance_kwargs = {}
        # reversing the order to compensate for `barh` reversing it
        importance = self.permutation_importance(**importance_kwargs).iloc[::-1]

        if plot_kwargs is None:
            plot_kwargs = {}
        ax = importance.plot.barh(**plot_kwargs)
        ax.set(title="Feature importances")
        ax.bar_label(ax.containers[0], fmt="%.2f")
        # extending plot on the right to accommodate labels
        ax.set_xlim((ax.get_xlim()[0], ax.get_xlim()[1] * 1.05))
        return ax

    @delegates(show_correlation)
    def show_correlation(self, **kwargs) -> Axes:
        """Show a correlation matrix for `self.X` and `self.y`

        If output is not rendering properly when you reopen a notebook,
        make sure the notebook is trusted.
        """
        return show_correlation(
            df=pd.concat((self.X, self.y), axis="columns"),
            **kwargs,
        )

    @delegates(plot_column_clusters)
    def plot_feature_clusters(self, **kwargs) -> Axes:
        """Plot a dendrogram based on feature correlations

        - `corr_method`: Method of correlation to pass to `df.corr()`
        - `ax`: Matplotlib `Axes` object. Plot will be added to this object
        if provided; otherwise a new `Axes` object will be generated.
        """
        return plot_column_clusters(self.X, **kwargs)

    @property
    def methods(self):
        """Show available methods"""
        return [
            i
            for i in dir(self)
            if not i.startswith("__")
            and i not in self.__stored_args__
            and i != "methods"
        ]

In [10]:
# export
_all_ = ["_Inspector"]