Skip to content

Commit

Permalink
Updated OcclusionAttribution with new default params and better class…
Browse files Browse the repository at this point in the history
… structure.
  • Loading branch information
nfelnlp committed Dec 17, 2022
1 parent cc49afc commit 000aac5
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_stages: [commit, push]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down
2 changes: 1 addition & 1 deletion inseq/attr/feat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LayerIntegratedGradientsAttribution,
SaliencyAttribution,
)
from .occlusion import OcclusionAttribution
from .perturbation_attribution import OcclusionAttribution


__all__ = [
Expand Down
81 changes: 0 additions & 81 deletions inseq/attr/feat/occlusion.py

This file was deleted.

74 changes: 74 additions & 0 deletions inseq/attr/feat/perturbation_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, Dict

import logging

from captum.attr import Occlusion

from ...data import PerturbationFeatureAttributionStepOutput
from ...utils import Registry
from ..attribution_decorators import set_hook, unset_hook
from .attribution_utils import get_source_target_attributions
from .gradient_attribution import FeatureAttribution


logger = logging.getLogger(__name__)


class PerturbationMethodRegistry(FeatureAttribution, Registry):
"""Occlusion-based attribution methods."""

@set_hook
def hook(self, **kwargs):
pass

@unset_hook
def unhook(self, **kwargs):
pass


class OcclusionAttribution(PerturbationMethodRegistry):
"""Occlusion-based attribution method.
Reference implementation:
`https://captum.ai/api/occlusion.html <https://captum.ai/api/occlusion.html>`__.
Usages in other implementations:
`niuzaisheng/AttExplainer <https://github.com/niuzaisheng/AttExplainer/blob/main/baseline_methods/\
explain_baseline_captum.py>`__
`andrewPoulton/explainable-asag <https://github.com/andrewPoulton/explainable-asag/blob/main/explanation.py>`__
`copenlu/xai-benchmark <https://github.com/copenlu/xai-benchmark/blob/master/saliency_gen/\
interpret_grads_occ.py>`__
`DFKI-NLP/thermostat <https://github.com/DFKI-NLP/thermostat/blob/main/src/thermostat/explainers/occlusion.py>`__
"""

method_name = "occlusion"

def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.is_layer_attribution = False
self.method = Occlusion(self.attribution_model)

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
) -> Any:

if "sliding_window_shapes" not in attribution_args:
# Sliding window shapes is defined as a tuple
# First entry is between 1 and length of input
# Second entry is given by the max length of the underlying model
# If not explicitly given via attribution_args, the default is (1, model_max_length)
attribution_args["sliding_window_shapes"] = (1, self.attribution_model.model_max_length)

attr = self.method.attribute(
**attribute_fn_main_args,
**attribution_args,
)

source_attributions, target_attributions = get_source_target_attributions(
attr, self.attribution_model.is_encoder_decoder
)
return PerturbationFeatureAttributionStepOutput(
source_attributions=source_attributions,
target_attributions=target_attributions,
)
2 changes: 2 additions & 0 deletions inseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FeatureAttributionSequenceOutput,
FeatureAttributionStepOutput,
GradientFeatureAttributionStepOutput,
PerturbationFeatureAttributionStepOutput,
)
from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch
from .viz import show_attributions
Expand All @@ -32,6 +33,7 @@
"FeatureAttributionInput",
"FeatureAttributionStepOutput",
"GradientFeatureAttributionStepOutput",
"PerturbationFeatureAttributionStepOutput",
"FeatureAttributionSequenceOutput",
"FeatureAttributionOutput",
"ModelIdentifier",
Expand Down
20 changes: 20 additions & 0 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,23 @@ class GradientFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""

_sequence_cls: Type["FeatureAttributionSequenceOutput"] = GradientFeatureAttributionSequenceOutput


# Perturbation attribution classes


@dataclass(eq=False, repr=False)
class PerturbationFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput):
"""Raw output of a single sequence of perturbation feature attribution."""

def __post_init__(self):
super().__post_init__()
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions
self._dict_aggregate_fn["target_attributions"]["sequence_aggregate"] = sum_normalize_attributions


@dataclass(eq=False, repr=False)
class PerturbationFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of perturbation feature attribution."""

_sequence_cls: Type["FeatureAttributionSequenceOutput"] = PerturbationFeatureAttributionSequenceOutput

0 comments on commit 000aac5

Please sign in to comment.