Skip to content

Commit

Permalink
Fix LIME and Occlusion outputs (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Oct 6, 2023
1 parent 8520469 commit 72febbb
Show file tree
Hide file tree
Showing 6 changed files with 943 additions and 790 deletions.
13 changes: 9 additions & 4 deletions inseq/attr/feat/perturbation_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from captum.attr import Occlusion

from ...data import (
CoarseFeatureAttributionSequenceOutput,
CoarseFeatureAttributionStepOutput,
GranularFeatureAttributionStepOutput,
)
from ...utils import Registry
Expand Down Expand Up @@ -46,7 +46,7 @@ def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
) -> CoarseFeatureAttributionSequenceOutput:
) -> CoarseFeatureAttributionStepOutput:
r"""Sliding window shapes is defined as a tuple.
First entry is between 1 and length of input.
Second entry is given by the embedding dimension of the underlying model.
Expand Down Expand Up @@ -76,7 +76,7 @@ def attribute_step(
if target_attributions is not None:
target_attributions = target_attributions[:, :, 0].abs()

return CoarseFeatureAttributionSequenceOutput(
return CoarseFeatureAttributionStepOutput(
source_attributions=source_attributions,
target_attributions=target_attributions,
)
Expand Down Expand Up @@ -111,4 +111,9 @@ def attribute_step(
raise NotImplementedError(
"LIME attribution with attribute_target=True currently not supported for encoder-decoder models."
)
super().attribute_step(attribute_fn_main_args, attribution_args)
out = super().attribute_step(attribute_fn_main_args, attribution_args)
return GranularFeatureAttributionStepOutput(
source_attributions=out.source_attributions,
target_attributions=out.target_attributions,
sequence_scores=out.sequence_scores,
)
3 changes: 2 additions & 1 deletion inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def from_step_attributions(
"""
attr = attributions[0]
num_sequences = len(attr.prefix)
if not all([len(attr.prefix) == num_sequences for attr in attributions]):
if not all(len(attr.prefix) == num_sequences for attr in attributions):
raise ValueError("All the attributions must include the same number of sequences.")
seq_attributions = []
sources = None
Expand Down Expand Up @@ -716,6 +716,7 @@ class CoarseFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput):

def __post_init__(self):
super().__post_init__()
self._aggregator = []


@dataclass(eq=False, repr=False)
Expand Down
Loading

0 comments on commit 72febbb

Please sign in to comment.