Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions captum/_utils/models/linear_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ class SkLearnRidge(SkLearnLinearModel):
def __init__(self, **kwargs):
r"""
Factory class. Trains a model with `sklearn.linear_model.Ridge`.

Any arguments provided to the sklearn constructor can be provided
as kwargs here.
"""
super().__init__(**kwargs, sklearn_module="linear_model.Ridge")

Expand All @@ -290,6 +293,9 @@ class SkLearnLinearRegression(SkLearnLinearModel):
def __init__(self, **kwargs):
r"""
Factory class. Trains a model with `sklearn.linear_model.LinearRegression`.

Any arguments provided to the sklearn constructor can be provided
as kwargs here.
"""
super().__init__(**kwargs, sklearn_module="linear_model.LinearRegression")

Expand Down
23 changes: 2 additions & 21 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,12 @@
import torch
from torch import Tensor

from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import Lime
from captum.log import log_usage


def linear_regression_interpretable_model_trainer(
interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs
):
try:
from sklearn import linear_model
except ImportError:
raise AssertionError(
"Requires sklearn for default interpretable model training with linear "
"regression. Please install sklearn or use a custom interpretable model "
"training function."
)
clf = linear_model.LinearRegression()
clf.fit(
interpretable_inputs.cpu().numpy(),
expected_outputs.cpu().numpy(),
weights.cpu().numpy(),
)
return torch.from_numpy(clf.coef_)


def combination(n: int, k: int) -> int:
try:
# Combination only available in Python 3.8
Expand Down Expand Up @@ -86,7 +67,7 @@ def __init__(self, forward_func: Callable) -> None:
Lime.__init__(
self,
forward_func,
linear_regression_interpretable_model_trainer,
SkLearnLinearRegression(),
kernel_shap_similarity_kernel,
)

Expand Down
137 changes: 55 additions & 82 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import Tensor
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset

