Skip to content

Commit

Permalink
Fix ContiguousSpanAggregation for prefixed generations (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 3, 2023
1 parent d4663af commit 33aa13b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 6 deletions.
2 changes: 1 addition & 1 deletion inseq/data/aggregation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AbsMaxAggregationFunction(AggregationFunction):
aggregation_function_name = "absmax"

def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.gather(dim, scores.abs().argmax(dim, keepdim=True)).squeeze(dim)
return scores.gather(dim, torch.nan_to_num(scores).abs().argmax(dim, keepdim=True)).squeeze(dim)


class VectorNormAggregationFunction(AggregationFunction):
Expand Down
56 changes: 54 additions & 2 deletions inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans
spans = cls.format_spans(spans)
prev_span_max = -1
for span in spans:
assert len(span) == 2, f"Spans must contain at least two indexes, got {spans}"
assert len(span) == 2, f"Spans must contain two indexes, got {spans}"
assert span[1] > span[0] + 1, f"Spans must be non-empty, got {spans}"
assert (
span[0] >= prev_span_max
Expand All @@ -578,6 +578,14 @@ def _aggregate_sequential_scores(scores, x_spans, y_spans, aggregate_fn):
scores_aggregated_x = aggregate_contiguous(scores_aggregated_y, x_spans, aggregate_fn, aggregate_dim=0)
return scores_aggregated_x

@staticmethod
def _relativize_target_spans(spans: List[Tuple[int, int]], start: int):
if start != 0 and spans:
# Remove target spans referring to the unattributed prefix, rescale remaining spans to relative idxs
# of the generated sequences and set 0 if the span starts before the generation begins.
spans = [(s[0] - start if s[0] > start else 0, s[1] - start) for s in spans if s[1] > start]
return spans

@staticmethod
def aggregate_source(attr, source_spans, **kwargs):
return aggregate_token_sequence(attr.source, source_spans)
Expand All @@ -590,6 +598,8 @@ def aggregate_target(attr, target_spans, **kwargs):
def aggregate_source_attributions(attr, source_spans, target_spans, aggregate_fn, **kwargs):
if attr.source_attributions is None:
return attr.source_attributions
# Handle the case in which generation starts from a prefix
target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start)
# First aggregate along generated target sequence, then along attributed source
return ContiguousSpanAggregator._aggregate_sequential_scores(
attr.source_attributions, source_spans, target_spans, aggregate_fn
Expand All @@ -599,16 +609,20 @@ def aggregate_source_attributions(attr, source_spans, target_spans, aggregate_fn
def aggregate_target_attributions(attr, target_spans, aggregate_fn, **kwargs):
if attr.target_attributions is None:
return attr.target_attributions
# Handle the case in which generation starts from a prefix
gen_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start)
# First aggregate along generated target sequence, then along attributed prefix
return ContiguousSpanAggregator._aggregate_sequential_scores(
attr.target_attributions, target_spans, target_spans, aggregate_fn
attr.target_attributions, target_spans, gen_spans, aggregate_fn
)

@staticmethod
def aggregate_step_scores(attr, target_spans, aggregate_fn, **kwargs):
if not attr.step_scores:
return attr.step_scores
out_dict = {}
# Handle the case in which generation starts from a prefix
target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start)
for name, step_scores in attr.step_scores.items():
agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn
out_dict[name] = aggregate_contiguous(step_scores, target_spans, agg_fn, aggregate_dim=0)
Expand All @@ -620,6 +634,8 @@ def aggregate_sequence_scores(attr, source_spans, target_spans, aggregate_fn, **
if not attr.sequence_scores:
return attr.sequence_scores
out_dict = {}
# Handle the case in which generation starts from a prefix
target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start)
for name, step_scores in attr.sequence_scores.items():
aggregate_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn
if name.startswith("decoder"):
Expand All @@ -636,6 +652,34 @@ def aggregate_sequence_scores(attr, source_spans, target_spans, aggregate_fn, **
)
return out_dict

@staticmethod
def aggregate_attr_pos_start(attr, target_spans, **kwargs):
if not target_spans:
return attr.attr_pos_start
tot_merged_prefix = sum([s[1] - s[0] - 1 for s in target_spans if s[1] <= attr.attr_pos_start])
new_pos_start = attr.attr_pos_start - tot_merged_prefix

# Handle the case in which tokens before and after the starting position are merged
# The resulting merged span will include the full merged token, but merged scores will reflect only the portion
# that was actually attributed. E.g. if "Hello world" if the prefix, ", how are you?" is the generation and the
# token "world," is formed during merging, the "world," token will be included in the attributed targets, but
# only scores of "," will be used for aggregation (i.e. no aggregation since it's a single token).
overlapping = [s for s in target_spans if s[0] < attr.attr_pos_start < s[1]]
if overlapping and len(overlapping) == 1:
new_pos_start -= attr.attr_pos_start - overlapping[0][0]
elif len(overlapping) > 1:
raise RuntimeError(f"Multiple overlapping spans detected for the starting position {attr.attr_pos_start}.")
return new_pos_start

@staticmethod
def aggregate_attr_pos_end(attr, target_spans, **kwargs):
if not target_spans:
return attr.attr_pos_end
new_start = ContiguousSpanAggregator.aggregate_attr_pos_start(attr, target_spans, **kwargs)
target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start)
tot_merged_sequence = sum([s[1] - s[0] - 1 for s in target_spans])
return new_start + ((attr.attr_pos_end - attr.attr_pos_start) - tot_merged_sequence)


