Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/reagent branch for ReAGent #250

Merged
merged 17 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available

- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)

- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024)

#### Step functions

Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
Expand Down Expand Up @@ -303,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi

## Research using Inseq

Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*).
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: April 2024. Please open a pull request to add your publication to the list.

<details>
<summary><b>2023</b></summary>
Expand All @@ -324,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
<ol>
<li><a href="https://arxiv.org/abs/2401.12576">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
</ol>

</details>
23 changes: 22 additions & 1 deletion docs/source/main_classes/feature_attribution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,25 @@ Perturbation-based Attribution Methods
:members:

.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
:members:
:members:

.. autoclass:: inseq.attr.feat.ReagentAttribution
:members:

.. automethod:: __init__

.. code:: python

import inseq

model = inseq.load_model(
"gpt2-medium",
"reagent",
keep_top_n=5,
stopping_condition_top_k=3,
replacing_ratio=0.3,
max_probe_steps=3000,
num_probes=8
)
out = model.attribute("Super Mario Land is a game that developed by")
out.show()
2 changes: 2 additions & 0 deletions inseq/attr/feat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LimeAttribution,
OcclusionAttribution,
PerturbationAttributionRegistry,
ReagentAttribution,
ValueZeroingAttribution,
)

Expand All @@ -43,4 +44,5 @@
"SequentialIntegratedGradientsAttribution",
"ValueZeroingAttribution",
"PerturbationAttributionRegistry",
"ReagentAttribution",
]
2 changes: 2 additions & 0 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import Reagent
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

Expand All @@ -9,5 +10,6 @@
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
"Reagent",
"SequentialIntegratedGradients",
]
134 changes: 134 additions & 0 deletions inseq/attr/feat/ops/reagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import TYPE_CHECKING, Any, Union

import torch
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from torch import Tensor
from typing_extensions import override

from ....utils.typing import InseqAttribution
from .reagent_core import (
AggregateRationalizer,
DeltaProbImportanceScoreEvaluator,
POSTagTokenSampler,
TopKStoppingConditionEvaluator,
UniformTokenReplacer,
)

if TYPE_CHECKING:
from ....models import HuggingfaceModel


class Reagent(InseqAttribution):
r"""Recursive attribution generator (ReAGent) method.

Measures importance as the drop in prediction probability produced by replacing a token with a plausible
alternative predicted by a LM.

Reference implementation:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`__

Args:
forward_func (callable): The forward function of the model or any modification of it
keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be
kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``.
keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
replaced instead of being kept.
stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target
exist in top k predictions
replacing_ratio (float): replacing ratio of tokens for probing
max_probe_steps (int): max_probe_steps
num_probes (int): number of probes in parallel

Example:
```
import inseq

model = inseq.load_model("gpt2-medium", "reagent",
keep_top_n=5,
stopping_condition_top_k=3,
replacing_ratio=0.3,
max_probe_steps=3000,
num_probes=8
)
out = model.attribute("Super Mario Land is a game that developed by")
out.show()
```
"""

def __init__(
self,
attribution_model: "HuggingfaceModel",
keep_top_n: int = 5,
keep_ratio: float = None,
invert_keep: bool = False,
stopping_condition_top_k: int = 3,
replacing_ratio: float = 0.3,
max_probe_steps: int = 3000,
num_probes: int = 16,
) -> None:
super().__init__(attribution_model)

model = attribution_model.model
tokenizer = attribution_model.tokenizer
model_name = attribution_model.model_name

sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device)
stopping_condition_evaluator = TopKStoppingConditionEvaluator(
model=model,
sampler=sampler,
top_k=stopping_condition_top_k,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
invert_keep=invert_keep,
)
importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
model=model,
tokenizer=tokenizer,
token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio),
stopping_condition_evaluator=stopping_condition_evaluator,
max_steps=max_probe_steps,
)

self.rationalizer = AggregateRationalizer(
importance_score_evaluator=importance_score_evaluator,
batch_size=num_probes,
overlap_threshold=0,
overlap_strict_pos=True,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
)

@override
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
_target: TargetType = None,
additional_forward_args: Any = None,
) -> Union[
TensorOrTupleOfTensorsGeneric,
tuple[TensorOrTupleOfTensorsGeneric, Tensor],
]:
"""Implement attribute"""
# encoder-decoder
if self.forward_func.is_encoder_decoder:
# with target-side attribution
if len(inputs) > 1:
self.rationalizer(
additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True
)
mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
return (
res[:, : additional_forward_args[0].shape[1], :],
res[:, additional_forward_args[0].shape[1] :, :],
)
# source-side only
else:
self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2])
# decoder-only
self.rationalizer(additional_forward_args[0], additional_forward_args[1])
mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
return (res,)
13 changes: 13 additions & 0 deletions inseq/attr/feat/ops/reagent_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator
from .rationalizer import AggregateRationalizer
from .stopping_condition_evaluator import TopKStoppingConditionEvaluator
from .token_replacer import UniformTokenReplacer
from .token_sampler import POSTagTokenSampler

__all__ = [
"DeltaProbImportanceScoreEvaluator",
"AggregateRationalizer",
"TopKStoppingConditionEvaluator",
"UniformTokenReplacer",
"POSTagTokenSampler",
]