Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature selection in EBMs via post-processing with the LASSO #460

Closed
brandongreenwell-8451 opened this issue Aug 4, 2023 · 21 comments
Closed

Comments

@brandongreenwell-8451
Copy link
Contributor

brandongreenwell-8451 commented Aug 4, 2023

It's possible to produce a comparably accurate, but much leaner EBM model by using the LASSO to effectively zero out less relevant terms (both main and interaction effects) and reweight the remaining contributions with simple coefficients. The more general idea is discussed in Friedman and Popescu (2003), which they refer to as "importance sampled learning ensembles" (ISLE); the idea is also briefly discussed in chapter 16 of The Elements of Statistical Learning. The basic idea is to use the LASSO to post-process a tree ensemble in the hopes of producing a much smaller model that's faster to train without sacrificing much in the way of accuracy, and in some cases, improving it.

I think the idea would work reasonably well for EBMs too. In the case of an EBM, we can apply the same idea to the individual shape functions (or term contributions). Below is an example using R (but we use reticulate to call the interpret library directly and fit an EBM); this can certainly be done in Python, but I'm less familiar with sklearn's Lasso module.

library(reticulate)  # to interact with Python's interpret library
library(glmnet)

#
# User needs to make sure reticulate is set up properly and that the interpret library is available.
# For details on reticulate, see https://rstudio.github.io/reticulate/.
#

# Load interpet module
interpret <- import("interpret")

# Read in the ALS data
url <- "https://web.stanford.edu/~hastie/CASI_files/DATA/ALS.txt"
als <- read.table(url, header = TRUE)

# Split into train/test sets
trn <- als[!als$testset, -1]  # training data w/o testset column
tst <- als[als$testset, -1]  # test data w/o testset column
X_trn <- subset(trn, select = -dFRS)
X_tst <- subset(tst, select = -dFRS)
y_trn <- trn$dFRS
y_tst <- tst$dFRS

# Fit a basic EBM regressor (by calling the Python interpret library via reticulate)
EBR <- interpret$glassbox$ExplainableBoostingRegressor
ebm = EBR(inner_bags = 25L, outer_bags = 25L)
ebm$fit(X_trn, y = y_trn)

# Mean squared error function
mse <- function(y, yhat, na.rm = FALSE) {
  mean((y - yhat) ^ 2, na.rm = na.rm)
}

# Compute test MSE from the original model
mse(y_tst, yhat = ebm$predict(X_tst))
# [1] 0.2654716

# Function to grab matrix of individual term contributions (no intercept)
predict_terms <- function(object, newdata) {
  contrib <-object$predict_and_contrib(newdata)[[2L]]  # grab second component
  colnames(contrib) <- ebm$term_names_  # add column names
  contrib  # Note: rowSum(contrib) + ebm$intercept_ = ebm$predict(newdata)
}

# Compute tmatrix of individual term contributions for train and test sets
contrib_trn <- predict_terms(ebm, newdata = X.trn)
contrib_tst <- predict_terms(ebm, newdata = X.tst)

# Fit the LASSO regularization path using the term contributions as inputs
ebm_lasso <- glmnet(
  x = contrib_trn,      # individual term contributions
  y = y_trn,            # original response variable
  lower.limits = 0,     # coefficients should be strictly positive
  standardize = FALSE,  # no need to standardize
  family = "gaussian"   # least squares regression
)

# Assess performance of fit using an independent test set
perf <- assess.glmnet(
  object = ebm_lasso,   # fitted LASSO model
  newx = contrib_tst,  # test set contributions
  newy = y_tst,        # same response variable (test set)
  family = "gaussian"  # for MSE and MAE metrics
)
perf <- do.call(cbind, args = perf)  # bind results into matrix

# Data frame of results (one row for each value of lambda)
ebm_lasso_results <- as.data.frame(cbind(
  "num_terms" = ebm_lasso$df,  # number of non-zero coefficients for each lambda
  perf,  # performance metrics (i.e., MSE and MAE for each lambda)
  "lambda" = ebm_lasso$lambda
))

