Skip to content

Commit

Permalink
FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules…
Browse files Browse the repository at this point in the history
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
  • Loading branch information
younesbelkada committed Apr 18, 2024
1 parent 005b957 commit 5728b5a
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 75 deletions.
27 changes: 25 additions & 2 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def fuse_awq_modules(model, quantization_config):
else:
raise ValueError("Fusing is only supported for the AutoAWQ backend")

fused_attention_modules = []

for name, module in model.named_modules():
if modules_to_not_convert is not None:
if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
Expand All @@ -241,7 +243,23 @@ def fuse_awq_modules(model, quantization_config):
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)

# Replace attention layers
_fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused)
attention_has_been_fused = _fuse_awq_attention_layers(
model, module, modules_to_fuse, name, QuantAttentionFused
)

if attention_has_been_fused:
fused_attention_modules.append(name.split(".")[0])

# For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
# `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
# by the `AttentionMaskConverter` module.
if len(fused_attention_modules) > 0:
for module_name, module in model.named_modules():
if any(
module_name in fused_attention_modules for fused_attention_parent_module in fused_attention_modules
):
if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"):
module.config._attn_implementation = "custom"
return model


Expand Down Expand Up @@ -332,8 +350,10 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

module_has_been_fused = False

if len(modules_to_fuse["attention"]) == 0:
return
return module_has_been_fused

if hasattr(module, modules_to_fuse["attention"][0]):
# First, we pack the QKV layers together
Expand Down Expand Up @@ -394,6 +414,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device))

del q_proj, k_proj, v_proj, o_proj
module_has_been_fused = True

return module_has_been_fused


def post_init_awq_exllama_modules(model, exllama_config):
Expand Down
99 changes: 63 additions & 36 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,63 @@ def _unmask_unattended(

return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))

@staticmethod
def _ignore_causal_mask_sdpa(
attention_mask: Optional[torch.Tensor],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
) -> bool:
"""
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""

batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

ignore_causal_mask = False

if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
#
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
if (
not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
expected_shape = (batch_size, 1, query_length, key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.

return ignore_causal_mask


def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
Expand Down Expand Up @@ -305,7 +362,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

key_value_length = input_shape[-1] + past_key_values_length
_, query_length = input_shape

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
Expand All @@ -316,41 +372,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

ignore_causal_mask = False

if attention_mask is None:
if (
not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
elif key_value_length == query_length:
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)

if ignore_causal_mask:
expanded_4d_mask = None
Expand Down
36 changes: 27 additions & 9 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -908,9 +911,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -974,24 +975,41 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
36 changes: 27 additions & 9 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -888,9 +891,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -960,24 +961,41 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down

0 comments on commit 5728b5a

Please sign in to comment.