Skip to content

Commit

Permalink
Fix aggregate_contiguous (#247)
Browse files Browse the repository at this point in the history
* fix aggregate_contiguous

* special_symbol -> special_chars
  • Loading branch information
gsarti authored Jan 12, 2024
1 parent dfea66f commit 7503576
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
26 changes: 13 additions & 13 deletions inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,10 @@ class SubwordAggregator(ContiguousSpanAggregator):
preserved (e.g. [0.3, -0.7, 0.1] -> -0.7).
aggregate_source (bool, optional): Whether to aggregate over the source sequence. Defaults to True.
aggregate_target (bool, optional): Whether to aggregate over the target sequence. Defaults to True.
special_symbol (str, optional): Symbol used to identify subwords. Defaults to '▁', used by SentencePiece.
If is_suffix_symbol=True, then this symbol is used to identify parts to be aggregated (e.g. # in WordPiece,
['phen', '##omen', '##al']). Otherwise, it identifies the roots that should be preserved (e.g. ▁ in
SentencePiece, ['▁phen', 'omen', 'al']).
special_chars (str or tuple of str, optional): One or more characters used to identify subword boundaries.
Defaults to '▁', used by SentencePiece. If is_suffix_symbol=True, then this symbol is used to identify
parts to be aggregated (e.g. # in WordPiece, ['phen', '##omen', '##al']). Otherwise, it identifies the
roots that should be preserved (e.g. ▁ in SentencePiece, ['▁phen', 'omen', 'al']).
is_suffix_symbol (bool, optional): Whether the special symbol is used to identify suffixes or prefixes.
Defaults to False.
"""
Expand All @@ -707,33 +707,33 @@ def aggregate(
attr: "FeatureAttributionSequenceOutput",
aggregate_source: bool = True,
aggregate_target: bool = True,
special_symbol: str = "▁",
special_chars: Union[str, Tuple[str, ...]] = "▁",
is_suffix_symbol: bool = False,
**kwargs,
):
source_spans = []
target_spans = []
if aggregate_source:
source_spans = cls.get_spans(attr.source, special_symbol, is_suffix_symbol)
source_spans = cls.get_spans(attr.source, special_chars, is_suffix_symbol)
if aggregate_target:
target_spans = cls.get_spans(attr.target, special_symbol, is_suffix_symbol)
target_spans = cls.get_spans(attr.target, special_chars, is_suffix_symbol)
return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs)

@staticmethod
def get_spans(tokens: List[TokenWithId], special_symbol: str, is_suffix_symbol: bool):
def get_spans(tokens: List[TokenWithId], special_chars: Union[str, Tuple[str, ...]], is_suffix_symbol: bool):
spans = []
last_prefix_idx = 0
has_special_symbol = any(sym in token.token for token in tokens for sym in special_symbol)
if not has_special_symbol:
has_special_chars = any(sym in token.token for token in tokens for sym in special_chars)
if not has_special_chars:
logger.warning(
f"ATTENTION: The {special_symbol} symbol is currently used for subword aggregation, but no instances "
"have been detected in the sequence. Change the special symbols using e.g. special_symbol=('Ġ', 'Ċ')"
f"The {special_chars} character is currently used for subword aggregation, but no instances "
"have been detected in the sequence. Change the special symbols using e.g. special_chars=('Ġ', 'Ċ')"
", and set is_suffix_symbol=True if they are used as suffix word separators (e.g. Hello</w> world</w>)"
)
return spans
for curr_idx, token in enumerate(tokens):
# Suffix if token start with special suffix symbol, or if it doesn't have the special prefix symbol.
is_suffix = token.token.startswith(special_symbol) == is_suffix_symbol
is_suffix = token.token.startswith(special_chars) == is_suffix_symbol
if is_suffix:
if curr_idx == len(tokens) - 1 and curr_idx - last_prefix_idx > 1:
spans.append((last_prefix_idx, curr_idx))
Expand Down
35 changes: 23 additions & 12 deletions inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,37 @@ def aggregate_contiguous(
t: torch.Tensor,
spans: Sequence[Tuple[int, int]],
aggregate_fn: Optional[Callable] = None,
aggregate_dim: int = 1,
aggregate_dim: int = 0,
):
"""Given a tensor, aggregate contiguous spans of the tensor along a given dimension using the provided
aggregation function. If no aggregation function is provided, the mean is used.
Args:
t: Tensor to aggregate
spans: Sequence of (start, end) tuples indicating contiguous spans to aggregate
aggregate_fn: Aggregation function to use. If None, torch.mean is used.
aggregate_dim: Dimension to aggregate along. Default is 0.
"""
if not spans:
return t
if aggregate_fn is None:
aggregate_fn = torch.mean
while t.ndim < 2:
t = t.unsqueeze(-1)
t = t.transpose(aggregate_dim, 1)
if aggregate_dim > t.ndim:
raise ValueError(f"aggregate_dim {aggregate_dim} is greater than tensor dimension {t.ndim}")
if aggregate_dim != 0:
t = t.transpose(aggregate_dim, 0)
slices = []
base_val = 0
for start, end in spans:
if start > base_val:
slices.append(t[:, base_val:start, ...])
slices.append(aggregate_fn(t[:, start:end, ...], dim=1).unsqueeze(1))
slices.append(t[base_val:start, ...])
slices.append(aggregate_fn(t[start:end, ...], dim=0).unsqueeze(0))
base_val = end
slices.append(t[:, base_val:])
out_cat = torch.cat(slices, dim=1).transpose(1, aggregate_dim)
if 1 in out_cat.shape:
out_cat = out_cat.transpose(1, 0).squeeze(0)
if base_val < t.shape[0]:
slices.append(t[base_val:, ...])
out_cat = torch.cat(slices, dim=0)
if aggregate_dim != 0:
out_cat = out_cat.transpose(aggregate_dim, 0)
return out_cat


Expand All @@ -174,8 +185,8 @@ def get_front_padding(t: torch.Tensor, pad: int = 0, dim: int = 1) -> List[int]:


def get_sequences_from_batched_steps(bsteps: List[torch.Tensor]) -> List[torch.Tensor]:
"""Given a sequence of batched step tensors of shape (batch_size, ...) builds a sequence
of tensors of shape (len(sequence), ...) where each resulting tensor is the aggregation
"""Given a sequence of batched step tensors of shape (batch_size, seq_len, ...) builds a sequence
of tensors of shape (seq_len, ...) where each resulting tensor is the aggregation
across batch steps for every batch element.
Input tensors will be padded with nans up to max length in non-uniform dimensions to allow for stacking.
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder

def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel):
out = saliency_gpt_model.attribute("Hello, world! I am,:.", "Hello, world! I am,:.!,. Last")
aggregated = out.aggregate("subwords", special_symbol=("Ġ", "Ċ")).aggregate()
aggregated = out.aggregate("subwords", special_chars=("Ġ", "Ċ")).aggregate()
assert aggregated[0].target_attributions.shape == (5, 2)
assert aggregated[0].attr_pos_start == 3
assert aggregated[0].attr_pos_end == 5
Expand Down

0 comments on commit 7503576

Please sign in to comment.