Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OcclusionAttribution and LimeAttribution #145

Merged
merged 24 commits into from
Feb 27, 2023
Merged

Add OcclusionAttribution and LimeAttribution #145

merged 24 commits into from
Feb 27, 2023

Conversation

nfelnlp
Copy link
Collaborator

@nfelnlp nfelnlp commented Oct 24, 2022

Description

Added classes for Occlusion methods and a wrapper for the Captum implementation of Zeiler & Fergus (2013).

Related Issue

Related to #107

Type of Change

  • 🚀 New feature (non-breaking change which adds functionality)

@nfelnlp nfelnlp added the enhancement New feature or request label Oct 24, 2022
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @nfelnlp, thank you for submitting a PR! We will respond as soon as possible.

inseq/attr/feat/occlusion.py Outdated Show resolved Hide resolved
inseq/attr/feat/__init__.py Outdated Show resolved Hide resolved
self,
batch: EncoderDecoderBatch,
target_ids: TargetIdsTensor,
attribute_target: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why I thought that custom attribution functions were only applicable to gradient-based methods, but in theory, there should be no problem in using them for captum.attr.Occlusion since it accepts a forward_func as input. Parameters should be updated here to reflect the changes merged from #138.

inseq/attr/feat/occlusion.py Outdated Show resolved Hide resolved
inseq/attr/feat/occlusion.py Outdated Show resolved Hide resolved
inseq/attr/feat/occlusion.py Outdated Show resolved Hide resolved
@gsarti gsarti changed the title Add OcclusionRegistry and OcclusionAttribution. [WIP] Add OcclusionRegistry and OcclusionAttribution. Oct 24, 2022
@nfelnlp nfelnlp changed the title [WIP] Add OcclusionRegistry and OcclusionAttribution. [WIP] Add PerturbationMethodRegistry and OcclusionAttribution. Dec 17, 2022
if "sliding_window_shapes" not in attribution_args:
# Sliding window shapes is defined as a tuple
# First entry is between 1 and length of input
# Second entry is given by the max length of the underlying model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit puzzled by the second entry: the max length you take here is the max generation length that the model can handle, but my understanding was that this would be the hidden_size from the model config, to ensure that there is no partial masking of token embeddings. Could you clarify?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I accidentally took the next best attribute in self.attribution_model that had 512 as the size. This was careless.
The second entry should rather be based on the embedding size, right?
Does accessing it via self.attribution_model.get_embedding_layer() make sense?


def __post_init__(self):
super().__post_init__()
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do perturbation attributions have shape [attributed_text_length, generated_text_length, hidden_size] like the ones generated by gradient methods? sum_normalize_attributions ensures to cast the 3D tensor above to a 2D tensor for visualization, but I thought that for occlusion this wouldn't be needed.

If indeed it is not the case, then we would not need a specific class for PerturbationFeatureAttribution methods and we could simply stick to the base FeatureAttributionSequenceOutput and FeatureAttributionStepOutput for the moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when I left out this aggregation step from the __post_init__, I had a 3D tensor that resulted in a shape violation here:

if attr.source_attributions is not None:
assert len(attr.source_attributions.shape) == 2
if attr.target_attributions is not None:
assert len(attr.target_attributions.shape) == 2

I assume this will apply to other perturbation methods as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you to have the return object to be of FeatureAttributionSequenceOutput and FeatureAttributionStepOutput for now.

@@ -5,7 +5,7 @@ default_stages: [commit, push]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be already in the main branch if you merged main!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this was an accident. Thanks for pointing it out!

inseq/attr/feat/perturbation_attribution.py Outdated Show resolved Hide resolved
logger = logging.getLogger(__name__)


class PerturbationMethodRegistry(FeatureAttribution, Registry):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it PerturbationAttribution to keep it consistent with GradientAttribution. We might want to bulk change them to add the Registry specification at a later time though!



