Skip to content

Commit

Permalink
Value Zeroing attribution method (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 28, 2024
1 parent f4d82b8 commit d09e827
Show file tree
Hide file tree
Showing 30 changed files with 1,252 additions and 174 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).

## 馃敡 Fixes & Refactoring

Expand All @@ -26,4 +28,5 @@

## 馃挜 Breaking Changes

*No changes*
- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution).
- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173))
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available

- `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016)

- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)

#### Step functions

Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
Expand Down
19 changes: 17 additions & 2 deletions docs/source/main_classes/feature_attribution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Attribution Methods
.. autoclass:: inseq.attr.FeatureAttribution
:members:

Gradient Attribution Methods
Gradient-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------

.. autoclass:: inseq.attr.feat.GradientAttributionRegistry
Expand Down Expand Up @@ -67,7 +67,7 @@ Layer Attribution Methods
:members:


Attention Attribution Methods
Internals-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------

.. autoclass:: inseq.attr.feat.InternalsAttributionRegistry
Expand All @@ -76,3 +76,18 @@ Attention Attribution Methods

.. autoclass:: inseq.attr.feat.AttentionWeightsAttribution
:members:

Perturbation-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------

.. autoclass:: inseq.attr.feat.PerturbationAttributionRegistry
:members:

.. autoclass:: inseq.attr.feat.OcclusionAttribution
:members:

.. autoclass:: inseq.attr.feat.LimeAttribution
:members:

.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
:members:
4 changes: 4 additions & 0 deletions inseq/attr/feat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .perturbation_attribution import (
LimeAttribution,
OcclusionAttribution,
PerturbationAttributionRegistry,
ValueZeroingAttribution,
)

__all__ = [
Expand All @@ -39,4 +41,6 @@
"OcclusionAttribution",
"LimeAttribution",
"SequentialIntegratedGradientsAttribution",
"ValueZeroingAttribution",
"PerturbationAttributionRegistry",
]
8 changes: 6 additions & 2 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,15 @@ def extract_args(
def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
if has_sequence_scores:
return (attr[0], attr[1], attr[2])
else:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
else:
return (None, attr[0])
return (None, None, attr[0]) if has_sequence_scores else (None, attr[0])
else:
return (attr, None) if is_encoder_decoder else (None, attr)
49 changes: 43 additions & 6 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
self.use_hidden_states: bool = False
self.use_predicted_target: bool = True
self.use_model_config: bool = False
self.is_final_step_method: bool = False
if hook_to_model:
self.hook(**kwargs)

Expand Down Expand Up @@ -272,6 +273,35 @@ def _run_compatibility_checks(self, attributed_fn) -> None:
" method."
)

@staticmethod
def _build_multistep_output_from_single_step(
single_step_output: FeatureAttributionStepOutput,
attr_pos_start: int,
attr_pos_end: int,
) -> list[FeatureAttributionStepOutput]:
if single_step_output.step_scores:
raise ValueError("step_scores are not supported for final step attribution methods.")
num_seq = len(single_step_output.prefix)
steps = []
for pos_idx in range(attr_pos_start, attr_pos_end):
step_output = single_step_output.clone_empty()
step_output.source = single_step_output.source
step_output.prefix = [single_step_output.prefix[seq_idx][:pos_idx] for seq_idx in range(num_seq)]
step_output.target = (
single_step_output.target
if pos_idx == attr_pos_end - 1
else [[single_step_output.prefix[seq_idx][pos_idx]] for seq_idx in range(num_seq)]
)
if single_step_output.source_attributions is not None:
step_output.source_attributions = single_step_output.source_attributions[:, :, pos_idx - 1]
if single_step_output.target_attributions is not None:
step_output.target_attributions = single_step_output.target_attributions[:, :pos_idx, pos_idx - 1]
single_step_output.step_scores = {}
if single_step_output.sequence_scores is not None:
step_output.sequence_scores = single_step_output.sequence_scores
steps.append(step_output)
return steps

def format_contrastive_targets(
self,
target_sequences: TextSequences,
Expand Down Expand Up @@ -416,9 +446,9 @@ def attribute(
target_lengths=targets_lengths,
method_name=self.method_name,
show=show_progress,
pretty=pretty_progress,
pretty=False if self.is_final_step_method else pretty_progress,
attr_pos_start=attr_pos_start,
attr_pos_end=attr_pos_end,
attr_pos_end=1 if self.is_final_step_method else attr_pos_end,
)
whitespace_indexes = find_char_indexes(sequences.targets, " ")
attribution_outputs = []
Expand All @@ -427,6 +457,8 @@ def attribute(

# Attribution loop for generation
for step in range(attr_pos_start, iter_pos_end):
if self.is_final_step_method and step != iter_pos_end - 1:
continue
tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True)
step_output = self.filtered_attribute_step(
batch[:step],
Expand All @@ -450,7 +482,7 @@ def attribute(
contrast_targets_alignments=contrast_targets_alignments,
)
attribution_outputs.append(step_output)
if pretty_progress:
if pretty_progress and not self.is_final_step_method:
tgt_tokens = batch.target_tokens
skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start)
attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1)
Expand All @@ -471,12 +503,17 @@ def attribute(
end = datetime.now()
close_progress_bar(pbar, show=show_progress, pretty=pretty_progress)
batch.detach().to("cpu")
if self.is_final_step_method:
attribution_outputs = self._build_multistep_output_from_single_step(
attribution_outputs[0],
attr_pos_start=attr_pos_start,
attr_pos_end=iter_pos_end,
)
out = FeatureAttributionOutput(
sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions(
attributions=attribution_outputs,
tokenized_target_sentences=target_tokens_with_ids,
pad_id=self.attribution_model.pad_token,
has_bos_token=self.attribution_model.is_encoder_decoder,
pad_token=self.attribution_model.pad_token,
attr_pos_end=attr_pos_end,
),
step_attributions=attribution_outputs if output_step_attributions else None,
Expand Down Expand Up @@ -593,7 +630,7 @@ def filtered_attribute_step(
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
step_output.remap_from_filtered(target_attention_mask, orig_batch)
step_output.remap_from_filtered(target_attention_mask, orig_batch, self.is_final_step_method)
step_output = step_output.detach().to("cpu")
return step_output

Expand Down
22 changes: 15 additions & 7 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import logging
from typing import Any, Optional

import torch
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import Attribution

from ...data import MultiDimensionalFeatureAttributionStepOutput
from ...utils import Registry
from ...utils.typing import MultiLayerMultiUnitScoreTensor
from ...utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor
from .feature_attribution import FeatureAttribution

logger = logging.getLogger(__name__)
Expand All @@ -38,7 +38,7 @@ class AttentionWeightsAttribution(InternalsAttributionRegistry):

method_name = "attention"

class AttentionWeights(Attribution):
class AttentionWeights(InseqAttribution):
@staticmethod
def has_convergence_delta() -> bool:
return False
Expand Down Expand Up @@ -74,9 +74,14 @@ def attribute(
:class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: A step output containing attention
weights for each layer and head, with shape :obj:`(batch_size, seq_len, n_layers, n_heads)`.
"""
# We adopt the format [batch_size, sequence_length, num_layers, num_heads]
# We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads]
# for consistency with other multi-unit methods (e.g. gradient attribution)
decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
decoder_self_attentions = torch.where(
decoder_self_attentions == 0,
(torch.ones_like(decoder_self_attentions) * float("nan")),
decoder_self_attentions,
)
if self.forward_func.is_encoder_decoder:
sequence_scores = {}
if len(inputs) > 1:
Expand All @@ -85,10 +90,11 @@ def attribute(
target_attributions = None
sequence_scores["decoder_self_attentions"] = decoder_self_attentions
sequence_scores["encoder_self_attentions"] = (
encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2)
encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
)
cross_attentions = cross_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
source_attributions=cross_attentions,
target_attributions=target_attributions,
sequence_scores=sequence_scores,
_num_dimensions=2, # num_layers, num_heads
Expand All @@ -106,6 +112,8 @@ def __init__(self, attribution_model, **kwargs):
self.use_attention_weights = True
# Does not rely on predicted output (i.e. decoding strategy agnostic)
self.use_predicted_target = False
# Needs only the final generation step to extract scores
self.is_final_step_method = True
self.method = self.AttentionWeights(attribution_model)

def attribute_step(
Expand Down
2 changes: 2 additions & 0 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

__all__ = [
"DiscretetizedIntegratedGradients",
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
"SequentialIntegratedGradients",
]
Loading

0 comments on commit d09e827

Please sign in to comment.