Skip to content

Commit

Permalink
Target prefix-constrained generation (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Apr 19, 2023
1 parent 9686ab3 commit a4a43e2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
8 changes: 4 additions & 4 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
from abc import abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from torchtyping import TensorType
from transformers import set_seed
Expand Down Expand Up @@ -150,7 +150,7 @@ def load(
@batched
def prepare_and_attribute(
self,
sources: Sequence[str],
sources: FeatureAttributionInput,
targets: FeatureAttributionInput,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
Expand All @@ -172,9 +172,9 @@ def prepare_and_attribute(
and the :meth:`~inseq.models.AttributionModel.prepare_inputs_for_attribution` method.
Args:
sources (:obj:`list(str)`): The sources provided to the
sources (:obj:`FeatureAttributionInput`): The sources provided to the
:meth:`~inseq.attr.feat.FeatureAttribution.prepare` method.
targets (:obj:`FeatureAttributionInput): The targets provided to the
targets (:obj:`FeatureAttributionInput`): The targets provided to the
:meth:`~inseq.attr.feat.FeatureAttribution.prepare` method.
attr_pos_start (:obj:`int`, `optional`): The initial position for performing
sequence attribution. Defaults to 0.
Expand Down
12 changes: 12 additions & 0 deletions inseq/commands/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class AttributeBaseArgs:
"help": "Performs the attribution procedure including the generated prefix at every step.",
},
)
generate_from_target_prefix: bool = field(
default=False,
metadata={
"help": (
"Whether the ``generated_texts`` should be used as"
"target prefixes for the generation process. If False, the ``generated_texts`` will be used as full"
"targets. This option is only available for encoder-decoder models, since the same behavior can be"
"achieved by modifying the input texts for decoder-only models. Default: False."
)
},
)
step_scores: List[str] = field(
default_factory=list, metadata={"help": "Adds step scores to the attribution output."}
)
Expand Down Expand Up @@ -165,6 +176,7 @@ def attribute(input_texts, generated_texts, args: AttributeBaseArgs):
generation_args={"max_new_tokens": args.max_gen_length},
attr_pos_start=args.start_pos,
attr_pos_end=args.end_pos,
generate_from_target_prefix=args.generate_from_target_prefix,
)
if args.viz_path:
print(f"Saving visualization to {args.viz_path}")
Expand Down
41 changes: 27 additions & 14 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def attribute(
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
device: Optional[str] = None,
batch_size: Optional[int] = None,
generate_from_target_prefix: bool = False,
generation_args: Dict[str, Any] = {},
**kwargs,
) -> FeatureAttributionOutput:
"""Perform sequential attribution of input texts for every token in generated texts using the specified method.
Expand Down Expand Up @@ -210,6 +212,10 @@ def attribute(
device will be used.
batch_size (:obj:`int`, `optional`): The batch size to use to dilute the attribution computation over the
set of inputs. If no batch size is provided, the full set of input texts will be attributed at once.
generate_from_target_prefix (:obj:`bool`, `optional`): Whether the ``generated_texts`` should be used as
target prefixes for the generation process. If False, the ``generated_texts`` will be used as full
targets. This option is only available for encoder-decoder models, since the same behavior can be
achieved by modifying the input texts for decoder-only models. Default: False.
**kwargs: Additional keyword arguments. These can include keyword arguments for the attribution method, for
the generation process or for the attributed function. Generation arguments can be provided explicitly
as a dictionary named ``generation_args``.
Expand All @@ -223,26 +229,32 @@ def attribute(
raise ValueError("At least one text must be provided to perform attribution.")
if attribute_target and not self.is_encoder_decoder:
logger.warning("attribute_target parameter is set to True, but will be ignored (not an encoder-decoder).")
attribute_target = False
if generate_from_target_prefix and not self.is_encoder_decoder:
logger.warning(
"generate_from_target_prefix parameter is set to True, but will be ignored (not an encoder-decoder)."
)
generate_from_target_prefix = False
original_device = self.device
if device is not None:
self.device = device
input_texts, generated_texts = format_input_texts(input_texts, generated_texts)
if batch_size is not None:
n_batches = len(input_texts) // batch_size + ((len(input_texts) % batch_size) > 0)
logger.info(f"Splitting input texts into {n_batches} batches of size {batch_size}.")
constrained_decoding = generated_texts is not None
orig_input_texts = input_texts
# If constrained decoding is not enabled, we need to generate the
# generated texts from the input texts.
generation_args = kwargs.pop("generation_args", {})
if constrained_decoding and generation_args:
logger.warning(
f"Generation arguments {generation_args} are provided, but constrained decoding is enabled. "
"Generation arguments will be ignored."
)
if not constrained_decoding:
has_generated_texts = generated_texts is not None
# If constrained decoding is not enabled, output texts are generated from input texts.
if not has_generated_texts or generate_from_target_prefix:
encoded_input = self.encode(input_texts, return_baseline=True, include_eos_baseline=include_eos_baseline)
if generate_from_target_prefix:
decoder_input = self.encode(generated_texts, as_targets=True)
generation_args["decoder_input_ids"] = decoder_input.input_ids
generated_texts = self.generate(encoded_input, return_generation_output=False, **generation_args)
else:
if generation_args:
logger.warning(
f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
)
logger.debug(f"reference_texts={generated_texts}")
attribution_method = self.get_attribution_method(method, override_default_attribution)
attributed_fn = self.get_attributed_fn(attributed_fn)
Expand All @@ -256,7 +268,7 @@ def attribute(
assert all(
generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts))
), "Forced generations with decoder-only models must start with the input texts."
if constrained_decoding and len(input_texts) > 1:
if has_generated_texts and len(input_texts) > 1:
logger.info(
"Batched constrained decoding is currently not supported for decoder-only models."
" Using batch size of 1."
Expand Down Expand Up @@ -289,12 +301,13 @@ def attribute(
step_scores_args=step_scores_args,
)
attribution_output = FeatureAttributionOutput.merge_attributions(attribution_outputs)
attribution_output.info["input_texts"] = orig_input_texts
attribution_output.info["input_texts"] = input_texts
attribution_output.info["generated_texts"] = (
[generated_texts] if isinstance(generated_texts, str) else generated_texts
)
attribution_output.info["generation_args"] = generation_args
attribution_output.info["constrained_decoding"] = constrained_decoding
attribution_output.info["constrained_decoding"] = has_generated_texts
attribution_output.info["generate_from_target_prefix"] = generate_from_target_prefix
if device and original_device:
self.device = original_device
return attribution_output
Expand Down

0 comments on commit a4a43e2

Please sign in to comment.