diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py index 97c56f597a..d1335089b0 100644 --- a/captum/_utils/models/linear_model/model.py +++ b/captum/_utils/models/linear_model/model.py @@ -279,6 +279,9 @@ class SkLearnRidge(SkLearnLinearModel): def __init__(self, **kwargs): r""" Factory class. Trains a model with `sklearn.linear_model.Ridge`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. """ super().__init__(**kwargs, sklearn_module="linear_model.Ridge") @@ -290,6 +293,9 @@ class SkLearnLinearRegression(SkLearnLinearModel): def __init__(self, **kwargs): r""" Factory class. Trains a model with `sklearn.linear_model.LinearRegression`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. """ super().__init__(**kwargs, sklearn_module="linear_model.LinearRegression") diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index c6a996ac42..9602f892b1 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -6,31 +6,12 @@ import torch from torch import Tensor +from captum._utils.models.linear_model import SkLearnLinearRegression from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._core.lime import Lime from captum.log import log_usage -def linear_regression_interpretable_model_trainer( - interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs -): - try: - from sklearn import linear_model - except ImportError: - raise AssertionError( - "Requires sklearn for default interpretable model training with linear " - "regression. Please install sklearn or use a custom interpretable model " - "training function." - ) - clf = linear_model.LinearRegression() - clf.fit( - interpretable_inputs.cpu().numpy(), - expected_outputs.cpu().numpy(), - weights.cpu().numpy(), - ) - return torch.from_numpy(clf.coef_) - - def combination(n: int, k: int) -> int: try: # Combination only available in Python 3.8 @@ -86,7 +67,7 @@ def __init__(self, forward_func: Callable) -> None: Lime.__init__( self, forward_func, - linear_regression_interpretable_model_trainer, + SkLearnLinearRegression(), kernel_shap_similarity_kernel, ) diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 120d287502..3fc4965c0a 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -7,6 +7,7 @@ import torch from torch import Tensor from torch.nn import CosineSimilarity +from torch.utils.data import DataLoader, TensorDataset from captum._utils.common import ( _expand_additional_forward_args, @@ -18,6 +19,8 @@ _reduce_list, _run_forward, ) +from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.models.model import Model from captum._utils.typing import ( BaselineType, Literal, @@ -66,7 +69,7 @@ class LimeBase(PerturbationAttribution): def __init__( self, forward_func: Callable, - train_interpretable_model_func: Callable, + interpretable_model: Model, similarity_func: Callable, perturb_func: Callable, perturb_interpretable_space: bool, @@ -82,22 +85,25 @@ def __init__( modification of it. If a batch is provided as input for attribution, it is expected that forward_func returns a scalar representing the entire batch. - train_interpretable_model_func (callable): Function which trains - an interpretable model and returns some representation of the - interpretable model. The return type of this will match the - returned type when calling attribute. - The expected signature of this callable is: - - train_interpretable_model_func( + interpretable_model (Model): Model object to train interpretable model. + A Model object provides a `fit` method to train the model, + given a dataloader, with batches containing three tensors: interpretable_inputs: Tensor [2D num_samples x num_interp_features], expected_outputs: Tensor [1D num_samples], weights: Tensor [1D num_samples] - **kwargs: Any - ) -> Any (Representation of interpretable model) - All kwargs passed to the attribute method are - provided as keyword arguments (kwargs) to this callable. + The model object must also provide a `representation` method to + access the appropriate coefficients or representation of the + interpretable model after fitting. + Some predefined interpretable linear models are provided in + captum._utils.models.linear_model including wrappers around + SkLearn linear models as well as SGD-based PyTorch linear + models. + + Note that calling fit multiple times should retrain the + interpretable model, each attribution call reuses + the same given interpretable model object. similarity_func (callable): Function which takes a single sample along with its corresponding interpretable representation and returns the weight of the interpretable sample for @@ -204,7 +210,7 @@ def __init__( provided as keyword arguments (kwargs) to this callable. """ PerturbationAttribution.__init__(self, forward_func) - self.train_interpretable_model_func = train_interpretable_model_func + self.interpretable_model = interpretable_model self.similarity_func = similarity_func self.perturb_func = perturb_func self.perturb_interpretable_space = perturb_interpretable_space @@ -230,7 +236,7 @@ def attribute( n_perturb_samples: int = 50, perturbations_per_eval: int = 1, **kwargs - ) -> TensorOrTupleOfTensorsGeneric: + ) -> Tensor: r""" This method attributes the output of the model with given target index (in case it is provided, otherwise it assumes that output is a @@ -342,20 +348,12 @@ def attribute( >>> # score of the target class. >>> >>> # For interpretable model training, we will use sklearn - >>> # in this example - >>> from sklearn import linear_model - >>> - >>> # Define interpretable model training function - >>> def linear_regression_interpretable_model_trainer( - >>> interpretable_inputs: Tensor, - >>> expected_outputs: Tensor, - >>> weights: Tensor, **kwargs): - >>> clf = linear_model.LinearRegression() - >>> clf.fit( - >>> interpretable_inputs.cpu().numpy(), - >>> expected_outputs.cpu().numpy(), - >>> weights.cpu().numpy()) - >>> return clf.coef_ + >>> # linear model in this example. We have provided wrappers + >>> # around sklearn linear models to fit the Model interface. + >>> # Any arguments provided to the sklearn constructor can also + >>> # be provided to the wrapper, e.g.: + >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0) + >>> from captum._utils.models.linear_model import SkLearnLinearModel >>> >>> >>> # Define similarity kernel (exponential kernel based on L2 norm) @@ -387,7 +385,7 @@ def attribute( >>> input = torch.randn(2, 5) >>> # Defining LimeBase interpreter >>> lime_attr = LimeBase(net, - linear_regression_interpretable_model_trainer, + SkLearnLinearModel("linear_model.Ridge"), similarity_func=similarity_kernel, perturb_func=perturb_func, perturb_interpretable_space=False, @@ -477,10 +475,13 @@ def attribute( if len(similarities[0].shape) > 0 else torch.stack(similarities) ) - interp_model = self.train_interpretable_model_func( - combined_interp_inps, combined_outputs, combined_sim, **kwargs + dataset = TensorDataset( + combined_interp_inps, combined_outputs, combined_sim ) - return interp_model + self.interpretable_model.fit( + DataLoader(dataset, batch_size=n_perturb_samples) + ) + return self.interpretable_model.representation() def _evaluate_batch( self, @@ -516,32 +517,6 @@ def multiplies_by_inputs(self): # for Lime child implementation. -def lasso_interpretable_model_trainer( - interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs -): - try: - import sklearn - from sklearn import linear_model - - assert ( - sklearn.__version__ >= "0.23.0" - ), "Must have sklearn version 0.23.0 or higher to use " - "sample_weight in Lasso regression." - except ImportError: - raise AssertionError( - "Requires sklearn for default interpretable model training with" - " Lasso regression. Please install sklearn or use a custom interpretable" - " model training function." - ) - clf = linear_model.Lasso(alpha=kwargs["alpha"] if "alpha" in kwargs else 1.0) - clf.fit( - interpretable_inputs.cpu().numpy(), - expected_outputs.cpu().numpy(), - weights.cpu().numpy(), - ) - return torch.from_numpy(clf.coef_) - - def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): assert ( "feature_mask" in kwargs @@ -665,7 +640,7 @@ class Lime(LimeBase): def __init__( self, forward_func: Callable, - train_interpretable_model_func: Callable = lasso_interpretable_model_trainer, + train_interpretable_model_func: Model = SkLearnLasso(alpha=1.0), similarity_func: Callable = get_exp_kernel_similarity_function(), perturb_func: Callable = default_perturb_func, ) -> None: @@ -676,31 +651,31 @@ def __init__( forward_func (callable): The forward function of the model or any modification of it - train_interpretable_model_func (optional, callable): Function which - trains an interpretable model and returns coefficients - of the interpretable model. - This function is optional, and the default function trains - an interpretable model using Lasso regression, using the - alpha parameter provided when calling attribute. - Using the default function requires having sklearn version - 0.23.0 or higher installed. - - If a custom function is provided, the expected signature of this - callable is: - - train_interpretable_model_func( + interpretable_model (optional, Model): Model object to train + interpretable model. + + This argument is optional and defaults to SkLearnLasso(alpha=1.0), + which is a wrapper around the Lasso linear model in SkLearn. + This requires having sklearn version >= 0.23 available. + + Other predefined interpretable linear models are provided in + captum._utils.models.linear_model. + + Alternatively, a custom model object must provide a `fit` method to + train the model, given a dataloader, with batches containing + three tensors: interpretable_inputs: Tensor [2D num_samples x num_interp_features], expected_outputs: Tensor [1D num_samples], weights: Tensor [1D num_samples] - **kwargs: Any - ) -> Tensor [1D num_interp_features] - The return type must be a 1D tensor containing the importance - or attribution of each input feature. - kwargs includes baselines, feature_mask, num_interp_features - (integer, determined from feature mask), and - alpha (for Lasso regression). + The model object must also provide a `representation` method to + access the appropriate coefficients or representation of the + interpretable model after fitting. + + Note that calling fit multiple times should retrain the + interpretable model, each attribution call reuses + the same given interpretable model object. similarity_func (callable): Function which takes a single sample along with its corresponding interpretable representation and returns the weight of the interpretable sample for @@ -764,7 +739,6 @@ def attribute( # type: ignore feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_perturb_samples: int = 25, perturbations_per_eval: int = 1, - alpha: float = 1.0, return_input_shape: bool = True, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -1058,7 +1032,6 @@ def attribute( # type: ignore if is_inputs_tuple else curr_feature_mask[0], num_interp_features=num_interp_features, - alpha=alpha, ) if return_input_shape: output_list.append( @@ -1095,7 +1068,6 @@ def attribute( # type: ignore baselines=baselines if is_inputs_tuple else baselines[0], feature_mask=feature_mask if is_inputs_tuple else feature_mask[0], num_interp_features=num_interp_features, - alpha=alpha, ) if return_input_shape: return self._convert_output_shape( @@ -1138,6 +1110,7 @@ def _convert_output_shape( num_interp_features: int, is_inputs_tuple: bool, ) -> Union[Tensor, Tuple[Tensor, ...]]: + coefs = coefs.flatten() attr = [torch.zeros_like(single_inp) for single_inp in formatted_inp] for tensor_ind in range(len(formatted_inp)): for single_feature in range(num_interp_features): diff --git a/sphinx/source/utilities.rst b/sphinx/source/utilities.rst index 289a129aad..f4e3d7ace6 100644 --- a/sphinx/source/utilities.rst +++ b/sphinx/source/utilities.rst @@ -24,3 +24,24 @@ Token Reference Base .. autoclass:: captum.attr.TokenReferenceBase :members: + + +Linear Models +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: captum._utils.models.model.Model + :members: +.. autoclass:: captum._utils.models.linear_model.SkLearnLinearModel + :members: +.. autoclass:: captum._utils.models.linear_model.SkLearnLinearRegression + :members: +.. autoclass:: captum._utils.models.linear_model.SkLearnLasso + :members: +.. autoclass:: captum._utils.models.linear_model.SkLearnRidge + :members: +.. autoclass:: captum._utils.models.linear_model.SGDLinearModel + :members: +.. autoclass:: captum._utils.models.linear_model.SGDLasso + :members: +.. autoclass:: captum._utils.models.linear_model.SGDRidge + :members: diff --git a/tests/attr/test_kernel_shap.py b/tests/attr/test_kernel_shap.py index c7c779a88d..48b8acb926 100644 --- a/tests/attr/test_kernel_shap.py +++ b/tests/attr/test_kernel_shap.py @@ -23,10 +23,12 @@ def setUp(self) -> None: super().setUp() try: import sklearn # noqa: F401 - except ImportError: - raise unittest.SkipTest( - "Skipping Kernel Shap tests, sklearn not available." - ) + + assert ( + sklearn.__version__ >= "0.23.0" + ), "Must have sklearn version 0.23.0 or higher" + except (ImportError, AssertionError): + raise unittest.SkipTest("Skipping KernelShap tests, sklearn not available.") def test_linear_kernel_shap(self) -> None: net = BasicModel_MultiLayer() diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index a44f8b18be..ae76769b86 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -6,13 +6,9 @@ import torch from torch import Tensor +from captum._utils.models.linear_model import SkLearnLasso from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric -from captum.attr._core.lime import ( - Lime, - LimeBase, - get_exp_kernel_similarity_function, - lasso_interpretable_model_trainer, -) +from captum.attr._core.lime import Lime, LimeBase, get_exp_kernel_similarity_function from captum.attr._utils.batching import _batch_example_iterator from captum.attr._utils.common import ( _construct_default_feature_mask, @@ -202,6 +198,7 @@ def test_multi_input_lime_with_mask(self) -> None: expected, additional_input=(1,), feature_mask=(mask1, mask2, mask3), + n_perturb_samples=500, expected_coefs_only=[251.0, 591.0, 0.0], ) expected_with_baseline = ( @@ -390,7 +387,6 @@ def _lime_test_assert( baselines: BaselineType = None, target: Union[None, int] = 0, n_perturb_samples: int = 100, - alpha: float = 1.0, delta: float = 1.0, batch_attr: bool = False, ) -> None: @@ -407,7 +403,6 @@ def _lime_test_assert( baselines=baselines, perturbations_per_eval=batch_size, n_perturb_samples=n_perturb_samples, - alpha=alpha, ) assertTensorTuplesAlmostEqual( self, attributions, expected_attr, delta=delta, mode="max" @@ -422,7 +417,6 @@ def _lime_test_assert( baselines=baselines, perturbations_per_eval=batch_size, n_perturb_samples=n_perturb_samples, - alpha=alpha, return_input_shape=False, ) assertTensorAlmostEqual( @@ -431,7 +425,7 @@ def _lime_test_assert( lime_alt = LimeBase( model, - lasso_interpretable_model_trainer, + SkLearnLasso(alpha=1.0), get_exp_kernel_similarity_function("euclidean", 1000.0), alt_perturb_func, False, @@ -465,7 +459,6 @@ def _lime_test_assert( baselines=baselines, perturbations_per_eval=batch_size, n_perturb_samples=n_perturb_samples, - alpha=alpha, num_interp_features=num_interp_features, ) assertTensorAlmostEqual( @@ -500,7 +493,6 @@ def _lime_test_assert( baselines=curr_baselines, perturbations_per_eval=batch_size, n_perturb_samples=n_perturb_samples, - alpha=alpha, num_interp_features=num_interp_features, ) assertTensorAlmostEqual(