Skip to content

Commit

Permalink
Add model config (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 14, 2023
1 parent f8e55f8 commit b6b8b13
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 234 deletions.
2 changes: 2 additions & 0 deletions docs/source/main_classes/main_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ functionalities required for its usage.

.. autofunction:: register_step_function

.. autofunction:: register_model_config

.. autofunction:: list_feature_attribution_methods

.. autofunction:: list_aggregators
Expand Down
3 changes: 2 additions & 1 deletion inseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
merge_attributions,
show_attributions,
)
from .models import AttributionModel, list_supported_frameworks, load_model
from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config
from .utils.id_utils import explain


Expand All @@ -34,5 +34,6 @@ def get_version() -> str:
"list_step_functions",
"list_supported_frameworks",
"register_step_function",
"register_model_config",
"merge_attributions",
]
42 changes: 38 additions & 4 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from ...utils.typing import ModelIdentifier, SingleScorePerStepTensor
from ..attribution_decorators import batched, set_hook, unset_hook
from ..step_functions import get_step_scores, get_step_scores_args
from ..step_functions import get_step_function, get_step_scores, get_step_scores_args
from .attribution_utils import (
check_attribute_positions,
get_source_target_attributions,
Expand Down Expand Up @@ -96,6 +96,13 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
use_baselines (:obj:`bool`, default `False`): Whether a baseline should be used for the attribution method.
use_attention_weights (:obj:`bool`, default `False`): Whether attention weights are used in the attribution
method.
use_hidden_states (:obj:`bool`, default `False`): Whether hidden states are used in the attribution method.
use_predicted_target (:obj:`bool`, default `True`): Whether the attribution method uses the predicted
target for attribution. In case it doesn't, a warning message will be shown if the target is not
the default one.
use_model_config (:obj:`bool`, default `False`): Whether the attribution method uses the model config. If
True, the method will try to load the config matching the model when hooking to the model. Missing
configurations can be registered using :meth:`~inseq.models.register_model_config`.
"""
super().__init__()
self.attribution_model = attribution_model
Expand All @@ -104,6 +111,9 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
self.target_layer = None
self.use_baselines: bool = False
self.use_attention_weights: bool = False
self.use_hidden_states: bool = False
self.use_predicted_target: bool = True
self.use_model_config: bool = False
if hook_to_model:
self.hook(**kwargs)

Expand Down Expand Up @@ -247,6 +257,21 @@ def prepare_and_attribute(
attribution_output.info["step_scores_args"] = step_scores_args
return attribution_output

def _run_compatibility_checks(self, attributed_fn) -> None:
default_attributed_fn = get_step_function(self.attribution_model.default_attributed_fn_id)
if not self.use_predicted_target and attributed_fn != default_attributed_fn:
logger.warning(
"Internals attribution methods are output agnostic, since they do not rely on specific output"
" targets to compute importance scores. Using a custom attributed function in this context does not"
" influence in any way the method's results."
)
if self.use_model_config and self.attribution_model.is_distributed:
raise RuntimeError(
"Distributed models are incompatible with attribution methods requiring access to models' internals "
"for storing or intervention purposes. Please use a non-distributed model with the current attribution"
" method."
)

def attribute(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
Expand Down Expand Up @@ -298,6 +323,7 @@ def attribute(
raise ValueError(
"Layer attribution methods do not support attribute_target=True. Use regular attributions instead."
)
self._run_compatibility_checks(attributed_fn)
attr_pos_start, attr_pos_end = check_attribute_positions(
batch.max_generation_length,
attr_pos_start,
Expand Down Expand Up @@ -503,16 +529,20 @@ def filtered_attribute_step(
forward_batch_embeds=self.forward_batch_embeds,
use_baselines=self.use_baselines,
)
if len(step_scores) > 0 or self.use_attention_weights:
if len(step_scores) > 0 or self.use_attention_weights or self.use_hidden_states:
with torch.no_grad():
output = self.attribution_model.get_forward_output(
batch,
use_embeddings=self.forward_batch_embeds,
output_attentions=self.use_attention_weights,
output_hidden_states=self.use_hidden_states,
)
if self.use_attention_weights:
attentions_dict = self.attribution_model.get_attentions_dict(output)
attribution_args = {**attribution_args, **attentions_dict}
if self.use_hidden_states:
hidden_states_dict = self.attribution_model.get_hidden_states_dict(output)
attribution_args = {**attribution_args, **hidden_states_dict}
# Perform attribution step
step_output = self.attribute_step(
attribute_main_args,
Expand Down Expand Up @@ -572,14 +602,18 @@ def hook(self, **kwargs) -> None:
r"""Hooks the attribution method to the model. Useful to implement pre-attribution logic
(e.g. freezing layers, replacing embeddings, raise warnings, etc.).
"""
pass
from ...models.model_config import get_model_config

if self.use_model_config and self.attribution_model is not None:
self.attribution_model.config = get_model_config(self.attribution_model.info["model_class"])

@unset_hook
def unhook(self, **kwargs) -> None:
r"""Unhooks the attribution method from the model. If the model was modified in any way, this
should restore its initial state.
"""
pass
if self.use_model_config and self.attribution_model is not None:
self.attribution_model.config = None


def list_feature_attribution_methods():
Expand Down
220 changes: 0 additions & 220 deletions inseq/attr/feat/ops/value_zeroing.py

This file was deleted.

15 changes: 10 additions & 5 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,26 +452,31 @@ def check_is_step_function(identifier: str) -> None:
)


def get_step_function(score_identifier: str) -> StepFunction:
"""Returns the step function corresponding to the provided identifier."""
check_is_step_function(score_identifier)
return STEP_SCORES_MAP[score_identifier]


def get_step_scores(
score_identifier: str,
step_fn_args: StepFunctionArgs,
step_fn_extra_args: Dict[str, Any] = {},
) -> SingleScorePerStepTensor:
"""Returns step scores for the target tokens in the batch."""
check_is_step_function(score_identifier)
return STEP_SCORES_MAP[score_identifier](step_fn_args, **step_fn_extra_args)
return get_step_function(score_identifier)(step_fn_args, **step_fn_extra_args)


def get_step_scores_args(
score_identifiers: List[str], kwargs: Dict[str, Any], default_args: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
step_scores_args = {}
for step_score in score_identifiers:
check_is_step_function(step_score)
for step_fn_id in score_identifiers:
step_fn = get_step_function(step_fn_id)
step_scores_args.update(
**extract_signature_args(
kwargs,
STEP_SCORES_MAP[step_score],
step_fn,
exclude_args=default_args,
return_remaining=False,
)
Expand Down
Loading

0 comments on commit b6b8b13

Please sign in to comment.