Skip to content

Commit

Permalink
Allow contrastive attribution with shorter contrastive targets (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jul 28, 2023
1 parent d108595 commit 5a217b0
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ Step functions are used to extract custom scores from the model at each step of
- `perplexity`: Perplexity of the target token.
- `contrast_prob`: Probability of the target token when different contrastive inputs are provided to the model. Equivalent to `probability` when no contrastive inputs are provided.
- `pcxmi`: Point-wise Contextual Cross-Mutual Information (P-CXMI) for the target token given original and contrastive contexts [(Yin et al. 2021)](https://arxiv.org/abs/2109.07446).
- `kl_divergence`: KL divergence of the predictive distribution given original and contrastive contexts. Can be limited to top-K most likely target token options using the `top_k` parameter.
- `kl_divergence`: KL divergence of the predictive distribution given original and contrastive contexts. Can be restricted to most likely target token options using the `top_k` and `top_p` parameters.
- `contrast_prob_diff`: Difference in probability between original and foil target tokens pair, can be used for contrastive evaluation as in [Contrastive Attribution](https://aclanthology.org/2022.emnlp-main.14/) (Yin and Neubig, 2022).
- `mc_dropout_prob_avg`: Average probability of the target token across multiple samples using [MC Dropout](https://arxiv.org/abs/1506.02142) (Gal and Ghahramani, 2016).
- `top_p_size`: The number of tokens with cumulative probability greater than `top_p` in the predictive distribution of the model.

The following example computes contrastive attributions using the `contrast_prob_diff` step function:

Expand Down
2 changes: 1 addition & 1 deletion inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def join_token_ids(
curr_seq = []
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq)):
contrast_pos_idx = get_aligned_idx(pos_idx, alignments_seq)
if token != contrast_target_tokens_seq[contrast_pos_idx]:
if contrast_pos_idx != -1 and token != contrast_target_tokens_seq[contrast_pos_idx]:
curr_seq.append(TokenWithId(f"{contrast_target_tokens_seq[contrast_pos_idx]}{token}", -1))
else:
curr_seq.append(TokenWithId(token, token_idx))
Expand Down
1 change: 1 addition & 0 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def attribute(
contrast_batch.target_tokens, as_targets=as_targets
),
special_tokens=self.attribution_model.special_tokens,
start_pos=attr_pos_start,
)
attributed_fn_args["contrast_targets_alignments"] = contrast_targets_alignments
if "contrast_targets" in step_scores_args:
Expand Down
2 changes: 2 additions & 0 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def format_contrast_targets_alignments(
contrast_sequences: List[str],
contrast_tokens: List[List[str]],
special_tokens: List[str] = [],
start_pos: int = 0,
) -> Tuple[DecoderOnlyBatch, Optional[List[List[Tuple[int, int]]]]]:
# Ensure that the contrast_targets_alignments are in the correct format (list of lists of idxs pairs)
if contrast_targets_alignments:
Expand Down Expand Up @@ -180,6 +181,7 @@ def format_contrast_targets_alignments(
contrast_tokens=c_tok,
fill_missing=True,
special_tokens=special_tokens,
start_pos=start_pos,
)
)
return adjusted_alignments
Expand Down
32 changes: 21 additions & 11 deletions inseq/utils/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase

from .errors import MissingAlignmentsError
from .misc import clean_tokens

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -232,11 +231,22 @@ def auto_align_sequences(
clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens)
clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens)
if len(removed_a_token_idxs) != len(removed_b_token_idxs):
raise ValueError(
logger.warning(
"The number of special tokens in the target and contrast sequences do not match. "
"Please provide sequences with the same number of special tokens."
"Trying to match special tokens based on their identity."
)
aligned_special_tokens = [(rm_a, rm_b) for rm_a, rm_b in zip(removed_a_token_idxs, removed_b_token_idxs)]
removed_a_tokens = [a_tokens[idx] for idx in removed_a_token_idxs]
removed_b_tokens = [b_tokens[idx] for idx in removed_b_token_idxs]
aligned_special_tokens = []
for curr_idx, rm_a in enumerate(removed_a_tokens):
rm_a_idx = removed_a_token_idxs[curr_idx]
if rm_a not in removed_b_tokens:
aligned_special_tokens.append((rm_a_idx, rm_a_idx))
else:
rm_b_idx = removed_b_token_idxs[removed_b_tokens.index(rm_a)]
aligned_special_tokens.append((rm_a_idx, rm_b_idx))
else:
aligned_special_tokens = [(rm_a, rm_b) for rm_a, rm_b in zip(removed_a_token_idxs, removed_b_token_idxs)]
a_word_to_token_align = align_tokenizations(a_words, clean_a_tokens)
b_word_to_token_align = align_tokenizations(b_words, clean_b_tokens)
# 3. Propagate word-level alignments to token-level alignments.
Expand Down Expand Up @@ -272,6 +282,7 @@ def get_adjusted_alignments(
do_sort: bool = True,
fill_missing: bool = False,
special_tokens: List[str] = [],
start_pos: int = 0,
) -> List[Tuple[int, int]]:
is_auto_aligned = False
if fill_missing and not target_tokens:
Expand Down Expand Up @@ -301,7 +312,7 @@ def get_adjusted_alignments(
# Filling alignments with missing tokens
if fill_missing:
filled_alignments = []
for pair_idx in range(len(target_tokens)):
for pair_idx in range(start_pos, len(target_tokens)):
match_pairs = [pair for pair in alignments if pair[0] == pair_idx]

if not match_pairs:
Expand All @@ -314,10 +325,11 @@ def get_adjusted_alignments(
filled_alignments.append(valid_match)
if alignments != filled_alignments:
logger.warning(
f"Provided alignments do not cover all {len(target_tokens)} tokens from the original sequence.\n"
"Filling missing position with 1:1 position alignments."
f"Provided alignments do not cover all {len(target_tokens) - start_pos} tokens from the original"
" sequence.\nFilling missing position with 1:1 position alignments."
)
if is_auto_aligned:
filled_alignments = [(a_idx, b_idx) for a_idx, b_idx in filled_alignments if a_idx >= start_pos]
logger.warning(
f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic "
f"sequences, or for languages not covered by the aligner.\nGenerated alignments: {filled_alignments}"
Expand All @@ -331,10 +343,8 @@ def get_aligned_idx(a_idx: int, alignments: List[Tuple[int, int]]) -> int:
# Find all alignment pairs for the current original target
aligned_idxs = [t_idx for s_idx, t_idx in alignments if s_idx == a_idx]
if not aligned_idxs:
raise MissingAlignmentsError(
f"No alignment found for token at index {a_idx}. "
"Please provide alignment pairs that cover all original target tokens."
)
# To be handled separately
return -1
# Select the minimum index to identify the next target token
return min(aligned_idxs)
return a_idx

0 comments on commit 5a217b0

Please sign in to comment.