from captum._utils.common import (
_expand_additional_forward_args,
Expand All @@ -18,6 +19,8 @@
_reduce_list,
_run_forward,
)
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.typing import (
BaselineType,
Literal,
Expand Down Expand Up @@ -66,7 +69,7 @@ class LimeBase(PerturbationAttribution):
def __init__(
self,
forward_func: Callable,
train_interpretable_model_func: Callable,
interpretable_model: Model,
similarity_func: Callable,
perturb_func: Callable,
perturb_interpretable_space: bool,
Expand All @@ -82,22 +85,25 @@ def __init__(
modification of it. If a batch is provided as input for
attribution, it is expected that forward_func returns a scalar
representing the entire batch.
train_interpretable_model_func (callable): Function which trains
an interpretable model and returns some representation of the
interpretable model. The return type of this will match the
returned type when calling attribute.
The expected signature of this callable is:

train_interpretable_model_func(
interpretable_model (Model): Model object to train interpretable model.
A Model object provides a `fit` method to train the model,
given a dataloader, with batches containing three tensors:
interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
expected_outputs: Tensor [1D num_samples],
weights: Tensor [1D num_samples]
**kwargs: Any
) -> Any (Representation of interpretable model)

All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.
Some predefined interpretable linear models are provided in
captum._utils.models.linear_model including wrappers around
SkLearn linear models as well as SGD-based PyTorch linear
models.

Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
Expand Down Expand Up @@ -204,7 +210,7 @@ def __init__(
provided as keyword arguments (kwargs) to this callable.
"""
PerturbationAttribution.__init__(self, forward_func)
self.train_interpretable_model_func = train_interpretable_model_func
self.interpretable_model = interpretable_model
self.similarity_func = similarity_func
self.perturb_func = perturb_func
self.perturb_interpretable_space = perturb_interpretable_space
Expand All @@ -230,7 +236,7 @@ def attribute(
n_perturb_samples: int = 50,
perturbations_per_eval: int = 1,
**kwargs
) -> TensorOrTupleOfTensorsGeneric:
) -> Tensor:
r"""
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
Expand Down Expand Up @@ -342,20 +348,12 @@ def attribute(
>>> # score of the target class.
>>>
>>> # For interpretable model training, we will use sklearn
>>> # in this example
>>> from sklearn import linear_model
>>>
>>> # Define interpretable model training function
>>> def linear_regression_interpretable_model_trainer(
>>> interpretable_inputs: Tensor,
>>> expected_outputs: Tensor,
>>> weights: Tensor, **kwargs):
>>> clf = linear_model.LinearRegression()
>>> clf.fit(
>>> interpretable_inputs.cpu().numpy(),
>>> expected_outputs.cpu().numpy(),
>>> weights.cpu().numpy())
>>> return clf.coef_
>>> # linear model in this example. We have provided wrappers
>>> # around sklearn linear models to fit the Model interface.
>>> # Any arguments provided to the sklearn constructor can also
>>> # be provided to the wrapper, e.g.:
>>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
>>> from captum._utils.models.linear_model import SkLearnLinearModel
>>>
>>>
>>> # Define similarity kernel (exponential kernel based on L2 norm)
Expand Down Expand Up @@ -387,7 +385,7 @@ def attribute(
>>> input = torch.randn(2, 5)
>>> # Defining LimeBase interpreter
>>> lime_attr = LimeBase(net,
linear_regression_interpretable_model_trainer,
SkLearnLinearModel("linear_model.Ridge"),
similarity_func=similarity_kernel,
perturb_func=perturb_func,
perturb_interpretable_space=False,
Expand Down Expand Up @@ -477,10 +475,13 @@ def attribute(
if len(similarities[0].shape) > 0
else torch.stack(similarities)
)
interp_model = self.train_interpretable_model_func(
combined_interp_inps, combined_outputs, combined_sim, **kwargs
dataset = TensorDataset(
combined_interp_inps, combined_outputs, combined_sim
)
return interp_model
self.interpretable_model.fit(
DataLoader(dataset, batch_size=n_perturb_samples)
)
return self.interpretable_model.representation()

def _evaluate_batch(
self,
Expand Down Expand Up @@ -516,32 +517,6 @@ def multiplies_by_inputs(self):
# for Lime child implementation.


def lasso_interpretable_model_trainer(
interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs
):
try:
import sklearn
from sklearn import linear_model

assert (
sklearn.__version__ >= "0.23.0"
), "Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
except ImportError:
raise AssertionError(
"Requires sklearn for default interpretable model training with"
" Lasso regression. Please install sklearn or use a custom interpretable"
" model training function."
)
clf = linear_model.Lasso(alpha=kwargs["alpha"] if "alpha" in kwargs else 1.0)
clf.fit(
interpretable_inputs.cpu().numpy(),
expected_outputs.cpu().numpy(),
weights.cpu().numpy(),
)
return torch.from_numpy(clf.coef_)


def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
assert (
"feature_mask" in kwargs
Expand Down Expand Up @@ -665,7 +640,7 @@ class Lime(LimeBase):
def __init__(
self,
forward_func: Callable,
train_interpretable_model_func: Callable = lasso_interpretable_model_trainer,
train_interpretable_model_func: Model = SkLearnLasso(alpha=1.0),
similarity_func: Callable = get_exp_kernel_similarity_function(),
perturb_func: Callable = default_perturb_func,
) -> None:
Expand All @@ -676,31 +651,31 @@ def __init__(

forward_func (callable): The forward function of the model or any
modification of it
train_interpretable_model_func (optional, callable): Function which
trains an interpretable model and returns coefficients
of the interpretable model.
This function is optional, and the default function trains
an interpretable model using Lasso regression, using the
alpha parameter provided when calling attribute.
Using the default function requires having sklearn version
0.23.0 or higher installed.

If a custom function is provided, the expected signature of this
callable is:

train_interpretable_model_func(
interpretable_model (optional, Model): Model object to train
interpretable model.

This argument is optional and defaults to SkLearnLasso(alpha=1.0),
which is a wrapper around the Lasso linear model in SkLearn.
This requires having sklearn version >= 0.23 available.

Other predefined interpretable linear models are provided in
captum._utils.models.linear_model.

Alternatively, a custom model object must provide a `fit` method to
train the model, given a dataloader, with batches containing
three tensors:
interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
expected_outputs: Tensor [1D num_samples],
weights: Tensor [1D num_samples]
**kwargs: Any
) -> Tensor [1D num_interp_features]
The return type must be a 1D tensor containing the importance
or attribution of each input feature.

kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask), and
alpha (for Lasso regression).
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.

Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
Expand Down Expand Up @@ -764,7 +739,6 @@ def attribute( # type: ignore
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_perturb_samples: int = 25,
perturbations_per_eval: int = 1,
alpha: float = 1.0,
return_input_shape: bool = True,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Expand Down Expand Up @@ -1058,7 +1032,6 @@ def attribute( # type: ignore
if is_inputs_tuple
else curr_feature_mask[0],
num_interp_features=num_interp_features,
alpha=alpha,
)
if return_input_shape:
output_list.append(
Expand Down Expand Up @@ -1095,7 +1068,6 @@ def attribute( # type: ignore
baselines=baselines if is_inputs_tuple else baselines[0],
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
num_interp_features=num_interp_features,
alpha=alpha,
)
if return_input_shape:
return self._convert_output_shape(
Expand Down Expand Up @@ -1138,6 +1110,7 @@ def _convert_output_shape(
num_interp_features: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]:
coefs = coefs.flatten()
attr = [torch.zeros_like(single_inp) for single_inp in formatted_inp]
for tensor_ind in range(len(formatted_inp)):
for single_feature in range(num_interp_features):
Expand Down
21 changes: 21 additions & 0 deletions sphinx/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,24 @@ Token Reference Base

.. autoclass:: captum.attr.TokenReferenceBase
:members:


Linear Models
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: captum._utils.models.model.Model
:members:
.. autoclass:: captum._utils.models.linear_model.SkLearnLinearModel
:members:
.. autoclass:: captum._utils.models.linear_model.SkLearnLinearRegression
:members:
.. autoclass:: captum._utils.models.linear_model.SkLearnLasso
:members:
.. autoclass:: captum._utils.models.linear_model.SkLearnRidge
:members:
.. autoclass:: captum._utils.models.linear_model.SGDLinearModel
:members:
.. autoclass:: captum._utils.models.linear_model.SGDLasso
:members:
.. autoclass:: captum._utils.models.linear_model.SGDRidge
:members:
10 changes: 6 additions & 4 deletions tests/attr/test_kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def setUp(self) -> None:
super().setUp()
try:
import sklearn # noqa: F401
except ImportError:
raise unittest.SkipTest(
"Skipping Kernel Shap tests, sklearn not available."
)

assert (
sklearn.__version__ >= "0.23.0"
), "Must have sklearn version 0.23.0 or higher"
except (ImportError, AssertionError):
raise unittest.SkipTest("Skipping KernelShap tests, sklearn not available.")

def test_linear_kernel_shap(self) -> None:
net = BasicModel_MultiLayer()
Expand Down
Loading