# Sort in ascending order of number of trees
head(ebm_lasso_results <- 
       ebm_lasso_results[order(ebm_lasso_results$num_terms), ])
#    num_terms       mse       mae      lambda
# s0         0 0.3203724 0.4583318 0.004549271
# s1         2 0.3123229 0.4522811 0.004145126
# s2         3 0.3053020 0.4462072 0.003776885
# s3         5 0.2995632 0.4409205 0.003441357
# s6         6 0.2859880 0.4283262 0.002603260
# s4         7 0.2945709 0.4363560 0.003135636

# Print results corresponding to smallest test MSE
ebm_lasso_results[which.min(ebm_lasso_results$mse), ]
#     num_terms       mse       mae       lambda
# s27        20 0.2665173 0.4046484 0.0003690054

# Grab "best" lambda
lambda <- ebm_lasso_results[which.min(ebm_lasso_results$mse), "lambda"]

# Plot results (left: LASSO regularization path; right: )
par(mfrow = c(1, 2))
palette("Okabe-Ito")
plot(ebm_lasso, xvar = "lambda", col = adjustcolor(3, alpha.f = 0.3))
abline(v = log(lambda), lty = 2, col = 1)
plot(ebm_lasso_results[, c("num_terms", "mse")], type = "l", las = 1,
     xlab = "Number of terms", ylab = "Test MSE")
abline(h = min(ebm_lasso_results$mse), col = 1, lty = 2)
abline(h = mse(y_tst, yhat = ebm$predict(X_tst)), lty = 2, col = 2)
legend("topright", inset = 0.03, bty = "n", col = c(2, 1), lty = 2,
       legend = c("Original (379 terms)", "Post LASSO (20 terms)"), cex = 0.8)
palette("default")

image

Compare this to Figure 17.13 in Computer Age Statistical Inference (free to read online), where they applied the same idea to these data, but with a GBM instead.

Thoughts on the idea?

EDIT: adding a pure Python example

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from interpret.glassbox import ExplainableBoostingRegressor as EBR
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import Lasso, lasso_path


# Read in ALS data (analysis adapted from "Computer Age Statistical Inference")
url = 'https://hastie.su.domains/CASI_files/DATA/ALS.txt'
als = pd.read_csv(url, delim_whitespace=True)
als.head()

# Split into train/test (and drop test set indicator)
als_trn = als[als['testset'] == False].drop('testset', axis=1)
als_tst = als[als['testset'] == True].drop('testset', axis=1)

# Separate predictors from response
X_trn = als_trn.drop('dFRS', axis=1)
X_tst = als_tst.drop('dFRS', axis=1)
y_trn = als_trn['dFRS']
y_tst = als_tst['dFRS']

# Fit an EBM
ebm = EBR(inner_bags=25, outer_bags=25)
ebm.fit(X_trn, y=y_trn)
#show(ebm.explain_global())

# Compute test MSE
mean_squared_error(y_tst, ebm.predict(X_tst))

# Function to get individual term contributions from a fitted EBM
def get_term_contributions(fit, X, as_frame=True, intercept=False):
    if is_classifier(self) and 2 < len(fit.classes_):
        msg = "multiclass models are not currently supported"
        _log.error(msg)
        raise ValueError(msg)
    contrib = fit.predict_and_contrib(X)[1]
    if as_frame:
        contrib = pd.DataFrame(contrib, columns=fit.term_names_)
        if intercept:
            contrib['intercept'] = fit.intercept_
    return contrib

# Quick sanity check (term contributions + intercept should sum to final prediction on inv link scale)
a = get_term_contributions(ebm, X_tst, intercept=True).sum(axis=1).values
b = ebm.predict(X_tst)
a.mean()
b.mean()

# Get term contributions for train and test sets; used for inputs into LASSO model
X_trn_tc = get_term_contributions(ebm, X=X_trn, intercept=False)
X_tst_tc = get_term_contributions(ebm, X=X_tst, intercept=False)

