Skip to content

Commit

Permalink
fix merge_attributions (#undefined)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriele Sarti <gabriele.sarti996@gmail.com>
  • Loading branch information
DanielSc4 and gsarti committed Aug 2, 2023
1 parent 5a217b0 commit d4663af
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 46 deletions.
2 changes: 2 additions & 0 deletions docs/source/main_classes/main_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ functionalities required for its usage.
.. autofunction:: list_aggregation_functions

.. autofunction:: show_attributions

.. autofunction:: merge_attributions
2 changes: 2 additions & 0 deletions inseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
FeatureAttributionOutput,
list_aggregation_functions,
list_aggregators,
merge_attributions,
show_attributions,
)
from .models import AttributionModel, list_supported_frameworks, load_model
Expand Down Expand Up @@ -33,4 +34,5 @@ def get_version() -> str:
"list_step_functions",
"list_supported_frameworks",
"register_step_function",
"merge_attributions",
]
2 changes: 2 additions & 0 deletions inseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GranularFeatureAttributionStepOutput,
MultiDimensionalFeatureAttributionStepOutput,
get_batch_from_inputs,
merge_attributions,
)
from .batch import (
Batch,
Expand Down Expand Up @@ -60,6 +61,7 @@
"list_aggregation_functions",
"MultiDimensionalFeatureAttributionStepOutput",
"get_batch_from_inputs",
"merge_attributions",
"list_aggregators",
"slice_batch_from_position",
]
89 changes: 44 additions & 45 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,49 @@ def get_batch_from_inputs(
return batch


def merge_attributions(attributions: List["FeatureAttributionOutput"]) -> "FeatureAttributionOutput":
"""Merges multiple :class:`~inseq.data.FeatureAttributionOutput` objects into a single one.
Merging is allowed only if the two outputs match on the fields specified in ``_merge_match_info_fields``.
Args:
attributions (:obj:`list` of :class:`~inseq.data.FeatureAttributionOutput`): The FeatureAttributionOutput
objects to be merged.
Returns:
:class:`~inseq.data.FeatureAttributionOutput`: Merged object.
"""
assert all(
isinstance(x, FeatureAttributionOutput) for x in attributions
), "Only FeatureAttributionOutput objects can be merged."
first = attributions[0]
for match_field in FeatureAttributionOutput._merge_match_info_fields:
assert all(
(
attr.info[match_field] == first.info[match_field]
if match_field in first.info
else match_field not in attr.info
)
for attr in attributions
), f"Cannot merge: incompatible values for field {match_field}"
out_info = first.info.copy()
if "attr_pos_end" in first.info:
out_info.update({"attr_pos_end": max(attr.info["attr_pos_end"] for attr in attributions)})
if "generated_texts" in first.info:
out_info.update({"generated_texts": [text for attr in attributions for text in attr.info["generated_texts"]]})
if "input_texts" in first.info:
out_info.update({"input_texts": [text for attr in attributions for text in attr.info["input_texts"]]})
return FeatureAttributionOutput(
sequence_attributions=[seqattr for attr in attributions for seqattr in attr.sequence_attributions],
step_attributions=(
[stepattr for attr in attributions for stepattr in attr.step_attributions]
if first.step_attributions is not None
else None
),
info=out_info,
)


@dataclass(eq=False, repr=False)
class FeatureAttributionSequenceOutput(TensorWrapper, AggregableMixin):
"""Output produced by a standard attribution method.
Expand Down Expand Up @@ -467,7 +510,7 @@ def __iter__(self):
return iter(self.sequence_attributions)

def __add__(self, other) -> "FeatureAttributionOutput":
return self.merge_attributions([self, other])
return merge_attributions([self, other])

def __radd__(self, other) -> "FeatureAttributionOutput":
return self.__add__(other)
Expand Down Expand Up @@ -610,50 +653,6 @@ def show(
if return_html:
return out_str

@classmethod
def merge_attributions(cls, attributions: List["FeatureAttributionOutput"]) -> "FeatureAttributionOutput":
"""Merges multiple :class:`~inseq.data.FeatureAttributionOutput` objects into a single one.
Merging is allowed only if the two outputs match on the fields specified in ``_merge_match_info_fields``.
Args:
attributions (`list(FeatureAttributionOutput)`): The FeatureAttributionOutput objects to be merged.
Returns:
`FeatureAttributionOutput`: Merged object
"""
assert all(
isinstance(x, FeatureAttributionOutput) for x in attributions
), "Only FeatureAttributionOutput objects can be merged."
first = attributions[0]
for match_field in cls._merge_match_info_fields:
assert all(
(
attr.info[match_field] == first.info[match_field]
if match_field in first.info
else match_field not in attr.info
)
for attr in attributions
), f"Cannot merge: incompatible values for field {match_field}"
out_info = first.info.copy()
if "attr_pos_end" in first.info:
out_info.update({"attr_pos_end": max(attr.info["attr_pos_end"] for attr in attributions)})
if "generated_texts" in first.info:
out_info.update(
{"generated_texts": [text for attr in attributions for text in attr.info["generated_texts"]]}
)
if "input_texts" in first.info:
out_info.update({"input_texts": [text for attr in attributions for text in attr.info["input_texts"]]})
return cls(
sequence_attributions=[seqattr for attr in attributions for seqattr in attr.sequence_attributions],
step_attributions=(
[stepattr for attr in attributions for stepattr in attr.step_attributions]
if first.step_attributions is not None
else None
),
info=out_info,
)

def weight_attributions(self, step_score_id: str):
for i, attr in enumerate(self.sequence_attributions):
self.sequence_attributions[i] = attr.weight_attributions(step_score_id)
Expand Down
3 changes: 2 additions & 1 deletion inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FeatureAttributionInput,
FeatureAttributionOutput,
FeatureAttributionStepOutput,
merge_attributions,
)
from ..utils import (
MissingAttributionMethodError,
Expand Down Expand Up @@ -451,7 +452,7 @@ def attribute(
attributed_fn_args=attributed_fn_args,
step_scores_args=step_scores_args,
)
attribution_output = FeatureAttributionOutput.merge_attributions(attribution_outputs)
attribution_output = merge_attributions(attribution_outputs)
attribution_output.info["input_texts"] = input_texts
attribution_output.info["generated_texts"] = (
[generated_texts] if isinstance(generated_texts, str) else generated_texts
Expand Down

0 comments on commit d4663af

Please sign in to comment.