From 801f1678aa10a914c37cad88b7ef695680aeba10 Mon Sep 17 00:00:00 2001 From: xuan25 Date: Sun, 18 Feb 2024 22:19:01 +0000 Subject: [PATCH 01/14] Middleware for ReAgent --- inseq/attr/feat/ops/__init__.py | 2 + inseq/attr/feat/ops/reagent.py | 78 +++++++++++++++++++++ inseq/attr/feat/perturbation_attribution.py | 29 ++++++++ 3 files changed, 109 insertions(+) create mode 100644 inseq/attr/feat/ops/reagent.py diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 388ab042..48011294 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,11 +1,13 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder +from .reagent import ReAGent from .sequential_integrated_gradients import SequentialIntegratedGradients __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Lime", + "ReAGent", "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py new file mode 100644 index 00000000..ee5dd7c9 --- /dev/null +++ b/inseq/attr/feat/ops/reagent.py @@ -0,0 +1,78 @@ + +from typing import Any, Callable, Union, cast +from captum.attr._utils.attribution import PerturbationAttribution +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from torch import Tensor +import torch +from ReAGent.src.rationalization.rationalizer.aggregate_rationalizer import AggregateRationalizer +from ReAGent.src.rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator +from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator +from ReAGent.src.rationalization.rationalizer.token_replacement.token_replacer.uniform import UniformTokenReplacer + +from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.postag import POSTagTokenSampler +from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.uniform import UniformTokenSampler + +class ReAGent(PerturbationAttribution): + + def __init__( + self, + attribution_model: Callable, + ) -> None: + PerturbationAttribution.__init__(self, forward_func=attribution_model) + + # TODO: Handle parameters via args + model = self.forward_func + tokenizer = self.forward_func.tokenizer + stopping_top_k = 1 + rational_size = 1 + rational_size_ratio = 1 + replacing = 0 + max_steps = 1 + batch = 1 + device = 'cpu' + + token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=device) + # token_sampler = UniformTokenSampler(tokenizer=tokenizer) + + stopping_condition_evaluator = TopKStoppingConditionEvaluator( + model=model, + token_sampler=token_sampler, + top_k=stopping_top_k, + top_n=rational_size, + top_n_ratio=rational_size_ratio, + tokenizer=tokenizer + ) + + importance_score_evaluator = DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=token_sampler, + ratio=replacing + ), + stopping_condition_evaluator=stopping_condition_evaluator, + max_steps=max_steps + ) + + + self.rationalizer = AggregateRationalizer( + importance_score_evaluator=importance_score_evaluator, + batch_size=batch, + overlap_threshold=2, + overlap_strict_pos=True, + top_n=rational_size, + top_n_ratio=rational_size_ratio + ) + + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, + tuple[TensorOrTupleOfTensorsGeneric, Tensor], + ]: + # TODO: Actual ReAgent implementation + res = torch.rand(inputs[0].shape) + return (res,) diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index fbebb780..617df692 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -11,6 +11,7 @@ from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution from .ops import Lime +from .ops import ReAGent logger = logging.getLogger(__name__) @@ -117,3 +118,31 @@ def attribute_step( target_attributions=out.target_attributions, sequence_scores=out.sequence_scores, ) + +class ReAGentAttribution(PerturbationAttributionRegistry): + + method_name = "ReAGent" + + def __init__(self, attribution_model, **kwargs): + super().__init__(attribution_model) + self.method = ReAGent(attribution_model=self.attribution_model, **kwargs) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any] = {}, + ) -> GranularFeatureAttributionStepOutput: + if len(attribute_fn_main_args["inputs"]) > 1: + # Captum's `_evaluate_batch` function for LIME does not account for multiple inputs when encoder-decoder + # models and attribute_target=True are used. The model output is of length two and if the inputs are either + # of length one (list containing a tuple) or of length two (tuple unpacked from the list), an error is + # raised. A workaround will be added soon. + raise NotImplementedError( + "ReAgent attribution with attribute_target=True currently not supported for encoder-decoder models." + ) + out = super().attribute_step(attribute_fn_main_args, attribution_args) + return GranularFeatureAttributionStepOutput( + source_attributions=out.source_attributions, + target_attributions=out.target_attributions, + sequence_scores=out.sequence_scores, + ) From 8a47b2ee6f37a51e1b85222c29b5ecd95447d9fe Mon Sep 17 00:00:00 2001 From: xuan25 Date: Sun, 18 Feb 2024 23:14:29 +0000 Subject: [PATCH 02/14] ReAGent implementation --- inseq/attr/feat/ops/reagent.py | 95 ++++++++++++++------- inseq/attr/feat/perturbation_attribution.py | 9 +- 2 files changed, 70 insertions(+), 34 deletions(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index ee5dd7c9..710f1c8c 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -7,72 +7,109 @@ from ReAGent.src.rationalization.rationalizer.aggregate_rationalizer import AggregateRationalizer from ReAGent.src.rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator +from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.dummy import DummyStoppingConditionEvaluator from ReAGent.src.rationalization.rationalizer.token_replacement.token_replacer.uniform import UniformTokenReplacer from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.postag import POSTagTokenSampler from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.uniform import UniformTokenSampler class ReAGent(PerturbationAttribution): + r""" + ReAGent + + Args: + forward_func (callable): The forward function of the model or any + modification of it + rational_size (int): Top n tokens based on importance_score are not been replaced during the prediction inference. + top_n_ratio will be used if top_n has been set to 0 + rational_size_ratio (float): TUse ratio of input length to control the top n + stopping_condition_top_k (int): Stop condition achieved when target exist in top k predictions + replacing_ratio (float): replacing ratio of tokens for probing + max_probe_steps (int): max_probe_steps + num_probes (int): number of probes in parallel + + References: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `_ + + Examples: + ``` + import inseq + + model = inseq.load_model("gpt2-medium", "ReAGent", + rational_size=5, + rational_size_ratio=None, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8) + out = model.attribute( + "Super Mario Land is a game that developed by", + ) + out.show() + ``` + """ def __init__( self, attribution_model: Callable, + rational_size: int=5, + rational_size_ratio: float=None, + stopping_condition_top_k: int=3, + replacing_ratio: float=0.3, + max_probe_steps: int=3000, + num_probes: int=8, ) -> None: PerturbationAttribution.__init__(self, forward_func=attribution_model) - # TODO: Handle parameters via args - model = self.forward_func - tokenizer = self.forward_func.tokenizer - stopping_top_k = 1 - rational_size = 1 - rational_size_ratio = 1 - replacing = 0 - max_steps = 1 - batch = 1 - device = 'cpu' + model = attribution_model.model + tokenizer = attribution_model.tokenizer - token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=device) - # token_sampler = UniformTokenSampler(tokenizer=tokenizer) + token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=model.device) stopping_condition_evaluator = TopKStoppingConditionEvaluator( - model=model, - token_sampler=token_sampler, - top_k=stopping_top_k, - top_n=rational_size, - top_n_ratio=rational_size_ratio, + model=model, + token_sampler=token_sampler, + top_k=stopping_condition_top_k, + top_n=rational_size, + top_n_ratio=rational_size_ratio, tokenizer=tokenizer ) + # stopping_condition_evaluator = DummyStoppingConditionEvaluator() importance_score_evaluator = DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, + model=model, + tokenizer=tokenizer, token_replacer=UniformTokenReplacer( - token_sampler=token_sampler, - ratio=replacing + token_sampler=token_sampler, + ratio=replacing_ratio ), stopping_condition_evaluator=stopping_condition_evaluator, - max_steps=max_steps + max_steps=max_probe_steps ) - self.rationalizer = AggregateRationalizer( importance_score_evaluator=importance_score_evaluator, - batch_size=batch, - overlap_threshold=2, + batch_size=num_probes, + overlap_threshold=0, overlap_strict_pos=True, - top_n=rational_size, + top_n=rational_size, top_n_ratio=rational_size_ratio ) + @override def attribute( # type: ignore self, inputs: TensorOrTupleOfTensorsGeneric, - target: TargetType = None, + _target: TargetType = None, additional_forward_args: Any = None, ) -> Union[ TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: - # TODO: Actual ReAgent implementation - res = torch.rand(inputs[0].shape) + """Implement attribute + """ + self.rationalizer.rationalize(additional_forward_args[0], additional_forward_args[1]) + mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) + res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) return (res,) diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 617df692..3fed9dbc 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -120,6 +120,9 @@ def attribute_step( ) class ReAGentAttribution(PerturbationAttributionRegistry): + """ReAGent-based attribution method. + The main part of the code is in ops/reagent.py. + """ method_name = "ReAGent" @@ -133,12 +136,8 @@ def attribute_step( attribution_args: dict[str, Any] = {}, ) -> GranularFeatureAttributionStepOutput: if len(attribute_fn_main_args["inputs"]) > 1: - # Captum's `_evaluate_batch` function for LIME does not account for multiple inputs when encoder-decoder - # models and attribute_target=True are used. The model output is of length two and if the inputs are either - # of length one (list containing a tuple) or of length two (tuple unpacked from the list), an error is - # raised. A workaround will be added soon. raise NotImplementedError( - "ReAgent attribution with attribute_target=True currently not supported for encoder-decoder models." + "ReAgent attribution not supported for encoder-decoder models." ) out = super().attribute_step(attribute_fn_main_args, attribution_args) return GranularFeatureAttributionStepOutput( From 15ed157b75618d6c5deb73f3dcb101dee949efd6 Mon Sep 17 00:00:00 2001 From: xuan25 Date: Sun, 18 Feb 2024 23:14:54 +0000 Subject: [PATCH 03/14] Add reagent_core --- inseq/attr/feat/ops/reagent.py | 18 +- .../reagent_core/aggregate_rationalizer.py | 294 ++++++++++++++++++ inseq/attr/feat/ops/reagent_core/base.py | 11 + .../importance_score_evaluator/base.py | 66 ++++ .../importance_score_evaluator/delta_prob.py | 245 +++++++++++++++ .../ops/reagent_core/sample_rationalizer.py | 245 +++++++++++++++ .../stopping_condition_evaluator/base.py | 20 ++ .../stopping_condition_evaluator/dummy.py | 38 +++ .../stopping_condition_evaluator/top_k.py | 102 ++++++ .../token_replacement/token_replacer/base.py | 43 +++ .../token_replacer/ranking.py | 68 ++++ .../token_replacer/threshold.py | 63 ++++ .../token_replacer/uniform.py | 51 +++ .../token_replacement/token_sampler/base.py | 22 ++ .../token_sampler/inferential.py | 45 +++ .../token_sampler/inferential_m.py | 91 ++++++ .../token_replacement/token_sampler/postag.py | 78 +++++ .../token_sampler/uniform.py | 55 ++++ .../ops/reagent_core/utils/serializing.py | 71 +++++ .../feat/ops/reagent_core/utils/traceable.py | 14 + 20 files changed, 1629 insertions(+), 11 deletions(-) create mode 100644 inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py create mode 100644 inseq/attr/feat/ops/reagent_core/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py create mode 100644 inseq/attr/feat/ops/reagent_core/sample_rationalizer.py create mode 100644 inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py create mode 100644 inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py create mode 100644 inseq/attr/feat/ops/reagent_core/utils/serializing.py create mode 100644 inseq/attr/feat/ops/reagent_core/utils/traceable.py diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 710f1c8c..5bc9c772 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -1,17 +1,14 @@ - -from typing import Any, Callable, Union, cast +from typing import Any, Callable, Union +from typing_extensions import override from captum.attr._utils.attribution import PerturbationAttribution from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from torch import Tensor import torch -from ReAGent.src.rationalization.rationalizer.aggregate_rationalizer import AggregateRationalizer -from ReAGent.src.rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator -from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator -from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.dummy import DummyStoppingConditionEvaluator -from ReAGent.src.rationalization.rationalizer.token_replacement.token_replacer.uniform import UniformTokenReplacer - -from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.postag import POSTagTokenSampler -from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.uniform import UniformTokenSampler +from .reagent_core.aggregate_rationalizer import AggregateRationalizer +from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator +from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator +from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer +from .reagent_core.token_replacement.token_sampler.postag import POSTagTokenSampler class ReAGent(PerturbationAttribution): r""" @@ -75,7 +72,6 @@ def __init__( top_n_ratio=rational_size_ratio, tokenizer=tokenizer ) - # stopping_condition_evaluator = DummyStoppingConditionEvaluator() importance_score_evaluator = DeltaProbImportanceScoreEvaluator( model=model, diff --git a/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py b/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py new file mode 100644 index 00000000..66f6bee3 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py @@ -0,0 +1,294 @@ + +import math +from typing import Union + +import torch +from .base import BaseRationalizer +from .importance_score_evaluator.base import BaseImportanceScoreEvaluator + +from typing_extensions import override + + +class AggregateRationalizer(BaseRationalizer): + """AggregateRationalizer + + """ + + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator, batch_size: int, overlap_threshold: int, overlap_strict_pos: bool = True, top_n: float = 0, top_n_ratio: float = 0) -> None: + """Constructor + + Args: + importance_score_evaluator: A ImportanceScoreEvaluator + batch_size: Batch size for aggregate + overlap_threshold: Overlap threshold of rational tokens within a batch + overlap_strict_pos: Whether overlap strict to position ot not + top_n: Rational size + top_n_ratio: Use ratio of sequence to define rational size + + """ + super().__init__(importance_score_evaluator) + + self.batch_size = batch_size + self.overlap_threshold = overlap_threshold + self.overlap_strict_pos = overlap_strict_pos + self.top_n = top_n + self.top_n_ratio = top_n_ratio + + assert overlap_strict_pos == True, "overlap_strict_pos = False not been supported yet" + + def get_separate_rational(self, input_ids, tokenizer) -> Union[torch.Tensor, list[list[str]]]: + + tokens = [ [ tokenizer.decode([input_ids[0, i]]) for i in s] for s in self.pos_top_n ] + + return self.pos_top_n, tokens + + @torch.no_grad() + def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] (first dimension need to be 1) + target_id: The target [batch] + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" + + batch_input_ids = input_ids.repeat(self.batch_size, 1) + + batch_importance_score = self.importance_score_evaluator.evaluate(batch_input_ids, target_id) + + important_score_masked = batch_importance_score * torch.unsqueeze(self.importance_score_evaluator.stop_mask, -1) + self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum(self.importance_score_evaluator.stop_mask) + + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) + + top_n = self.top_n + + if top_n == 0: + top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) + + pos_top_n = pos_sorted[:, :top_n] + self.pos_top_n = pos_top_n + + if self.overlap_strict_pos: + count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) + pos_top_n_overlap = torch.unsqueeze(torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0) + return pos_top_n_overlap + else: + token_id_top_n = input_ids[0, pos_top_n] + count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) + token_id_top_n_overlap = torch.unsqueeze(torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0) + # TODO: Convert back to pos + raise NotImplementedError("TODO") + + + @override + def trace_start(self) -> None: + """Start tracing + + """ + super().trace_start() + + self.importance_score_evaluator.trace_start() + + @override + def trace_stop(self) -> None: + """Stop tracing + + """ + super().trace_stop() + + self.importance_score_evaluator.trace_stop() + + +@torch.no_grad() +def main(): + + from stopping_condition_evaluator.top_k import \ + TopKStoppingConditionEvaluator + from token_replacement.token_replacer.uniform import UniformTokenReplacer + from token_replacement.token_sampler.inferential import \ + InferentialTokenSampler + from token_replacement.token_sampler.postag import POSTagTokenSampler + from token_replacement.token_sampler.uniform import UniformTokenSampler + from transformers import AutoModelWithLMHead, AutoTokenizer + + from rationalization.rationalizer.importance_score_evaluator.delta_prob import \ + DeltaProbImportanceScoreEvaluator + from utils.serializing import serialize_rational + + # ======== model loading ======== + # Load model from Hugging Face + model = AutoModelWithLMHead.from_pretrained("gpt2-medium") + tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") + + model.cuda() + model.eval() + + # ======== prepare data ======== + + # batch with size 1 + input_string = [ + # "I love eating breakfast in the", + "When my flight landed in Thailand. I was staying in the capital city of" + # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. I was staying in the capital city of" + # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" + ] + + # generate prediction + input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'].to(model.device) + generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) + print(' generated input -->', [ [ tokenizer.decode(token) for token in seq] for seq in generated_input ]) + + # extract target from prediction + target_id = generated_input[:, input_ids.shape[1]] + print(' target -->', [ tokenizer.decode(token) for token in target_id ]) + + # ======== hyper-parameters ======== + + # replacing ratio during importance score updating + updating_replacing_ratio = 0.3 + # keep top n word based on importance score for both stop condition evaluation and rationalization + rationale_size_ratio = None + rational_size = 5 + # stop when target exist in top k predictions + stop_condition_tolerance = 5 + + # Batch size for aggregate + aggregate_batch_size = 5 + # Overlap threshold of rational tokens within a batch + overlap_threshold = 3 + # Whether overlap strict to position ot not + overlap_strict_pos = True + + # ======== rationalization ======== + + approach_sample_replacing_token = "uniform" + # approach_sample_replacing_token = "inference" + # approach_sample_replacing_token = "postag" + + # prepare rationalizer + if approach_sample_replacing_token == "uniform": + # Approach 1: sample replacing token from uniform distribution + rationalizer = AggregateRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=UniformTokenSampler(tokenizer), + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=UniformTokenSampler(tokenizer), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + batch_size=aggregate_batch_size, + overlap_threshold=overlap_threshold, + overlap_strict_pos=overlap_strict_pos, + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + elif approach_sample_replacing_token == "inference": + # Approach 2: sample replacing token from model inference + rationalizer = AggregateRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + batch_size=aggregate_batch_size, + overlap_threshold=overlap_threshold, + overlap_strict_pos=overlap_strict_pos, + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + elif approach_sample_replacing_token == "postag": + # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag + ts = POSTagTokenSampler(tokenizer=tokenizer, device=input_ids.device) # Initialize POSTagTokenSampler takes time so share it + rationalizer = AggregateRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=ts, + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=ts, + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + batch_size=aggregate_batch_size, + overlap_threshold=overlap_threshold, + overlap_strict_pos=overlap_strict_pos, + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + else: + raise ValueError("Invalid approach_sample_replacing_token") + + rationalizer.trace_start() + + # rationalization + pos_rational = rationalizer.rationalize(input_ids, generated_input[:, input_ids.shape[1]]) + + # convert results + + print() + print(f"========================") + print() + print(f'Input --> {input_string[0]}') + print(f'Target --> {tokenizer.decode(target_id[0])}') + print(f"Rational positions --> {pos_rational}") + print(f"Rational words -->") + for i in range(pos_rational.shape[0]): + ids_rational = input_ids[0, pos_rational[i]] + text_rational = [ tokenizer.decode([id_rational]) for id_rational in ids_rational ] + print(f"{text_rational}") + + # output + + serialize_rational( + "rationalization_results/demo.json", + -1, + input_ids[0], + target_id[0], + pos_rational[0], + tokenizer, + rationalizer.importance_score_evaluator.important_score[0], + compact=False, + comments= { + "message": "This is a demo output. [comments] is an optional field", + "model": "gpt2-medium", + "approach_type": approach_sample_replacing_token + }, + trace_rationalizer=rationalizer + ) + + rationalizer.trace_stop() + +if __name__ == '__main__': + main() diff --git a/inseq/attr/feat/ops/reagent_core/base.py b/inseq/attr/feat/ops/reagent_core/base.py new file mode 100644 index 00000000..f7e10cb3 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/base.py @@ -0,0 +1,11 @@ +from .importance_score_evaluator.base import BaseImportanceScoreEvaluator +from .utils.traceable import Traceable + + +class BaseRationalizer(Traceable): + + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: + super().__init__() + + self.importance_score_evaluator = importance_score_evaluator + self.mean_important_score = None diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py new file mode 100644 index 00000000..831b3c45 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py @@ -0,0 +1,66 @@ +import torch +from transformers import AutoModelWithLMHead, AutoTokenizer +from typing_extensions import override + +from ..stopping_condition_evaluator.base import StoppingConditionEvaluator +from ..token_replacement.token_replacer.base import TokenReplacer +from ..utils.traceable import Traceable + + +class BaseImportanceScoreEvaluator(Traceable): + """Importance Score Evaluator + + """ + + def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None: + """Base Constructor + + Args: + model: A Huggingface AutoModelWithLMHead model + tokenizer: A Huggingface AutoTokenizer + + """ + + self.model = model + self.tokenizer = tokenizer + + self.important_score = None + + self.trace_importance_score = None + self.trace_target_likelihood_original = None + + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + + """ + + raise NotImplementedError() + + @override + def trace_start(self): + """Start tracing + + """ + super().trace_start() + + self.trace_importance_score = [] + self.trace_target_likelihood_original = -1 + self.stopping_condition_evaluator.trace_start() + + @override + def trace_stop(self): + """Stop tracing + + """ + super().trace_stop() + + self.trace_importance_score = None + self.trace_target_likelihood_original = None + self.stopping_condition_evaluator.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py new file mode 100644 index 00000000..12a9f7f4 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py @@ -0,0 +1,245 @@ +import logging + +import torch +from transformers import AutoModelWithLMHead, AutoTokenizer + +from ..stopping_condition_evaluator.base import StoppingConditionEvaluator +from ..token_replacement.token_replacer.base import TokenReplacer +from .base import BaseImportanceScoreEvaluator + + +class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): + """Importance Score Evaluator + + """ + + def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator, max_steps: float) -> None: + """Constructor + + Args: + model: A Huggingface AutoModelWithLMHead model + tokenizer: A Huggingface AutoTokenizer + token_replacer: A TokenReplacer + stopping_condition_evaluator: A StoppingConditionEvaluator + + """ + + super().__init__(model, tokenizer) + + self.token_replacer = token_replacer + self.stopping_condition_evaluator = stopping_condition_evaluator + self.max_steps = max_steps + + self.important_score = None + self.trace_importance_score = None + self.trace_target_likelihood_original = None + self.num_steps = 0 + + def update_importance_score(self, logit_importance_score: torch.Tensor, input_ids: torch.Tensor, target_id: torch.Tensor, prob_original_target: torch.Tensor) -> torch.Tensor: + """Update importance score by one step + + Args: + logit_importance_score: Current importance score in logistic scale [batch] + input_ids: input tensor [batch, sequence] + target_id: target tensor [batch] + prob_original_target: predictive probability of the target on the original sequence [batch] + + Return: + logit_importance_score: updated importance score in logistic scale [batch] + + """ + # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} + + input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) + + logging.debug(f"Replacing mask: { mask_replacing }") + logging.debug(f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }") + + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) + + logits_replaced = self.model(input_ids_replaced)['logits'] + prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] + self.trace_prob_original_target = prob_replaced_target + + # Compute changes delta = p^{(y)} - \hat{p^{(y)}} + + delta_prob_target = prob_original_target - prob_replaced_target + logging.debug(f"likelihood delta: { delta_prob_target }") + + # Update importance scores based on delta (magnitude) and replacement (direction) + + delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target + # TODO: better solution? + # Rescaling from [-1, 1] to [0, 1] before logit function + logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) + logit_importance_score = logit_importance_score + logit_delta_score + logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") + + return logit_importance_score + + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + + """ + + self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) + + # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) + + logits_original = self.model(input_ids)['logits'] + prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] + + if self.trace_target_likelihood_original != None: + self.trace_target_likelihood_original = prob_original_target + + # Initialize importance score s for each token in the sequence y_{1...t} + + logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) + logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") + + # TODO: limit max steps + self.num_steps = 0 + while self.num_steps < self.max_steps: + self.num_steps += 1 + + # Update importance score + logit_importance_score_update = self.update_importance_score(logit_importance_score, input_ids, target_id, prob_original_target) + logit_importance_score = ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + + self.important_score = torch.softmax(logit_importance_score, -1) + if self.trace_importance_score != None: + self.trace_importance_score.append(self.important_score) + + # Evaluate stop condition + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate(input_ids, target_id, self.important_score) + if torch.prod(self.stop_mask) > 0: + break + + logging.info(f"Importance score evaluated in {self.num_steps} steps.") + + return torch.softmax(logit_importance_score, -1) + + + + + +class DeltaProbImportanceScoreEvaluator_imp(BaseImportanceScoreEvaluator): + """Importance Score Evaluator + + """ + + def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator) -> None: + """Constructor + + Args: + model: A Huggingface AutoModelWithLMHead model + tokenizer: A Huggingface AutoTokenizer + token_replacer: A TokenReplacer + stopping_condition_evaluator: A StoppingConditionEvaluator + + """ + + self.model = model + self.tokenizer = tokenizer + self.token_replacer = token_replacer + self.stopping_condition_evaluator = stopping_condition_evaluator + self.important_score = None + + self.trace_importance_score = None + self.trace_target_likelihood_original = None + self.num_steps = 0 + + def update_importance_score(self, logit_importance_score: torch.Tensor, input_ids: torch.Tensor, target_id: torch.Tensor, prob_original_target: torch.Tensor) -> torch.Tensor: + """Update importance score by one step + + Args: + logit_importance_score: Current importance score in logistic scale [batch] + input_ids: input tensor [batch, sequence] + target_id: target tensor [batch] + prob_original_target: predictive probability of the target on the original sequence [batch] + + Return: + logit_importance_score: updated importance score in logistic scale [batch] + + """ + # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} + + input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) + + logging.debug(f"Replacing mask: { mask_replacing }") + logging.debug(f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }") + + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) + + logits_replaced = self.model(input_ids_replaced)['logits'] + prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] + self.trace_prob_original_target = prob_replaced_target + + # Compute changes delta = p^{(y)} - \hat{p^{(y)}} + + delta_prob_target = prob_original_target - prob_replaced_target + logging.debug(f"likelihood delta: { delta_prob_target }") + + # Update importance scores based on delta (magnitude) and replacement (direction) + + delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target + logit_importance_score = logit_importance_score + delta_score + logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") + + return logit_importance_score + + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + + """ + + self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) + + # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) + + logits_original = self.model(input_ids)['logits'] + prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] + + if self.trace_target_likelihood_original != None: + self.trace_target_likelihood_original = prob_original_target + + # Initialize importance score s for each token in the sequence y_{1...t} + + logit_importance_score = torch.zeros(input_ids.shape, device=input_ids.device) + logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") + + # TODO: limit max steps + self.num_steps = 0 + while True: + self.num_steps += 1 + + # Update importance score + logit_importance_score_update = self.update_importance_score(logit_importance_score, input_ids, target_id, prob_original_target) + logit_importance_score = ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + + self.important_score = torch.softmax(logit_importance_score, -1) + if self.trace_importance_score != None: + self.trace_importance_score.append(self.important_score) + + # Evaluate stop condition + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate(input_ids, target_id, self.important_score) + if torch.prod(self.stop_mask) > 0: + break + + logging.info(f"Importance score evaluated in {self.num_steps} steps.") + + return torch.softmax(logit_importance_score, -1) diff --git a/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py b/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py new file mode 100644 index 00000000..8078a1f3 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py @@ -0,0 +1,245 @@ + +import math + +import torch +from .base import BaseRationalizer +from .importance_score_evaluator.base import BaseImportanceScoreEvaluator +from typing_extensions import override + + +class SampleRationalizer(BaseRationalizer): + """SampleRationalizer + + """ + + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator, top_n: float = 0, top_n_ratio: float = 0) -> None: + """Constructor + + Args: + importance_score_evaluator: A ImportanceScoreEvaluator + top_n: Rational size + top_n_ratio: Use ratio of sequence to define rational size + + """ + super().__init__(importance_score_evaluator) + + self.top_n = top_n + self.top_n_ratio = top_n_ratio + + @torch.no_grad() + def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] + target_id: The target [batch] + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + batch_importance_score = self.importance_score_evaluator.evaluate(input_ids, target_id) + + self.mean_important_score = torch.mean(batch_importance_score, dim=0) + + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) + + top_n = self.top_n + + if top_n == 0: + top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) + + pos_top_n = pos_sorted[:, :top_n] + + return pos_top_n + + @override + def trace_start(self) -> None: + """Start tracing + + """ + super().trace_start() + + self.importance_score_evaluator.trace_start() + + @override + def trace_stop(self) -> None: + """Stop tracing + + """ + super().trace_stop() + + self.importance_score_evaluator.trace_stop() + +@torch.no_grad() +def main(): + + from stopping_condition_evaluator.top_k import \ + TopKStoppingConditionEvaluator + from token_replacement.token_replacer.uniform import UniformTokenReplacer + from token_replacement.token_sampler.inferential import \ + InferentialTokenSampler + from token_replacement.token_sampler.postag import POSTagTokenSampler + from token_replacement.token_sampler.uniform import UniformTokenSampler + from transformers import AutoModelWithLMHead, AutoTokenizer + + from rationalization.rationalizer.importance_score_evaluator.delta_prob import \ + DeltaProbImportanceScoreEvaluator + from utils.serializing import serialize_rational + + # ======== model loading ======== + # Load model from Hugging Face + model = AutoModelWithLMHead.from_pretrained("gpt2-medium") + tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") + + model.cuda() + model.eval() + + # ======== prepare data ======== + + # batch with size 1 + input_string = [ + # "I love eating breakfast in the", + "When my flight landed in Thailand. I was staying in the capital city of" + # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. I was staying in the capital city of" + # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" + ] + + # generate prediction + input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'].to(model.device) + generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) + print(' generated input -->', [ [ tokenizer.decode(token) for token in seq] for seq in generated_input ]) + + # extract target from prediction + target_id = generated_input[:, input_ids.shape[1]] + print(' target -->', [ tokenizer.decode(token) for token in target_id ]) + + # ======== hyper-parameters ======== + + # replacing ratio during importance score updating + updating_replacing_ratio = 0.3 + # keep top n word based on importance score for both stop condition evaluation and rationalization + rationale_size_ratio = None + rational_size = 5 + # stop when target exist in top k predictions + stop_condition_tolerance = 5 + + # ======== rationalization ======== + + approach_sample_replacing_token = "uniform" + # approach_sample_replacing_token = "inference" + # approach_sample_replacing_token = "postag" + + # prepare rationalizer + if approach_sample_replacing_token == "uniform": + # Approach 1: sample replacing token from uniform distribution + rationalizer = SampleRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=UniformTokenSampler(tokenizer), + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=UniformTokenSampler(tokenizer), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + elif approach_sample_replacing_token == "inference": + # Approach 2: sample replacing token from model inference + rationalizer = SampleRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + elif approach_sample_replacing_token == "postag": + # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag + ts = POSTagTokenSampler(tokenizer=tokenizer, device=input_ids.device) # Initialize POSTagTokenSampler takes time so share it + rationalizer = SampleRationalizer( + importance_score_evaluator=DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer( + token_sampler=ts, + ratio=updating_replacing_ratio + ), + stopping_condition_evaluator=TopKStoppingConditionEvaluator( + model=model, + token_sampler=ts, + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer + ) + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio + ) + else: + raise ValueError("Invalid approach_sample_replacing_token") + + rationalizer.trace_start() + + # rationalization + pos_rational = rationalizer.rationalize(input_ids, generated_input[:, input_ids.shape[1]]) + + # convert results + + print() + print(f"========================") + print() + print(f'Input --> {input_string[0]}') + print(f'Target --> {tokenizer.decode(target_id[0])}') + print(f"Rational positions --> {pos_rational}") + print(f"Rational words -->") + for i in range(pos_rational.shape[0]): + ids_rational = input_ids[0, pos_rational[i]] + text_rational = [ tokenizer.decode([id_rational]) for id_rational in ids_rational ] + print(f"{text_rational}") + + # output + + serialize_rational( + "rationalization_results/demo.json", + -1, + input_ids[0], + target_id[0], + pos_rational[0], + tokenizer, + rationalizer.importance_score_evaluator.important_score[0], + compact=False, + comments= { + "message": "This is a demo output. [comments] is an optional field", + "model": "gpt2-medium", + "approach_type": approach_sample_replacing_token + }, + trace_rationalizer=rationalizer + ) + + rationalizer.trace_stop() + +if __name__ == '__main__': + main() diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py new file mode 100644 index 00000000..a92d16e9 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py @@ -0,0 +1,20 @@ +import torch + +from ..utils.traceable import Traceable + + +class StoppingConditionEvaluator(Traceable): + """Base class for Stopping Condition Evaluators + + """ + + def __init__(self): + """Base Constructor + + """ + self.trace_target_likelihood = [] + + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: + """Base evaluate + + """ diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py new file mode 100644 index 00000000..eb5c5b17 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py @@ -0,0 +1,38 @@ +import torch +from typing_extensions import override + +from .base import StoppingConditionEvaluator + + +class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Stopping Condition Evaluator which stop when target exist in top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + @override + def __init__(self) -> None: + """Constructor + + """ + super().__init__() + + @override + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + + Return: + Whether the stop condition achieved [batch] + + """ + super().evaluate(input_ids, target_id, importance_score) + + match_hit = torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) + + # Stop flags for each sample in the batch + return match_hit diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py new file mode 100644 index 00000000..cb16d506 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py @@ -0,0 +1,102 @@ +import logging + +import torch +from transformers import AutoModelWithLMHead, AutoTokenizer +from typing_extensions import override + +from ..token_replacement.token_replacer.ranking import RankingTokenReplacer +from ..token_replacement.token_sampler.base import TokenSampler +from .base import StoppingConditionEvaluator + + +class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Stopping Condition Evaluator which stop when target exist in top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + @override + def __init__(self, model: AutoModelWithLMHead, token_sampler: TokenSampler, top_k: int, top_n: int = 0, top_n_ratio: float = 0, tokenizer: AutoTokenizer = None) -> None: + """Constructor + + Args: + model: A Huggingface AutoModelWithLMHead. + token_sampler: A TokenSampler to sample replacement tokens + top_k: Stop condition achieved when target exist in top k predictions + top_n: Top n tokens based on importance_score are not been replaced during the prediction inference. + top_n_ratio will be used if top_n has been set to 0 + top_n_ratio: Use ratio of input length to control the top n + tokenizer: (Optional) Used for logging top_k_words at each step + + """ + super().__init__() + + self.model = model + self.token_sampler = token_sampler + self.top_k = top_k + self.token_replacer = RankingTokenReplacer(self.token_sampler, top_n, top_n_ratio) + self.tokenizer = tokenizer + + @override + def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + + Return: + Whether the stop condition achieved [batch] + + """ + super().evaluate(input_ids, target_id, importance_score) + + # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} + + self.token_replacer.set_score(importance_score) + input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) + + logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") + + # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} + + assert input_ids_replaced.requires_grad == False, "Error: auto-diff engine not disabled" + with torch.no_grad(): + logits_replaced = self.model(input_ids_replaced)['logits'] + + if self.trace_target_likelihood != None: + self.trace_target_likelihood.append(torch.softmax(logits_replaced, dim=-1)[:, -1, target_id]) + + ids_prediction_sorted = torch.argsort(logits_replaced[:, -1 ,:], descending=True) + ids_prediction_top_k = ids_prediction_sorted[:, :self.top_k] + + if self.tokenizer: + top_k_words = [ [ self.tokenizer.decode([token_id]) for token_id in seq] for seq in ids_prediction_top_k ] + logging.debug(f"Top K words -> {top_k_words}") + + match_mask = ids_prediction_top_k == target_id + match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool) + + # Stop flags for each sample in the batch + return match_hit + + @override + def trace_start(self) -> None: + """Start tracing + + """ + super().trace_start() + + self.token_sampler.trace_start() + self.token_replacer.trace_start() + + @override + def trace_stop(self) -> None: + """Stop tracing + + """ + super().trace_stop() + + self.token_sampler.trace_stop() + self.token_replacer.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py new file mode 100644 index 00000000..15e6a842 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py @@ -0,0 +1,43 @@ +from typing import Union + +import torch +from typing_extensions import override + +from ...utils.traceable import Traceable +from ..token_sampler.base import TokenSampler + + +class TokenReplacer(Traceable): + """ + Base class for token replacers + + """ + + def __init__(self, token_sampler: TokenSampler) -> None: + """Base Constructor + + """ + self.token_sampler = token_sampler + + def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Base sample + + """ + + @override + def trace_start(self): + """Start tracing + + """ + super().trace_start() + + self.token_sampler.trace_start() + + @override + def trace_stop(self): + """Stop tracing + + """ + super().trace_stop() + + self.token_sampler.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py new file mode 100644 index 00000000..7ef6b1d0 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py @@ -0,0 +1,68 @@ +import math +from typing import Union + +import torch +from typing_extensions import override + +from ..token_sampler.base import TokenSampler +from .base import TokenReplacer + + +class RankingTokenReplacer(TokenReplacer): + """Replace tokens in a sequence based on top-N ranking + + """ + + @override + def __init__(self, token_sampler: TokenSampler, top_n: int = 0, top_n_ratio: float = 0, replace_greater: bool = False) -> None: + """Constructor + + Args: + token_sampler: A TokenSampler for sampling replace token. + top_n: Top N as the threshold. If top_n is 0, use top_n_ratio instead. + top_n_ratio: Use ratio of input to control to top_n + replace_greater: Whether replace top-n. Otherwise, replace the rests. + + """ + super().__init__(token_sampler) + + self.top_n = top_n + self.top_n_ratio = top_n_ratio + self.replace_greater = replace_greater + + def set_score(self, value: torch.Tensor) -> None: + + pos_sorted = torch.argsort(value, descending=True) + + top_n = self.top_n + + + if top_n == 0: + top_n = int(math.ceil(self.top_n_ratio * value.shape[-1])) + + pos_top_n = pos_sorted[..., :top_n] + + if not self.replace_greater: + self.mask_replacing = torch.ones(value.shape, device=value.device, dtype=torch.bool).scatter(-1, pos_top_n, 0) + else: + self.mask_replacing = torch.zeros(value.shape, device=value.device, dtype=torch.bool).scatter(-1, pos_top_n, 1) + + @override + def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Sample a sequence + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + mask_replacing: Identify which token has been replaced [batch, sequence] + + """ + super().sample(input) + + token_sampled = self.token_sampler.sample(input) + + input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing + + return input_replaced, self.mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py new file mode 100644 index 00000000..a45bbe1e --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py @@ -0,0 +1,63 @@ + +from typing import Union + +import torch +from typing_extensions import override + +from ..token_sampler.base import TokenSampler +from .base import TokenReplacer + + +class ThresholdTokenReplacer(TokenReplacer): + """Replace tokens in a sequence based on a threshold + + """ + + @override + def __init__(self, token_sampler: TokenSampler, threshold: float, replace_greater: bool = False) -> None: + """Constructor + + Args: + token_sampler: A TokenSampler for sampling replace token. + threshold: replacing threshold + replace_greater: Whether replace top-n. Otherwise, replace the rests. + + """ + super().__init__(token_sampler) + + self.threshold = threshold + self.replace_greater = replace_greater + + def set_value(self, value: torch.Tensor) -> None: + """Set the value for threshold control + + Args: + value: value [batch, sequence] + + """ + if not self.replace_greater: + self.mask_replacing = value < self.threshold + else: + self.mask_replacing = value > self.threshold + + @override + def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Sample a sequence + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + mask_replacing: Identify which token has been replaced [batch, sequence] + + """ + super().sample(input) + + token_sampled = self.token_sampler.sample(input) + + input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing + + return input_replaced, self.mask_replacing + + \ No newline at end of file diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py new file mode 100644 index 00000000..5c4bcf75 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py @@ -0,0 +1,51 @@ + +from typing import Union + +import torch +from typing_extensions import override + +from ..token_sampler.base import TokenSampler +from .base import TokenReplacer + + +class UniformTokenReplacer(TokenReplacer): + """Replace tokens in a sequence where selecting is base on uniform distribution + + """ + + @override + def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: + """Constructor + + Args: + token_sampler: A TokenSampler for sampling replace token. + ratio: replacing ratio + + """ + super().__init__(token_sampler) + + self.ratio = ratio + + @override + def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Sample a sequence + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + mask_replacing: Identify which token has been replaced [batch, sequence] + + """ + super().sample(input) + + sample_uniform = torch.rand(input.shape, device=input.device) + mask_replacing = sample_uniform < self.ratio + + token_sampled = self.token_sampler.sample(input) + + input_replaced = input * ~mask_replacing + token_sampled * mask_replacing + + return input_replaced, mask_replacing + \ No newline at end of file diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py new file mode 100644 index 00000000..16b4ee19 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py @@ -0,0 +1,22 @@ +import torch +from typing_extensions import override + +from ...utils.traceable import Traceable + + +class TokenSampler(Traceable): + """Base class for token samplers + + """ + + @override + def __init__(self) -> None: + """Base Constructor + + """ + super().__init__() + + def sample(self, input: torch.Tensor) -> torch.Tensor: + """Base sample + + """ diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py new file mode 100644 index 00000000..1d48e277 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py @@ -0,0 +1,45 @@ +import torch +from transformers import AutoModelWithLMHead, AutoTokenizer +from typing_extensions import override + +from .base import TokenSampler + + +class InferentialTokenSampler(TokenSampler): + """Sample tokens from a seq-2-seq model + + """ + + @override + def __init__(self, tokenizer: AutoTokenizer, model: AutoModelWithLMHead) -> None: + """Constructor + + Args: + tokenizer: A Huggingface AutoTokenizer. + model: A Huggingface AutoModelWithLMHead for inference the output. + + """ + super().__init__() + + self.tokenizer = tokenizer + self.model = model + + @override + def sample(self, input: torch.Tensor) -> torch.Tensor: + """Sample a tensor + + Args: + input: input tensor [batch, sequence] + + Returns: + token_inferences: sampled (placement) tokens by inference + + """ + super().sample(input) + + logits_replacing = self.model(input)['logits'] + ids_infer = torch.argmax(logits_replacing, dim=-1) + + token_inferences = torch.cat([ input[:, 0:1], ids_infer[:, :-1] ], dim=1) + + return token_inferences diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py new file mode 100644 index 00000000..d02bb605 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py @@ -0,0 +1,91 @@ +import torch +from transformers import AutoModelWithLMHead, AutoTokenizer +from typing_extensions import override + +from .base import TokenSampler + + +class InferentialMTokenSampler(TokenSampler): + """Sample tokens from a seq-2-seq model + + """ + + @override + def __init__(self, source_tokenizer: AutoTokenizer, sampler_tokenizer: AutoTokenizer, sampler_model: AutoModelWithLMHead) -> None: + """Constructor + + Args: + source_tokenizer: A Huggingface AutoTokenizer for decoding the inputs. + sampler_tokenizer: A Huggingface AutoTokenizer for inference the output. + sampler_model: A Huggingface AutoModelWithLMHead for inference the output. + + """ + super().__init__() + + self.source_tokenizer = source_tokenizer + self.sampler_tokenizer = sampler_tokenizer + self.sampler_model = sampler_model + + @override + def sample(self, inputs: torch.Tensor) -> torch.Tensor: + """Sample a tensor + + Args: + inputs: input tensor [batch, sequence] + + Returns: + token_inferences: sampled (placement) tokens by inference + + """ + super().sample(inputs) + + batch_li = [] + for seq_i in torch.arange(inputs.shape[0]): + seq_li = [] + for pos_i in torch.arange(inputs.shape[1]): + + # first token + if pos_i == 0: + seq_li.append(inputs[seq_i, 0]) + continue + + # following tokens + + probe_prefix = torch.tensor([self.sampler_tokenizer.encode(self.source_tokenizer.decode(inputs[seq_i, :pos_i]))], device=inputs.device) + probe_prefix = probe_prefix[:,:-1] # trim + output_replacing_m = self.sampler_model(probe_prefix) + logits_replacing_m = output_replacing_m['logits'] + logits_replacing_m_last = logits_replacing_m[:,-1] + id_infer_m = torch.argmax(logits_replacing_m_last, dim=-1) + + seq_li.append(id_infer_m.item()) + + batch_li.append(seq_li) + + res = torch.tensor(batch_li, device=inputs.device) + + return res + +if __name__ == "__main__": + from transformers import AutoModelForCausalLM, AutoTokenizer + device = "cpu" + + source_tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="cache") + source_model = AutoModelForCausalLM.from_pretrained("gpt2", cache_dir="cache").to(device) + source_model.eval() + + sampler_tokenizer = AutoTokenizer.from_pretrained("roberta-base", cache_dir="cache") + sampler_model = AutoModelForCausalLM.from_pretrained("roberta-base", cache_dir="cache").to(device) + sampler_model.eval() + + sampler = InferentialMTokenSampler(source_tokenizer, sampler_tokenizer, sampler_model) + + text = "This is a test sequence" + inputs = torch.tensor([ source_tokenizer.encode(text) ], device=device) + + outputs = sampler.sample(inputs) + + print(outputs) + print(source_tokenizer.decode(outputs[0])) + + diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py new file mode 100644 index 00000000..0cb1590d --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py @@ -0,0 +1,78 @@ +import nltk +import torch +from transformers import AutoTokenizer +from typing_extensions import override + +from .base import TokenSampler + + +class POSTagTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution on a set of words with the same POS tag + + """ + + @override + def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: + """Constructor + + Args: + tokenizer: A Huggingface AutoTokenizer. + + """ + super().__init__() + + self.tokenizer = tokenizer + + # extract mapping from postag to words + # debug_mapping_postag_to_group_word = {} + mapping_postag_to_group_token_id = {} + + for i in range(tokenizer.vocab_size): + word = tokenizer.decode([i]) + _, tag = nltk.pos_tag([word.strip()])[0] + if tag not in mapping_postag_to_group_token_id: + # debug_mapping_postag_to_group_word[tag] = [] + mapping_postag_to_group_token_id[tag] = [] + # debug_mapping_postag_to_group_word[tag].append(word) + mapping_postag_to_group_token_id[tag].append(i) + + if i % 5000 == 0: + print(f"[POSTagTokenSampler] Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") + + # create tag_id for postags + self.list_postag = [ tag for tag in mapping_postag_to_group_token_id.keys() ] + num_postags = len(self.list_postag) + + # build mapping from tag_id to word group + list_group_token_id = [ torch.tensor(mapping_postag_to_group_token_id[postag], dtype=torch.long, device=device) for postag in self.list_postag ] + + # build mapping from token_id to tag_id + self.mapping_token_id_to_tag_id = torch.zeros([tokenizer.vocab_size], dtype=torch.long, device=device) + for tag_id, group_token_id in enumerate(list_group_token_id): + self.mapping_token_id_to_tag_id[group_token_id] = tag_id + + # build mapping from tag_id to token_id + # postag groups are concat together, index them via compact_idx = group_offsets[tag_id] + group_idx + self.group_sizes = torch.tensor([ group_token_id.shape[0] for group_token_id in list_group_token_id ], dtype=torch.long, device=device) + self.group_offsets = torch.sum(torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.group_sizes, dim=-1) + self.compact_group_token_id = torch.cat(list_group_token_id) + + @override + def sample(self, input: torch.Tensor) -> torch.Tensor: + """Sample a input + + Args: + input: input tensor [batch, sequence] + + Returns: + token_sampled: A sampled tensor where its shape is the same with the input + + """ + super().sample(input) + + tag_id_input = self.mapping_token_id_to_tag_id[input] + sample_uniform = torch.rand(input.shape, device=input.device) + compact_group_idx = (sample_uniform * self.group_sizes[tag_id_input] + self.group_offsets[tag_id_input]).type(torch.long) + token_sampled = self.compact_group_token_id[compact_group_idx] + + return token_sampled diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py new file mode 100644 index 00000000..3d49a914 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py @@ -0,0 +1,55 @@ +import torch +from transformers import AutoTokenizer +from typing_extensions import override + +from .base import TokenSampler + + +class UniformTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution + + """ + + @override + def __init__(self, tokenizer: AutoTokenizer) -> None: + """Constructor + + Args: + tokenizer: A Huggingface AutoTokenizer. + + """ + super().__init__() + + self.tokenizer = tokenizer + + # masking tokens + avail_mask = torch.ones(tokenizer.vocab_size) + + # mask out special tokens + avail_mask[tokenizer.bos_token_id] = 0 + avail_mask[tokenizer.eos_token_id] = 0 + avail_mask[tokenizer.unk_token_id] = 0 + + # collect available tokens + self.avail_tokens = torch.arange(tokenizer.vocab_size)[avail_mask != 0] + + @override + def sample(self, input: torch.Tensor) -> torch.Tensor: + """Sample a tensor + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + + """ + super().sample(input) + + # sample idx form uniform distribution + sample_uniform = torch.rand(input.shape, device=input.device) + sample_uniform_idx = (sample_uniform * self.avail_tokens.shape[0]).type(torch.int32) + # map idx to tokens + token_uniform = self.avail_tokens.to(sample_uniform_idx)[sample_uniform_idx] + + return token_uniform diff --git a/inseq/attr/feat/ops/reagent_core/utils/serializing.py b/inseq/attr/feat/ops/reagent_core/utils/serializing.py new file mode 100644 index 00000000..dfa85df1 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/utils/serializing.py @@ -0,0 +1,71 @@ +import json + +import torch +from transformers import AutoTokenizer + +from ..base import BaseRationalizer + + +def serialize_rational( + filename: str, + id: int, + token_inputs: torch.Tensor, + token_target: torch.Tensor, + position_rational: torch.Tensor, + tokenizer: AutoTokenizer, + important_score: torch.Tensor, + comments: dict = None, + compact: bool = False, + trace_rationalizer: BaseRationalizer = None, + trace_batch_idx: int = 0, + schema_file: str = "../docs/rationalization.schema.json" +) -> None: + """Serialize rationalization result to a json file + + Args: + filename: Filename to store json file + id: id of the record + token_inputs: token_inputs [sequence] + token_target: token_target [1] + position_rational: position of rational tokens [rational] + tokenizer: A Huggingface AutoTokenizer + important_score: final important score of tokens [sequence] + comments: (Optional) A dictionary of comments + compact: Whether store json file in a compact style + trace_rationalizer: (Optional) A Rationalizer with trace started to store trace information + trace_batch_idx: trace index in the batch, if applicable + schema_file: location of the json schema file + + """ + data = { + "$schema": schema_file, + "id": id, + "input-text": [tokenizer.decode([i]) for i in token_inputs], + "input-tokens": [i.item() for i in token_inputs], + "target-text": tokenizer.decode([token_target]), + "target-token": token_target.item(), + "rational-size": position_rational.shape[0], + "rational-positions": [i.item() for i in position_rational], + "rational-text": [tokenizer.decode([i]) for i in token_inputs[position_rational]], + "rational-tokens": [i.item() for i in token_inputs[position_rational]], + } + + if important_score != None: + data["importance-scores"] = [i.item() for i in important_score] + + if comments: + data["comments"] = comments + + if trace_rationalizer: + trace = { + "importance-scores": [ [ v.item() for v in i[trace_batch_idx] ] for i in trace_rationalizer.importance_score_evaluator.trace_importance_score ], + "target-likelihood-original": trace_rationalizer.importance_score_evaluator.trace_target_likelihood_original[trace_batch_idx].item(), + "target-likelihood": [ i[trace_batch_idx].item() for i in trace_rationalizer.importance_score_evaluator.stopping_condition_evaluator.trace_target_likelihood ] + } + data["trace"] = trace + + indent = None if compact else 4 + json_str = json.dumps(data, indent=indent) + + with open(filename, 'w') as f_output: + f_output.write(json_str) diff --git a/inseq/attr/feat/ops/reagent_core/utils/traceable.py b/inseq/attr/feat/ops/reagent_core/utils/traceable.py new file mode 100644 index 00000000..94588e8c --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/utils/traceable.py @@ -0,0 +1,14 @@ +class Traceable: + """Traceable base + + """ + + def trace_start(self) -> None: + """Base trace_start + + """ + + def trace_stop(self) -> None: + """Base trace_stop + + """ From 36fde624fa3aa3b316261c902d1935208e7d1a88 Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Mon, 19 Feb 2024 19:29:43 +0000 Subject: [PATCH 04/14] lint for ReAGent --- inseq/attr/feat/ops/reagent.py | 39 ++-- .../reagent_core/aggregate_rationalizer.py | 198 +++++++++--------- inseq/attr/feat/ops/reagent_core/base.py | 1 - .../importance_score_evaluator/base.py | 18 +- .../importance_score_evaluator/delta_prob.py | 110 ++++++---- .../ops/reagent_core/sample_rationalizer.py | 170 +++++++-------- .../stopping_condition_evaluator/base.py | 16 +- .../stopping_condition_evaluator/dummy.py | 10 +- .../stopping_condition_evaluator/top_k.py | 40 ++-- .../token_replacement/token_replacer/base.py | 18 +- .../token_replacer/ranking.py | 20 +- .../token_replacer/threshold.py | 11 +- .../token_replacer/uniform.py | 8 +- .../token_replacement/token_sampler/base.py | 12 +- .../token_sampler/inferential.py | 10 +- .../token_sampler/inferential_m.py | 34 +-- .../token_replacement/token_sampler/postag.py | 27 ++- .../token_sampler/uniform.py | 8 +- .../ops/reagent_core/utils/serializing.py | 22 +- .../feat/ops/reagent_core/utils/traceable.py | 12 +- inseq/attr/feat/perturbation_attribution.py | 8 +- 21 files changed, 399 insertions(+), 393 deletions(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 5bc9c772..8e3e9765 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -1,15 +1,18 @@ from typing import Any, Callable, Union -from typing_extensions import override -from captum.attr._utils.attribution import PerturbationAttribution + +import torch from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import PerturbationAttribution from torch import Tensor -import torch +from typing_extensions import override + from .reagent_core.aggregate_rationalizer import AggregateRationalizer from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer from .reagent_core.token_replacement.token_sampler.postag import POSTagTokenSampler + class ReAGent(PerturbationAttribution): r""" ReAGent @@ -33,8 +36,8 @@ class ReAGent(PerturbationAttribution): ``` import inseq - model = inseq.load_model("gpt2-medium", "ReAGent", - rational_size=5, + model = inseq.load_model("gpt2-medium", "ReAGent", + rational_size=5, rational_size_ratio=None, stopping_condition_top_k=3, replacing_ratio=0.3, @@ -50,12 +53,12 @@ class ReAGent(PerturbationAttribution): def __init__( self, attribution_model: Callable, - rational_size: int=5, - rational_size_ratio: float=None, - stopping_condition_top_k: int=3, - replacing_ratio: float=0.3, - max_probe_steps: int=3000, - num_probes: int=8, + rational_size: int = 5, + rational_size_ratio: float = None, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 8, ) -> None: PerturbationAttribution.__init__(self, forward_func=attribution_model) @@ -70,18 +73,15 @@ def __init__( top_k=stopping_condition_top_k, top_n=rational_size, top_n_ratio=rational_size_ratio, - tokenizer=tokenizer + tokenizer=tokenizer, ) importance_score_evaluator = DeltaProbImportanceScoreEvaluator( model=model, tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=token_sampler, - ratio=replacing_ratio - ), + token_replacer=UniformTokenReplacer(token_sampler=token_sampler, ratio=replacing_ratio), stopping_condition_evaluator=stopping_condition_evaluator, - max_steps=max_probe_steps + max_steps=max_probe_steps, ) self.rationalizer = AggregateRationalizer( @@ -90,7 +90,7 @@ def __init__( overlap_threshold=0, overlap_strict_pos=True, top_n=rational_size, - top_n_ratio=rational_size_ratio + top_n_ratio=rational_size_ratio, ) @override @@ -103,8 +103,7 @@ def attribute( # type: ignore TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: - """Implement attribute - """ + """Implement attribute""" self.rationalizer.rationalize(additional_forward_args[0], additional_forward_args[1]) mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) diff --git a/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py b/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py index 66f6bee3..b19e326a 100644 --- a/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py @@ -1,20 +1,25 @@ - import math from typing import Union import torch +from typing_extensions import override + from .base import BaseRationalizer from .importance_score_evaluator.base import BaseImportanceScoreEvaluator -from typing_extensions import override - class AggregateRationalizer(BaseRationalizer): - """AggregateRationalizer - - """ - - def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator, batch_size: int, overlap_threshold: int, overlap_strict_pos: bool = True, top_n: float = 0, top_n_ratio: float = 0) -> None: + """AggregateRationalizer""" + + def __init__( + self, + importance_score_evaluator: BaseImportanceScoreEvaluator, + batch_size: int, + overlap_threshold: int, + overlap_strict_pos: bool = True, + top_n: float = 0, + top_n_ratio: float = 0, + ) -> None: """Constructor Args: @@ -34,11 +39,10 @@ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator, bat self.top_n = top_n self.top_n_ratio = top_n_ratio - assert overlap_strict_pos == True, "overlap_strict_pos = False not been supported yet" + assert overlap_strict_pos, "overlap_strict_pos = False not been supported yet" def get_separate_rational(self, input_ids, tokenizer) -> Union[torch.Tensor, list[list[str]]]: - - tokens = [ [ tokenizer.decode([input_ids[0, i]]) for i in s] for s in self.pos_top_n ] + tokens = [[tokenizer.decode([input_ids[0, i]]) for i in s] for s in self.pos_top_n] return self.pos_top_n, tokens @@ -59,9 +63,13 @@ def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch batch_input_ids = input_ids.repeat(self.batch_size, 1) batch_importance_score = self.importance_score_evaluator.evaluate(batch_input_ids, target_id) - - important_score_masked = batch_importance_score * torch.unsqueeze(self.importance_score_evaluator.stop_mask, -1) - self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum(self.importance_score_evaluator.stop_mask) + + important_score_masked = batch_importance_score * torch.unsqueeze( + self.importance_score_evaluator.stop_mask, -1 + ) + self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum( + self.importance_score_evaluator.stop_mask + ) pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) @@ -75,30 +83,29 @@ def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch if self.overlap_strict_pos: count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) - pos_top_n_overlap = torch.unsqueeze(torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0) + pos_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) return pos_top_n_overlap else: token_id_top_n = input_ids[0, pos_top_n] count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) - token_id_top_n_overlap = torch.unsqueeze(torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0) + _token_id_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) # TODO: Convert back to pos raise NotImplementedError("TODO") - @override def trace_start(self) -> None: - """Start tracing - - """ + """Start tracing""" super().trace_start() self.importance_score_evaluator.trace_start() @override def trace_stop(self) -> None: - """Stop tracing - - """ + """Stop tracing""" super().trace_stop() self.importance_score_evaluator.trace_stop() @@ -106,18 +113,14 @@ def trace_stop(self) -> None: @torch.no_grad() def main(): - - from stopping_condition_evaluator.top_k import \ - TopKStoppingConditionEvaluator + from rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator + from stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator from token_replacement.token_replacer.uniform import UniformTokenReplacer - from token_replacement.token_sampler.inferential import \ - InferentialTokenSampler + from token_replacement.token_sampler.inferential import InferentialTokenSampler from token_replacement.token_sampler.postag import POSTagTokenSampler from token_replacement.token_sampler.uniform import UniformTokenSampler from transformers import AutoModelWithLMHead, AutoTokenizer - from rationalization.rationalizer.importance_score_evaluator.delta_prob import \ - DeltaProbImportanceScoreEvaluator from utils.serializing import serialize_rational # ======== model loading ======== @@ -127,7 +130,7 @@ def main(): model.cuda() model.eval() - + # ======== prepare data ======== # batch with size 1 @@ -138,14 +141,14 @@ def main(): # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" ] - # generate prediction - input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'].to(model.device) - generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) - print(' generated input -->', [ [ tokenizer.decode(token) for token in seq] for seq in generated_input ]) + # generate prediction + input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(model.device) + generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) + print(" generated input -->", [[tokenizer.decode(token) for token in seq] for seq in generated_input]) # extract target from prediction target_id = generated_input[:, input_ids.shape[1]] - print(' target -->', [ tokenizer.decode(token) for token in target_id ]) + print(" target -->", [tokenizer.decode(token) for token in target_id]) # ======== hyper-parameters ======== @@ -165,7 +168,7 @@ def main(): overlap_strict_pos = True # ======== rationalization ======== - + approach_sample_replacing_token = "uniform" # approach_sample_replacing_token = "inference" # approach_sample_replacing_token = "postag" @@ -175,81 +178,79 @@ def main(): # Approach 1: sample replacing token from uniform distribution rationalizer = AggregateRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, + model=model, + tokenizer=tokenizer, token_replacer=UniformTokenReplacer( - token_sampler=UniformTokenSampler(tokenizer), - ratio=updating_replacing_ratio + token_sampler=UniformTokenSampler(tokenizer), ratio=updating_replacing_ratio ), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=UniformTokenSampler(tokenizer), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), + model=model, + token_sampler=UniformTokenSampler(tokenizer), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), batch_size=aggregate_batch_size, overlap_threshold=overlap_threshold, overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) elif approach_sample_replacing_token == "inference": # Approach 2: sample replacing token from model inference rationalizer = AggregateRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, + model=model, + tokenizer=tokenizer, token_replacer=UniformTokenReplacer( - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - ratio=updating_replacing_ratio + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + ratio=updating_replacing_ratio, ), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), + model=model, + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), batch_size=aggregate_batch_size, overlap_threshold=overlap_threshold, overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) elif approach_sample_replacing_token == "postag": # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag - ts = POSTagTokenSampler(tokenizer=tokenizer, device=input_ids.device) # Initialize POSTagTokenSampler takes time so share it + ts = POSTagTokenSampler( + tokenizer=tokenizer, device=input_ids.device + ) # Initialize POSTagTokenSampler takes time so share it rationalizer = AggregateRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=ts, - ratio=updating_replacing_ratio - ), + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer(token_sampler=ts, ratio=updating_replacing_ratio), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=ts, - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), + model=model, + token_sampler=ts, + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), batch_size=aggregate_batch_size, overlap_threshold=overlap_threshold, overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) else: raise ValueError("Invalid approach_sample_replacing_token") - + rationalizer.trace_start() # rationalization @@ -258,37 +259,38 @@ def main(): # convert results print() - print(f"========================") + print("========================") print() - print(f'Input --> {input_string[0]}') - print(f'Target --> {tokenizer.decode(target_id[0])}') + print(f"Input --> {input_string[0]}") + print(f"Target --> {tokenizer.decode(target_id[0])}") print(f"Rational positions --> {pos_rational}") - print(f"Rational words -->") + print("Rational words -->") for i in range(pos_rational.shape[0]): ids_rational = input_ids[0, pos_rational[i]] - text_rational = [ tokenizer.decode([id_rational]) for id_rational in ids_rational ] + text_rational = [tokenizer.decode([id_rational]) for id_rational in ids_rational] print(f"{text_rational}") # output serialize_rational( - "rationalization_results/demo.json", - -1, - input_ids[0], - target_id[0], - pos_rational[0], - tokenizer, + "rationalization_results/demo.json", + -1, + input_ids[0], + target_id[0], + pos_rational[0], + tokenizer, rationalizer.importance_score_evaluator.important_score[0], compact=False, - comments= { + comments={ "message": "This is a demo output. [comments] is an optional field", "model": "gpt2-medium", - "approach_type": approach_sample_replacing_token + "approach_type": approach_sample_replacing_token, }, - trace_rationalizer=rationalizer + trace_rationalizer=rationalizer, ) rationalizer.trace_stop() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/inseq/attr/feat/ops/reagent_core/base.py b/inseq/attr/feat/ops/reagent_core/base.py index f7e10cb3..8512ae06 100644 --- a/inseq/attr/feat/ops/reagent_core/base.py +++ b/inseq/attr/feat/ops/reagent_core/base.py @@ -3,7 +3,6 @@ class BaseRationalizer(Traceable): - def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: super().__init__() diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py index 831b3c45..e627cac7 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py @@ -2,15 +2,11 @@ from transformers import AutoModelWithLMHead, AutoTokenizer from typing_extensions import override -from ..stopping_condition_evaluator.base import StoppingConditionEvaluator -from ..token_replacement.token_replacer.base import TokenReplacer from ..utils.traceable import Traceable class BaseImportanceScoreEvaluator(Traceable): - """Importance Score Evaluator - - """ + """Importance Score Evaluator""" def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None: """Base Constructor @@ -23,7 +19,7 @@ def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None self.model = model self.tokenizer = tokenizer - + self.important_score = None self.trace_importance_score = None @@ -42,12 +38,10 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te """ raise NotImplementedError() - + @override def trace_start(self): - """Start tracing - - """ + """Start tracing""" super().trace_start() self.trace_importance_score = [] @@ -56,9 +50,7 @@ def trace_start(self): @override def trace_stop(self): - """Stop tracing - - """ + """Stop tracing""" super().trace_stop() self.trace_importance_score = None diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py index 12a9f7f4..7923d3b2 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py @@ -9,11 +9,16 @@ class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): - """Importance Score Evaluator - - """ - - def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator, max_steps: float) -> None: + """Importance Score Evaluator""" + + def __init__( + self, + model: AutoModelWithLMHead, + tokenizer: AutoTokenizer, + token_replacer: TokenReplacer, + stopping_condition_evaluator: StoppingConditionEvaluator, + max_steps: float, + ) -> None: """Constructor Args: @@ -35,7 +40,13 @@ def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_r self.trace_target_likelihood_original = None self.num_steps = 0 - def update_importance_score(self, logit_importance_score: torch.Tensor, input_ids: torch.Tensor, target_id: torch.Tensor, prob_original_target: torch.Tensor) -> torch.Tensor: + def update_importance_score( + self, + logit_importance_score: torch.Tensor, + input_ids: torch.Tensor, + target_id: torch.Tensor, + prob_original_target: torch.Tensor, + ) -> torch.Tensor: """Update importance score by one step Args: @@ -53,11 +64,13 @@ def update_importance_score(self, logit_importance_score: torch.Tensor, input_id input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) logging.debug(f"Replacing mask: { mask_replacing }") - logging.debug(f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }") - + logging.debug( + f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }" + ) + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) - logits_replaced = self.model(input_ids_replaced)['logits'] + logits_replaced = self.model(input_ids_replaced)["logits"] prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] self.trace_prob_original_target = prob_replaced_target @@ -93,10 +106,10 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) - logits_original = self.model(input_ids)['logits'] + logits_original = self.model(input_ids)["logits"] prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] - if self.trace_target_likelihood_original != None: + if self.trace_target_likelihood_original is not None: self.trace_target_likelihood_original = prob_original_target # Initialize importance score s for each token in the sequence y_{1...t} @@ -108,34 +121,42 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te self.num_steps = 0 while self.num_steps < self.max_steps: self.num_steps += 1 - + # Update importance score - logit_importance_score_update = self.update_importance_score(logit_importance_score, input_ids, target_id, prob_original_target) - logit_importance_score = ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + logit_importance_score_update = self.update_importance_score( + logit_importance_score, input_ids, target_id, prob_original_target + ) + logit_importance_score = ( + ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + ) self.important_score = torch.softmax(logit_importance_score, -1) - if self.trace_importance_score != None: + if self.trace_importance_score is not None: self.trace_importance_score.append(self.important_score) # Evaluate stop condition - self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate(input_ids, target_id, self.important_score) + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate( + input_ids, target_id, self.important_score + ) if torch.prod(self.stop_mask) > 0: break - + logging.info(f"Importance score evaluated in {self.num_steps} steps.") return torch.softmax(logit_importance_score, -1) - - - class DeltaProbImportanceScoreEvaluator_imp(BaseImportanceScoreEvaluator): - """Importance Score Evaluator - - """ - - def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator) -> None: + """Importance Score Evaluator""" + + def __init__( + self, + model: AutoModelWithLMHead, + tokenizer: AutoTokenizer, + token_replacer: TokenReplacer, + stopping_condition_evaluator: StoppingConditionEvaluator, + ) -> None: """Constructor Args: @@ -156,7 +177,13 @@ def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_r self.trace_target_likelihood_original = None self.num_steps = 0 - def update_importance_score(self, logit_importance_score: torch.Tensor, input_ids: torch.Tensor, target_id: torch.Tensor, prob_original_target: torch.Tensor) -> torch.Tensor: + def update_importance_score( + self, + logit_importance_score: torch.Tensor, + input_ids: torch.Tensor, + target_id: torch.Tensor, + prob_original_target: torch.Tensor, + ) -> torch.Tensor: """Update importance score by one step Args: @@ -174,11 +201,13 @@ def update_importance_score(self, logit_importance_score: torch.Tensor, input_id input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) logging.debug(f"Replacing mask: { mask_replacing }") - logging.debug(f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }") - + logging.debug( + f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }" + ) + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) - logits_replaced = self.model(input_ids_replaced)['logits'] + logits_replaced = self.model(input_ids_replaced)["logits"] prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] self.trace_prob_original_target = prob_replaced_target @@ -211,10 +240,10 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) - logits_original = self.model(input_ids)['logits'] + logits_original = self.model(input_ids)["logits"] prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] - if self.trace_target_likelihood_original != None: + if self.trace_target_likelihood_original is not None: self.trace_target_likelihood_original = prob_original_target # Initialize importance score s for each token in the sequence y_{1...t} @@ -226,20 +255,27 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te self.num_steps = 0 while True: self.num_steps += 1 - + # Update importance score - logit_importance_score_update = self.update_importance_score(logit_importance_score, input_ids, target_id, prob_original_target) - logit_importance_score = ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + logit_importance_score_update = self.update_importance_score( + logit_importance_score, input_ids, target_id, prob_original_target + ) + logit_importance_score = ( + ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + ) self.important_score = torch.softmax(logit_importance_score, -1) - if self.trace_importance_score != None: + if self.trace_importance_score is not None: self.trace_importance_score.append(self.important_score) # Evaluate stop condition - self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate(input_ids, target_id, self.important_score) + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate( + input_ids, target_id, self.important_score + ) if torch.prod(self.stop_mask) > 0: break - + logging.info(f"Importance score evaluated in {self.num_steps} steps.") return torch.softmax(logit_importance_score, -1) diff --git a/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py b/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py index 8078a1f3..29878217 100644 --- a/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py @@ -1,18 +1,18 @@ - import math import torch +from typing_extensions import override + from .base import BaseRationalizer from .importance_score_evaluator.base import BaseImportanceScoreEvaluator -from typing_extensions import override class SampleRationalizer(BaseRationalizer): - """SampleRationalizer - - """ + """SampleRationalizer""" - def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator, top_n: float = 0, top_n_ratio: float = 0) -> None: + def __init__( + self, importance_score_evaluator: BaseImportanceScoreEvaluator, top_n: float = 0, top_n_ratio: float = 0 + ) -> None: """Constructor Args: @@ -41,50 +41,43 @@ def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch batch_importance_score = self.importance_score_evaluator.evaluate(input_ids, target_id) self.mean_important_score = torch.mean(batch_importance_score, dim=0) - + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) top_n = self.top_n if top_n == 0: top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) - + pos_top_n = pos_sorted[:, :top_n] return pos_top_n @override def trace_start(self) -> None: - """Start tracing - - """ + """Start tracing""" super().trace_start() self.importance_score_evaluator.trace_start() @override def trace_stop(self) -> None: - """Stop tracing - - """ + """Stop tracing""" super().trace_stop() self.importance_score_evaluator.trace_stop() + @torch.no_grad() def main(): - - from stopping_condition_evaluator.top_k import \ - TopKStoppingConditionEvaluator + from rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator + from stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator from token_replacement.token_replacer.uniform import UniformTokenReplacer - from token_replacement.token_sampler.inferential import \ - InferentialTokenSampler + from token_replacement.token_sampler.inferential import InferentialTokenSampler from token_replacement.token_sampler.postag import POSTagTokenSampler from token_replacement.token_sampler.uniform import UniformTokenSampler from transformers import AutoModelWithLMHead, AutoTokenizer - from rationalization.rationalizer.importance_score_evaluator.delta_prob import \ - DeltaProbImportanceScoreEvaluator from utils.serializing import serialize_rational # ======== model loading ======== @@ -94,7 +87,7 @@ def main(): model.cuda() model.eval() - + # ======== prepare data ======== # batch with size 1 @@ -105,14 +98,14 @@ def main(): # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" ] - # generate prediction - input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'].to(model.device) - generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) - print(' generated input -->', [ [ tokenizer.decode(token) for token in seq] for seq in generated_input ]) + # generate prediction + input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(model.device) + generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) + print(" generated input -->", [[tokenizer.decode(token) for token in seq] for seq in generated_input]) # extract target from prediction target_id = generated_input[:, input_ids.shape[1]] - print(' target -->', [ tokenizer.decode(token) for token in target_id ]) + print(" target -->", [tokenizer.decode(token) for token in target_id]) # ======== hyper-parameters ======== @@ -125,7 +118,7 @@ def main(): stop_condition_tolerance = 5 # ======== rationalization ======== - + approach_sample_replacing_token = "uniform" # approach_sample_replacing_token = "inference" # approach_sample_replacing_token = "postag" @@ -135,72 +128,70 @@ def main(): # Approach 1: sample replacing token from uniform distribution rationalizer = SampleRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, + model=model, + tokenizer=tokenizer, token_replacer=UniformTokenReplacer( - token_sampler=UniformTokenSampler(tokenizer), - ratio=updating_replacing_ratio + token_sampler=UniformTokenSampler(tokenizer), ratio=updating_replacing_ratio ), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=UniformTokenSampler(tokenizer), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio + model=model, + token_sampler=UniformTokenSampler(tokenizer), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) elif approach_sample_replacing_token == "inference": # Approach 2: sample replacing token from model inference rationalizer = SampleRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, + model=model, + tokenizer=tokenizer, token_replacer=UniformTokenReplacer( - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - ratio=updating_replacing_ratio + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + ratio=updating_replacing_ratio, ), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio + model=model, + token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) elif approach_sample_replacing_token == "postag": # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag - ts = POSTagTokenSampler(tokenizer=tokenizer, device=input_ids.device) # Initialize POSTagTokenSampler takes time so share it + ts = POSTagTokenSampler( + tokenizer=tokenizer, device=input_ids.device + ) # Initialize POSTagTokenSampler takes time so share it rationalizer = SampleRationalizer( importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=ts, - ratio=updating_replacing_ratio - ), + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer(token_sampler=ts, ratio=updating_replacing_ratio), stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=ts, - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer - ) - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio + model=model, + token_sampler=ts, + top_k=stop_condition_tolerance, + top_n=rational_size, + top_n_ratio=rationale_size_ratio, + tokenizer=tokenizer, + ), + ), + top_n=rational_size, + top_n_ratio=rationale_size_ratio, ) else: raise ValueError("Invalid approach_sample_replacing_token") - + rationalizer.trace_start() # rationalization @@ -209,37 +200,38 @@ def main(): # convert results print() - print(f"========================") + print("========================") print() - print(f'Input --> {input_string[0]}') - print(f'Target --> {tokenizer.decode(target_id[0])}') + print(f"Input --> {input_string[0]}") + print(f"Target --> {tokenizer.decode(target_id[0])}") print(f"Rational positions --> {pos_rational}") - print(f"Rational words -->") + print("Rational words -->") for i in range(pos_rational.shape[0]): ids_rational = input_ids[0, pos_rational[i]] - text_rational = [ tokenizer.decode([id_rational]) for id_rational in ids_rational ] + text_rational = [tokenizer.decode([id_rational]) for id_rational in ids_rational] print(f"{text_rational}") # output serialize_rational( - "rationalization_results/demo.json", - -1, - input_ids[0], - target_id[0], - pos_rational[0], - tokenizer, + "rationalization_results/demo.json", + -1, + input_ids[0], + target_id[0], + pos_rational[0], + tokenizer, rationalizer.importance_score_evaluator.important_score[0], compact=False, - comments= { + comments={ "message": "This is a demo output. [comments] is an optional field", "model": "gpt2-medium", - "approach_type": approach_sample_replacing_token + "approach_type": approach_sample_replacing_token, }, - trace_rationalizer=rationalizer + trace_rationalizer=rationalizer, ) rationalizer.trace_stop() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py index a92d16e9..82a08909 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py @@ -4,17 +4,13 @@ class StoppingConditionEvaluator(Traceable): - """Base class for Stopping Condition Evaluators - - """ + """Base class for Stopping Condition Evaluators""" def __init__(self): - """Base Constructor - - """ + """Base Constructor""" self.trace_target_likelihood = [] - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: - """Base evaluate - - """ + def evaluate( + self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + ) -> torch.Tensor: + """Base evaluate""" diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py index eb5c5b17..f98efe52 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py @@ -6,19 +6,19 @@ class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): """ - Stopping Condition Evaluator which stop when target exist in top k predictions, + Stopping Condition Evaluator which stop when target exist in top k predictions, while top n tokens based on importance_score are not been replaced. """ @override def __init__(self) -> None: - """Constructor - - """ + """Constructor""" super().__init__() @override - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: + def evaluate( + self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + ) -> torch.Tensor: """Evaluate stop condition Args: diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py index cb16d506..f3461d33 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py @@ -11,14 +11,22 @@ class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): """ - Stopping Condition Evaluator which stop when target exist in top k predictions, + Stopping Condition Evaluator which stop when target exist in top k predictions, while top n tokens based on importance_score are not been replaced. """ @override - def __init__(self, model: AutoModelWithLMHead, token_sampler: TokenSampler, top_k: int, top_n: int = 0, top_n_ratio: float = 0, tokenizer: AutoTokenizer = None) -> None: + def __init__( + self, + model: AutoModelWithLMHead, + token_sampler: TokenSampler, + top_k: int, + top_n: int = 0, + top_n_ratio: float = 0, + tokenizer: AutoTokenizer = None, + ) -> None: """Constructor - + Args: model: A Huggingface AutoModelWithLMHead. token_sampler: A TokenSampler to sample replacement tokens @@ -38,7 +46,9 @@ def __init__(self, model: AutoModelWithLMHead, token_sampler: TokenSampler, top_ self.tokenizer = tokenizer @override - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor) -> torch.Tensor: + def evaluate( + self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + ) -> torch.Tensor: """Evaluate stop condition Args: @@ -53,7 +63,7 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_ super().evaluate(input_ids, target_id, importance_score) # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} - + self.token_replacer.set_score(importance_score) input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) @@ -61,18 +71,18 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_ # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} - assert input_ids_replaced.requires_grad == False, "Error: auto-diff engine not disabled" + assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" with torch.no_grad(): - logits_replaced = self.model(input_ids_replaced)['logits'] + logits_replaced = self.model(input_ids_replaced)["logits"] - if self.trace_target_likelihood != None: + if self.trace_target_likelihood is not None: self.trace_target_likelihood.append(torch.softmax(logits_replaced, dim=-1)[:, -1, target_id]) - ids_prediction_sorted = torch.argsort(logits_replaced[:, -1 ,:], descending=True) - ids_prediction_top_k = ids_prediction_sorted[:, :self.top_k] + ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) + ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] if self.tokenizer: - top_k_words = [ [ self.tokenizer.decode([token_id]) for token_id in seq] for seq in ids_prediction_top_k ] + top_k_words = [[self.tokenizer.decode([token_id]) for token_id in seq] for seq in ids_prediction_top_k] logging.debug(f"Top K words -> {top_k_words}") match_mask = ids_prediction_top_k == target_id @@ -83,9 +93,7 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_ @override def trace_start(self) -> None: - """Start tracing - - """ + """Start tracing""" super().trace_start() self.token_sampler.trace_start() @@ -93,9 +101,7 @@ def trace_start(self) -> None: @override def trace_stop(self) -> None: - """Stop tracing - - """ + """Stop tracing""" super().trace_stop() self.token_sampler.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py index 15e6a842..04260f5c 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py @@ -14,30 +14,22 @@ class TokenReplacer(Traceable): """ def __init__(self, token_sampler: TokenSampler) -> None: - """Base Constructor - - """ + """Base Constructor""" self.token_sampler = token_sampler def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: - """Base sample + """Base sample""" - """ - @override def trace_start(self): - """Start tracing - - """ + """Start tracing""" super().trace_start() self.token_sampler.trace_start() @override def trace_stop(self): - """Stop tracing - - """ + """Stop tracing""" super().trace_stop() - + self.token_sampler.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py index 7ef6b1d0..56541e98 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py @@ -9,12 +9,12 @@ class RankingTokenReplacer(TokenReplacer): - """Replace tokens in a sequence based on top-N ranking - - """ + """Replace tokens in a sequence based on top-N ranking""" @override - def __init__(self, token_sampler: TokenSampler, top_n: int = 0, top_n_ratio: float = 0, replace_greater: bool = False) -> None: + def __init__( + self, token_sampler: TokenSampler, top_n: int = 0, top_n_ratio: float = 0, replace_greater: bool = False + ) -> None: """Constructor Args: @@ -31,11 +31,9 @@ def __init__(self, token_sampler: TokenSampler, top_n: int = 0, top_n_ratio: flo self.replace_greater = replace_greater def set_score(self, value: torch.Tensor) -> None: - pos_sorted = torch.argsort(value, descending=True) top_n = self.top_n - if top_n == 0: top_n = int(math.ceil(self.top_n_ratio * value.shape[-1])) @@ -43,9 +41,13 @@ def set_score(self, value: torch.Tensor) -> None: pos_top_n = pos_sorted[..., :top_n] if not self.replace_greater: - self.mask_replacing = torch.ones(value.shape, device=value.device, dtype=torch.bool).scatter(-1, pos_top_n, 0) + self.mask_replacing = torch.ones(value.shape, device=value.device, dtype=torch.bool).scatter( + -1, pos_top_n, 0 + ) else: - self.mask_replacing = torch.zeros(value.shape, device=value.device, dtype=torch.bool).scatter(-1, pos_top_n, 1) + self.mask_replacing = torch.zeros(value.shape, device=value.device, dtype=torch.bool).scatter( + -1, pos_top_n, 1 + ) @override def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: @@ -53,7 +55,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: Args: input: input sequence [batch, sequence] - + Returns: input_replaced: A replaced sequence [batch, sequence] mask_replacing: Identify which token has been replaced [batch, sequence] diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py index a45bbe1e..90084fae 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py @@ -1,4 +1,3 @@ - from typing import Union import torch @@ -9,9 +8,7 @@ class ThresholdTokenReplacer(TokenReplacer): - """Replace tokens in a sequence based on a threshold - - """ + """Replace tokens in a sequence based on a threshold""" @override def __init__(self, token_sampler: TokenSampler, threshold: float, replace_greater: bool = False) -> None: @@ -30,7 +27,7 @@ def __init__(self, token_sampler: TokenSampler, threshold: float, replace_greate def set_value(self, value: torch.Tensor) -> None: """Set the value for threshold control - + Args: value: value [batch, sequence] @@ -46,7 +43,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: Args: input: input sequence [batch, sequence] - + Returns: input_replaced: A replaced sequence [batch, sequence] mask_replacing: Identify which token has been replaced [batch, sequence] @@ -59,5 +56,3 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing return input_replaced, self.mask_replacing - - \ No newline at end of file diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py index 5c4bcf75..5fac259a 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py @@ -1,4 +1,3 @@ - from typing import Union import torch @@ -9,9 +8,7 @@ class UniformTokenReplacer(TokenReplacer): - """Replace tokens in a sequence where selecting is base on uniform distribution - - """ + """Replace tokens in a sequence where selecting is base on uniform distribution""" @override def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: @@ -32,7 +29,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: Args: input: input sequence [batch, sequence] - + Returns: input_replaced: A replaced sequence [batch, sequence] mask_replacing: Identify which token has been replaced [batch, sequence] @@ -48,4 +45,3 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: input_replaced = input * ~mask_replacing + token_sampled * mask_replacing return input_replaced, mask_replacing - \ No newline at end of file diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py index 16b4ee19..05d31d71 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py @@ -5,18 +5,12 @@ class TokenSampler(Traceable): - """Base class for token samplers + """Base class for token samplers""" - """ - @override def __init__(self) -> None: - """Base Constructor - - """ + """Base Constructor""" super().__init__() def sample(self, input: torch.Tensor) -> torch.Tensor: - """Base sample - - """ + """Base sample""" diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py index 1d48e277..0322517a 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py @@ -6,9 +6,7 @@ class InferentialTokenSampler(TokenSampler): - """Sample tokens from a seq-2-seq model - - """ + """Sample tokens from a seq-2-seq model""" @override def __init__(self, tokenizer: AutoTokenizer, model: AutoModelWithLMHead) -> None: @@ -30,16 +28,16 @@ def sample(self, input: torch.Tensor) -> torch.Tensor: Args: input: input tensor [batch, sequence] - + Returns: token_inferences: sampled (placement) tokens by inference """ super().sample(input) - logits_replacing = self.model(input)['logits'] + logits_replacing = self.model(input)["logits"] ids_infer = torch.argmax(logits_replacing, dim=-1) - token_inferences = torch.cat([ input[:, 0:1], ids_infer[:, :-1] ], dim=1) + token_inferences = torch.cat([input[:, 0:1], ids_infer[:, :-1]], dim=1) return token_inferences diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py index d02bb605..4526fd38 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py @@ -6,12 +6,12 @@ class InferentialMTokenSampler(TokenSampler): - """Sample tokens from a seq-2-seq model - - """ + """Sample tokens from a seq-2-seq model""" @override - def __init__(self, source_tokenizer: AutoTokenizer, sampler_tokenizer: AutoTokenizer, sampler_model: AutoModelWithLMHead) -> None: + def __init__( + self, source_tokenizer: AutoTokenizer, sampler_tokenizer: AutoTokenizer, sampler_model: AutoModelWithLMHead + ) -> None: """Constructor Args: @@ -32,7 +32,7 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor: Args: inputs: input tensor [batch, sequence] - + Returns: token_inferences: sampled (placement) tokens by inference @@ -43,31 +43,35 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor: for seq_i in torch.arange(inputs.shape[0]): seq_li = [] for pos_i in torch.arange(inputs.shape[1]): - # first token if pos_i == 0: - seq_li.append(inputs[seq_i, 0]) - continue + seq_li.append(inputs[seq_i, 0]) + continue # following tokens - probe_prefix = torch.tensor([self.sampler_tokenizer.encode(self.source_tokenizer.decode(inputs[seq_i, :pos_i]))], device=inputs.device) - probe_prefix = probe_prefix[:,:-1] # trim + probe_prefix = torch.tensor( + [self.sampler_tokenizer.encode(self.source_tokenizer.decode(inputs[seq_i, :pos_i]))], + device=inputs.device, + ) + probe_prefix = probe_prefix[:, :-1] # trim output_replacing_m = self.sampler_model(probe_prefix) - logits_replacing_m = output_replacing_m['logits'] - logits_replacing_m_last = logits_replacing_m[:,-1] + logits_replacing_m = output_replacing_m["logits"] + logits_replacing_m_last = logits_replacing_m[:, -1] id_infer_m = torch.argmax(logits_replacing_m_last, dim=-1) seq_li.append(id_infer_m.item()) batch_li.append(seq_li) - + res = torch.tensor(batch_li, device=inputs.device) return res + if __name__ == "__main__": from transformers import AutoModelForCausalLM, AutoTokenizer + device = "cpu" source_tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="cache") @@ -81,11 +85,9 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor: sampler = InferentialMTokenSampler(source_tokenizer, sampler_tokenizer, sampler_model) text = "This is a test sequence" - inputs = torch.tensor([ source_tokenizer.encode(text) ], device=device) + inputs = torch.tensor([source_tokenizer.encode(text)], device=device) outputs = sampler.sample(inputs) print(outputs) print(source_tokenizer.decode(outputs[0])) - - diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py index 0cb1590d..ae99f623 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py @@ -7,9 +7,7 @@ class POSTagTokenSampler(TokenSampler): - """Sample tokens from Uniform distribution on a set of words with the same POS tag - - """ + """Sample tokens from Uniform distribution on a set of words with the same POS tag""" @override def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: @@ -26,7 +24,7 @@ def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: # extract mapping from postag to words # debug_mapping_postag_to_group_word = {} mapping_postag_to_group_token_id = {} - + for i in range(tokenizer.vocab_size): word = tokenizer.decode([i]) _, tag = nltk.pos_tag([word.strip()])[0] @@ -40,11 +38,14 @@ def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: print(f"[POSTagTokenSampler] Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") # create tag_id for postags - self.list_postag = [ tag for tag in mapping_postag_to_group_token_id.keys() ] + self.list_postag = list(mapping_postag_to_group_token_id.keys()) num_postags = len(self.list_postag) # build mapping from tag_id to word group - list_group_token_id = [ torch.tensor(mapping_postag_to_group_token_id[postag], dtype=torch.long, device=device) for postag in self.list_postag ] + list_group_token_id = [ + torch.tensor(mapping_postag_to_group_token_id[postag], dtype=torch.long, device=device) + for postag in self.list_postag + ] # build mapping from token_id to tag_id self.mapping_token_id_to_tag_id = torch.zeros([tokenizer.vocab_size], dtype=torch.long, device=device) @@ -53,8 +54,12 @@ def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: # build mapping from tag_id to token_id # postag groups are concat together, index them via compact_idx = group_offsets[tag_id] + group_idx - self.group_sizes = torch.tensor([ group_token_id.shape[0] for group_token_id in list_group_token_id ], dtype=torch.long, device=device) - self.group_offsets = torch.sum(torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.group_sizes, dim=-1) + self.group_sizes = torch.tensor( + [group_token_id.shape[0] for group_token_id in list_group_token_id], dtype=torch.long, device=device + ) + self.group_offsets = torch.sum( + torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.group_sizes, dim=-1 + ) self.compact_group_token_id = torch.cat(list_group_token_id) @override @@ -63,7 +68,7 @@ def sample(self, input: torch.Tensor) -> torch.Tensor: Args: input: input tensor [batch, sequence] - + Returns: token_sampled: A sampled tensor where its shape is the same with the input @@ -72,7 +77,9 @@ def sample(self, input: torch.Tensor) -> torch.Tensor: tag_id_input = self.mapping_token_id_to_tag_id[input] sample_uniform = torch.rand(input.shape, device=input.device) - compact_group_idx = (sample_uniform * self.group_sizes[tag_id_input] + self.group_offsets[tag_id_input]).type(torch.long) + compact_group_idx = (sample_uniform * self.group_sizes[tag_id_input] + self.group_offsets[tag_id_input]).type( + torch.long + ) token_sampled = self.compact_group_token_id[compact_group_idx] return token_sampled diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py index 3d49a914..8401f942 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py @@ -6,9 +6,7 @@ class UniformTokenSampler(TokenSampler): - """Sample tokens from Uniform distribution - - """ + """Sample tokens from Uniform distribution""" @override def __init__(self, tokenizer: AutoTokenizer) -> None: @@ -32,14 +30,14 @@ def __init__(self, tokenizer: AutoTokenizer) -> None: # collect available tokens self.avail_tokens = torch.arange(tokenizer.vocab_size)[avail_mask != 0] - + @override def sample(self, input: torch.Tensor) -> torch.Tensor: """Sample a tensor Args: input: input tensor [batch, sequence] - + Returns: token_uniform: A sampled tensor where its shape is the same with the input diff --git a/inseq/attr/feat/ops/reagent_core/utils/serializing.py b/inseq/attr/feat/ops/reagent_core/utils/serializing.py index dfa85df1..9502ab7f 100644 --- a/inseq/attr/feat/ops/reagent_core/utils/serializing.py +++ b/inseq/attr/feat/ops/reagent_core/utils/serializing.py @@ -18,10 +18,10 @@ def serialize_rational( compact: bool = False, trace_rationalizer: BaseRationalizer = None, trace_batch_idx: int = 0, - schema_file: str = "../docs/rationalization.schema.json" + schema_file: str = "../docs/rationalization.schema.json", ) -> None: """Serialize rationalization result to a json file - + Args: filename: Filename to store json file id: id of the record @@ -50,7 +50,7 @@ def serialize_rational( "rational-tokens": [i.item() for i in token_inputs[position_rational]], } - if important_score != None: + if important_score is not None: data["importance-scores"] = [i.item() for i in important_score] if comments: @@ -58,14 +58,22 @@ def serialize_rational( if trace_rationalizer: trace = { - "importance-scores": [ [ v.item() for v in i[trace_batch_idx] ] for i in trace_rationalizer.importance_score_evaluator.trace_importance_score ], - "target-likelihood-original": trace_rationalizer.importance_score_evaluator.trace_target_likelihood_original[trace_batch_idx].item(), - "target-likelihood": [ i[trace_batch_idx].item() for i in trace_rationalizer.importance_score_evaluator.stopping_condition_evaluator.trace_target_likelihood ] + "importance-scores": [ + [v.item() for v in i[trace_batch_idx]] + for i in trace_rationalizer.importance_score_evaluator.trace_importance_score + ], + "target-likelihood-original": trace_rationalizer.importance_score_evaluator.trace_target_likelihood_original[ + trace_batch_idx + ].item(), + "target-likelihood": [ + i[trace_batch_idx].item() + for i in trace_rationalizer.importance_score_evaluator.stopping_condition_evaluator.trace_target_likelihood + ], } data["trace"] = trace indent = None if compact else 4 json_str = json.dumps(data, indent=indent) - with open(filename, 'w') as f_output: + with open(filename, "w") as f_output: f_output.write(json_str) diff --git a/inseq/attr/feat/ops/reagent_core/utils/traceable.py b/inseq/attr/feat/ops/reagent_core/utils/traceable.py index 94588e8c..93e16ae5 100644 --- a/inseq/attr/feat/ops/reagent_core/utils/traceable.py +++ b/inseq/attr/feat/ops/reagent_core/utils/traceable.py @@ -1,14 +1,8 @@ class Traceable: - """Traceable base - - """ + """Traceable base""" def trace_start(self) -> None: - """Base trace_start - - """ + """Base trace_start""" def trace_stop(self) -> None: - """Base trace_stop - - """ + """Base trace_stop""" diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 3fed9dbc..138d4bfb 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -10,8 +10,7 @@ from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime -from .ops import ReAGent +from .ops import Lime, ReAGent logger = logging.getLogger(__name__) @@ -119,6 +118,7 @@ def attribute_step( sequence_scores=out.sequence_scores, ) + class ReAGentAttribution(PerturbationAttributionRegistry): """ReAGent-based attribution method. The main part of the code is in ops/reagent.py. @@ -136,9 +136,7 @@ def attribute_step( attribution_args: dict[str, Any] = {}, ) -> GranularFeatureAttributionStepOutput: if len(attribute_fn_main_args["inputs"]) > 1: - raise NotImplementedError( - "ReAgent attribution not supported for encoder-decoder models." - ) + raise NotImplementedError("ReAgent attribution not supported for encoder-decoder models.") out = super().attribute_step(attribute_fn_main_args, attribution_args) return GranularFeatureAttributionStepOutput( source_attributions=out.source_attributions, From 06d5f603d489b75badc989d3564fb12ee8981eb0 Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Mon, 19 Feb 2024 19:30:31 +0000 Subject: [PATCH 05/14] update dependencies for ReAGent --- pyproject.toml | 1 + requirements.txt | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf400855..decee7c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "torch>=2.1.1", "matplotlib>=3.5.3", "tqdm>=4.64.0", + "nltk>=3.8.1", "nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'", "nvidia-cuda-cupti-cu11>=11.7.101; sys_platform=='Linux'", "nvidia-cuda-nvrtc-cu11>=11.7.99; sys_platform=='Linux'", diff --git a/requirements.txt b/requirements.txt index 9f392d72..ba99f21d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv v0.1.5 via the following command: # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 certifi==2024.2.2 # via requests charset-normalizer==3.3.2 # via requests +click==8.1.7 + # via nltk contourpy==1.2.0 # via matplotlib cycler==0.12.1 @@ -29,6 +31,8 @@ idna==3.6 jaxtyping==0.2.25 jinja2==3.1.3 # via torch +joblib==1.3.2 + # via nltk kiwisolver==1.4.5 # via matplotlib markdown-it-py==3.0.0 @@ -43,6 +47,7 @@ mpmath==1.3.0 # via sympy networkx==3.2.1 # via torch +nltk==3.8.1 numpy==1.26.4 # via # captum @@ -70,7 +75,9 @@ pyyaml==6.0.1 # huggingface-hub # transformers regex==2023.12.25 - # via transformers + # via + # nltk + # transformers requests==2.31.0 # via # huggingface-hub @@ -92,6 +99,7 @@ tqdm==4.66.2 # via # captum # huggingface-hub + # nltk # transformers transformers==4.37.2 typeguard==2.13.3 From 82d6e93417aabc13f5d257f32cb99a3b6a28843a Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Mon, 19 Feb 2024 23:14:11 +0100 Subject: [PATCH 06/14] Added caching for POSTagTokenSampler, minor fixes --- inseq/attr/feat/ops/__init__.py | 4 +- inseq/attr/feat/ops/reagent.py | 29 +++--- inseq/attr/feat/ops/reagent_core/__init__.py | 0 .../token_replacer/ranking.py | 2 +- .../token_replacer/uniform.py | 2 +- .../token_sampler/inferential_m.py | 4 - .../feat/ops/reagent_core/token_sampler.py | 90 +++++++++++++++++++ inseq/attr/feat/perturbation_attribution.py | 18 ++-- inseq/utils/__init__.py | 2 + inseq/utils/import_utils.py | 5 ++ 10 files changed, 131 insertions(+), 25 deletions(-) create mode 100644 inseq/attr/feat/ops/reagent_core/__init__.py create mode 100644 inseq/attr/feat/ops/reagent_core/token_sampler.py diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 48011294..abe53d29 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,13 +1,13 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder -from .reagent import ReAGent +from .reagent import Reagent from .sequential_integrated_gradients import SequentialIntegratedGradients __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Lime", - "ReAGent", + "Reagent", "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 8e3e9765..95fd8988 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import TYPE_CHECKING, Any, Union import torch from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric @@ -10,12 +10,21 @@ from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer -from .reagent_core.token_replacement.token_sampler.postag import POSTagTokenSampler +from .reagent_core.token_sampler import POSTagTokenSampler +if TYPE_CHECKING: + from ....models import HuggingfaceModel -class ReAGent(PerturbationAttribution): - r""" - ReAGent + +class Reagent(PerturbationAttribution): + r"""Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ Args: forward_func (callable): The forward function of the model or any @@ -28,10 +37,6 @@ class ReAGent(PerturbationAttribution): max_probe_steps (int): max_probe_steps num_probes (int): number of probes in parallel - References: - `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models - `_ - Examples: ``` import inseq @@ -52,7 +57,7 @@ class ReAGent(PerturbationAttribution): def __init__( self, - attribution_model: Callable, + attribution_model: "HuggingfaceModel", rational_size: int = 5, rational_size_ratio: float = None, stopping_condition_top_k: int = 3, @@ -65,7 +70,9 @@ def __init__( model = attribution_model.model tokenizer = attribution_model.tokenizer - token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=model.device) + token_sampler = POSTagTokenSampler( + tokenizer=tokenizer, identifier=attribution_model.model_name, device=attribution_model.device + ) stopping_condition_evaluator = TopKStoppingConditionEvaluator( model=model, diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py index 56541e98..fe842652 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py @@ -63,7 +63,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: """ super().sample(input) - token_sampled = self.token_sampler.sample(input) + token_sampled = self.token_sampler(input) input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py index 5fac259a..4c663bf1 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py @@ -40,7 +40,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: sample_uniform = torch.rand(input.shape, device=input.device) mask_replacing = sample_uniform < self.ratio - token_sampled = self.token_sampler.sample(input) + token_sampled = self.token_sampler(input) input_replaced = input * ~mask_replacing + token_sampled * mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py index 4526fd38..954b43ae 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py @@ -2,8 +2,6 @@ from transformers import AutoModelWithLMHead, AutoTokenizer from typing_extensions import override -from .base import TokenSampler - class InferentialMTokenSampler(TokenSampler): """Sample tokens from a seq-2-seq model""" @@ -37,8 +35,6 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor: token_inferences: sampled (placement) tokens by inference """ - super().sample(inputs) - batch_li = [] for seq_i in torch.arange(inputs.shape[0]): seq_li = [] diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py new file mode 100644 index 00000000..73cfb7e7 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -0,0 +1,90 @@ +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available +from .....utils.typing import IdsTensor + +logger = logging.getLogger(__name__) + + +class TokenSampler(ABC): + """Base class for token samplers""" + + @abstractmethod + def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: + """Sample tokens according to the specified strategy.""" + pass + + +class POSTagTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution on a set of words with the same POS tag.""" + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizerBase], + identifier: str = "pos_tag_sampler", + save_cache: bool = True, + overwrite_cache: bool = False, + cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache", + device: Optional[str] = None, + tokenizer_kwargs: Optional[dict[str, Any]] = {}, + ) -> None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) + cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl" + self.pos2ids = self.build_pos_mapping_from_vocab( + cache_dir, + cache_filename, + save_cache, + overwrite_cache, + tokenizer=self.tokenizer, + ) + num_postags = len(self.pos2ids) + self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device) + for pos_idx, ids in enumerate(self.pos2ids.values()): + self.id2pos[ids] = pos_idx + self.num_ids_per_pos = torch.tensor( + [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device + ) + self.offsets = torch.sum( + torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos, + dim=-1, + ) + self.compact_idx = torch.cat( + tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values()) + ) + + @staticmethod + @cache_results + def build_pos_mapping_from_vocab( + tokenizer: PreTrainedTokenizerBase, + log_every: int = 5000, + ) -> dict[str, list[int]]: + """Build mapping from POS tags to list of token ids from tokenizer's vocabulary.""" + if not is_nltk_available(): + raise ImportError("nltk is required to build POS tag mapping. Please install nltk.") + import nltk + + nltk.download("averaged_perceptron_tagger") + pos2ids = defaultdict(list) + for i in range(tokenizer.vocab_size): + word = tokenizer.decode([i]) + _, tag = nltk.pos_tag([word.strip()])[0] + pos2ids[tag].append(i) + if i % log_every == 0: + logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") + return pos2ids + + def __call__(self, input_ids: IdsTensor) -> IdsTensor: + input_ids_pos = self.id2pos[input_ids] + sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) + compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long() + return self.compact_idx[compact_group_idx] diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 138d4bfb..24a4aa68 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -10,7 +10,7 @@ from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime, ReAGent +from .ops import Lime, Reagent logger = logging.getLogger(__name__) @@ -119,16 +119,22 @@ def attribute_step( ) -class ReAGentAttribution(PerturbationAttributionRegistry): - """ReAGent-based attribution method. - The main part of the code is in ops/reagent.py. +class ReagentAttribution(PerturbationAttributionRegistry): + """Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ """ - method_name = "ReAGent" + method_name = "reagent" def __init__(self, attribution_model, **kwargs): super().__init__(attribution_model) - self.method = ReAGent(attribution_model=self.attribution_model, **kwargs) + self.method = Reagent(attribution_model=self.attribution_model, **kwargs) def attribute_step( self, diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 29f81615..7dabecbb 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -13,6 +13,7 @@ is_datasets_available, is_ipywidgets_available, is_joblib_available, + is_nltk_available, is_scikitlearn_available, is_sentencepiece_available, is_transformers_available, @@ -94,6 +95,7 @@ "is_datasets_available", "is_captum_available", "is_joblib_available", + "is_nltk_available", "check_device", "get_default_device", "ndarray_to_bin_str", diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index cbd03420..2a1ccc2d 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -7,6 +7,7 @@ _datasets_available = find_spec("datasets") is not None _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None +_nltk_available = find_spec("nltk") is not None def is_ipywidgets_available(): @@ -35,3 +36,7 @@ def is_captum_available(): def is_joblib_available(): return _joblib_available + + +def is_nltk_available(): + return _nltk_available From be3b9836706361ffa94eb1cf74958b0860f635f1 Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Thu, 29 Feb 2024 01:39:57 +0000 Subject: [PATCH 07/14] reagent: adapt codebase structure & cleanup unused classes --- inseq/attr/feat/ops/reagent.py | 10 +- .../reagent_core/aggregate_rationalizer.py | 296 ------------------ inseq/attr/feat/ops/reagent_core/base.py | 10 - ..._prob.py => importance_score_evaluator.py} | 164 ++-------- .../importance_score_evaluator/base.py | 58 ---- .../feat/ops/reagent_core/rationalizer.py | 115 +++++++ .../ops/reagent_core/sample_rationalizer.py | 237 -------------- ...p_k.py => stopping_condition_evaluator.py} | 75 +++-- .../stopping_condition_evaluator/base.py | 16 - .../stopping_condition_evaluator/dummy.py | 38 --- .../token_replacement/token_replacer/base.py | 35 --- .../token_replacer/threshold.py | 58 ---- .../token_replacer/uniform.py | 47 --- .../token_replacement/token_sampler/base.py | 16 - .../token_sampler/inferential.py | 43 --- .../token_sampler/inferential_m.py | 89 ------ .../token_replacement/token_sampler/postag.py | 85 ----- .../token_sampler/uniform.py | 53 ---- .../ranking.py => token_replacer.py} | 71 ++++- .../feat/ops/reagent_core/token_sampler.py | 23 +- .../ops/reagent_core/utils/serializing.py | 79 ----- .../feat/ops/reagent_core/utils/traceable.py | 8 - 22 files changed, 289 insertions(+), 1337 deletions(-) delete mode 100644 inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py delete mode 100644 inseq/attr/feat/ops/reagent_core/base.py rename inseq/attr/feat/ops/reagent_core/{importance_score_evaluator/delta_prob.py => importance_score_evaluator.py} (51%) delete mode 100644 inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py create mode 100644 inseq/attr/feat/ops/reagent_core/rationalizer.py delete mode 100644 inseq/attr/feat/ops/reagent_core/sample_rationalizer.py rename inseq/attr/feat/ops/reagent_core/{stopping_condition_evaluator/top_k.py => stopping_condition_evaluator.py} (64%) delete mode 100644 inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py delete mode 100644 inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py delete mode 100644 inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py rename inseq/attr/feat/ops/reagent_core/{token_replacement/token_replacer/ranking.py => token_replacer.py} (50%) delete mode 100644 inseq/attr/feat/ops/reagent_core/utils/serializing.py delete mode 100644 inseq/attr/feat/ops/reagent_core/utils/traceable.py diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 95fd8988..4e189b63 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -6,10 +6,10 @@ from torch import Tensor from typing_extensions import override -from .reagent_core.aggregate_rationalizer import AggregateRationalizer -from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator -from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator -from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer +from .reagent_core.importance_score_evaluator import DeltaProbImportanceScoreEvaluator +from .reagent_core.rationalizer import AggregateRationalizer +from .reagent_core.stopping_condition_evaluator import TopKStoppingConditionEvaluator +from .reagent_core.token_replacer import UniformTokenReplacer from .reagent_core.token_sampler import POSTagTokenSampler if TYPE_CHECKING: @@ -111,7 +111,7 @@ def attribute( # type: ignore tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: """Implement attribute""" - self.rationalizer.rationalize(additional_forward_args[0], additional_forward_args[1]) + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) return (res,) diff --git a/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py b/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py deleted file mode 100644 index b19e326a..00000000 --- a/inseq/attr/feat/ops/reagent_core/aggregate_rationalizer.py +++ /dev/null @@ -1,296 +0,0 @@ -import math -from typing import Union - -import torch -from typing_extensions import override - -from .base import BaseRationalizer -from .importance_score_evaluator.base import BaseImportanceScoreEvaluator - - -class AggregateRationalizer(BaseRationalizer): - """AggregateRationalizer""" - - def __init__( - self, - importance_score_evaluator: BaseImportanceScoreEvaluator, - batch_size: int, - overlap_threshold: int, - overlap_strict_pos: bool = True, - top_n: float = 0, - top_n_ratio: float = 0, - ) -> None: - """Constructor - - Args: - importance_score_evaluator: A ImportanceScoreEvaluator - batch_size: Batch size for aggregate - overlap_threshold: Overlap threshold of rational tokens within a batch - overlap_strict_pos: Whether overlap strict to position ot not - top_n: Rational size - top_n_ratio: Use ratio of sequence to define rational size - - """ - super().__init__(importance_score_evaluator) - - self.batch_size = batch_size - self.overlap_threshold = overlap_threshold - self.overlap_strict_pos = overlap_strict_pos - self.top_n = top_n - self.top_n_ratio = top_n_ratio - - assert overlap_strict_pos, "overlap_strict_pos = False not been supported yet" - - def get_separate_rational(self, input_ids, tokenizer) -> Union[torch.Tensor, list[list[str]]]: - tokens = [[tokenizer.decode([input_ids[0, i]]) for i in s] for s in self.pos_top_n] - - return self.pos_top_n, tokens - - @torch.no_grad() - def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: - """Compute rational of a sequence on a target - - Args: - input_ids: The sequence [batch, sequence] (first dimension need to be 1) - target_id: The target [batch] - - Return: - pos_top_n: rational position in the sequence [batch, rational_size] - - """ - assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" - - batch_input_ids = input_ids.repeat(self.batch_size, 1) - - batch_importance_score = self.importance_score_evaluator.evaluate(batch_input_ids, target_id) - - important_score_masked = batch_importance_score * torch.unsqueeze( - self.importance_score_evaluator.stop_mask, -1 - ) - self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum( - self.importance_score_evaluator.stop_mask - ) - - pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) - - top_n = self.top_n - - if top_n == 0: - top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) - - pos_top_n = pos_sorted[:, :top_n] - self.pos_top_n = pos_top_n - - if self.overlap_strict_pos: - count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) - pos_top_n_overlap = torch.unsqueeze( - torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 - ) - return pos_top_n_overlap - else: - token_id_top_n = input_ids[0, pos_top_n] - count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) - _token_id_top_n_overlap = torch.unsqueeze( - torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 - ) - # TODO: Convert back to pos - raise NotImplementedError("TODO") - - @override - def trace_start(self) -> None: - """Start tracing""" - super().trace_start() - - self.importance_score_evaluator.trace_start() - - @override - def trace_stop(self) -> None: - """Stop tracing""" - super().trace_stop() - - self.importance_score_evaluator.trace_stop() - - -@torch.no_grad() -def main(): - from rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator - from stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator - from token_replacement.token_replacer.uniform import UniformTokenReplacer - from token_replacement.token_sampler.inferential import InferentialTokenSampler - from token_replacement.token_sampler.postag import POSTagTokenSampler - from token_replacement.token_sampler.uniform import UniformTokenSampler - from transformers import AutoModelWithLMHead, AutoTokenizer - - from utils.serializing import serialize_rational - - # ======== model loading ======== - # Load model from Hugging Face - model = AutoModelWithLMHead.from_pretrained("gpt2-medium") - tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") - - model.cuda() - model.eval() - - # ======== prepare data ======== - - # batch with size 1 - input_string = [ - # "I love eating breakfast in the", - "When my flight landed in Thailand. I was staying in the capital city of" - # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. I was staying in the capital city of" - # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" - ] - - # generate prediction - input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(model.device) - generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) - print(" generated input -->", [[tokenizer.decode(token) for token in seq] for seq in generated_input]) - - # extract target from prediction - target_id = generated_input[:, input_ids.shape[1]] - print(" target -->", [tokenizer.decode(token) for token in target_id]) - - # ======== hyper-parameters ======== - - # replacing ratio during importance score updating - updating_replacing_ratio = 0.3 - # keep top n word based on importance score for both stop condition evaluation and rationalization - rationale_size_ratio = None - rational_size = 5 - # stop when target exist in top k predictions - stop_condition_tolerance = 5 - - # Batch size for aggregate - aggregate_batch_size = 5 - # Overlap threshold of rational tokens within a batch - overlap_threshold = 3 - # Whether overlap strict to position ot not - overlap_strict_pos = True - - # ======== rationalization ======== - - approach_sample_replacing_token = "uniform" - # approach_sample_replacing_token = "inference" - # approach_sample_replacing_token = "postag" - - # prepare rationalizer - if approach_sample_replacing_token == "uniform": - # Approach 1: sample replacing token from uniform distribution - rationalizer = AggregateRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=UniformTokenSampler(tokenizer), ratio=updating_replacing_ratio - ), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=UniformTokenSampler(tokenizer), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - batch_size=aggregate_batch_size, - overlap_threshold=overlap_threshold, - overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - elif approach_sample_replacing_token == "inference": - # Approach 2: sample replacing token from model inference - rationalizer = AggregateRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - ratio=updating_replacing_ratio, - ), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - batch_size=aggregate_batch_size, - overlap_threshold=overlap_threshold, - overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - elif approach_sample_replacing_token == "postag": - # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag - ts = POSTagTokenSampler( - tokenizer=tokenizer, device=input_ids.device - ) # Initialize POSTagTokenSampler takes time so share it - rationalizer = AggregateRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer(token_sampler=ts, ratio=updating_replacing_ratio), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=ts, - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - batch_size=aggregate_batch_size, - overlap_threshold=overlap_threshold, - overlap_strict_pos=overlap_strict_pos, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - else: - raise ValueError("Invalid approach_sample_replacing_token") - - rationalizer.trace_start() - - # rationalization - pos_rational = rationalizer.rationalize(input_ids, generated_input[:, input_ids.shape[1]]) - - # convert results - - print() - print("========================") - print() - print(f"Input --> {input_string[0]}") - print(f"Target --> {tokenizer.decode(target_id[0])}") - print(f"Rational positions --> {pos_rational}") - print("Rational words -->") - for i in range(pos_rational.shape[0]): - ids_rational = input_ids[0, pos_rational[i]] - text_rational = [tokenizer.decode([id_rational]) for id_rational in ids_rational] - print(f"{text_rational}") - - # output - - serialize_rational( - "rationalization_results/demo.json", - -1, - input_ids[0], - target_id[0], - pos_rational[0], - tokenizer, - rationalizer.importance_score_evaluator.important_score[0], - compact=False, - comments={ - "message": "This is a demo output. [comments] is an optional field", - "model": "gpt2-medium", - "approach_type": approach_sample_replacing_token, - }, - trace_rationalizer=rationalizer, - ) - - rationalizer.trace_stop() - - -if __name__ == "__main__": - main() diff --git a/inseq/attr/feat/ops/reagent_core/base.py b/inseq/attr/feat/ops/reagent_core/base.py deleted file mode 100644 index 8512ae06..00000000 --- a/inseq/attr/feat/ops/reagent_core/base.py +++ /dev/null @@ -1,10 +0,0 @@ -from .importance_score_evaluator.base import BaseImportanceScoreEvaluator -from .utils.traceable import Traceable - - -class BaseRationalizer(Traceable): - def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: - super().__init__() - - self.importance_score_evaluator = importance_score_evaluator - self.mean_important_score = None diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py similarity index 51% rename from inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py rename to inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index 7923d3b2..8a139bc6 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/delta_prob.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -1,96 +1,33 @@ import logging +from abc import ABC, abstractmethod import torch from transformers import AutoModelWithLMHead, AutoTokenizer +from typing_extensions import override -from ..stopping_condition_evaluator.base import StoppingConditionEvaluator -from ..token_replacement.token_replacer.base import TokenReplacer -from .base import BaseImportanceScoreEvaluator +from .stopping_condition_evaluator import StoppingConditionEvaluator +from .token_replacer import TokenReplacer -class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): +class BaseImportanceScoreEvaluator(ABC): """Importance Score Evaluator""" - def __init__( - self, - model: AutoModelWithLMHead, - tokenizer: AutoTokenizer, - token_replacer: TokenReplacer, - stopping_condition_evaluator: StoppingConditionEvaluator, - max_steps: float, - ) -> None: - """Constructor + def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None: + """Base Constructor Args: model: A Huggingface AutoModelWithLMHead model tokenizer: A Huggingface AutoTokenizer - token_replacer: A TokenReplacer - stopping_condition_evaluator: A StoppingConditionEvaluator """ - super().__init__(model, tokenizer) - - self.token_replacer = token_replacer - self.stopping_condition_evaluator = stopping_condition_evaluator - self.max_steps = max_steps + self.model = model + self.tokenizer = tokenizer self.important_score = None - self.trace_importance_score = None - self.trace_target_likelihood_original = None - self.num_steps = 0 - - def update_importance_score( - self, - logit_importance_score: torch.Tensor, - input_ids: torch.Tensor, - target_id: torch.Tensor, - prob_original_target: torch.Tensor, - ) -> torch.Tensor: - """Update importance score by one step - - Args: - logit_importance_score: Current importance score in logistic scale [batch] - input_ids: input tensor [batch, sequence] - target_id: target tensor [batch] - prob_original_target: predictive probability of the target on the original sequence [batch] - - Return: - logit_importance_score: updated importance score in logistic scale [batch] - - """ - # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} - - input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) - - logging.debug(f"Replacing mask: { mask_replacing }") - logging.debug( - f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }" - ) - - # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) - - logits_replaced = self.model(input_ids_replaced)["logits"] - prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] - self.trace_prob_original_target = prob_replaced_target - - # Compute changes delta = p^{(y)} - \hat{p^{(y)}} - - delta_prob_target = prob_original_target - prob_replaced_target - logging.debug(f"likelihood delta: { delta_prob_target }") - - # Update importance scores based on delta (magnitude) and replacement (direction) - - delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target - # TODO: better solution? - # Rescaling from [-1, 1] to [0, 1] before logit function - logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) - logit_importance_score = logit_importance_score + logit_delta_score - logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") - - return logit_importance_score - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + @abstractmethod + def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: """Evaluate importance score of input sequence Args: @@ -101,61 +38,20 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te importance_score: evaluated importance score for each token in the input [batch, sequence] """ - - self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) - - # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) - - logits_original = self.model(input_ids)["logits"] - prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] - - if self.trace_target_likelihood_original is not None: - self.trace_target_likelihood_original = prob_original_target - - # Initialize importance score s for each token in the sequence y_{1...t} - - logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) - logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") - - # TODO: limit max steps - self.num_steps = 0 - while self.num_steps < self.max_steps: - self.num_steps += 1 - - # Update importance score - logit_importance_score_update = self.update_importance_score( - logit_importance_score, input_ids, target_id, prob_original_target - ) - logit_importance_score = ( - ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update - + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score - ) - - self.important_score = torch.softmax(logit_importance_score, -1) - if self.trace_importance_score is not None: - self.trace_importance_score.append(self.important_score) - - # Evaluate stop condition - self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate( - input_ids, target_id, self.important_score - ) - if torch.prod(self.stop_mask) > 0: - break - - logging.info(f"Importance score evaluated in {self.num_steps} steps.") - - return torch.softmax(logit_importance_score, -1) + raise NotImplementedError() -class DeltaProbImportanceScoreEvaluator_imp(BaseImportanceScoreEvaluator): +class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): """Importance Score Evaluator""" + @override def __init__( self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator, + max_steps: float, ) -> None: """Constructor @@ -167,14 +63,13 @@ def __init__( """ - self.model = model - self.tokenizer = tokenizer + super().__init__(model, tokenizer) + self.token_replacer = token_replacer self.stopping_condition_evaluator = stopping_condition_evaluator - self.important_score = None + self.max_steps = max_steps - self.trace_importance_score = None - self.trace_target_likelihood_original = None + self.important_score = None self.num_steps = 0 def update_importance_score( @@ -198,7 +93,7 @@ def update_importance_score( """ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} - input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) logging.debug(f"Replacing mask: { mask_replacing }") logging.debug( @@ -219,12 +114,16 @@ def update_importance_score( # Update importance scores based on delta (magnitude) and replacement (direction) delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target - logit_importance_score = logit_importance_score + delta_score + # TODO: better solution? + # Rescaling from [-1, 1] to [0, 1] before logit function + logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) + logit_importance_score = logit_importance_score + logit_delta_score logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") return logit_importance_score - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + @override + def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: """Evaluate importance score of input sequence Args: @@ -243,17 +142,14 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te logits_original = self.model(input_ids)["logits"] prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] - if self.trace_target_likelihood_original is not None: - self.trace_target_likelihood_original = prob_original_target - # Initialize importance score s for each token in the sequence y_{1...t} - logit_importance_score = torch.zeros(input_ids.shape, device=input_ids.device) + logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") # TODO: limit max steps self.num_steps = 0 - while True: + while self.num_steps < self.max_steps: self.num_steps += 1 # Update importance score @@ -266,11 +162,9 @@ def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te ) self.important_score = torch.softmax(logit_importance_score, -1) - if self.trace_importance_score is not None: - self.trace_importance_score.append(self.important_score) # Evaluate stop condition - self.stop_mask = self.stop_mask | self.stopping_condition_evaluator.evaluate( + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( input_ids, target_id, self.important_score ) if torch.prod(self.stop_mask) > 0: diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py deleted file mode 100644 index e627cac7..00000000 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator/base.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from transformers import AutoModelWithLMHead, AutoTokenizer -from typing_extensions import override - -from ..utils.traceable import Traceable - - -class BaseImportanceScoreEvaluator(Traceable): - """Importance Score Evaluator""" - - def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None: - """Base Constructor - - Args: - model: A Huggingface AutoModelWithLMHead model - tokenizer: A Huggingface AutoTokenizer - - """ - - self.model = model - self.tokenizer = tokenizer - - self.important_score = None - - self.trace_importance_score = None - self.trace_target_likelihood_original = None - - def evaluate(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: - """Evaluate importance score of input sequence - - Args: - input_ids: input sequence [batch, sequence] - target_id: target token [batch] - - Return: - importance_score: evaluated importance score for each token in the input [batch, sequence] - - """ - - raise NotImplementedError() - - @override - def trace_start(self): - """Start tracing""" - super().trace_start() - - self.trace_importance_score = [] - self.trace_target_likelihood_original = -1 - self.stopping_condition_evaluator.trace_start() - - @override - def trace_stop(self): - """Stop tracing""" - super().trace_stop() - - self.trace_importance_score = None - self.trace_target_likelihood_original = None - self.stopping_condition_evaluator.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py new file mode 100644 index 00000000..1b4511a1 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -0,0 +1,115 @@ +import math +from abc import ABC, abstractmethod + +import torch +from typing_extensions import override + +from .importance_score_evaluator import BaseImportanceScoreEvaluator + + +class BaseRationalizer(ABC): + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: + super().__init__() + + self.importance_score_evaluator = importance_score_evaluator + self.mean_important_score = None + + @abstractmethod + def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] (first dimension need to be 1) + target_id: The target [batch] + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + raise NotImplementedError() + + +class AggregateRationalizer(BaseRationalizer): + """AggregateRationalizer""" + + @override + def __init__( + self, + importance_score_evaluator: BaseImportanceScoreEvaluator, + batch_size: int, + overlap_threshold: int, + overlap_strict_pos: bool = True, + top_n: float = 0, + top_n_ratio: float = 0, + ) -> None: + """Constructor + + Args: + importance_score_evaluator: A ImportanceScoreEvaluator + batch_size: Batch size for aggregate + overlap_threshold: Overlap threshold of rational tokens within a batch + overlap_strict_pos: Whether overlap strict to position ot not + top_n: Rational size + top_n_ratio: Use ratio of sequence to define rational size + + """ + super().__init__(importance_score_evaluator) + + self.batch_size = batch_size + self.overlap_threshold = overlap_threshold + self.overlap_strict_pos = overlap_strict_pos + self.top_n = top_n + self.top_n_ratio = top_n_ratio + + assert overlap_strict_pos, "overlap_strict_pos = False not been supported yet" + + @override + @torch.no_grad() + def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] (first dimension need to be 1) + target_id: The target [batch] + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" + + batch_input_ids = input_ids.repeat(self.batch_size, 1) + + batch_importance_score = self.importance_score_evaluator(batch_input_ids, target_id) + + important_score_masked = batch_importance_score * torch.unsqueeze( + self.importance_score_evaluator.stop_mask, -1 + ) + self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum( + self.importance_score_evaluator.stop_mask + ) + + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) + + top_n = self.top_n + + if top_n == 0: + top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) + + pos_top_n = pos_sorted[:, :top_n] + self.pos_top_n = pos_top_n + + if self.overlap_strict_pos: + count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) + pos_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) + return pos_top_n_overlap + else: + token_id_top_n = input_ids[0, pos_top_n] + count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) + _token_id_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) + # TODO: Convert back to pos + raise NotImplementedError("TODO") diff --git a/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py b/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py deleted file mode 100644 index 29878217..00000000 --- a/inseq/attr/feat/ops/reagent_core/sample_rationalizer.py +++ /dev/null @@ -1,237 +0,0 @@ -import math - -import torch -from typing_extensions import override - -from .base import BaseRationalizer -from .importance_score_evaluator.base import BaseImportanceScoreEvaluator - - -class SampleRationalizer(BaseRationalizer): - """SampleRationalizer""" - - def __init__( - self, importance_score_evaluator: BaseImportanceScoreEvaluator, top_n: float = 0, top_n_ratio: float = 0 - ) -> None: - """Constructor - - Args: - importance_score_evaluator: A ImportanceScoreEvaluator - top_n: Rational size - top_n_ratio: Use ratio of sequence to define rational size - - """ - super().__init__(importance_score_evaluator) - - self.top_n = top_n - self.top_n_ratio = top_n_ratio - - @torch.no_grad() - def rationalize(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: - """Compute rational of a sequence on a target - - Args: - input_ids: The sequence [batch, sequence] - target_id: The target [batch] - - Return: - pos_top_n: rational position in the sequence [batch, rational_size] - - """ - batch_importance_score = self.importance_score_evaluator.evaluate(input_ids, target_id) - - self.mean_important_score = torch.mean(batch_importance_score, dim=0) - - pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) - - top_n = self.top_n - - if top_n == 0: - top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) - - pos_top_n = pos_sorted[:, :top_n] - - return pos_top_n - - @override - def trace_start(self) -> None: - """Start tracing""" - super().trace_start() - - self.importance_score_evaluator.trace_start() - - @override - def trace_stop(self) -> None: - """Stop tracing""" - super().trace_stop() - - self.importance_score_evaluator.trace_stop() - - -@torch.no_grad() -def main(): - from rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator - from stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator - from token_replacement.token_replacer.uniform import UniformTokenReplacer - from token_replacement.token_sampler.inferential import InferentialTokenSampler - from token_replacement.token_sampler.postag import POSTagTokenSampler - from token_replacement.token_sampler.uniform import UniformTokenSampler - from transformers import AutoModelWithLMHead, AutoTokenizer - - from utils.serializing import serialize_rational - - # ======== model loading ======== - # Load model from Hugging Face - model = AutoModelWithLMHead.from_pretrained("gpt2-medium") - tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") - - model.cuda() - model.eval() - - # ======== prepare data ======== - - # batch with size 1 - input_string = [ - # "I love eating breakfast in the", - "When my flight landed in Thailand. I was staying in the capital city of" - # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. I was staying in the capital city of" - # "When my flight landed in Thailand, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that's a story for another time). I was staying in the capital city of" - ] - - # generate prediction - input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(model.device) - generated_input = model.generate(input_ids=input_ids, max_length=80, do_sample=False) - print(" generated input -->", [[tokenizer.decode(token) for token in seq] for seq in generated_input]) - - # extract target from prediction - target_id = generated_input[:, input_ids.shape[1]] - print(" target -->", [tokenizer.decode(token) for token in target_id]) - - # ======== hyper-parameters ======== - - # replacing ratio during importance score updating - updating_replacing_ratio = 0.3 - # keep top n word based on importance score for both stop condition evaluation and rationalization - rationale_size_ratio = None - rational_size = 5 - # stop when target exist in top k predictions - stop_condition_tolerance = 5 - - # ======== rationalization ======== - - approach_sample_replacing_token = "uniform" - # approach_sample_replacing_token = "inference" - # approach_sample_replacing_token = "postag" - - # prepare rationalizer - if approach_sample_replacing_token == "uniform": - # Approach 1: sample replacing token from uniform distribution - rationalizer = SampleRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=UniformTokenSampler(tokenizer), ratio=updating_replacing_ratio - ), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=UniformTokenSampler(tokenizer), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - elif approach_sample_replacing_token == "inference": - # Approach 2: sample replacing token from model inference - rationalizer = SampleRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer( - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - ratio=updating_replacing_ratio, - ), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=InferentialTokenSampler(tokenizer=tokenizer, model=model), - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - elif approach_sample_replacing_token == "postag": - # Approach 3: sample replacing token from uniform distribution on a set of words with the same POS tag - ts = POSTagTokenSampler( - tokenizer=tokenizer, device=input_ids.device - ) # Initialize POSTagTokenSampler takes time so share it - rationalizer = SampleRationalizer( - importance_score_evaluator=DeltaProbImportanceScoreEvaluator( - model=model, - tokenizer=tokenizer, - token_replacer=UniformTokenReplacer(token_sampler=ts, ratio=updating_replacing_ratio), - stopping_condition_evaluator=TopKStoppingConditionEvaluator( - model=model, - token_sampler=ts, - top_k=stop_condition_tolerance, - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - tokenizer=tokenizer, - ), - ), - top_n=rational_size, - top_n_ratio=rationale_size_ratio, - ) - else: - raise ValueError("Invalid approach_sample_replacing_token") - - rationalizer.trace_start() - - # rationalization - pos_rational = rationalizer.rationalize(input_ids, generated_input[:, input_ids.shape[1]]) - - # convert results - - print() - print("========================") - print() - print(f"Input --> {input_string[0]}") - print(f"Target --> {tokenizer.decode(target_id[0])}") - print(f"Rational positions --> {pos_rational}") - print("Rational words -->") - for i in range(pos_rational.shape[0]): - ids_rational = input_ids[0, pos_rational[i]] - text_rational = [tokenizer.decode([id_rational]) for id_rational in ids_rational] - print(f"{text_rational}") - - # output - - serialize_rational( - "rationalization_results/demo.json", - -1, - input_ids[0], - target_id[0], - pos_rational[0], - tokenizer, - rationalizer.importance_score_evaluator.important_score[0], - compact=False, - comments={ - "message": "This is a demo output. [comments] is an optional field", - "model": "gpt2-medium", - "approach_type": approach_sample_replacing_token, - }, - trace_rationalizer=rationalizer, - ) - - rationalizer.trace_stop() - - -if __name__ == "__main__": - main() diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py similarity index 64% rename from inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py rename to inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index f3461d33..43e0b564 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/top_k.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -1,12 +1,33 @@ import logging +from abc import ABC, abstractmethod import torch from transformers import AutoModelWithLMHead, AutoTokenizer from typing_extensions import override -from ..token_replacement.token_replacer.ranking import RankingTokenReplacer -from ..token_replacement.token_sampler.base import TokenSampler -from .base import StoppingConditionEvaluator +from .token_replacer import RankingTokenReplacer +from .token_sampler import TokenSampler + + +class StoppingConditionEvaluator(ABC): + """Base class for Stopping Condition Evaluators""" + + @abstractmethod + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + ) -> torch.Tensor: + """Evaluate stop condition according to the specified strategy. + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + + Return: + Whether the stop condition achieved [batch] + + """ + raise NotImplementedError() class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): @@ -37,8 +58,6 @@ def __init__( tokenizer: (Optional) Used for logging top_k_words at each step """ - super().__init__() - self.model = model self.token_sampler = token_sampler self.top_k = top_k @@ -46,7 +65,7 @@ def __init__( self.tokenizer = tokenizer @override - def evaluate( + def __call__( self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor ) -> torch.Tensor: """Evaluate stop condition @@ -60,12 +79,10 @@ def evaluate( Whether the stop condition achieved [batch] """ - super().evaluate(input_ids, target_id, importance_score) - # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} self.token_replacer.set_score(importance_score) - input_ids_replaced, mask_replacing = self.token_replacer.sample(input_ids) + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") @@ -75,9 +92,6 @@ def evaluate( with torch.no_grad(): logits_replaced = self.model(input_ids_replaced)["logits"] - if self.trace_target_likelihood is not None: - self.trace_target_likelihood.append(torch.softmax(logits_replaced, dim=-1)[:, -1, target_id]) - ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] @@ -91,18 +105,33 @@ def evaluate( # Stop flags for each sample in the batch return match_hit - @override - def trace_start(self) -> None: - """Start tracing""" - super().trace_start() - self.token_sampler.trace_start() - self.token_replacer.trace_start() +class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Stopping Condition Evaluator which stop when target exist in top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + @override + def __init__(self) -> None: + """Constructor""" @override - def trace_stop(self) -> None: - """Stop tracing""" - super().trace_stop() + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + ) -> torch.Tensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + + Return: + Whether the stop condition achieved [batch] + + """ + match_hit = torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) - self.token_sampler.trace_stop() - self.token_replacer.trace_stop() + # Stop flags for each sample in the batch + return match_hit diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py deleted file mode 100644 index 82a08909..00000000 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/base.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -from ..utils.traceable import Traceable - - -class StoppingConditionEvaluator(Traceable): - """Base class for Stopping Condition Evaluators""" - - def __init__(self): - """Base Constructor""" - self.trace_target_likelihood = [] - - def evaluate( - self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor - ) -> torch.Tensor: - """Base evaluate""" diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py deleted file mode 100644 index f98efe52..00000000 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator/dummy.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from typing_extensions import override - -from .base import StoppingConditionEvaluator - - -class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): - """ - Stopping Condition Evaluator which stop when target exist in top k predictions, - while top n tokens based on importance_score are not been replaced. - """ - - @override - def __init__(self) -> None: - """Constructor""" - super().__init__() - - @override - def evaluate( - self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor - ) -> torch.Tensor: - """Evaluate stop condition - - Args: - input_ids: Input sequence [batch, sequence] - target_id: Target token [batch] - importance_score: Importance score of the input [batch, sequence] - - Return: - Whether the stop condition achieved [batch] - - """ - super().evaluate(input_ids, target_id, importance_score) - - match_hit = torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) - - # Stop flags for each sample in the batch - return match_hit diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py deleted file mode 100644 index 04260f5c..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/base.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Union - -import torch -from typing_extensions import override - -from ...utils.traceable import Traceable -from ..token_sampler.base import TokenSampler - - -class TokenReplacer(Traceable): - """ - Base class for token replacers - - """ - - def __init__(self, token_sampler: TokenSampler) -> None: - """Base Constructor""" - self.token_sampler = token_sampler - - def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: - """Base sample""" - - @override - def trace_start(self): - """Start tracing""" - super().trace_start() - - self.token_sampler.trace_start() - - @override - def trace_stop(self): - """Stop tracing""" - super().trace_stop() - - self.token_sampler.trace_stop() diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py deleted file mode 100644 index 90084fae..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/threshold.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Union - -import torch -from typing_extensions import override - -from ..token_sampler.base import TokenSampler -from .base import TokenReplacer - - -class ThresholdTokenReplacer(TokenReplacer): - """Replace tokens in a sequence based on a threshold""" - - @override - def __init__(self, token_sampler: TokenSampler, threshold: float, replace_greater: bool = False) -> None: - """Constructor - - Args: - token_sampler: A TokenSampler for sampling replace token. - threshold: replacing threshold - replace_greater: Whether replace top-n. Otherwise, replace the rests. - - """ - super().__init__(token_sampler) - - self.threshold = threshold - self.replace_greater = replace_greater - - def set_value(self, value: torch.Tensor) -> None: - """Set the value for threshold control - - Args: - value: value [batch, sequence] - - """ - if not self.replace_greater: - self.mask_replacing = value < self.threshold - else: - self.mask_replacing = value > self.threshold - - @override - def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: - """Sample a sequence - - Args: - input: input sequence [batch, sequence] - - Returns: - input_replaced: A replaced sequence [batch, sequence] - mask_replacing: Identify which token has been replaced [batch, sequence] - - """ - super().sample(input) - - token_sampled = self.token_sampler.sample(input) - - input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing - - return input_replaced, self.mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py deleted file mode 100644 index 4c663bf1..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Union - -import torch -from typing_extensions import override - -from ..token_sampler.base import TokenSampler -from .base import TokenReplacer - - -class UniformTokenReplacer(TokenReplacer): - """Replace tokens in a sequence where selecting is base on uniform distribution""" - - @override - def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: - """Constructor - - Args: - token_sampler: A TokenSampler for sampling replace token. - ratio: replacing ratio - - """ - super().__init__(token_sampler) - - self.ratio = ratio - - @override - def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: - """Sample a sequence - - Args: - input: input sequence [batch, sequence] - - Returns: - input_replaced: A replaced sequence [batch, sequence] - mask_replacing: Identify which token has been replaced [batch, sequence] - - """ - super().sample(input) - - sample_uniform = torch.rand(input.shape, device=input.device) - mask_replacing = sample_uniform < self.ratio - - token_sampled = self.token_sampler(input) - - input_replaced = input * ~mask_replacing + token_sampled * mask_replacing - - return input_replaced, mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py deleted file mode 100644 index 05d31d71..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/base.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from typing_extensions import override - -from ...utils.traceable import Traceable - - -class TokenSampler(Traceable): - """Base class for token samplers""" - - @override - def __init__(self) -> None: - """Base Constructor""" - super().__init__() - - def sample(self, input: torch.Tensor) -> torch.Tensor: - """Base sample""" diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py deleted file mode 100644 index 0322517a..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from transformers import AutoModelWithLMHead, AutoTokenizer -from typing_extensions import override - -from .base import TokenSampler - - -class InferentialTokenSampler(TokenSampler): - """Sample tokens from a seq-2-seq model""" - - @override - def __init__(self, tokenizer: AutoTokenizer, model: AutoModelWithLMHead) -> None: - """Constructor - - Args: - tokenizer: A Huggingface AutoTokenizer. - model: A Huggingface AutoModelWithLMHead for inference the output. - - """ - super().__init__() - - self.tokenizer = tokenizer - self.model = model - - @override - def sample(self, input: torch.Tensor) -> torch.Tensor: - """Sample a tensor - - Args: - input: input tensor [batch, sequence] - - Returns: - token_inferences: sampled (placement) tokens by inference - - """ - super().sample(input) - - logits_replacing = self.model(input)["logits"] - ids_infer = torch.argmax(logits_replacing, dim=-1) - - token_inferences = torch.cat([input[:, 0:1], ids_infer[:, :-1]], dim=1) - - return token_inferences diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py deleted file mode 100644 index 954b43ae..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -from transformers import AutoModelWithLMHead, AutoTokenizer -from typing_extensions import override - - -class InferentialMTokenSampler(TokenSampler): - """Sample tokens from a seq-2-seq model""" - - @override - def __init__( - self, source_tokenizer: AutoTokenizer, sampler_tokenizer: AutoTokenizer, sampler_model: AutoModelWithLMHead - ) -> None: - """Constructor - - Args: - source_tokenizer: A Huggingface AutoTokenizer for decoding the inputs. - sampler_tokenizer: A Huggingface AutoTokenizer for inference the output. - sampler_model: A Huggingface AutoModelWithLMHead for inference the output. - - """ - super().__init__() - - self.source_tokenizer = source_tokenizer - self.sampler_tokenizer = sampler_tokenizer - self.sampler_model = sampler_model - - @override - def sample(self, inputs: torch.Tensor) -> torch.Tensor: - """Sample a tensor - - Args: - inputs: input tensor [batch, sequence] - - Returns: - token_inferences: sampled (placement) tokens by inference - - """ - batch_li = [] - for seq_i in torch.arange(inputs.shape[0]): - seq_li = [] - for pos_i in torch.arange(inputs.shape[1]): - # first token - if pos_i == 0: - seq_li.append(inputs[seq_i, 0]) - continue - - # following tokens - - probe_prefix = torch.tensor( - [self.sampler_tokenizer.encode(self.source_tokenizer.decode(inputs[seq_i, :pos_i]))], - device=inputs.device, - ) - probe_prefix = probe_prefix[:, :-1] # trim - output_replacing_m = self.sampler_model(probe_prefix) - logits_replacing_m = output_replacing_m["logits"] - logits_replacing_m_last = logits_replacing_m[:, -1] - id_infer_m = torch.argmax(logits_replacing_m_last, dim=-1) - - seq_li.append(id_infer_m.item()) - - batch_li.append(seq_li) - - res = torch.tensor(batch_li, device=inputs.device) - - return res - - -if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer - - device = "cpu" - - source_tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="cache") - source_model = AutoModelForCausalLM.from_pretrained("gpt2", cache_dir="cache").to(device) - source_model.eval() - - sampler_tokenizer = AutoTokenizer.from_pretrained("roberta-base", cache_dir="cache") - sampler_model = AutoModelForCausalLM.from_pretrained("roberta-base", cache_dir="cache").to(device) - sampler_model.eval() - - sampler = InferentialMTokenSampler(source_tokenizer, sampler_tokenizer, sampler_model) - - text = "This is a test sequence" - inputs = torch.tensor([source_tokenizer.encode(text)], device=device) - - outputs = sampler.sample(inputs) - - print(outputs) - print(source_tokenizer.decode(outputs[0])) diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py deleted file mode 100644 index ae99f623..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/postag.py +++ /dev/null @@ -1,85 +0,0 @@ -import nltk -import torch -from transformers import AutoTokenizer -from typing_extensions import override - -from .base import TokenSampler - - -class POSTagTokenSampler(TokenSampler): - """Sample tokens from Uniform distribution on a set of words with the same POS tag""" - - @override - def __init__(self, tokenizer: AutoTokenizer, device=None) -> None: - """Constructor - - Args: - tokenizer: A Huggingface AutoTokenizer. - - """ - super().__init__() - - self.tokenizer = tokenizer - - # extract mapping from postag to words - # debug_mapping_postag_to_group_word = {} - mapping_postag_to_group_token_id = {} - - for i in range(tokenizer.vocab_size): - word = tokenizer.decode([i]) - _, tag = nltk.pos_tag([word.strip()])[0] - if tag not in mapping_postag_to_group_token_id: - # debug_mapping_postag_to_group_word[tag] = [] - mapping_postag_to_group_token_id[tag] = [] - # debug_mapping_postag_to_group_word[tag].append(word) - mapping_postag_to_group_token_id[tag].append(i) - - if i % 5000 == 0: - print(f"[POSTagTokenSampler] Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") - - # create tag_id for postags - self.list_postag = list(mapping_postag_to_group_token_id.keys()) - num_postags = len(self.list_postag) - - # build mapping from tag_id to word group - list_group_token_id = [ - torch.tensor(mapping_postag_to_group_token_id[postag], dtype=torch.long, device=device) - for postag in self.list_postag - ] - - # build mapping from token_id to tag_id - self.mapping_token_id_to_tag_id = torch.zeros([tokenizer.vocab_size], dtype=torch.long, device=device) - for tag_id, group_token_id in enumerate(list_group_token_id): - self.mapping_token_id_to_tag_id[group_token_id] = tag_id - - # build mapping from tag_id to token_id - # postag groups are concat together, index them via compact_idx = group_offsets[tag_id] + group_idx - self.group_sizes = torch.tensor( - [group_token_id.shape[0] for group_token_id in list_group_token_id], dtype=torch.long, device=device - ) - self.group_offsets = torch.sum( - torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.group_sizes, dim=-1 - ) - self.compact_group_token_id = torch.cat(list_group_token_id) - - @override - def sample(self, input: torch.Tensor) -> torch.Tensor: - """Sample a input - - Args: - input: input tensor [batch, sequence] - - Returns: - token_sampled: A sampled tensor where its shape is the same with the input - - """ - super().sample(input) - - tag_id_input = self.mapping_token_id_to_tag_id[input] - sample_uniform = torch.rand(input.shape, device=input.device) - compact_group_idx = (sample_uniform * self.group_sizes[tag_id_input] + self.group_offsets[tag_id_input]).type( - torch.long - ) - token_sampled = self.compact_group_token_id[compact_group_idx] - - return token_sampled diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py deleted file mode 100644 index 8401f942..00000000 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/uniform.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from transformers import AutoTokenizer -from typing_extensions import override - -from .base import TokenSampler - - -class UniformTokenSampler(TokenSampler): - """Sample tokens from Uniform distribution""" - - @override - def __init__(self, tokenizer: AutoTokenizer) -> None: - """Constructor - - Args: - tokenizer: A Huggingface AutoTokenizer. - - """ - super().__init__() - - self.tokenizer = tokenizer - - # masking tokens - avail_mask = torch.ones(tokenizer.vocab_size) - - # mask out special tokens - avail_mask[tokenizer.bos_token_id] = 0 - avail_mask[tokenizer.eos_token_id] = 0 - avail_mask[tokenizer.unk_token_id] = 0 - - # collect available tokens - self.avail_tokens = torch.arange(tokenizer.vocab_size)[avail_mask != 0] - - @override - def sample(self, input: torch.Tensor) -> torch.Tensor: - """Sample a tensor - - Args: - input: input tensor [batch, sequence] - - Returns: - token_uniform: A sampled tensor where its shape is the same with the input - - """ - super().sample(input) - - # sample idx form uniform distribution - sample_uniform = torch.rand(input.shape, device=input.device) - sample_uniform_idx = (sample_uniform * self.avail_tokens.shape[0]).type(torch.int32) - # map idx to tokens - token_uniform = self.avail_tokens.to(sample_uniform_idx)[sample_uniform_idx] - - return token_uniform diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py similarity index 50% rename from inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py rename to inseq/attr/feat/ops/reagent_core/token_replacer.py index fe842652..4757c579 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py @@ -1,11 +1,36 @@ import math +from abc import ABC, abstractmethod from typing import Union import torch from typing_extensions import override -from ..token_sampler.base import TokenSampler -from .base import TokenReplacer +from .token_sampler import TokenSampler + + +class TokenReplacer(ABC): + """ + Base class for token replacers + + """ + + def __init__(self, token_sampler: TokenSampler) -> None: + """Base Constructor""" + self.token_sampler = token_sampler + + @abstractmethod + def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Replace tokens according to the specified strategy. + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + mask_replacing: Identify which token has been replaced [batch, sequence] + + """ + raise NotImplementedError() class RankingTokenReplacer(TokenReplacer): @@ -50,7 +75,7 @@ def set_score(self, value: torch.Tensor) -> None: ) @override - def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: """Sample a sequence Args: @@ -61,10 +86,48 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: mask_replacing: Identify which token has been replaced [batch, sequence] """ - super().sample(input) token_sampled = self.token_sampler(input) input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing return input_replaced, self.mask_replacing + + +class UniformTokenReplacer(TokenReplacer): + """Replace tokens in a sequence where selecting is base on uniform distribution""" + + @override + def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: + """Constructor + + Args: + token_sampler: A TokenSampler for sampling replace token. + ratio: replacing ratio + + """ + super().__init__(token_sampler) + + self.ratio = ratio + + @override + def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + """Sample a sequence + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + mask_replacing: Identify which token has been replaced [batch, sequence] + + """ + + sample_uniform = torch.rand(input.shape, device=input.device) + mask_replacing = sample_uniform < self.ratio + + token_sampled = self.token_sampler(input) + + input_replaced = input * ~mask_replacing + token_sampled * mask_replacing + + return input_replaced, mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py index 73cfb7e7..35da438a 100644 --- a/inseq/attr/feat/ops/reagent_core/token_sampler.py +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -6,6 +6,7 @@ import torch from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing_extensions import override from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available from .....utils.typing import IdsTensor @@ -18,8 +19,16 @@ class TokenSampler(ABC): @abstractmethod def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: - """Sample tokens according to the specified strategy.""" - pass + """Sample tokens according to the specified strategy. + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + + """ + raise NotImplementedError() class POSTagTokenSampler(TokenSampler): @@ -83,7 +92,17 @@ def build_pos_mapping_from_vocab( logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") return pos2ids + @override def __call__(self, input_ids: IdsTensor) -> IdsTensor: + """Sample a tensor + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + + """ input_ids_pos = self.id2pos[input_ids] sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long() diff --git a/inseq/attr/feat/ops/reagent_core/utils/serializing.py b/inseq/attr/feat/ops/reagent_core/utils/serializing.py deleted file mode 100644 index 9502ab7f..00000000 --- a/inseq/attr/feat/ops/reagent_core/utils/serializing.py +++ /dev/null @@ -1,79 +0,0 @@ -import json - -import torch -from transformers import AutoTokenizer - -from ..base import BaseRationalizer - - -def serialize_rational( - filename: str, - id: int, - token_inputs: torch.Tensor, - token_target: torch.Tensor, - position_rational: torch.Tensor, - tokenizer: AutoTokenizer, - important_score: torch.Tensor, - comments: dict = None, - compact: bool = False, - trace_rationalizer: BaseRationalizer = None, - trace_batch_idx: int = 0, - schema_file: str = "../docs/rationalization.schema.json", -) -> None: - """Serialize rationalization result to a json file - - Args: - filename: Filename to store json file - id: id of the record - token_inputs: token_inputs [sequence] - token_target: token_target [1] - position_rational: position of rational tokens [rational] - tokenizer: A Huggingface AutoTokenizer - important_score: final important score of tokens [sequence] - comments: (Optional) A dictionary of comments - compact: Whether store json file in a compact style - trace_rationalizer: (Optional) A Rationalizer with trace started to store trace information - trace_batch_idx: trace index in the batch, if applicable - schema_file: location of the json schema file - - """ - data = { - "$schema": schema_file, - "id": id, - "input-text": [tokenizer.decode([i]) for i in token_inputs], - "input-tokens": [i.item() for i in token_inputs], - "target-text": tokenizer.decode([token_target]), - "target-token": token_target.item(), - "rational-size": position_rational.shape[0], - "rational-positions": [i.item() for i in position_rational], - "rational-text": [tokenizer.decode([i]) for i in token_inputs[position_rational]], - "rational-tokens": [i.item() for i in token_inputs[position_rational]], - } - - if important_score is not None: - data["importance-scores"] = [i.item() for i in important_score] - - if comments: - data["comments"] = comments - - if trace_rationalizer: - trace = { - "importance-scores": [ - [v.item() for v in i[trace_batch_idx]] - for i in trace_rationalizer.importance_score_evaluator.trace_importance_score - ], - "target-likelihood-original": trace_rationalizer.importance_score_evaluator.trace_target_likelihood_original[ - trace_batch_idx - ].item(), - "target-likelihood": [ - i[trace_batch_idx].item() - for i in trace_rationalizer.importance_score_evaluator.stopping_condition_evaluator.trace_target_likelihood - ], - } - data["trace"] = trace - - indent = None if compact else 4 - json_str = json.dumps(data, indent=indent) - - with open(filename, "w") as f_output: - f_output.write(json_str) diff --git a/inseq/attr/feat/ops/reagent_core/utils/traceable.py b/inseq/attr/feat/ops/reagent_core/utils/traceable.py deleted file mode 100644 index 93e16ae5..00000000 --- a/inseq/attr/feat/ops/reagent_core/utils/traceable.py +++ /dev/null @@ -1,8 +0,0 @@ -class Traceable: - """Traceable base""" - - def trace_start(self) -> None: - """Base trace_start""" - - def trace_stop(self) -> None: - """Base trace_stop""" From 504838b62021d5db03ef40a8ff02274104709353 Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Thu, 29 Feb 2024 04:39:19 +0000 Subject: [PATCH 08/14] reagent: adaptation for encoder-decoder models --- inseq/attr/feat/ops/reagent.py | 8 +++- .../importance_score_evaluator.py | 42 ++++++++++++------- .../feat/ops/reagent_core/rationalizer.py | 12 ++++-- .../stopping_condition_evaluator.py | 27 ++++++++++-- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 4e189b63..aef40553 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -111,7 +111,13 @@ def attribute( # type: ignore tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: """Implement attribute""" - self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + if len(additional_forward_args) == 6: + # decoder only + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + elif len(additional_forward_args) == 9: + # encoder-decoder + self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) + mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) return (res,) diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index 8a139bc6..17b833a3 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import torch -from transformers import AutoModelWithLMHead, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from typing_extensions import override from .stopping_condition_evaluator import StoppingConditionEvaluator @@ -12,11 +12,11 @@ class BaseImportanceScoreEvaluator(ABC): """Importance Score Evaluator""" - def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None: + def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None: """Base Constructor Args: - model: A Huggingface AutoModelWithLMHead model + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model tokenizer: A Huggingface AutoTokenizer """ @@ -27,12 +27,15 @@ def __init__(self, model: AutoModelWithLMHead, tokenizer: AutoTokenizer) -> None self.important_score = None @abstractmethod - def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None + ) -> torch.Tensor: """Evaluate importance score of input sequence Args: input_ids: input sequence [batch, sequence] target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: importance_score: evaluated importance score for each token in the input [batch, sequence] @@ -47,7 +50,7 @@ class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): @override def __init__( self, - model: AutoModelWithLMHead, + model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, token_replacer: TokenReplacer, stopping_condition_evaluator: StoppingConditionEvaluator, @@ -56,7 +59,7 @@ def __init__( """Constructor Args: - model: A Huggingface AutoModelWithLMHead model + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model tokenizer: A Huggingface AutoTokenizer token_replacer: A TokenReplacer stopping_condition_evaluator: A StoppingConditionEvaluator @@ -78,6 +81,7 @@ def update_importance_score( input_ids: torch.Tensor, target_id: torch.Tensor, prob_original_target: torch.Tensor, + decoder_input_ids: torch.Tensor = None, ) -> torch.Tensor: """Update importance score by one step @@ -86,6 +90,7 @@ def update_importance_score( input_ids: input tensor [batch, sequence] target_id: target tensor [batch] prob_original_target: predictive probability of the target on the original sequence [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: logit_importance_score: updated importance score in logistic scale [batch] @@ -102,9 +107,12 @@ def update_importance_score( # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) - logits_replaced = self.model(input_ids_replaced)["logits"] - prob_replaced_target = torch.softmax(logits_replaced[:, input_ids_replaced.shape[1] - 1, :], -1)[:, target_id] - self.trace_prob_original_target = prob_replaced_target + if decoder_input_ids is None: + logits_replaced = self.model(input_ids_replaced)["logits"] + else: + logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)["logits"] + + prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id] # Compute changes delta = p^{(y)} - \hat{p^{(y)}} @@ -123,12 +131,15 @@ def update_importance_score( return logit_importance_score @override - def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None + ) -> torch.Tensor: """Evaluate importance score of input sequence Args: input_ids: input sequence [batch, sequence] target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: importance_score: evaluated importance score for each token in the input [batch, sequence] @@ -138,9 +149,12 @@ def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) + if decoder_input_ids is None: + logits_original = self.model(input_ids)["logits"] + else: + logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"] - logits_original = self.model(input_ids)["logits"] - prob_original_target = torch.softmax(logits_original[:, input_ids.shape[1] - 1, :], -1)[:, target_id] + prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id] # Initialize importance score s for each token in the sequence y_{1...t} @@ -154,7 +168,7 @@ def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te # Update importance score logit_importance_score_update = self.update_importance_score( - logit_importance_score, input_ids, target_id, prob_original_target + logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids ) logit_importance_score = ( ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update @@ -165,7 +179,7 @@ def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te # Evaluate stop condition self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( - input_ids, target_id, self.important_score + input_ids, target_id, self.important_score, decoder_input_ids ) if torch.prod(self.stop_mask) > 0: break diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index 1b4511a1..58eb116f 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -15,12 +15,15 @@ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> self.mean_important_score = None @abstractmethod - def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None + ) -> torch.Tensor: """Compute rational of a sequence on a target Args: input_ids: The sequence [batch, sequence] (first dimension need to be 1) target_id: The target [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: pos_top_n: rational position in the sequence [batch, rational_size] @@ -65,12 +68,15 @@ def __init__( @override @torch.no_grad() - def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Tensor: + def __call__( + self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None + ) -> torch.Tensor: """Compute rational of a sequence on a target Args: input_ids: The sequence [batch, sequence] (first dimension need to be 1) target_id: The target [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: pos_top_n: rational position in the sequence [batch, rational_size] @@ -80,7 +86,7 @@ def __call__(self, input_ids: torch.Tensor, target_id: torch.Tensor) -> torch.Te batch_input_ids = input_ids.repeat(self.batch_size, 1) - batch_importance_score = self.importance_score_evaluator(batch_input_ids, target_id) + batch_importance_score = self.importance_score_evaluator(batch_input_ids, target_id, decoder_input_ids) important_score_masked = batch_importance_score * torch.unsqueeze( self.importance_score_evaluator.stop_mask, -1 diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index 43e0b564..51a8da8b 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -14,7 +14,11 @@ class StoppingConditionEvaluator(ABC): @abstractmethod def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + self, + input_ids: torch.Tensor, + target_id: torch.Tensor, + importance_score: torch.Tensor, + decoder_input_ids: torch.Tensor = None, ) -> torch.Tensor: """Evaluate stop condition according to the specified strategy. @@ -22,6 +26,7 @@ def __call__( input_ids: Input sequence [batch, sequence] target_id: Target token [batch] importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: Whether the stop condition achieved [batch] @@ -66,7 +71,11 @@ def __init__( @override def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + self, + input_ids: torch.Tensor, + target_id: torch.Tensor, + importance_score: torch.Tensor, + decoder_input_ids: torch.Tensor = None, ) -> torch.Tensor: """Evaluate stop condition @@ -74,6 +83,7 @@ def __call__( input_ids: Input sequence [batch, sequence] target_id: Target token [batch] importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: Whether the stop condition achieved [batch] @@ -90,7 +100,12 @@ def __call__( assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" with torch.no_grad(): - logits_replaced = self.model(input_ids_replaced)["logits"] + if decoder_input_ids is None: + logits_replaced = self.model(input_ids_replaced)["logits"] + else: + logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)[ + "logits" + ] ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] @@ -118,7 +133,11 @@ def __init__(self) -> None: @override def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, importance_score: torch.Tensor + self, + input_ids: torch.Tensor, + target_id: torch.Tensor, + importance_score: torch.Tensor, + decoder_input_ids: torch.Tensor = None, ) -> torch.Tensor: """Evaluate stop condition From cd39d47b4ddc36bdbdbb49bb33a02d20ac63734a Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Thu, 29 Feb 2024 05:06:23 +0000 Subject: [PATCH 09/14] reagent: adapt type annotation --- .../importance_score_evaluator.py | 29 ++++++++-------- .../feat/ops/reagent_core/rationalizer.py | 11 ++++--- .../stopping_condition_evaluator.py | 33 ++++++++++--------- .../feat/ops/reagent_core/token_replacer.py | 8 +++-- 4 files changed, 46 insertions(+), 35 deletions(-) diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index 17b833a3..bd751655 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -1,10 +1,13 @@ import logging from abc import ABC, abstractmethod +from typing import Optional import torch +from jaxtyping import Float from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from typing_extensions import override +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor from .stopping_condition_evaluator import StoppingConditionEvaluator from .token_replacer import TokenReplacer @@ -28,8 +31,8 @@ def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenize @abstractmethod def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None - ) -> torch.Tensor: + self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence Args: @@ -77,23 +80,23 @@ def __init__( def update_importance_score( self, - logit_importance_score: torch.Tensor, - input_ids: torch.Tensor, - target_id: torch.Tensor, - prob_original_target: torch.Tensor, - decoder_input_ids: torch.Tensor = None, - ) -> torch.Tensor: + logit_importance_score: MultipleScoresPerStepTensor, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + prob_original_target: Float[torch.Tensor, "batch_size 1"], + decoder_input_ids: Optional[IdsTensor] = None, + ) -> MultipleScoresPerStepTensor: """Update importance score by one step Args: - logit_importance_score: Current importance score in logistic scale [batch] + logit_importance_score: Current importance score in logistic scale [batch, sequence] input_ids: input tensor [batch, sequence] target_id: target tensor [batch] - prob_original_target: predictive probability of the target on the original sequence [batch] + prob_original_target: predictive probability of the target on the original sequence [batch, 1] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] Return: - logit_importance_score: updated importance score in logistic scale [batch] + logit_importance_score: updated importance score in logistic scale [batch, sequence] """ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} @@ -132,8 +135,8 @@ def update_importance_score( @override def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None - ) -> torch.Tensor: + self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence Args: diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index 58eb116f..92d0e01b 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -1,9 +1,12 @@ import math from abc import ABC, abstractmethod +from typing import Optional import torch +from jaxtyping import Int64 from typing_extensions import override +from .....utils.typing import IdsTensor, TargetIdsTensor from .importance_score_evaluator import BaseImportanceScoreEvaluator @@ -16,8 +19,8 @@ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> @abstractmethod def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None - ) -> torch.Tensor: + self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target Args: @@ -69,8 +72,8 @@ def __init__( @override @torch.no_grad() def __call__( - self, input_ids: torch.Tensor, target_id: torch.Tensor, decoder_input_ids: torch.Tensor = None - ) -> torch.Tensor: + self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target Args: diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index 51a8da8b..647042e9 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -1,10 +1,13 @@ import logging from abc import ABC, abstractmethod +from typing import Optional import torch +from jaxtyping import Int64 from transformers import AutoModelWithLMHead, AutoTokenizer from typing_extensions import override +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor from .token_replacer import RankingTokenReplacer from .token_sampler import TokenSampler @@ -15,11 +18,11 @@ class StoppingConditionEvaluator(ABC): @abstractmethod def __call__( self, - input_ids: torch.Tensor, - target_id: torch.Tensor, - importance_score: torch.Tensor, - decoder_input_ids: torch.Tensor = None, - ) -> torch.Tensor: + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition according to the specified strategy. Args: @@ -72,11 +75,11 @@ def __init__( @override def __call__( self, - input_ids: torch.Tensor, - target_id: torch.Tensor, - importance_score: torch.Tensor, - decoder_input_ids: torch.Tensor = None, - ) -> torch.Tensor: + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition Args: @@ -134,11 +137,11 @@ def __init__(self) -> None: @override def __call__( self, - input_ids: torch.Tensor, - target_id: torch.Tensor, - importance_score: torch.Tensor, - decoder_input_ids: torch.Tensor = None, - ) -> torch.Tensor: + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition Args: diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py index 4757c579..038dc5b6 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacer.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py @@ -3,8 +3,10 @@ from typing import Union import torch +from jaxtyping import Int64 from typing_extensions import override +from .....utils.typing import IdsTensor from .token_sampler import TokenSampler @@ -19,7 +21,7 @@ def __init__(self, token_sampler: TokenSampler) -> None: self.token_sampler = token_sampler @abstractmethod - def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: """Replace tokens according to the specified strategy. Args: @@ -75,7 +77,7 @@ def set_score(self, value: torch.Tensor) -> None: ) @override - def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: """Sample a sequence Args: @@ -111,7 +113,7 @@ def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: self.ratio = ratio @override - def __call__(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: """Sample a sequence Args: From 00d5fb4d12433e589d2ad9af66cb4392cedcb9ba Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Thu, 29 Feb 2024 15:29:06 +0000 Subject: [PATCH 10/14] reagent: implement attribute_target for encoder-decoder models --- inseq/attr/feat/ops/reagent.py | 13 ++++-- .../importance_score_evaluator.py | 40 +++++++++++++++---- .../feat/ops/reagent_core/rationalizer.py | 21 ++++++++-- .../stopping_condition_evaluator.py | 19 ++++++++- inseq/attr/feat/perturbation_attribution.py | 2 - 5 files changed, 78 insertions(+), 17 deletions(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index aef40553..52964630 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -111,12 +111,19 @@ def attribute( # type: ignore tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: """Implement attribute""" - if len(additional_forward_args) == 6: - # decoder only - self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + if len(additional_forward_args) == 8: + # encoder-decoder with target + self.rationalizer(additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True) + + mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) + res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) + return (res[:, : additional_forward_args[0].shape[1], :], res[:, additional_forward_args[0].shape[1] :, :]) elif len(additional_forward_args) == 9: # encoder-decoder self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) + elif len(additional_forward_args) == 6: + # decoder only + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index bd751655..8b1042af 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -31,7 +31,11 @@ def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenize @abstractmethod def __call__( - self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence @@ -39,6 +43,7 @@ def __call__( input_ids: input sequence [batch, sequence] target_id: target token [batch] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: importance_score: evaluated importance score for each token in the input [batch, sequence] @@ -85,6 +90,7 @@ def update_importance_score( target_id: TargetIdsTensor, prob_original_target: Float[torch.Tensor, "batch_size 1"], decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Update importance score by one step @@ -94,6 +100,7 @@ def update_importance_score( target_id: target tensor [batch] prob_original_target: predictive probability of the target on the original sequence [batch, 1] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: logit_importance_score: updated importance score in logistic scale [batch, sequence] @@ -101,7 +108,12 @@ def update_importance_score( """ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} - input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + if not attribute_target: + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + else: + ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] logging.debug(f"Replacing mask: { mask_replacing }") logging.debug( @@ -112,8 +124,12 @@ def update_importance_score( if decoder_input_ids is None: logits_replaced = self.model(input_ids_replaced)["logits"] - else: + elif not attribute_target: logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)["logits"] + else: + logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids_replaced)[ + "logits" + ] prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id] @@ -135,7 +151,11 @@ def update_importance_score( @override def __call__( - self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence @@ -143,6 +163,7 @@ def __call__( input_ids: input sequence [batch, sequence] target_id: target token [batch] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: importance_score: evaluated importance score for each token in the input [batch, sequence] @@ -161,7 +182,12 @@ def __call__( # Initialize importance score s for each token in the sequence y_{1...t} - logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) + if not attribute_target: + logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) + else: + logit_importance_score = torch.rand( + (input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device + ) logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") # TODO: limit max steps @@ -171,7 +197,7 @@ def __call__( # Update importance score logit_importance_score_update = self.update_importance_score( - logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids + logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target ) logit_importance_score = ( ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update @@ -182,7 +208,7 @@ def __call__( # Evaluate stop condition self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( - input_ids, target_id, self.important_score, decoder_input_ids + input_ids, target_id, self.important_score, decoder_input_ids, attribute_target ) if torch.prod(self.stop_mask) > 0: break diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index 92d0e01b..ed933650 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -19,7 +19,11 @@ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> @abstractmethod def __call__( - self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target @@ -27,6 +31,7 @@ def __call__( input_ids: The sequence [batch, sequence] (first dimension need to be 1) target_id: The target [batch] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: pos_top_n: rational position in the sequence [batch, rational_size] @@ -72,7 +77,11 @@ def __init__( @override @torch.no_grad() def __call__( - self, input_ids: IdsTensor, target_id: TargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target @@ -80,6 +89,7 @@ def __call__( input_ids: The sequence [batch, sequence] (first dimension need to be 1) target_id: The target [batch] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: pos_top_n: rational position in the sequence [batch, rational_size] @@ -88,8 +98,13 @@ def __call__( assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" batch_input_ids = input_ids.repeat(self.batch_size, 1) + batch_decoder_input_ids = ( + decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None + ) - batch_importance_score = self.importance_score_evaluator(batch_input_ids, target_id, decoder_input_ids) + batch_importance_score = self.importance_score_evaluator( + batch_input_ids, target_id, batch_decoder_input_ids, attribute_target + ) important_score_masked = batch_importance_score * torch.unsqueeze( self.importance_score_evaluator.stop_mask, -1 diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index 647042e9..b3b3f6ae 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -22,6 +22,7 @@ def __call__( target_id: TargetIdsTensor, importance_score: MultipleScoresPerStepTensor, decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition according to the specified strategy. @@ -30,6 +31,7 @@ def __call__( target_id: Target token [batch] importance_score: Importance score of the input [batch, sequence] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: Whether the stop condition achieved [batch] @@ -79,6 +81,7 @@ def __call__( target_id: TargetIdsTensor, importance_score: MultipleScoresPerStepTensor, decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition @@ -87,6 +90,7 @@ def __call__( target_id: Target token [batch] importance_score: Importance score of the input [batch, sequence] decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: Whether the stop condition achieved [batch] @@ -95,7 +99,12 @@ def __call__( # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} self.token_replacer.set_score(importance_score) - input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + if not attribute_target: + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + else: + ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") @@ -105,10 +114,14 @@ def __call__( with torch.no_grad(): if decoder_input_ids is None: logits_replaced = self.model(input_ids_replaced)["logits"] - else: + elif not attribute_target: logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)[ "logits" ] + else: + logits_replaced = self.model( + input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids_replaced + )["logits"] ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] @@ -141,6 +154,7 @@ def __call__( target_id: TargetIdsTensor, importance_score: MultipleScoresPerStepTensor, decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size"]: """Evaluate stop condition @@ -148,6 +162,7 @@ def __call__( input_ids: Input sequence [batch, sequence] target_id: Target token [batch] importance_score: Importance score of the input [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models Return: Whether the stop condition achieved [batch] diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 24a4aa68..87bb3196 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -141,8 +141,6 @@ def attribute_step( attribute_fn_main_args: dict[str, Any], attribution_args: dict[str, Any] = {}, ) -> GranularFeatureAttributionStepOutput: - if len(attribute_fn_main_args["inputs"]) > 1: - raise NotImplementedError("ReAgent attribution not supported for encoder-decoder models.") out = super().attribute_step(attribute_fn_main_args, attribution_args) return GranularFeatureAttributionStepOutput( source_attributions=out.source_attributions, From 68132de422d6a2781e8ca38953c4de96562e28ac Mon Sep 17 00:00:00 2001 From: Xuan25 Date: Thu, 29 Feb 2024 15:48:58 +0000 Subject: [PATCH 11/14] reagent: increase the default num_probes --- inseq/attr/feat/ops/reagent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 52964630..420bf7d1 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -63,7 +63,7 @@ def __init__( stopping_condition_top_k: int = 3, replacing_ratio: float = 0.3, max_probe_steps: int = 3000, - num_probes: int = 8, + num_probes: int = 16, ) -> None: PerturbationAttribution.__init__(self, forward_func=attribution_model) From fbf8daf677cf4d91ac3e25a07b4e260e4670a3ad Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Thu, 4 Apr 2024 19:14:49 +0200 Subject: [PATCH 12/14] Various fixes to style, imports and naming --- inseq/attr/feat/ops/reagent.py | 76 ++++++------- inseq/attr/feat/ops/reagent_core/__init__.py | 13 +++ .../importance_score_evaluator.py | 2 + .../feat/ops/reagent_core/rationalizer.py | 34 +++--- .../stopping_condition_evaluator.py | 84 +++++---------- .../feat/ops/reagent_core/token_replacer.py | 100 +++++++----------- .../feat/ops/reagent_core/token_sampler.py | 2 - pyproject.toml | 4 +- requirements-dev.txt | 13 ++- requirements.txt | 12 +-- 10 files changed, 148 insertions(+), 192 deletions(-) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 420bf7d1..88c5af77 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -2,21 +2,23 @@ import torch from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric -from captum.attr._utils.attribution import PerturbationAttribution from torch import Tensor from typing_extensions import override -from .reagent_core.importance_score_evaluator import DeltaProbImportanceScoreEvaluator -from .reagent_core.rationalizer import AggregateRationalizer -from .reagent_core.stopping_condition_evaluator import TopKStoppingConditionEvaluator -from .reagent_core.token_replacer import UniformTokenReplacer -from .reagent_core.token_sampler import POSTagTokenSampler +from ....utils.typing import InseqAttribution +from .reagent_core import ( + AggregateRationalizer, + DeltaProbImportanceScoreEvaluator, + POSTagTokenSampler, + TopKStoppingConditionEvaluator, + UniformTokenReplacer, +) if TYPE_CHECKING: from ....models import HuggingfaceModel -class Reagent(PerturbationAttribution): +class Reagent(InseqAttribution): r"""Recursive attribution generator (ReAGent) method. Measures importance as the drop in prediction probability produced by replacing a token with a plausible @@ -27,30 +29,30 @@ class Reagent(PerturbationAttribution): `__ Args: - forward_func (callable): The forward function of the model or any - modification of it - rational_size (int): Top n tokens based on importance_score are not been replaced during the prediction inference. - top_n_ratio will be used if top_n has been set to 0 - rational_size_ratio (float): TUse ratio of input length to control the top n - stopping_condition_top_k (int): Stop condition achieved when target exist in top k predictions + forward_func (callable): The forward function of the model or any modification of it + keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. + keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions replacing_ratio (float): replacing ratio of tokens for probing max_probe_steps (int): max_probe_steps num_probes (int): number of probes in parallel - Examples: + Example: ``` import inseq - model = inseq.load_model("gpt2-medium", "ReAGent", - rational_size=5, - rational_size_ratio=None, - stopping_condition_top_k=3, - replacing_ratio=0.3, - max_probe_steps=3000, - num_probes=8) - out = model.attribute( - "Super Mario Land is a game that developed by", + model = inseq.load_model("gpt2-medium", "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 ) + out = model.attribute("Super Mario Land is a game that developed by") out.show() ``` """ @@ -58,35 +60,33 @@ class Reagent(PerturbationAttribution): def __init__( self, attribution_model: "HuggingfaceModel", - rational_size: int = 5, - rational_size_ratio: float = None, + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, stopping_condition_top_k: int = 3, replacing_ratio: float = 0.3, max_probe_steps: int = 3000, num_probes: int = 16, ) -> None: - PerturbationAttribution.__init__(self, forward_func=attribution_model) + super().__init__(attribution_model) model = attribution_model.model tokenizer = attribution_model.tokenizer + model_name = attribution_model.model_name - token_sampler = POSTagTokenSampler( - tokenizer=tokenizer, identifier=attribution_model.model_name, device=attribution_model.device - ) - + sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device) stopping_condition_evaluator = TopKStoppingConditionEvaluator( model=model, - token_sampler=token_sampler, + sampler=sampler, top_k=stopping_condition_top_k, - top_n=rational_size, - top_n_ratio=rational_size_ratio, - tokenizer=tokenizer, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, ) - importance_score_evaluator = DeltaProbImportanceScoreEvaluator( model=model, tokenizer=tokenizer, - token_replacer=UniformTokenReplacer(token_sampler=token_sampler, ratio=replacing_ratio), + token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio), stopping_condition_evaluator=stopping_condition_evaluator, max_steps=max_probe_steps, ) @@ -96,8 +96,8 @@ def __init__( batch_size=num_probes, overlap_threshold=0, overlap_strict_pos=True, - top_n=rational_size, - top_n_ratio=rational_size_ratio, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, ) @override diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py index e69de29b..13917c00 100644 --- a/inseq/attr/feat/ops/reagent_core/__init__.py +++ b/inseq/attr/feat/ops/reagent_core/__init__.py @@ -0,0 +1,13 @@ +from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator +from .rationalizer import AggregateRationalizer +from .stopping_condition_evaluator import TopKStoppingConditionEvaluator +from .token_replacer import UniformTokenReplacer +from .token_sampler import POSTagTokenSampler + +__all__ = [ + "DeltaProbImportanceScoreEvaluator", + "AggregateRationalizer", + "TopKStoppingConditionEvaluator", + "UniformTokenReplacer", + "POSTagTokenSampler", +] diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index 8b1042af..a74a161b 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from abc import ABC, abstractmethod from typing import Optional diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index ed933650..7fff9d25 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -13,7 +13,6 @@ class BaseRationalizer(ABC): def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: super().__init__() - self.importance_score_evaluator = importance_score_evaluator self.mean_important_score = None @@ -50,8 +49,8 @@ def __init__( batch_size: int, overlap_threshold: int, overlap_strict_pos: bool = True, - top_n: float = 0, - top_n_ratio: float = 0, + keep_top_n: int = 0, + keep_ratio: float = 0, ) -> None: """Constructor @@ -60,18 +59,16 @@ def __init__( batch_size: Batch size for aggregate overlap_threshold: Overlap threshold of rational tokens within a batch overlap_strict_pos: Whether overlap strict to position ot not - top_n: Rational size - top_n_ratio: Use ratio of sequence to define rational size + keep_top_n: Rational size + keep_ratio: Use ratio of sequence to define rational size """ super().__init__(importance_score_evaluator) - self.batch_size = batch_size self.overlap_threshold = overlap_threshold self.overlap_strict_pos = overlap_strict_pos - self.top_n = top_n - self.top_n_ratio = top_n_ratio - + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio assert overlap_strict_pos, "overlap_strict_pos = False not been supported yet" @override @@ -115,11 +112,7 @@ def __call__( pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) - top_n = self.top_n - - if top_n == 0: - top_n = int(math.ceil(self.top_n_ratio * input_ids.shape[-1])) - + top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n pos_top_n = pos_sorted[:, :top_n] self.pos_top_n = pos_top_n @@ -130,10 +123,11 @@ def __call__( ) return pos_top_n_overlap else: - token_id_top_n = input_ids[0, pos_top_n] - count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) - _token_id_top_n_overlap = torch.unsqueeze( - torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 - ) + raise NotImplementedError("overlap_strict_pos = False not been supported yet") + # TODO: Convert back to pos - raise NotImplementedError("TODO") + # token_id_top_n = input_ids[0, pos_top_n] + # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) + # _token_id_top_n_overlap = torch.unsqueeze( + # torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + # ) diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index b3b3f6ae..c73ef764 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -3,9 +3,7 @@ from typing import Optional import torch -from jaxtyping import Int64 -from transformers import AutoModelWithLMHead, AutoTokenizer -from typing_extensions import override +from transformers import AutoModelForCausalLM from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor from .token_replacer import RankingTokenReplacer @@ -23,7 +21,7 @@ def __call__( importance_score: MultipleScoresPerStepTensor, decoder_input_ids: Optional[IdsTensor] = None, attribute_target: bool = False, - ) -> Int64[torch.Tensor, "batch_size"]: + ) -> TargetIdsTensor: """Evaluate stop condition according to the specified strategy. Args: @@ -34,7 +32,7 @@ def __call__( attribute_target: whether attribute target for encoder-decoder models Return: - Whether the stop condition achieved [batch] + Boolean flag per sequence signaling whether the stop condition was reached [batch] """ raise NotImplementedError() @@ -42,39 +40,36 @@ def __call__( class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): """ - Stopping Condition Evaluator which stop when target exist in top k predictions, + Evaluator stopping when target exist among the top k predictions, while top n tokens based on importance_score are not been replaced. """ - @override def __init__( self, - model: AutoModelWithLMHead, - token_sampler: TokenSampler, + model: AutoModelForCausalLM, + sampler: TokenSampler, top_k: int, - top_n: int = 0, - top_n_ratio: float = 0, - tokenizer: AutoTokenizer = None, + keep_top_n: int = 0, + keep_ratio: float = 0, + invert_keep: bool = False, ) -> None: - """Constructor + """Constructor for the TopKStoppingConditionEvaluator class. Args: - model: A Huggingface AutoModelWithLMHead. - token_sampler: A TokenSampler to sample replacement tokens - top_k: Stop condition achieved when target exist in top k predictions - top_n: Top n tokens based on importance_score are not been replaced during the prediction inference. - top_n_ratio will be used if top_n has been set to 0 - top_n_ratio: Use ratio of input length to control the top n - tokenizer: (Optional) Used for logging top_k_words at each step - + model: A Huggingface AutoModelForCausalLM. + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens. + top_k: Top K predictions in which the target must be included in order to achieve the stopping condition. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. """ self.model = model - self.token_sampler = token_sampler self.top_k = top_k - self.token_replacer = RankingTokenReplacer(self.token_sampler, top_n, top_n_ratio) - self.tokenizer = tokenizer + self.replacer = RankingTokenReplacer(sampler, keep_top_n, keep_ratio, invert_keep) - @override def __call__( self, input_ids: IdsTensor, @@ -82,7 +77,7 @@ def __call__( importance_score: MultipleScoresPerStepTensor, decoder_input_ids: Optional[IdsTensor] = None, attribute_target: bool = False, - ) -> Int64[torch.Tensor, "batch_size"]: + ) -> TargetIdsTensor: """Evaluate stop condition Args: @@ -93,23 +88,20 @@ def __call__( attribute_target: whether attribute target for encoder-decoder models Return: - Whether the stop condition achieved [batch] - + Boolean flag per sequence signaling whether the stop condition was reached [batch] """ # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} - - self.token_replacer.set_score(importance_score) + self.replacer.set_score(importance_score) if not attribute_target: - input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + input_ids_replaced, mask_replacing = self.replacer(input_ids) else: - ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1)) + ids_replaced, mask_replacing = self.replacer(torch.cat((input_ids, decoder_input_ids), 1)) input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} - assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" with torch.no_grad(): if decoder_input_ids is None: @@ -126,10 +118,6 @@ def __call__( ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] - if self.tokenizer: - top_k_words = [[self.tokenizer.decode([token_id]) for token_id in seq] for seq in ids_prediction_top_k] - logging.debug(f"Top K words -> {top_k_words}") - match_mask = ids_prediction_top_k == target_id match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool) @@ -143,19 +131,7 @@ class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): while top n tokens based on importance_score are not been replaced. """ - @override - def __init__(self) -> None: - """Constructor""" - - @override - def __call__( - self, - input_ids: IdsTensor, - target_id: TargetIdsTensor, - importance_score: MultipleScoresPerStepTensor, - decoder_input_ids: Optional[IdsTensor] = None, - attribute_target: bool = False, - ) -> Int64[torch.Tensor, "batch_size"]: + def __call__(self, input_ids: IdsTensor, **kwargs) -> TargetIdsTensor: """Evaluate stop condition Args: @@ -165,10 +141,6 @@ def __call__( attribute_target: whether attribute target for encoder-decoder models Return: - Whether the stop condition achieved [batch] - + Boolean flag per sequence signaling whether the stop condition was reached [batch] """ - match_hit = torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) - - # Stop flags for each sample in the batch - return match_hit + return torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py index 038dc5b6..0d889144 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacer.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py @@ -1,9 +1,7 @@ import math from abc import ABC, abstractmethod -from typing import Union import torch -from jaxtyping import Int64 from typing_extensions import override from .....utils.typing import IdsTensor @@ -16,12 +14,11 @@ class TokenReplacer(ABC): """ - def __init__(self, token_sampler: TokenSampler) -> None: - """Base Constructor""" - self.token_sampler = token_sampler + def __init__(self, sampler: TokenSampler) -> None: + self.sampler = sampler @abstractmethod - def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: """Replace tokens according to the specified strategy. Args: @@ -29,7 +26,7 @@ def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "ba Returns: input_replaced: A replaced sequence [batch, sequence] - mask_replacing: Identify which token has been replaced [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] """ raise NotImplementedError() @@ -40,96 +37,75 @@ class RankingTokenReplacer(TokenReplacer): @override def __init__( - self, token_sampler: TokenSampler, top_n: int = 0, top_n_ratio: float = 0, replace_greater: bool = False + self, sampler: TokenSampler, keep_top_n: int = 0, keep_ratio: float = 0, invert_keep: bool = False ) -> None: - """Constructor + """Constructor for the RankingTokenReplacer class. Args: - token_sampler: A TokenSampler for sampling replace token. - top_n: Top N as the threshold. If top_n is 0, use top_n_ratio instead. - top_n_ratio: Use ratio of input to control to top_n - replace_greater: Whether replace top-n. Otherwise, replace the rests. - + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. """ - super().__init__(token_sampler) - - self.top_n = top_n - self.top_n_ratio = top_n_ratio - self.replace_greater = replace_greater + super().__init__(sampler) + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio + self.invert_keep = invert_keep def set_score(self, value: torch.Tensor) -> None: pos_sorted = torch.argsort(value, descending=True) - - top_n = self.top_n - - if top_n == 0: - top_n = int(math.ceil(self.top_n_ratio * value.shape[-1])) - + top_n = int(math.ceil(self.keep_ratio * value.shape[-1])) if not self.keep_top_n else self.keep_top_n pos_top_n = pos_sorted[..., :top_n] - - if not self.replace_greater: - self.mask_replacing = torch.ones(value.shape, device=value.device, dtype=torch.bool).scatter( - -1, pos_top_n, 0 - ) - else: - self.mask_replacing = torch.zeros(value.shape, device=value.device, dtype=torch.bool).scatter( - -1, pos_top_n, 1 - ) + self.replacement_mask = torch.ones_like(value, device=value.device, dtype=torch.bool).scatter( + -1, pos_top_n, self.invert_keep + ) @override - def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: """Sample a sequence Args: - input: input sequence [batch, sequence] + input: Input sequence of ids of shape [batch, sequence] Returns: input_replaced: A replaced sequence [batch, sequence] - mask_replacing: Identify which token has been replaced [batch, sequence] - + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] """ - - token_sampled = self.token_sampler(input) - - input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing - - return input_replaced, self.mask_replacing + token_sampled = self.sampler(input) + input_replaced = input * ~self.replacement_mask + token_sampled * self.replacement_mask + return input_replaced, self.replacement_mask class UniformTokenReplacer(TokenReplacer): """Replace tokens in a sequence where selecting is base on uniform distribution""" @override - def __init__(self, token_sampler: TokenSampler, ratio: float) -> None: + def __init__(self, sampler: TokenSampler, ratio: float) -> None: """Constructor Args: - token_sampler: A TokenSampler for sampling replace token. - ratio: replacing ratio - + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + ratio: Ratio of tokens to replace in the sequence. """ - super().__init__(token_sampler) - + super().__init__(sampler) self.ratio = ratio @override - def __call__(self, input: IdsTensor) -> Union[IdsTensor, Int64[torch.Tensor, "batch_size seq_len"]]: + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: """Sample a sequence Args: - input: input sequence [batch, sequence] + input: Input sequence of ids of shape [batch, sequence] Returns: input_replaced: A replaced sequence [batch, sequence] - mask_replacing: Identify which token has been replaced [batch, sequence] - + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] """ - sample_uniform = torch.rand(input.shape, device=input.device) - mask_replacing = sample_uniform < self.ratio - - token_sampled = self.token_sampler(input) - - input_replaced = input * ~mask_replacing + token_sampled * mask_replacing - - return input_replaced, mask_replacing + replacement_mask = sample_uniform < self.ratio + token_sampled = self.sampler(input) + input_replaced = input * ~replacement_mask + token_sampled * replacement_mask + return input_replaced, replacement_mask diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py index 35da438a..7ca41bf2 100644 --- a/inseq/attr/feat/ops/reagent_core/token_sampler.py +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -26,7 +26,6 @@ def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: Returns: token_uniform: A sampled tensor where its shape is the same with the input - """ raise NotImplementedError() @@ -101,7 +100,6 @@ def __call__(self, input_ids: IdsTensor) -> IdsTensor: Returns: token_uniform: A sampled tensor where its shape is the same with the input - """ input_ids_pos = self.id2pos[input_ids] sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) diff --git a/pyproject.toml b/pyproject.toml index 81154396..cb7c4781 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ dependencies = [ "torch>=2.1.1", "matplotlib>=3.5.3", "tqdm>=4.64.0", - "nltk>=3.8.1", "nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'", "nvidia-cuda-cupti-cu11>=11.7.101; sys_platform=='Linux'", "nvidia-cuda-nvrtc-cu11>=11.7.99; sys_platform=='Linux'", @@ -94,6 +93,9 @@ notebook = [ "ipykernel>=6.29.2", "ipywidgets>=8.1.2" ] +nltk = [ + "nltk>=3.8.1", +] [project.urls] homepage = "https://github.com/inseq-team/inseq" diff --git a/requirements-dev.txt b/requirements-dev.txt index 92a9ca95..3f1083e4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile --all-extras pyproject.toml -o requirements-dev.txt aiohttp==3.9.3 # via @@ -32,6 +32,7 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via + # nltk # pydoclint # safety # typer @@ -123,7 +124,9 @@ jinja2==3.1.3 # sphinx # torch joblib==1.3.2 - # via scikit-learn + # via + # nltk + # scikit-learn jupyter-client==8.6.0 # via ipykernel jupyter-core==5.7.1 @@ -160,6 +163,7 @@ nest-asyncio==1.6.0 # via ipykernel networkx==3.2.1 # via torch +nltk==3.8.1 nodeenv==1.8.0 # via pre-commit numpy==1.26.4 @@ -258,7 +262,9 @@ pyzmq==25.1.2 # ipykernel # jupyter-client regex==2023.12.25 - # via transformers + # via + # nltk + # transformers requests==2.31.0 # via # datasets @@ -351,6 +357,7 @@ tqdm==4.66.2 # captum # datasets # huggingface-hub + # nltk # transformers traitlets==5.14.1 # via diff --git a/requirements.txt b/requirements.txt index 38d86737..93809632 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ -# This file was autogenerated by uv v0.1.5 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 certifi==2024.2.2 # via requests charset-normalizer==3.3.2 # via requests -click==8.1.7 - # via nltk contourpy==1.2.0 # via matplotlib cycler==0.12.1 @@ -31,8 +29,6 @@ idna==3.6 jaxtyping==0.2.25 jinja2==3.1.3 # via torch -joblib==1.3.2 - # via nltk kiwisolver==1.4.5 # via matplotlib markdown-it-py==3.0.0 @@ -47,7 +43,6 @@ mpmath==1.3.0 # via sympy networkx==3.2.1 # via torch -nltk==3.8.1 numpy==1.26.4 # via # captum @@ -75,9 +70,7 @@ pyyaml==6.0.1 # huggingface-hub # transformers regex==2023.12.25 - # via - # nltk - # transformers + # via transformers requests==2.31.0 # via # huggingface-hub @@ -99,7 +92,6 @@ tqdm==4.66.2 # via # captum # huggingface-hub - # nltk # transformers transformers==4.38.1 typeguard==2.13.3 From d96e00804b6220b9516f327160ef58e0499a42d6 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Thu, 4 Apr 2024 19:38:45 +0200 Subject: [PATCH 13/14] Bump cryptography package (fix safety) --- pyproject.toml | 2 +- requirements-dev.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cb7c4781..379b8060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ docs = [ ] lint = [ "bandit>=1.7.4", - "safety>=2.2.0", + "safety>=3.1.0", "pydoclint>=0.4.0", "pre-commit>=2.19.0", "pytest>=7.2.0", diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f1083e4..2fda0ec1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -44,7 +44,7 @@ contourpy==1.2.0 # via matplotlib coverage==7.4.1 # via pytest-cov -cryptography==42.0.2 +cryptography==42.0.5 # via authlib cycler==0.12.1 # via matplotlib @@ -286,7 +286,7 @@ ruamel-yaml-clib==0.2.8 ruff==0.2.1 safetensors==0.4.2 # via transformers -safety==3.0.1 +safety==3.1.0 safety-schemas==0.0.2 # via safety scikit-learn==1.4.0 From 1862b840d126fef471ae1981e7025352f327a9ea Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Sat, 13 Apr 2024 11:42:29 +0200 Subject: [PATCH 14/14] Finished revising initial ReAGent implementation --- README.md | 8 +++- .../main_classes/feature_attribution.rst | 23 ++++++++- inseq/attr/feat/__init__.py | 2 + inseq/attr/feat/ops/reagent.py | 36 +++++++------- .../importance_score_evaluator.py | 40 ++++------------ .../feat/ops/reagent_core/rationalizer.py | 29 +++++------- .../stopping_condition_evaluator.py | 20 ++------ inseq/attr/feat/perturbation_attribution.py | 47 +++++++++++++++++-- 8 files changed, 118 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 5e09d734..583d8a88 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available - `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023) +- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024) + #### Step functions Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported: @@ -303,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi ## Research using Inseq -Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*). +Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. + +> [!TIP] +> Last update: April 2024. Please open a pull request to add your publication to the list.
2023 @@ -324,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
  1. LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools (Wang et al., 2024)
  2. ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models (Zhao et al., 2024)
  3. +
  4. Revisiting subword tokenization: A case study on affixal negation in large language models (Truong et al., 2024)
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst index 1f282626..d7c4f5fc 100644 --- a/docs/source/main_classes/feature_attribution.rst +++ b/docs/source/main_classes/feature_attribution.rst @@ -90,4 +90,25 @@ Perturbation-based Attribution Methods :members: .. autoclass:: inseq.attr.feat.ValueZeroingAttribution - :members: \ No newline at end of file + :members: + +.. autoclass:: inseq.attr.feat.ReagentAttribution + :members: + + .. automethod:: __init__ + +.. code:: python + + import inseq + + model = inseq.load_model( + "gpt2-medium", + "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 + ) + out = model.attribute("Super Mario Land is a game that developed by") + out.show() diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py index 2b25778a..7a81014f 100644 --- a/inseq/attr/feat/__init__.py +++ b/inseq/attr/feat/__init__.py @@ -18,6 +18,7 @@ LimeAttribution, OcclusionAttribution, PerturbationAttributionRegistry, + ReagentAttribution, ValueZeroingAttribution, ) @@ -43,4 +44,5 @@ "SequentialIntegratedGradientsAttribution", "ValueZeroingAttribution", "PerturbationAttributionRegistry", + "ReagentAttribution", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 88c5af77..080a2131 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -111,20 +111,24 @@ def attribute( # type: ignore tuple[TensorOrTupleOfTensorsGeneric, Tensor], ]: """Implement attribute""" - if len(additional_forward_args) == 8: - # encoder-decoder with target - self.rationalizer(additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True) - - mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) - res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) - return (res[:, : additional_forward_args[0].shape[1], :], res[:, additional_forward_args[0].shape[1] :, :]) - elif len(additional_forward_args) == 9: - # encoder-decoder - self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) - elif len(additional_forward_args) == 6: - # decoder only - self.rationalizer(additional_forward_args[0], additional_forward_args[1]) - - mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0) - res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2]) + # encoder-decoder + if self.forward_func.is_encoder_decoder: + # with target-side attribution + if len(inputs) > 1: + self.rationalizer( + additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True + ) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) + return ( + res[:, : additional_forward_args[0].shape[1], :], + res[:, additional_forward_args[0].shape[1] :, :], + ) + # source-side only + else: + self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) + # decoder-only + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) return (res,) diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index a74a161b..adc79b7b 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -23,13 +23,10 @@ def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenize Args: model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model tokenizer: A Huggingface AutoTokenizer - """ - self.model = model self.tokenizer = tokenizer - - self.important_score = None + self.importance_score = None @abstractmethod def __call__( @@ -73,16 +70,12 @@ def __init__( tokenizer: A Huggingface AutoTokenizer token_replacer: A TokenReplacer stopping_condition_evaluator: A StoppingConditionEvaluator - """ - super().__init__(model, tokenizer) - self.token_replacer = token_replacer self.stopping_condition_evaluator = stopping_condition_evaluator self.max_steps = max_steps - - self.important_score = None + self.importance_score = None self.num_steps = 0 def update_importance_score( @@ -106,10 +99,8 @@ def update_importance_score( Return: logit_importance_score: updated importance score in logistic scale [batch, sequence] - """ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} - if not attribute_target: input_ids_replaced, mask_replacing = self.token_replacer(input_ids) else: @@ -123,32 +114,23 @@ def update_importance_score( ) # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) - - if decoder_input_ids is None: - logits_replaced = self.model(input_ids_replaced)["logits"] - elif not attribute_target: - logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)["logits"] - else: - logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids_replaced)[ - "logits" - ] - + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id] # Compute changes delta = p^{(y)} - \hat{p^{(y)}} - delta_prob_target = prob_original_target - prob_replaced_target logging.debug(f"likelihood delta: { delta_prob_target }") # Update importance scores based on delta (magnitude) and replacement (direction) - delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target # TODO: better solution? # Rescaling from [-1, 1] to [0, 1] before logit function logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) logit_importance_score = logit_importance_score + logit_delta_score logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") - return logit_importance_score @override @@ -169,9 +151,7 @@ def __call__( Return: importance_score: evaluated importance score for each token in the input [batch, sequence] - """ - self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) @@ -183,7 +163,6 @@ def __call__( prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id] # Initialize importance score s for each token in the sequence y_{1...t} - if not attribute_target: logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) else: @@ -196,7 +175,6 @@ def __call__( self.num_steps = 0 while self.num_steps < self.max_steps: self.num_steps += 1 - # Update importance score logit_importance_score_update = self.update_importance_score( logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target @@ -205,16 +183,14 @@ def __call__( ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score ) - - self.important_score = torch.softmax(logit_importance_score, -1) + self.importance_score = torch.softmax(logit_importance_score, -1) # Evaluate stop condition self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( - input_ids, target_id, self.important_score, decoder_input_ids, attribute_target + input_ids, target_id, self.importance_score, decoder_input_ids, attribute_target ) if torch.prod(self.stop_mask) > 0: break logging.info(f"Importance score evaluated in {self.num_steps} steps.") - return torch.softmax(logit_importance_score, -1) diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index 7fff9d25..ab7c2be8 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -14,7 +14,7 @@ class BaseRationalizer(ABC): def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: super().__init__() self.importance_score_evaluator = importance_score_evaluator - self.mean_important_score = None + self.mean_importance_score = None @abstractmethod def __call__( @@ -59,9 +59,10 @@ def __init__( batch_size: Batch size for aggregate overlap_threshold: Overlap threshold of rational tokens within a batch overlap_strict_pos: Whether overlap strict to position ot not - keep_top_n: Rational size - keep_ratio: Use ratio of sequence to define rational size - + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. """ super().__init__(importance_score_evaluator) self.batch_size = batch_size @@ -69,7 +70,7 @@ def __init__( self.overlap_strict_pos = overlap_strict_pos self.keep_top_n = keep_top_n self.keep_ratio = keep_ratio - assert overlap_strict_pos, "overlap_strict_pos = False not been supported yet" + assert overlap_strict_pos, "overlap_strict_pos = False is not supported yet" @override @torch.no_grad() @@ -83,9 +84,10 @@ def __call__( """Compute rational of a sequence on a target Args: - input_ids: The sequence [batch, sequence] (first dimension need to be 1) - target_id: The target [batch] - decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + input_ids: A tensor of ids of shape [batch, sequence_len] + target_id: A tensor of predicted targets of size [batch] + decoder_input_ids (optional): A tensor of ids representing the decoder input sequence for + ``AutoModelForSeq2SeqLM``, with shape [batch, sequence_len] attribute_target: whether attribute target for encoder-decoder models Return: @@ -93,29 +95,23 @@ def __call__( """ assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" - batch_input_ids = input_ids.repeat(self.batch_size, 1) batch_decoder_input_ids = ( decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None ) - batch_importance_score = self.importance_score_evaluator( batch_input_ids, target_id, batch_decoder_input_ids, attribute_target ) - - important_score_masked = batch_importance_score * torch.unsqueeze( + importance_score_masked = batch_importance_score * torch.unsqueeze( self.importance_score_evaluator.stop_mask, -1 ) - self.mean_important_score = torch.sum(important_score_masked, dim=0) / torch.sum( + self.mean_importance_score = torch.sum(importance_score_masked, dim=0) / torch.sum( self.importance_score_evaluator.stop_mask ) - pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) - top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n pos_top_n = pos_sorted[:, :top_n] self.pos_top_n = pos_top_n - if self.overlap_strict_pos: count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) pos_top_n_overlap = torch.unsqueeze( @@ -124,7 +120,6 @@ def __call__( return pos_top_n_overlap else: raise NotImplementedError("overlap_strict_pos = False not been supported yet") - # TODO: Convert back to pos # token_id_top_n = input_ids[0, pos_top_n] # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index c73ef764..fd3bb67d 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -56,7 +56,7 @@ def __init__( """Constructor for the TopKStoppingConditionEvaluator class. Args: - model: A Huggingface AutoModelForCausalLM. + model: A Huggingface ``AutoModelForCausalLM``. sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens. top_k: Top K predictions in which the target must be included in order to achieve the stopping condition. keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be @@ -104,24 +104,14 @@ def __call__( # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" with torch.no_grad(): - if decoder_input_ids is None: - logits_replaced = self.model(input_ids_replaced)["logits"] - elif not attribute_target: - logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)[ - "logits" - ] - else: - logits_replaced = self.model( - input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids_replaced - )["logits"] - + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] - match_mask = ids_prediction_top_k == target_id match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool) - - # Stop flags for each sample in the batch return match_hit diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 47dd8767..498093af 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import TYPE_CHECKING, Any from captum.attr import Occlusion @@ -13,6 +13,9 @@ from .gradient_attribution import FeatureAttribution from .ops import Lime, Reagent, ValueZeroing +if TYPE_CHECKING: + from ...models import HuggingfaceModel + logger = logging.getLogger(__name__) @@ -127,15 +130,49 @@ class ReagentAttribution(PerturbationAttributionRegistry): alternative predicted by a LM. Reference implementation: - `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models - `__ + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models `__ """ method_name = "reagent" - def __init__(self, attribution_model, **kwargs): + def __init__( + self, + attribution_model: "HuggingfaceModel", + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 16, + ): + """ReAGent method constructor. + + Args: + keep_top_n (:obj:`int`, `optional`): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. Default: ``5``. + keep_ratio (:obj:`float`, `optional`): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep (:obj:`bool`, `optional`): If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. Default: ``False``. + stopping_condition_top_k (:obj:`int`, `optional`): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions. Default: ``3``. + replacing_ratio (:obj:`float`, `optional`): replacing ratio of tokens for probing. Default: ``0.3``. + max_probe_steps (:obj:`int`, `optional`): Max number of steps before stopping the probing. Default: ``3000``. + num_probes (:obj:`int`, `optional`): Number of probes performed in parallel. Default: ``16``. + """ super().__init__(attribution_model) - self.method = Reagent(attribution_model=self.attribution_model, **kwargs) + # Custom target attribution is currently not supported + self.use_predicted_target = False + self.method = Reagent( + attribution_model=self.attribution_model, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, + stopping_condition_top_k=stopping_condition_top_k, + replacing_ratio=replacing_ratio, + max_probe_steps=max_probe_steps, + num_probes=num_probes, + ) def attribute_step( self,