In [None]:
# default_exp sklearn

In [None]:
# export
from typing import Iterable

In [3]:
from IPython.display import HTML
import sklearn.datasets
from sklearn.linear_model import LinearRegression, LogisticRegression

In [4]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext lab_black

# sklearn

> Inspect scikit-learn models.

In [5]:
# hide
from nbdev.showdoc import *

In [6]:
# export
def generate_linear_model_html(
    model,
    feature_names: Iterable[str],
    target_name: str,
    intercept_formatter: str = ".2f",
    coef_formatter: str = ".2f",
):
    """Generate an HTML equation that characterizes a linear model
    
    Model components are color-coded as follows:
    - target: red
    - intercept: purple
    - coefficients: green
    - features: blue

    Parameters
    ----------
    model :
        Fitted scikit-learn linear model of the form
        `y = b0 + b1 * x1 + ...`
    feature_names :
        Feature names in the order in which they were given to the model
    target_name :
        Name of target variable `y`
    intercept_formatter : str, optional
        Format specifier for model intercept
    coef_formatter : str, optional
        Format specifier for model coefficients
    """
    if len(model.coef_) != len(feature_names):
        raise ValueError("len(model.coef_) != len(feature_cols)")
    model_string = f"""
        <span style='color:red'>{target_name}</span>
        = <span style='color:purple'>{model.intercept_:{intercept_formatter}}</span>
    """
    for coef, feature_col in zip(model.coef_, feature_names):
        model_string += f"""
            <span style='color:green'>{"+" if coef >= 0 else "-"} {abs(coef):{coef_formatter}}</span>
            * <span style='color:blue'>{feature_col}</span>
        """
    return model_string

In [7]:
diabetes = sklearn.datasets.load_diabetes()
X, y = diabetes["data"], diabetes["target"]

HTML(
    generate_linear_model_html(
        model=LinearRegression().fit(X, y),
        feature_names=diabetes["feature_names"],
        target_name="progression",
    )
)

In [8]:
# export
def generate_logistic_model_html(
    model,
    feature_names: Iterable[str],
    target_names: Iterable[str],
    intercept_formatter: str = ".2f",
    coef_formatter: str = ".2f",
):
    """Generate an HTML equation that characterizes a logistic
    regression model
    
    Model components are color-coded as follows:
    - target: red
    - intercept: purple
    - coefficients: green
    - features: blue

    Parameters
    ----------
    model
        Fitted scikit-learn linear model of the form
        `log-odds(y) = b0 + b1 * x1 + ...`
    feature_names
        Feature names in the order in which they were given to the model
    target_names
        Names of the values of the target variable
    intercept_formatter
        Format specifier for model intercept
    coef_formatter
        Format specifier for model coefficients
    """
    for coefs in model.coef_:
        if len(coefs) != len(feature_names):
            raise ValueError("len(model.coef_) != len(feature_cols)")
    model_string = "<p>"
    for target_name, coefs, intercept in zip(
        target_names, model.coef_, model.intercept_
    ):
        model_string += f"""
            <span style='color:red'>log-odds({target_name})</span>
            = <span style='color:purple'>{intercept:{intercept_formatter}}</span>
        """
        for coef, feature_col in zip(coefs, feature_names):
            model_string += f"""
                <span style='color:green'>{"+" if coef >= 0 else "-"} {abs(coef):{coef_formatter}}</span>
                * <span style='color:blue'>{feature_col}</span>
            """
        model_string += "</p>"
    return model_string

In [9]:
iris = sklearn.datasets.load_iris()
X, y = iris["data"], iris["target"]

HTML(
    generate_logistic_model_html(
        model=LogisticRegression(max_iter=1_000).fit(X, y),
        feature_names=iris["feature_names"],
        target_names=iris["target_names"],
    )
)

In [10]:
from nbdev.export import *

notebook2script()

Converted 00_sklearn.ipynb.
Converted index.ipynb.