class PerturbationMethodRegistry(FeatureAttribution, Registry):
"""Occlusion-based attribution methods."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to """Perturbation-based attribution method registry."""

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Jan 16, 2023

Occlusion sliding window shapes is now based on the dimension of the attribution model's embedding layer.
GradientShap worked pretty well out-of-the-box.

LimeBase does not, however, and is tricky, because the way perturb_func and token_similarity_kernel are implemented I put a 3D tensor through the surrogate model (SkLearnLinearModel) and get the following error:

Traceback (most recent call last):
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/attr/_core/lime.py", line 525, in attribute
    self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/_utils/models/linear_model/model.py", line 270, in fit
    return super().fit(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/_utils/models/linear_model/model.py", line 123, in fit
    return self.train_fn(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/_utils/models/linear_model/train.py", line 328, in sklearn_train_linear_model
    sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/sklearn/linear_model/_ridge.py", line 1126, in fit
    X, y = self._validate_data(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/sklearn/base.py", line 554, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/sklearn/utils/validation.py", line 1104, in check_X_y
    X = check_array(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/sklearn/utils/validation.py", line 913, in check_array
    raise ValueError(
ValueError: Found array with dim 3. Ridge expected <= 2.

Process finished with exit code 1

It's not clear to me yet which dimensions the original and perturbed tensor should have. Also, I wasn't sure how to apply the perturbation (mask) to the 3D tensor, especially in cases with a batch size larger than 1.

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Jan 21, 2023

Update on LIME: The LimeBase (and the accompanying linear model) does not accept attribution for more than one example at a time: pytorch/captum#905 (comment)

So it makes sense to handle 2D tensors at a time. I thought of having this loop inside attribute_step of LimeAttribution, but it feels quite dirty: I unpack the tuples and check every entry if it's a tensor and access the i-th entry, so self.method.attribute can be called with tensors without their batch dimension.

attrs = []
for b, batch in enumerate(attribute_fn_main_args['inputs'][0]):
    single_input = tuple(
        [inp[b] if type(inp) == torch.Tensor else inp
         for inp in attribute_fn_main_args['inputs']]
    )
    single_additional_forward_args = tuple(
        [arg[b] if type(arg) == torch.Tensor else arg
         for arg in attribute_fn_main_args['additional_forward_args']]
    )
    single_attribute_fn_main_args = {
        'inputs': single_input,
        'additional_forward_args': single_additional_forward_args,
    }

    single_attr = self.method.attribute(
        **single_attribute_fn_main_args,
        **attribution_args,
    )
    attrs.append(single_attr)
attr = torch.stack([single_a for single_a in attrs], dim=0)

Would you recommend following this path or should I handle unpacking the batch in other ways, @gsarti ? Maybe overriding the attribute function in LimeAttribution is more sensible? Or adding hooks? In DIG, there's also captum.attr._utils.batching._batch_attribution which sounds like it could help, but I didn't quite get yet. These are the three alternatives I came up with thus far, so I think it makes sense to ask you first before I proceed.

The error message I get after above change (putting the examples in a batch through attribute one at a time) is this one:

ValueError: not enough values to unpack (expected 2, got 1)
Traceback (most recent call last):
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/attr/_core/lime.py", line 479, in attribute
    model_out = self._evaluate_batch(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/attr/_core/lime.py", line 535, in _evaluate_batch
    model_out = _run_forward(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/_utils/common.py", line 456, in _run_forward
    output = forward_func(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfel/PycharmProjects/inseq/inseq/models/encoder_decoder.py", line 273, in forward
    output = self.get_forward_output(
  File "/home/nfel/PycharmProjects/inseq/inseq/models/encoder_decoder.py", line 248, in get_forward_output
    return self.model(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 1444, in forward
    outputs = self.model(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 1224, in forward
    encoder_outputs = self.encoder(
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 751, in forward
    embed_pos = self.embed_positions(input_shape)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 136, in forward
    bsz, seq_len = input_ids_shape[:2]
ValueError: not enough values to unpack (expected 2, got 1)

Process finished with exit code 1

The underlying model (Helsinki-NLP/opus-mt-en-de) expects a 3D tensor here, of course. The second challenge right now is to find the right place to turn the single batch entry (2D) back into a 3D tensor.

Modified token_similarity_kernel and perturb_func to work with the current instance-wise attribution
@staticmethod
def token_similarity_kernel(
    original_input: tuple,
    perturbed_input: tuple,
    perturbed_interpretable_input: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    original_input_tensor = original_input[0]
    perturbed_input_tensor = perturbed_input[0]
    assert original_input_tensor.shape == perturbed_input_tensor.shape
    similarity = torch.sum(original_input_tensor == perturbed_input_tensor)/len(original_input_tensor)
    return similarity

def perturb_func(
    self,
    original_input: tuple,  # always needs to be last argument before **kwargs due to "partial"
    **kwargs: Any,
) -> tuple:
    """
    Sampling function
    """
    original_input_tensor = original_input[0]
    mask = torch.randint(low=0, high=2, size=original_input_tensor.size()).to(self.attribution_model.device)
    perturbed_input = (
        original_input_tensor * mask + (1 - mask) * self.attribution_model.tokenizer.pad_token_id
    )
    perturbed_input_tuple = tuple({perturbed_input})
    return perturbed_input_tuple

There seem to be many ways to make this work, but I haven't found a clean and safe way yet.
Any help is much appreciated!

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Jan 25, 2023

I implemented a rough solution for reshaping the 3D tensor into a 2D one for LimeBase to handle.
However, I'm getting weird, almost uniform attribution scores and that's probably because either the similarity function is not correct or the reshape I'm doing (see below) is non-sensical. I'm tending towards the latter.

Screenshot from 2023-01-25 04-52-53

The problem was that the linear model in LimeBase expects a 2D tensor of shape (n_samples x "everything else"), so my idea was to apply .view(-1).unsqueeze(0) to the original 3D tensor (1 x input_dim x embedding_dim).

""" Modification of original attribute function:
Squeeze the batch dimension out of interpretable_inps
-> 2D tensor (n_samples ✕ (input_dim * embedding_dim))
"""
combined_interp_inps = torch.cat([i.view(-1).unsqueeze(dim=0) for i in interpretable_inps]).double()
combined_outputs = (torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)).double()
combined_sim = (
torch.cat(similarities) if len(similarities[0].shape) > 0 else torch.stack(similarities)
).double()
dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim)
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
""" Second modification:
Reshape of the learned representation
-> 3D tensor (b=1 ✕ input_dim ✕ embedding_dim)
"""
return self.interpretable_model.representation().reshape(inp_tensor.shape)

Do you have a suggestion on what might be the problem?

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Jan 26, 2023

LIME is ready for testing. I used the perturb_func implemented in Thermostat which gave me more sensible results.

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Jan 29, 2023

LIME results are still somewhat strange. I'm not sure if all values are supposed to be positive.
Implementing an alternative perturbation function might be sensible.
I found one in the Captum tutorial on image and text classification.

@gsarti
Copy link
Member

gsarti commented Feb 5, 2023

Hey @nfelnlp,

I finally had time to start reviewing the perturbation methods, here are some thoughts:

Occlusion

  • The occlusion sliding_window_shapes presently does not account for attribute_target=True, raising error: Must provide sliding window dimensions for each input tensor.. Add a check to set the value to a tuple of tuple (see Captum docs) if this condition is verified and no sliding_window_shapes value is specified.
  • I am not quite sure we are achieving the desired default behavior of blank-out (see e.g. ALTI baseline implementation) as a default for the occlusion method. What we want is equivalent to FeatureAblation with a default feature_mask separating individual tokens (i.e. [seq_len, embed_dim] with every i-th tensor [1, embed_dim] being filled with value i to indicate it should be masked as a single feature. Could you confirm that the method achieves this behavior, or adapt it to do so?
  • I am not sure that with these changes in place the final output will require aggregation as it is currently the case for PerturbationFeatureAttributionSequenceOutput, this will need to be verified.

Lime

  • Input values for original and perturbed inputs in Lime.token_similarity_kernel should be uniformed so that they are passed as two tuples
  • The Lime method will require a seed argument to allow reproducible training for the surrogate model
  • Since we use the UNK token as default baseline for IG, it would probably be sensible to also have it as the default mask token in Lime.perturb_func.
  • Ensure from_interp_rep_transform and to_interp_rep_transform can be user-defined values when calling the LimeBase.__init__ method, using Lime.__init__ params.

GradientSHAP

  • Consider renaming the method id to gradient_shap to make it more evident that we're approximating shap through gradients.

@gsarti gsarti added this to the 0.3.4 milestone Feb 6, 2023
@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Feb 20, 2023

Thank you very much for your feedback! I think we're approaching a publishable version for this branch.

In Occlusion, I managed to correct the sliding_window_shapes behavior for attribute_target=True.
For LIME, I fixed all points except for the one about the seed. I didn't find anything regarding setting seeds in Captum's implementation for linear models from sklearn. This needs a more thorough investigation where and how exactly the seed can be set.
I also renamed GradientSHAP.

The remaining points (second and third in Occlusion and second in LIME) have require more time to investigate and implement.

Let me know what you think of the proposed changes and how close we are to merging. Thanks a lot!

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Feb 21, 2023

I am not sure that with these changes in place the final output will require aggregation as it is currently the case for PerturbationFeatureAttributionSequenceOutput, this will need to be verified.

If I left the aggregation (using sum_normalize_attributions) out, I would get the following error for attributing with "lime":

  File "/home/nfel/PycharmProjects/inseq/test.py", line 19, in <module>
    mt_out.show()
  File "/home/nfel/PycharmProjects/inseq/inseq/data/attribution.py", line 506, in show
    attr.show(min_val, max_val, display, return_html, aggregator, **kwargs)
  File "/home/nfel/PycharmProjects/inseq/inseq/data/attribution.py", line 200, in show
    aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self
  File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 130, in aggregate
    return aggregator.aggregate(self, **kwargs)
  File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 161, in aggregate
    aggregated = super().aggregate(attr, aggregate_fn=aggregate_fn, **kwargs)
  File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 74, in aggregate
    cls.end_aggregation_hook(aggregated, **kwargs)
  File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 174, in end_aggregation_hook
    assert len(attr.source_attributions.shape) == 2
AssertionError

Does it mean that

  1. the end_aggregation_hook in the SequenceAttributionAggregator needs a custom implementation for perturbation-based methods
  2. that the shape of the source_attributions should be different for perturbation-based methods
  3. or that PerturbationFeatureAttributionSequenceOutput should in fact not inherit from FeatureAttributionSequenceOutput whose .show() function always has do_aggregation set to True?

@gsarti
Copy link
Member

gsarti commented Feb 21, 2023

The checks in end_aggregation_hook are there to check whether everything has the right shape for visualization, so the fact that it fails on that point simply tells you that the attribution outputs is indeed 3D, and that you will need to deal with the hidden dimension somehow. Now, whether this dimension needs to be simply squeezed out right inside the attribute_step call by picking one score per embedding (if they are all the same), or by aggregating like we do for gradient methods as a post-processing step, depends on the attribution output!

@nfelnlp
Copy link
Collaborator Author

nfelnlp commented Feb 25, 2023

(copied from the comment in the code I added with the last commit)
Captum's _evaluate_batch function for LIME does not account for multiple inputs when encoder-decoder models and attribute_target=True are used. The model output is of length two and if the inputs are either of length one (list containing a tuple) or of length two (tuple unpacked from the list), an error is raised. A workaround will be added soon.

@gsarti gsarti changed the title [WIP] Add PerturbationMethodRegistry and OcclusionAttribution. Add OcclusionAttribution and LimeAttribution Feb 27, 2023
@gsarti gsarti merged commit 8d1f602 into main Feb 27, 2023
@gsarti gsarti deleted the occlusion branch February 27, 2023 10:00
gsarti added a commit that referenced this pull request Feb 27, 2023
* origin/main:
  Add OcclusionAttribution and LimeAttribution (#145)
@gsarti gsarti added this to the v0.5 milestone Jul 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants