Skip to content

Commit

Permalink
VZ working for GPT-2, including last layer
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Apr 20, 2023
1 parent 009bb6d commit 1637a73
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
29 changes: 20 additions & 9 deletions inseq/attr/feat/ops/value_zeroing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def attribute(
layers is int or None), or a custom function defined by the user.
metric (:obj:`str`, optional): The similarity metric to use for computing the distance between hidden
states produced with and without the zeroing operation. Default: cosine.
encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers,
encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for
encoder-decoders models. Default: None.
decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
Expand Down Expand Up @@ -162,35 +162,46 @@ def attribute(
# Hooks:
# 1. states_extract_and_patch_hook on the transformer block stores corrupted states and force clean states
# as the output of the block forward pass, i.e. the zeroing is done independently across layers.
# 2. value_zeroing_block_hook on the attention module performs the value zeroing by dynamically replacing
# the intermediate "value" tensor in the forward (name is config-dependent) with a zeroed version for the
# specified token index.
# 2. value_zeroing_hook on the attention module performs the value zeroing by replacing the "value" tensor
# during the forward (name is config-dependent) with a zeroed version for the specified token index.
#
# State extraction hooks can be registered only once since they are token-independent
# Skip last block since its states are not used raw, but may have further transformations applied to them
# (e.g. LayerNorm, Dropout). These are extracted separately from the model outputs.
states_extraction_hook_handles = []
for block_idx, block in enumerate(decoder_stack):
for block_idx in range(len(decoder_stack) - 1):
states_extract_and_patch_hook = self.get_states_extract_and_patch_hook(block_idx, hidden_state_idx=0)
states_extraction_hook_handles.append(block.register_forward_hook(states_extract_and_patch_hook))
states_extraction_hook_handles.append(
decoder_stack[block_idx].register_forward_hook(states_extract_and_patch_hook)
)

# Zeroing is done for every token in the target sequence separately (O(n) complexity)
for token_idx in range(tgt_seq_len):
value_zeroing_hook_handles = []
# Value zeroing hooks are registered for every token separately since they are token-dependent
for block in decoder_stack:
attention_module = block.get_submodule(self.forward_func.config.attention_module)
value_zeroing_block_hook = get_post_variable_assignment_hook(
value_zeroing_hook = get_post_variable_assignment_hook(
attention_module,
hook_fn=self.get_value_zeroing_hook(self.forward_func.config.value_vector),
varname=self.forward_func.config.value_vector,
value_zeroing_index=token_idx,
)
value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_block_hook)
value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_hook)
value_zeroing_hook_handles.append(value_zeroing_hook_handle)

# Run forward pass with hooks. Fills self.corrupted_hidden_states with corrupted states across layers
# when zeroing the specified token index.
with torch.no_grad():
self.forward_func(*inputs, *additional_forward_args)
output = self.forward_func.forward_with_output(
*inputs, *additional_forward_args, output_hidden_states=True
)
# Extract last layer states directly from the model outputs
corrupted_states_dict = self.forward_func.get_hidden_states_dict(output)
corrupted_decoder_last_hidden_state = (
corrupted_states_dict["decoder_hidden_states"][:, -1, ...].clone().detach().cpu()
)
self.corrupted_block_output_states[len(decoder_stack) - 1] = corrupted_decoder_last_hidden_state
for handle in value_zeroing_hook_handles:
handle.remove()
for block_idx in range(len(decoder_stack)):
Expand Down
23 changes: 19 additions & 4 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .model_decorators import unhooked

ModelOutput = TypeVar("ModelOutput")
CustomForwardOutput = TypeVar("CustomForwardOutput")


logger = logging.getLogger(__name__)
Expand All @@ -55,7 +56,7 @@ def __call__(
use_embeddings: bool,
attributed_fn_argnames: Optional[List[str]],
*args,
) -> LogitsTensor:
) -> CustomForwardOutput:
...


Expand Down Expand Up @@ -109,9 +110,9 @@ def enrich_step_output(
raise NotImplementedError()

@staticmethod
def format_forward_args(forward: ForwardMethod) -> Callable[..., LogitsTensor]:
def format_forward_args(forward: ForwardMethod) -> Callable[..., CustomForwardOutput]:
@wraps(forward)
def formatted_forward_input_wrapper(self, *args, **kwargs):
def formatted_forward_input_wrapper(self, *args, **kwargs) -> CustomForwardOutput:
raise NotImplementedError()

return formatted_forward_input_wrapper
Expand Down Expand Up @@ -568,10 +569,11 @@ def _forward(
use_embeddings: bool = True,
attributed_fn_argnames: Optional[List[str]] = None,
*args,
**kwargs,
) -> LogitsTensor:
assert len(args) == len(attributed_fn_argnames), "Number of arguments and number of argnames must match"
target_ids = target_ids.squeeze(-1)
output = self.get_forward_output(batch, use_embeddings=use_embeddings)
output = self.get_forward_output(batch, use_embeddings=use_embeddings, **kwargs)
logger.debug(f"logits: {pretty_tensor(output.logits)}")
step_function_args = self.formatter.format_step_function_args(
attribution_model=self,
Expand All @@ -582,6 +584,19 @@ def _forward(
)
return attributed_fn(**step_function_args)

def _forward_with_output(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
use_embeddings: bool = True,
*args,
**kwargs,
) -> ModelOutput:
return self.get_forward_output(batch, use_embeddings=use_embeddings, **kwargs)

@formatter.format_forward_args
def forward(self, *args, **kwargs) -> LogitsTensor:
return self._forward(*args, **kwargs)

@formatter.format_forward_args
def forward_with_output(self, *args, **kwargs) -> ModelOutput:
return self._forward_with_output(*args, **kwargs)
19 changes: 14 additions & 5 deletions inseq/models/decoder_only.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import torch

Expand All @@ -27,6 +27,8 @@
)
from .attribution_model import AttributionModel, ForwardMethod, InputFormatter, ModelOutput

CustomForwardOutput = TypeVar("CustomForwardOutput")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -165,8 +167,8 @@ def convert_args_to_batch(
return DecoderOnlyBatch(encoding, embedding)

@staticmethod
def format_forward_args(forward: ForwardMethod) -> Callable[..., LogitsTensor]:
@wraps(forward)
def format_forward_args(forward_fn: ForwardMethod) -> Callable[..., CustomForwardOutput]:
@wraps(forward_fn)
def formatted_forward_input_wrapper(
self: "DecoderOnlyAttributionModel",
forward_tensor: AttributionForwardInputs,
Expand All @@ -177,13 +179,16 @@ def formatted_forward_input_wrapper(
use_embeddings: bool = True,
attributed_fn_argnames: Optional[List[str]] = None,
*args,
) -> LogitsTensor:
**kwargs,
) -> CustomForwardOutput:
batch = self.formatter.convert_args_to_batch(
input_ids=input_ids,
attention_mask=attention_mask,
input_embeds=forward_tensor if use_embeddings else None,
)
return self._forward(batch, target_ids, attributed_fn, use_embeddings, attributed_fn_argnames, *args)
return forward_fn(
self, batch, target_ids, attributed_fn, use_embeddings, attributed_fn_argnames, *args, **kwargs
)

return formatted_forward_input_wrapper

Expand Down Expand Up @@ -217,6 +222,10 @@ def get_forward_output(
def forward(self, *args, **kwargs) -> LogitsTensor:
return self._forward(*args, **kwargs)

@formatter.format_forward_args
def forward_with_output(self, *args, **kwargs) -> ModelOutput:
return self._forward_with_output(*args, **kwargs)

def get_encoder(self) -> torch.nn.Module:
raise NotImplementedError("Decoder-only models do not have an encoder.")

Expand Down
19 changes: 14 additions & 5 deletions inseq/models/encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

from ..attr.feat import join_token_ids
from ..data import (
Expand All @@ -25,6 +25,8 @@
)
from .attribution_model import AttributionModel, ForwardMethod, InputFormatter, ModelOutput

CustomForwardOutput = TypeVar("CustomForwardOutput")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -229,8 +231,8 @@ def convert_args_to_batch(
return EncoderDecoderBatch(source_batch, target_batch)

@staticmethod
def format_forward_args(forward: ForwardMethod) -> Callable[..., LogitsTensor]:
@wraps(forward)
def format_forward_args(forward_fn: ForwardMethod) -> Callable[..., CustomForwardOutput]:
@wraps(forward_fn)
def formatted_forward_input_wrapper(
self: "EncoderDecoderAttributionModel",
encoder_tensors: AttributionForwardInputs,
Expand All @@ -244,7 +246,8 @@ def formatted_forward_input_wrapper(
use_embeddings: bool = True,
attributed_fn_argnames: Optional[List[str]] = None,
*args,
) -> LogitsTensor:
**kwargs,
) -> CustomForwardOutput:
batch = self.formatter.convert_args_to_batch(
encoder_input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids,
Expand All @@ -253,7 +256,9 @@ def formatted_forward_input_wrapper(
encoder_input_embeds=encoder_tensors if use_embeddings else None,
decoder_input_embeds=decoder_input_embeds,
)
return self._forward(batch, target_ids, attributed_fn, use_embeddings, attributed_fn_argnames, *args)
return forward_fn(
self, batch, target_ids, attributed_fn, use_embeddings, attributed_fn_argnames, *args, **kwargs
)

return formatted_forward_input_wrapper

Expand Down Expand Up @@ -288,3 +293,7 @@ def get_forward_output(
@formatter.format_forward_args
def forward(self, *args, **kwargs) -> LogitsTensor:
return self._forward(*args, **kwargs)

@formatter.format_forward_args
def forward_with_output(self, *args, **kwargs) -> ModelOutput:
return self._forward_with_output(*args, **kwargs)

0 comments on commit 1637a73

Please sign in to comment.