class SubwordAggregator(ContiguousSpanAggregator):
"""Aggregates over subwords by automatic detecting contiguous subword spans.
Expand Down Expand Up @@ -679,6 +723,14 @@ def aggregate(
def get_spans(tokens: List[TokenWithId], special_symbol: str, is_suffix_symbol: bool):
spans = []
last_prefix_idx = 0
has_special_symbol = any(sym in token.token for token in tokens for sym in special_symbol)
if not has_special_symbol:
logger.warning(
f"ATTENTION: The {special_symbol} symbol is currently used for subword aggregation, but no instances "
"have been detected in the sequence. Change the special symbols using e.g. special_symbol=('臓', '膴')"
", and set is_suffix_symbol=True if they are used as suffix word separators (e.g. Hello</w> world</w>)"
)
return spans
for curr_idx, token in enumerate(tokens):
# Suffix if token start with special suffix symbol, or if it doesn't have the special prefix symbol.
is_suffix = token.token.startswith(special_symbol) == is_suffix_symbol
Expand Down
3 changes: 1 addition & 2 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def get_heatmap_type(
)
elif heatmap_type == "Target":
if attribution.target_attributions is not None:
mask = np.ones_like(attribution.target_attributions.numpy()) * float("nan")
mask = np.tril(mask, k=-attribution.attr_pos_start)
mask = np.where(attribution.target_attributions.numpy() == 0, float("nan"), 0)
target_attributions = attribution.target_attributions.numpy() + mask
else:
target_attributions = None
Expand Down
15 changes: 14 additions & 1 deletion tests/data/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SequenceAttributionAggregator,
SubwordAggregator,
)
from inseq.models import HuggingfaceEncoderDecoderModel
from inseq.models import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel

EXAMPLES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/aggregator.json")
EXAMPLES = json.load(open(EXAMPLES_FILE))
Expand All @@ -23,6 +23,11 @@ def saliency_mt_model() -> HuggingfaceEncoderDecoderModel:
return inseq.load_model("Helsinki-NLP/opus-mt-en-it", "saliency", device="cpu")


@fixture(scope="session")
def saliency_gpt_model() -> HuggingfaceDecoderOnlyModel:
return inseq.load_model("gpt2", "saliency", device="cpu")


def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDecoderModel):
out = saliency_mt_model.attribute(
"This is a test.",
Expand Down Expand Up @@ -56,6 +61,14 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder
assert out_agg.step_scores["probability"].shape == (4,)


def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel):
out = saliency_gpt_model.attribute("Hello, world! I am,:.", "Hello, world! I am,:.!,. Last")
aggregated = out.aggregate("subwords", special_symbol=("臓", "膴")).aggregate()
assert aggregated[0].target_attributions.shape == (5, 2)
assert aggregated[0].attr_pos_start == 3
assert aggregated[0].attr_pos_end == 5


def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel):
out = saliency_mt_model.attribute(
"This is a test.", attribute_target=True, step_scores=["probability"], device="cpu", show_progress=False
Expand Down

0 comments on commit 33aa13b

Please sign in to comment.