# Extracting & Plotting  Feature Names & Importance from Scikit-Learn Pipelines

If you have ever been tasked with productionalizing a machine learning model, you probably know that Scikit-Learn library offers one of the best ways -- if not the best way -- of creating production-quality machine learning workflows. The ecosystem's [Pipeline](https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html), [ColumnTransformer](https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html), [preprocessors](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing), [imputers](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.impute) & [feature selection](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.feature_selection) classes are powerful tools that transform raw data into model-ready features.

However, before anyone is going to let you deploy to production, you are going to want to have some minimal understanding of how the new model works. The most common way to explain how a black-box model works is by plotting feature names and importance values. If you have ever tried to extract the feature names from a heterogeneous dataset processed by ColumnTransformer, you know that this is no easy task. Exhaustive Internet searches have only brought to my attention where others have [asked](https://github.com/scikit-learn/scikit-learn/issues/6424) [the](https://github.com/scikit-learn/scikit-learn/pull/6431) [same](https://github.com/scikit-learn/scikit-learn/pull/12627) [question](https://github.com/scikit-learn/scikit-learn/pull/13307) or offered a [partial answer](https://github.com/scikit-learn/scikit-learn/issues/12525), instead of yielding a comprehensive and satisfying solution. 

To remedy this situation, I have developed a class called `FeatureImportance` that will extract feature names and importance values from a Pipeline instance. It then uses the Plotly library to plot the feature importance using only a few lines of code. In this post, I will load a fitted Pipeline, demonstrate how to use my class and then give an overview of how it works. The complete code can be found [here](https://www.kaggle.com/kylegilde/feature-importance) or at the end of this blog post.

There are two things I should note before continuing:

1. I credit Joey Gao's code on [this thread](https://github.com/scikit-learn/scikit-learn/issues/12525#issuecomment-436217100) with showing the way to tackle this problem.

2. My post assumes that you have worked with Scikit-Learn and Pandas before and are familiar with how ColumnTransformer, Pipeline & preprocessing classes facilitate reproducible feature engineering processes. If you need a refresher, check out this [Scikit-Learn example](https://scikit-learn.org/stable/auto_examples/compose/plot_column_transformer_mixed_types.html).

## Creating a Pipeline


For the purposes of demonstration, I've written a script called [fit_pipeline_ames.py](https://www.kaggle.com/kylegilde/fit-pipeline-ames). It loads the [Ames housing training data from Kaggle](https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data) and fits a moderately complex Pipeline. The `pipe` instance contains the following 4 steps:

1. The [ColumnTransformer](https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html) instance is composed of 3 Pipelines, containing a total of 4 transformer instances, including [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html), [OneHotEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html) & [GLMMEncoder](http://contrib.scikit-learn.org/category_encoders/glmm.html) from the [category_encoders](https://contrib.scikit-learn.org/category_encoders/) package. See my [previous blog post](https://towardsdatascience.com/building-columntransformers-dynamically-1-6354bd08aa54) for a full explanation of how I dynamically constructed this particular ColumnTransformer.

2. The [VarianceThreshold](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html) uses the default threshold of 0, which removes any features that contain only a single value. Some models will fail if a feature has no variance.

3. The [SelectPercentile](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectPercentile.html) uses the [f_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) scoring function with a percentile threshold of 90. These settings retain the top 90% of features and discard the bottom 10%.

4. The [CatBoostRegressor](https://catboost.ai/docs/concepts/python-reference_catboostregressor.html) model is fit to the `SalesPrice` dependent variable using the features created and selected in the preceding steps.

In [1]:
from fit_pipeline_ames import *
pipe

Pipeline(steps=[('column_transformer',
                 ColumnTransformer(n_jobs=4,
                                   transformers=[('numeric_pipeline',
                                                  Pipeline(steps=[('simpleimputer',
                                                                   SimpleImputer(add_indicator=True,
                                                                                 strategy='median'))]),
                                                  <sklearn.compose._column_transformer.make_column_selector object at 0x7fd44216a9d0>),
                                                 ('oh_pipeline',
                                                  Pipeline(steps=[('simpleimputer',
                                                                   SimpleImputer(strategy='constant'...
                                                  <function select_oh_features at 0x7fd43ff814d0>),
                                                 ('hc_pipeline',
   

## Plotting FeatureImportance


With the help of FeatureImportance, we can extract the feature names and importance values and plot them with 3 lines of code.

In [2]:
from feature_importance import FeatureImportance
feature_importance = FeatureImportance(pipe)
feature_importance.plot(top_n_features=25)

The `plot` method takes a number of arguments that control the plot's display. The most important ones are the following:

- `top_n_features`: This controls how many features will be plotted. The default value is 100. The plot's title will indicate this value as well as how many features there are in total. To plot all features, just set `top_n_features` to a number larger than the total features. 

- `rank_features`: This argument controls whether the integer ranks are displayed in front of the feature names. The default is `True`. I find that this aids with interpretation, especially when comparing the feature importance from multiple models.

- `max_scale`: This determines whether the importance values are scaled by the maximum value & multiplied by 100. The default is `True`. I find that this enables an intuitive way to compare how important other features are vis-a-viz the most important one. For instance, in the plot of above, we can say that `GrLivArea` is about 81% as important to the model as the top feature, `OverallQty`.

## How It Works

The `FeatureImportance` class should be instantiated using a fitted Pipeline instance. (You can also change the `verbose` argument to `True` if you want to have all of the diagnostics printed to your console.) My class validates that this Pipeline starts with a `ColumnTransformer` instance and ends with a regression or classification model that has the `feature_importance_` attribute. As intermediate steps, the Pipeline can have any number or no instances of classes from [sklearn.feature_selection](https://scikit-learn.org/stable/modules/feature_selection.html).

The `FeatureImportance` class is composed of 4 methods.

1. `get_feature_names_from_col_transformer` was the hardest method to devise. It iterates through the `ColumnTransformer` transformers, uses the `hasattr` function to discern what type of class we are dealing with and pulls the feature names accordingly. (Special Note: If the ColumnTransformer contains Pipelines and if one of the transformers in the Pipeline is adding completely new columns, it must come last in the pipeline. For example, OneHotEncoder, [MissingIndicator](https://scikit-learn.org/stable/modules/generated/sklearn.impute.MissingIndicator.html) & SimpleImputer(add_indicator=True) add columns to the dataset that didn't exist before, so they should come last in the Pipeline.)

2. `get_selected_features` calls `get_feature_names_from_col_transformer`. Then it tests for whether the main Pipeline contains any classes from sklearn.feature_selection based upon the existence of the `get_support` method. If it does, this method returns only the features names that were retained by the selector class or classes.

3. `get_feature_importance` calls `get_selected_features` and then creates a Pandas Series where values are the feature importance values from the model and its index is the feature names created by the first 2 methods. This Series is then stored in the `feature_importance` attribute.

4. `plot` calls `get_feature_importance` and plots the output based upon the specifications.

## Complete Code

The complete code is shown below and can be found here. If you create a Pipeline that you believe should be supported by FeatureImportance but is not, please provide a reproducible example, and I will consider making the necessary changes. 

The original notebook for this blog post can be found [here](https://www.kaggle.com/kylegilde/extracting-scikit-feature-names-importances). Stay tuned for further posts on training & regularizing models with Scikit-Learn ColumnTransformers and Pipelines. Let me know if you found this post helpful or have any ideas for improvement. Thanks!

In [3]:
import numpy as np  
import pandas as pd  
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted
import plotly.express as px


class FeatureImportance:

    """
    
    Extract & Plot the Feature Names & Importance Values from a Scikit-Learn Pipeline.
    
    The input is a Pipeline that starts with a ColumnTransformer & ends with a regression or classification model. 
    As intermediate steps, the Pipeline can have any number or no instances from sklearn.feature_selection.

    Note: 
    If the ColumnTransformer contains Pipelines and if one of the transformers in the Pipeline is adding completely new columns, 
    it must come last in the pipeline. For example, OneHotEncoder, MissingIndicator & SimpleImputer(add_indicator=True) add columns 
    to the dataset that didn't exist before, so there should come last in the Pipeline.
    
    
    Parameters
    ----------
    pipeline : a Scikit-learn Pipeline class where the a ColumnTransformer is the first element and model estimator is the last element
    verbose : a boolean. Whether to print all of the diagnostics. Default is False.
    
    Attributes
    __________
    
    feature_importance :  A Pandas Series containing the feature importance values and feature names as the index.    
    discarded_features : The features names that were not selected by a sklearn.feature_selection instance.
    plot_importances_dt : A Pandas DataFrame containing the subset of features and values that are actually displaced in the plot. 
    
    
    
    """
    def __init__(self, pipeline, verbose=False):
        self.pipeline = pipeline
        self.verbose = verbose


    def get_feature_names_from_col_transformer(self, verbose=None):  

        """

        Get the column names from the a ColumnTransformer containing transformers & pipelines

        Parameters
        ----------
        verbose : a boolean indicating whether to print summaries. 
            default = False


        Returns
        -------
        a list of the correct feature names

        Note: 
        If the ColumnTransformer contains Pipelines and if one of the transformers in the Pipeline is adding completely new columns, 
        it must come last in the pipeline. For example, OneHotEncoder, MissingIndicator & SimpleImputer(add_indicator=True) add columns 
        to the dataset that didn't exist before, so there should come last in the Pipeline.

        Inspiration: https://github.com/scikit-learn/scikit-learn/issues/12525 

        """

        if verbose is None:
            verbose = self.verbose
            
        if verbose: print('''\n\n---------\nRunning get_feature_names_from_col_transformer\n---------\n''')
        
        column_transformer = self.pipeline[0]        
        assert isinstance(column_transformer, ColumnTransformer), "Input isn't a ColumnTransformer"
        check_is_fitted(column_transformer)

        new_feature_names = []

        for i, transformer_item in enumerate(column_transformer.transformers_): 
            
            transformer_name, transformer, orig_feature_names = transformer_item
            orig_feature_names = list(orig_feature_names)
            
            if verbose: 
                print('\n\n', i, '. Transformer/Pipeline: ', transformer_name, ',', 
                      transformer.__class__.__name__, '\n')
                print('\tn_orig_feature_names:', len(orig_feature_names))

            if transformer_name == 'remainder' and transformer == 'drop':
                    
                continue
                
            if isinstance(transformer, Pipeline):
                # if pipeline, get the last transformer in the Pipeline
                transformer = transformer.steps[-1][1]

            if hasattr(transformer, 'get_feature_names'):

                if 'input_features' in transformer.get_feature_names.__code__.co_varnames:

                    names = list(transformer.get_feature_names(orig_feature_names))

                else:

                    names = list(transformer.get_feature_names())

            elif hasattr(transformer,'indicator_') and transformer.add_indicator:
                # is this transformer one of the imputers & did it call the MissingIndicator?

                missing_indicator_indices = transformer.indicator_.features_
                missing_indicators = [orig_feature_names[idx] + '_missing_flag'\
                                      for idx in missing_indicator_indices]
                names = orig_feature_names + missing_indicators

            elif hasattr(transformer,'features_'):
                # is this a MissingIndicator class? 
                missing_indicator_indices = transformer.features_
                missing_indicators = [orig_feature_names[idx] + '_missing_flag'\
                                      for idx in missing_indicator_indices]

            else:

                names = orig_feature_names

            if verbose: 
                print('\tn_new_features:', len(names))
                print('\tnew_features:\n', names)

            new_feature_names.extend(names)

        return new_feature_names

    
    def get_selected_features(self, verbose=None):
        """

        Get the Feature Names that were retained after Feature Selection (sklearn.feature_selection)

        Parameters
        ----------
        verbose : a boolean indicating whether to print summaries. default = False

        Returns
        -------
        a list of the correct feature names


        """

        if verbose is None:
            verbose = self.verbose

        assert isinstance(self.pipeline, Pipeline), "Input isn't a Pipeline"

        features = self.get_feature_names_from_col_transformer()
        
        if verbose: print('\n\n---------\nRunning get_selected_features\n---------\n')
            
        all_discarded_features = []

        for i, step_item in enumerate(self.pipeline.steps[:]):
            
            step_name, step = step_item

            if hasattr(step, 'get_support'):

                if verbose: print('\nStep ', i, ": ", step_name, ',', 
                                  step.__class__.__name__, '\n')
                    
                check_is_fitted(step)

                feature_mask = step.get_support()
                features = [feature for feature, is_retained in zip(features, feature_mask)\
                            if is_retained]
                discarded_features = [feature for feature, is_retained in zip(features, feature_mask)\
                                      if not is_retained]
                all_discarded_features.extend(discarded_features)
                
                if verbose: 
                    print(f'\t{len(features)} retained, {len(discarded_features)} discarded')
                    if len(discarded_features) > 0:
                        print('\n\tdiscarded_features:\n\n', discarded_features)

        self.discarded_features = all_discarded_features
        
        return features

    def get_feature_importance(self):
        
        """
        Creates a Pandas Series where values are the feature importance values from the model and feature names are set as the index. 
        
        This Series is stored in the `feature_importance` attribute.

        Returns
        -------
        A pandas Series containing the feature importance values and feature names as the index.
        
        """
        
        assert isinstance(self.pipeline, Pipeline), "Input isn't a Pipeline"

        features = self.get_selected_features()
             
        assert hasattr(self.pipeline[-1], 'feature_importances_'),\
            "The last element in the pipeline isn't an estimator with a feature_importances_ attribute"
        
        importance_values = self.pipeline[-1].feature_importances_
        
        assert len(features) == len(importance_values),\
            "The number of feature names & importance values doesn't match"
        
        feature_importance = pd.Series(importance_values, index=features)
        self.feature_importance = feature_importance
        
        return feature_importance
        
    
    def plot(self, top_n_features=100, rank_features=True, max_scale=True, 
             display_imp_values=True, display_imp_value_decimals=1,
             height_per_feature=25, orientation='h', width=750, height=None, 
             str_pad_width=15, yaxes_tickfont_family='Courier New', 
             yaxes_tickfont_size=15):
        """

        Plot the Feature Names & Importances 


        Parameters
        ----------

        top_n_features : the number of features to plot, default is 100
        rank_features : whether to rank the features with integers, default is True
        max_scale : Should the importance values be scaled by the maximum value & mulitplied by 100?  Default is True.
        display_imp_values : Should the importance values be displayed? Default is True.
        display_imp_value_decimals : If display_imp_values is True, how many decimal places should be displayed. Default is 1.
        height_per_feature : if height is None, the plot height is calculated by top_n_features * height_per_feature. 
        This allows all the features enough space to be displayed
        orientation : the plot orientation, 'h' (default) or 'v'
        width :  the width of the plot, default is 500
        height : the height of the plot, the default is top_n_features * height_per_feature
        str_pad_width : When rank_features=True, this number of spaces to add between the rank integer and feature name. 
            This will enable the rank integers to line up with each other for easier reading. 
            Default is 15. If you have long feature names, you can increase this number to make the integers line up more.
            It can also be set to 0.
        yaxes_tickfont_family : the font for the feature names. Default is Courier New.
        yaxes_tickfont_size : the font size for the feature names. Default is 15.

        Returns
        -------
        plot

        """
        if height is None:
            height = top_n_features * height_per_feature
            
        # prep the data
        
        all_importances = self.get_feature_importance()
        n_all_importances = len(all_importances)
        
        plot_importances_dt =\
            all_importances\
            .sort_values()\
            .to_frame('value')\
            .nlargest(top_n_features, 'value')\
            .sort_values('value', ascending=True)\
            .rename_axis('feature')\
            .reset_index()
                
        if max_scale:
            plot_importances_dt['value'] = \
                                plot_importances_dt.value.abs() /\
                                plot_importances_dt.value.abs().max() * 100
            
        self.plot_importances_dt = plot_importances_dt.copy()
        
        if len(all_importances) < top_n_features:
            title_text = 'All Feature Importances'
        else:
            title_text = f'Top {top_n_features} (of {n_all_importances}) Feature Importances'       
        
        if rank_features:
            padded_features = \
                plot_importances_dt.feature\
                .str.pad(width=str_pad_width)\
                .values
            
            ranked_features =\
                plot_importances_dt.index\
                .to_series()\
                .sort_values(ascending=False)\
                .add(1)\
                .astype(str)\
                .str.cat(padded_features, sep='. ')\
                .values

            plot_importances_dt['feature'] = ranked_features
        
        if display_imp_values:
            text = plot_importances_dt.value.round(display_imp_value_decimals)
        else:
            text = None

        # create the plot 
        
        fig = px.bar(plot_importances_dt, 
                     x='value', 
                     y='feature',
                     orientation=orientation, 
                     width=width, 
                     height=height,
                     text=text)
        fig.update_layout(title_text=title_text, title_x=0.5) 
        fig.update(layout_showlegend=False)
        fig.update_yaxes(tickfont=dict(family=yaxes_tickfont_family, 
                                       size=yaxes_tickfont_size),
                         title='')
        fig.show()