# Fit LASSO path and grab alpha values
ebm_lasso_path = lasso_path(X_trn_tc, y=y_trn, positive=True)
alphas = ebm_lasso_path[0]

# Fit a LASSO for each alpha; compute test MSE and number of non-zero coefs
mses = []
nterms = []
for alpha in alphas:
    _lasso = Lasso(alpha=alpha, positive=True)
    _lasso.fit(X_trn_tc, y=y_trn)
    _mse = mean_squared_error(y_tst, y_pred=_lasso.predict(X_tst_tc))
    mses.append(_mse)
    nterms.append(np.count_nonzero(_lasso.coef_))

# Grab alpha corresponding to smallest test error
alpha_best = alphas[np.argmin(mses)]

# Plot results
sns.scatterplot(x=nterms, y=mses)
plt.show()

# Fit a LASSO using the "best" alpha 
ebm_lasso = Lasso(alpha=alpha_best, positive=True)
ebm_lasso.fit(X_trn_tc, y=y_trn)
np.count_nonzero(ebm_lasso.coef_)  # number non-zero coefficients
# 20

# Compare intercepts
ebm.intercept_, ebm_lasso.intercept_
# (-0.6800701947503642, -0.6802357522044189)

# Compare fits
mean_squared_error(y_tst, y_pred=ebm.predict(X_tst))
# 0.2654716402696781
mean_squared_error(y_tst, y_pred=ebm_lasso.predict(X_tst_tc))
# 0.2665129843046598

Figure_1

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 4, 2023

FWIW, the EBM's intercept term should probably be used as an offset in the linear model (which we do in boosting with the initial fit value), but I did not see much of a difference in this example.

@paulbkoch
Copy link
Collaborator

Hi @brandongreenwell-8451 -- What a delightful idea on how to reduce model complexity, and it's also nice that it has been successful in other domains which I see mentioned in your book. One of the areas we get criticized for in papers that compare EBMs to other glassbox models is the lack of sparsity in EBMs, which is somewhat amusing to me since it's fairly easy to drop terms in an EBM. I suspect using LASSO as you suggest would prove to be better than many alternatives for handling this.

Is this something you'd be interested in contributing to interpret? It's nice that it can be a stand-alone post process function.

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 7, 2023

Thanks @paulbkoch (and for noticing it in the book). Happy to spend some time on it. Any idea which component(s) to focus on modifying? In essence, I think the term contributions (.term_scores_?) just need to be modified (i.e., by multiplying them by a fixed constant given by the corresponding coefficients).

Seems like a simple standalone function/method similar to .monotone() would work. Though, I wonder how this would/should impact the output from .explain_global() and .explain_local()? Using the coefficients would seem sensible, but I'd have to think hard about it.

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 7, 2023

Took a stab at a simple implementation:

