Skip to content

Commit

Permalink
Add top_p_size step fn, StepFunctionArgs class (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jul 26, 2023
1 parent ea9d982 commit 5ad7a7d
Show file tree
Hide file tree
Showing 14 changed files with 338 additions and 277 deletions.
48 changes: 24 additions & 24 deletions docs/source/examples/custom_attribute_target.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,28 @@ with Contrastive Explanations" <https://arxiv.org/abs/2202.10419>`__ by Yin and
by complementing the output probabilities with the ones from their contrastive counterpart, and using the difference between the two as attribution
target.

We can define such attribution function using the standard template adopted by Inseq.
We can define such attribution function using the standard template adopted by Inseq. The :class:`~inseq.attr.step_functions.StepFunctionArgs` class is used for convenience to encapsulate all default arguments passed to step functions, namely:

- :obj:`attribution_model`: the attribution model used to compute attributions.

- :obj:`forward_output`: the output of the forward pass of the attribution model.

- :obj:`target_ids`: the ids corresponding to the next predicted tokens for the current generation step.

- :obj:`ids`, :obj:`embeddings` and :obj:`attention mask` corresponding to the model input at the present step, including inputs for the encoder in case of encoder-decoder models.

.. code-block:: python
from inseq.attr.step_functions import probability_fn
from inseq.attr.step_functions import probability_fn, StepFunctionArgs
# Simplified implementation of inseq.attr.step_functions.contrast_prob_diff_fn
# Works only for encoder-decoder models!
def example_prob_diff_fn(
# Default arguments in attribution_model.forward
attribution_model,
forward_output,
encoder_input_embeds,
encoder_attention_mask,
decoder_input_ids,
decoder_attention_mask,
target_ids,
# Default arguments for all step functions
args: StepFunctionArgs,
# Extra arguments for our use case
contrast_ids,
contrast_attention_mask,
# We use kwargs to collect unused default arguments
**kwargs,
):
"""Custom attribution function returning the difference between next step probability for
candidate generation vs. a contrastive alternative, answering the question "Which features
Expand All @@ -73,22 +73,24 @@ We can define such attribution function using the standard template adopted by I
contrast_attention_mask: Tensor containing the attention mask for the contrastive input
"""
# We truncate contrastive ids and their attention map to the current generation step
contrast_decoder_input_ids = contrast_ids[:, : decoder_input_ids.shape[1]].to(attribution_model.device)
contrast_decoder_attention_mask = contrast_attention_mask[:, : decoder_attention_mask.shape[1]].to(
attribution_model.device
)
device = args.attribution_model.device
len_inputs = args.decoder_input_ids.shape[1]
contrast_decoder_input_ids = contrast_ids[:, : len_inputs].to(device)
contrast_decoder_attention_mask = contrast_attention_mask[:, : len_inputs].to(device)
# We select the next contrastive token as target
contrast_target_ids = contrast_ids[:, decoder_input_ids.shape[1]].to(attribution_model.device)
contrast_target_ids = contrast_ids[:, len_inputs].to(device)
# Forward pass with the same model used for the main generation, but using contrastive inputs instead
contrast_output = attribution_model.model(
inputs_embeds=encoder_input_embeds,
attention_mask=encoder_attention_mask,
contrast_output = args.attribution_model.model(
inputs_embeds=args.encoder_input_embeds,
attention_mask=args.encoder_attention_mask,
decoder_input_ids=contrast_decoder_input_ids,
decoder_attention_mask=contrast_decoder_attention_mask,
)
# Return the prob difference as target for attribution
model_probs = probability_fn(attribution_model, forward_output, target_ids)
contrast_probs = probability_fn(attribution_model, contrast_output, contrast_target_ids)
model_probs = probability_fn(args)
args.forward_output = contrast_output
args.target_ids = contrast_target_ids
contrast_probs = probability_fn(args)
return model_probs - contrast_probs
Besides common arguments such as the attribution model, its outputs after the forward pass and all the input ids
Expand All @@ -101,8 +103,6 @@ Now that we have our custom attribution function, integrating it in Inseq is ver
.. code-block:: python
import inseq
from inseq.data.aggregator import AggregatorPipeline
# Register the function defined above
# Since outputs are still probabilities, contiguous tokens can still be aggregated using product
Expand Down
19 changes: 6 additions & 13 deletions docs/source/examples/tuned_lens.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ The first step is to install the tuned lens library using ``pip install tuned-le
def confidence_from_prediction_depth(
# Default arguments for Inseq step functions
attribution_model,
decoder_input_ids,
decoder_attention_mask,
target_ids,
args,
# Extra arguments for our use case
lens: Lens,
# We use kwargs to collect unused default arguments
Expand All @@ -95,11 +92,7 @@ The first step is to install the tuned lens library using ``pip install tuned-le
14 is the number of layers in the model, plus the embedding layer, plus 1 to account for the case
where the token is not predicted by the model.
"""
batch = attribution_model.formatter.convert_args_to_batch(
decoder_input_ids=decoder_input_ids,
decoder_input_embeds=None,
decoder_attention_mask=decoder_attention_mask,
)
batch = attribution_model.formatter.convert_args_to_batch(args)
# Record activations at every model layer
with record_residual_stream(attribution_model.model) as stream:
outputs = attribution_model.get_forward_output(batch, use_embeddings=False)
Expand All @@ -118,11 +111,11 @@ The first step is to install the tuned lens library using ``pip install tuned-le
probs = hidden_lps.map(lambda x: x.exp() * 100)
probs = torch.stack(list(probs))
top_idx_per_layer = probs.abs().topk(1, dim=-1).indices.squeeze(-1).reshape(-1, num_layers)
if target_ids.ndim == 0:
target_ids = target_ids.unsqueeze(0)
if args.target_ids.ndim == 0:
args.target_ids = args.target_ids.unsqueeze(0)
# Set to max denominator to return 0 only if the target token is not predicted by the model
indices = torch.ones_like(target_ids) * (num_layers + 1)
for i, t in enumerate(target_ids):
indices = torch.ones_like(args.target_ids) * (num_layers + 1)
for i, t in enumerate(args.target_ids):
pos = torch.where(top_idx_per_layer[i, :] == t.int())[0]
if pos.numel() > 0:
indices[i] = pos[0] + 1
Expand Down
16 changes: 14 additions & 2 deletions docs/source/main_classes/step_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@
Step Functions
=======================================================================================================================

The following functions can be used as attribution targets or step functions in the :meth:`inseq.models.AttributionModel.attribute` function call.

.. currentmodule:: inseq.attr.step_functions

Step Functions Default arguments
-----------------------------------------------------------------------------------------------------------------------

The default arguments passed to all step functions are collected in the :class:`StepFunctionArgs` class.

.. autoclass:: StepFunctionArgs

Pre-registered Step Functions
-----------------------------------------------------------------------------------------------------------------------

The following functions can be used out-of-the-box as attribution targets or step functions in the :meth:`inseq.models.AttributionModel.attribute` function call simply by passing their string identifier (function name minus the ``_fn`` suffix).

.. autofunction:: logit_fn

.. autofunction:: probability_fn
Expand All @@ -36,3 +46,5 @@ The following functions can be used as attribution targets or step functions in
.. autofunction:: contrast_prob_diff_fn

.. autofunction:: mc_dropout_prob_avg_fn

.. autofunction:: top_p_size_fn
4 changes: 2 additions & 2 deletions inseq/attr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .feat import FeatureAttribution, extract_args, list_feature_attribution_methods
from .step_functions import (
STEP_SCORES_MAP,
get_step_function_reserved_args,
StepFunctionArgs,
list_step_functions,
register_step_function,
)
Expand All @@ -13,5 +13,5 @@
"register_step_function",
"STEP_SCORES_MAP",
"extract_args",
"get_step_function_reserved_args",
"StepFunctionArgs",
]
11 changes: 5 additions & 6 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from ...utils.typing import ModelIdentifier, SingleScorePerStepTensor
from ..attribution_decorators import batched, set_hook, unset_hook
from ..step_functions import get_step_scores
from ..step_functions import get_step_scores, get_step_scores_args
from .attribution_utils import (
check_attribute_positions,
get_source_target_attributions,
Expand Down Expand Up @@ -518,16 +518,15 @@ def filtered_attribute_step(
attribution_args,
)
# Format step scores arguments and calculate step scores
if len(step_scores) > 0:
step_scores_args = self.attribution_model.formatter.format_step_function_args(
for step_score in step_scores:
step_fn_args = self.attribution_model.formatter.format_step_function_args(
attribution_model=self.attribution_model,
forward_output=output,
target_ids=target_ids,
batch=batch,
**step_scores_args,
)
for step_score in step_scores:
step_output.step_scores[step_score] = get_step_scores(step_score, step_scores_args)
step_fn_extra_args = get_step_scores_args([step_score], step_scores_args)
step_output.step_scores[step_score] = get_step_scores(step_score, step_fn_args, step_fn_extra_args)
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
step_output.remap_from_filtered(target_attention_mask, orig_batch)
Expand Down
Loading

0 comments on commit 5ad7a7d

Please sign in to comment.