From f2b9d9204a1d8bf28a41853845e7701057a11d99 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Thu, 25 Apr 2024 21:56:16 +0200 Subject: [PATCH] Support Value Zeroing for non-eager attention types (#267) --- inseq/attr/feat/ops/value_zeroing.py | 4 ++-- inseq/models/huggingface_model.py | 3 --- inseq/utils/__init__.py | 3 +-- inseq/utils/hooks.py | 7 +++---- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py index c2afdbc7..ed95eb12 100644 --- a/inseq/attr/feat/ops/value_zeroing.py +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -14,6 +14,7 @@ import logging from enum import Enum +from types import FrameType from typing import TYPE_CHECKING, Callable, Optional import torch @@ -22,7 +23,6 @@ from torch.utils.hooks import RemovableHandle from ....utils import ( - StackFrame, find_block_stack, get_post_variable_assignment_hook, recursive_get_submodule, @@ -100,7 +100,7 @@ def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]: """ def value_zeroing_forward_mid_hook( - frame: StackFrame, + frame: FrameType, zeroed_token_index: Optional[int] = None, zeroed_units_indices: Optional[OneOrMoreIndices] = None, batch_size: int = 1, diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index d6cc3f6b..3ea8bc59 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -95,9 +95,6 @@ def __init__( if isinstance(model, PreTrainedModel): self.model = model else: - if "output_attentions" not in model_kwargs: - model_kwargs["output_attentions"] = True - self.model = self._autoclass.from_pretrained(model, **model_kwargs) self.model_name = self.model.config.name_or_path self.tokenizer_name = tokenizer if isinstance(tokenizer, str) else None diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 9eb39aba..92a763cf 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -8,7 +8,7 @@ MissingAttributionMethodError, UnknownAttributionMethodError, ) -from .hooks import StackFrame, get_post_variable_assignment_hook +from .hooks import get_post_variable_assignment_hook from .import_utils import ( is_accelerate_available, is_captum_available, @@ -127,7 +127,6 @@ "filter_logits", "cli_arg", "get_post_variable_assignment_hook", - "StackFrame", "validate_indices", "pad_with_nan", "recursive_get_submodule", diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py index 02472f4e..98fd07e5 100644 --- a/inseq/utils/hooks.py +++ b/inseq/utils/hooks.py @@ -1,14 +1,13 @@ import re from inspect import getsourcelines from sys import gettrace, settrace -from typing import Callable, Optional, TypeVar +from types import FrameType +from typing import Callable, Optional from torch import nn from .misc import get_left_padding -StackFrame = TypeVar("StackFrame") - def get_last_variable_assignment_position( module: nn.Module, @@ -57,7 +56,7 @@ def get_post_variable_assignment_hook( module: nn.Module, varname: str, fname: str = "forward", - hook_fn: Callable[[StackFrame], None] = lambda **kwargs: None, + hook_fn: Callable[[FrameType], None] = lambda **kwargs: None, **kwargs, ) -> Callable[[], None]: """Creates a hook that is called after the last variable assignment in the specified method of a `nn.Module`.