def reweight(fit, weights, new_intercept=None):
        """Reweight the individual term contributions from a fitted EBM. This is 
        useful for introducing sparsity by post-processing the model via the 
        LASSO...

        Args:
            fit: A fitted EBM model (does not currently support multiclass).
            weights: List of weights (one weight for each term in the model).
            new_intercept: Optional new/updated intercept.

        Returns: The modified fit.

        TODO: Add references, more details, etc.

        Note: Adapted from .monotonize() method.

        check_is_fitted(fit, "has_fitted_")

        if is_classifier(fit) and 2 < len(fit.classes_):
            msg = "multiclass models are not currently supported"
            _log.error(msg)
            raise ValueError(msg)

        if len(weights) != len(fit.term_names_):
            msg = "need to supply one weight for each term"
            _log.error(msg)
            raise ValueError(msg)

        # copy any fields we overwrite in case someone has a shallow copy of self
        term_scores = fit.term_scores_.copy()

        for idw, w in enumerate(weights):
            scores = term_scores[idw].copy()
            # the missing and unknown bins are not part of the continuous range
            y = scores[1:-1]
            y *= w  # multiplying by scalar works for np arrays
            scores[1:-1] = y
            term_scores[idw] = scores

        fit.term_scores_ = term_scores

        # Update intercept
        if new_intercept is not None:
            fit.intercept_ = float(new_intercept)

        return fit


from copy import deepcopy


# Create a (deep) copy of the original model for modification
ebm2 = deepcopy(ebm)

# Reweight terms using LASSO coefficients obtained in first example
reweight(ebm2, weights=ebm_lasso.coef_, new_intercept=ebm_lasso.intercept_)  # from previous example

# Original EBM
mean_squared_error(y_tst, y_pred=ebm.predict(X_tst))
# 0.2654716402696781

# Modified EBM
mean_squared_error(y_tst, y_pred=ebm2.predict(X_tst))
# 0.2665129843046598

# LASSO fit corresponding to modified EBM 
mean_squared_error(y_tst, y_pred=ebm_lasso.predict(X_tst_tc))  # same as above
# 0.2665129843046598

# Shape functions for features zeroed out by the LASSO should be flat line at zero
show(ebm2.explain_global())

# TODO: Would it be useful to completely remove these terms from the model 
#   altogether? Seems like a lot of attributes from the original fit would have 
#   to be changed.

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 8, 2023

Looks like this automatically zeroes out the variable importance (VI) scores for the removed terms and correctly adjusts the others by multiplying each by the corresponding LASSO coefficient; this makes sense since the VI scores still satisfy the same definition of being computed as the mean absolute contribution (and the contributions are multiplied by the corresponding LASSO coefficients in the above function). Nice!

@paulbkoch
Copy link
Collaborator

Seems to work well 👍

In terms of which attributes to scale, I think these 3:

self.bagged_scores_ = bagged_scores
self.term_scores_ = term_scores
self.standard_deviations_ = standard_deviations

Regarding the TODO note, yes, removing terms would significantly reduce the time it takes to generate predictions, and would also simplify the model overall for human understanding. Removing terms is fairly simple. If you want to remove term 7 for example, you'd just need to remove the 7th item in each of these lists:

# per-term
self.term_features_ = term_features
self.bin_weights_ = bin_weights
self.bagged_scores_ = bagged_scores
self.term_scores_ = term_scores
self.standard_deviations_ = standard_deviations

@bgreenwell
Copy link

Thanks again @paulbkoch, I’ll modify and test and then work on a PR with some examples. I assume this would work better as a standalone function, as opposed to a method on the object itself?

@paulbkoch
Copy link
Collaborator

Let's put it right next to the monotonize function inside the object.

@bgreenwell
Copy link

@paulbkoch PR in, it just got flagged by some DCO thing. Let me know if you get a chance to take a look. Tested well on my side with removal of term attributes.

@brandongreenwell-8451
Copy link
Contributor Author

Updated version of the method:

def reweight_terms(self, weights, new_intercept=None):
        """Reweight the individual term contributions. For example, you can 
        remove specific terms by setting their corresponding weights to zero. 
        This is useful for introducing sparsity by post-processing the model via 
        the LASSO. See the EBM documentation for examples and further details.

        Args:
            weights: List of weights (one weight for each term in the model). 
                This should be a list or numpy vector (i.e., 1-d array) of 
                floats whose i-th element corresponds to the i-th element of the
                `.term_*_` attributes (e.g., `.term_names_`).
            new_intercept: Optional new/updated intercept.

        Returns: 
            Itself.

        """
        check_is_fitted(self, "has_fitted_")

        if is_classifier(self) and 2 < len(self.classes_):
            msg = "multiclass models are not currently supported"
            _log.error(msg)
            raise ValueError(msg)

        if len(weights) != len(self.term_names_):
            msg = "need to supply one weight for each term"
            _log.error(msg)
            raise ValueError(msg)

        # Copy any fields we'll overwrite in case someone has a shallow copy of self
        term_features = self.term_features_.copy()
        term_names = self.term_names_.copy()
        term_scores = self.term_scores_.copy()
        bagged_scores = self.bagged_scores_.copy()
        bin_weights = self.bin_weights_.copy()
        standard_deviations = self.standard_deviations_.copy()

        for idw, w in enumerate(weights):

            scores = term_scores[idw].copy()
            bscores = bagged_scores[idw].copy()
            stdevs = standard_deviations[idw].copy()
            # the missing and unknown bins are not part of the continuous range
            y = scores[1:-1]
            y_bagged = bscores[1:-1]
            y_sd = stdevs[1:-1]
            y *= w  # multiplying by scalar works for np arrays
            y_bagged *= w
            y_sd *= w
            scores[1:-1] = y
            bscores[1:-1] = y_bagged
            stdevs[1:-1] = y_sd
            term_scores[idw] = scores
            bagged_scores[idw] = bscores
            standard_deviations[idw] = stdevs

        # Delete components that have a weight of zero
        is_zero = np.where(weights == 0)[0].tolist()
        if len(is_zero) > 0:
            def remove_indices(x, idx):  # FIXME: More robust way?
                # Remove elements of a list based on provided index
                return [i for j, i in enumerate(x) if j not in idx]
            term_features = remove_indices(term_features, idx=is_zero)
            term_names = remove_indices(term_names, idx=is_zero)
            term_scores = remove_indices(term_scores, idx=is_zero)
            bagged_scores = remove_indices(bagged_scores, idx=is_zero)
            standard_deviations = remove_indices(standard_deviations, idx=is_zero)
            bin_weights = remove_indices(bin_weights, idx=is_zero)

        # Update components of self
        self.term_features_ = term_features
        self.term_names_ = term_names
        self.term_scores_ = term_scores
        self.bagged_scores_ = bagged_scores
        self.standard_deviations_ = standard_deviations
        self.bin_weights_ = bin_weights

        # Update intercept
        if new_intercept is not None:
            # FIXME: Doesn't seem to like <class 'numpy.float64'>?
            self.intercept_ = float(new_intercept)

        return self

@brandongreenwell-8451
Copy link
Contributor Author

Merged in #469 (comment).

@Harsha-Nori
Copy link
Collaborator

Just wanted to chime in and say that this is a really cool feature @brandongreenwell-8451 :). Thanks for the hard work on getting it into the library! I'm definitely going to play around with it.

By the way, curious if you have thoughts on interpretations of statistical significance when LASSO terms get shrunk to zero. I remember reading ages ago though that frequentist interpretations of statistics/p-values from LASSO methods is tough/unreliable. Just curious if you think there's any way to tie a relationship between the two, or if you have abstract ideas on other statistical testing approaches for determining term significance in a GAM.

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 15, 2023

Hi @Harsha-Nori, I hope you (and others) find this approach useful!

You pose an interesting question. I don't think significance is necessarily meaningful for the zeroed-out terms (because they were not selected to be in the final model). However, I do think it's feasible to do inference on coefficients for the remaining terms using post-selection inference techniques. There's an R package, for example, called selectiveInference, that will let you compute p-values and confidence intervals for LASSO model parameters after finding the optimal lambda; in particular, see the docs for the fixedLassoInf() function, which I think could be used in this application in a pretty straightforward manner.

So, in theory, I think it's possible to at least do inference on the EBM term weights if one were to take this approach! But there could be an issue on interpreting, because we're talking about inference on weights for the coefficient of the shape functions, if that makes sense?

Some references:

Jason Lee, Dennis Sun, Yuekai Sun, and Jonathan Taylor (2013). Exact post-selection inference,
with application to the lasso. arXiv:1311.6238.

Jonathan Taylor and Robert Tibshirani (2016) Post-selection inference for L1-penalized likelihood
models. arXiv:1602.07358

@paulbkoch
Copy link
Collaborator

Hi @brandongreenwell-8451 -- I added a new function called "sweep" that may be of interest when using LASSO to eliminate terms. The function signature is: sweep(sweep_terms=True, sweep_bins=True, sweep_features=False)

It has two purposes:

  1. To optionally eliminate any features that are no longer needed once terms are removed. This is a bit of a heavyweight change since it will change the format of the numpy array or pandas dataframe that gets passed into the predict function, so by default this functionality is turned off.
  2. The self.bins_ attribute that contains information for binning the data in X is normally simplified to contain only the information required by the EBM. Typically this means that if a feature is not used in a pair, we do not need, and do not store the lower resolution pair bins. If the user has eliminated terms, they might have eliminated a pair term which would allow us to eliminate the pair bin definitions for some features. Leaving that information in the model isn't detrimental, but eliminating it allows for a simpler and more compact EBM.

The function is located at:

def sweep(self, sweep_terms=True, sweep_bins=True, sweep_features=False):

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 24, 2023

Thanks @paulbkoch, I'll be sure to check it out! #1 makes perfect sense. So in regard to the original application, calling .sweep() after rescaling and removing terms, this just does some further cleanup for unnecessary components?

Any rough idea for the next release?

@paulbkoch
Copy link
Collaborator

Yes, that's correct.

Regarding a release, I'm happy to put one out now if that would be helpful.

On the binning, I probably didn't explain the bin sweeping part very well, so here's an example:

Let's say we built an EBM with 2 features. We'd probably get 3 tersm from this: "A", "B", "A & B". If you looked at the ebm.bins_ attribute you'd probably see something like this:

ebm.bins_ = [[[1.5, 2.5, 3.5], [2.5]], [[7.5, 8.5, 9.5], [8.5]]]

The [1.5, 2.5, 3.5] is the higher resolution binning for feature 1 used for the mains. The [2.5] is the lower resolution binning for feature 1 used in pairs. The [7.5, 8.5, 9.5] is the higher resolution binning used for binning feature 2 for mains, and [8.5] is the lower resolution binning used for binning feature 2 in a pair.

Let's say now that we set the term "A & B" to zero, and then subsequently sweep it away so we're left with only a mains model with terms "A" and "B". We no longer need the lower resolution pair binning anymore. If you call sweep(sweep_bins=True), then you would find:

ebm.bins_ = [[[1.5, 2.5, 3.5]], [[7.5, 8.5, 9.5]]]

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 24, 2023

Gotcha, that makes sense!

And yes, a new release would be nice (although I know you just released an update a few weeks back that brought some performance improvement)! We've had trouble installing from source locally on Macs, so I've been itching to try the new features from within the package, but no rush really on my end.

@brandongreenwell-8451
Copy link
Contributor Author

brandongreenwell-8451 commented Aug 24, 2023

On a related note to some of this, is that I've been thinking about LASSO extensions that do something similar. For instance, if we want to follow the usual hierarchical principle, then we might want to make sure that the LASSO (or similar) only keeps pairwise interactions if the associated main effects are in the model (so if "A" is zero, then "A & B" should also be set to zero); similar with categorical inputs if you, say, use one-hot-encoding. This might possible in current software, but I haven't looked too much.

@paulbkoch
Copy link
Collaborator

Haha, I've got you covered ;-)

def remove_features(self, features, remove_dependent_terms=False):

I wrote this new "remove_features" function to remove individual features. It's called by the sweep function, but can be called independently. If you call it with remove_dependent_terms=True, then it will remove the term "A & B" if you remove either feature "A" or feature "B" from the EBM. If this works for you as it, then great, otherwise it would be a good template for doing similar EBM manipulations yourself.

Your PR got me excited about model editing, so I wrote a few new ones myself. I think with your 2 functions and the other new ones that I wrote we're in a good place to put out a release focused on model editing. If there's not a strong rush then I might put in one more function (model saving) and then make a release.

@brandongreenwell-8451
Copy link
Contributor Author

Awesome, I look forward to the next release!

@paulbkoch
Copy link
Collaborator

Hi @brandongreenwell-8451 -- The release for v0.4.4 is out. Enjoy. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

4 participants