Skip to content

Commit

Permalink
Fast softmax (huggingface#972)
Browse files Browse the repository at this point in the history
Co-authored-by: Dudi Lester <160421192+dudilester@users.noreply.github.com>
Co-authored-by: Sayantan Sarkar <sasarkar@habana.ai>
  • Loading branch information
3 people authored and imangohari1 committed Jun 12, 2024
1 parent 482a174 commit d26b9ae
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2",
"allowlist": {"types": [], "names": []},
"blocklist": {"types": [], "names": []},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
)
parser.add_argument(
"--flash_attention_fast_softmax",
action="store_true",
help="Whether to enable Habana Flash Attention in fast softmax mode.",
)
parser.add_argument(
"--book_source",
action="store_true",
Expand Down
3 changes: 3 additions & 0 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, tokenizer, model, args, options):
self.model_inputs.update(
{
"attn_softmax_bf16": self.options.attn_softmax_bf16,
"use_flash_attention": self.options.use_flash_attention,
"flash_attention_recompute": self.options.flash_attention_recompute,
"flash_attention_causal_mask": self.options.flash_attention_causal_mask,
}
)
if args.warmup:
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code

return generation_config
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable recompute if use Habana flash attention.
flash_attention_causal_mask (`bool`, *optional*):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -51,4 +53,5 @@ def __init__(self, **kwargs):
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,9 @@ def generate(
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
model_kwargs["flash_attention_fast_softmax"] = (
True if generation_config.flash_attention_fast_softmax else False
)
model_kwargs["num_virtual_tokens"] = num_virtual_tokens

if not self.config.is_encoder_decoder:
Expand Down
43 changes: 37 additions & 6 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
import warnings
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -221,6 +222,16 @@ def gaudi_llama_repeat_kv(
return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -278,6 +289,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

Expand Down Expand Up @@ -325,6 +337,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand All @@ -339,6 +352,7 @@ def pre_attn_forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
- add new arg num_virtual_tokens
"""
bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -434,22 +448,27 @@ def pre_attn_forward(
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

else:
Expand Down Expand Up @@ -543,6 +562,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand All @@ -556,6 +576,7 @@ def forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand All @@ -577,6 +598,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
Expand Down Expand Up @@ -610,6 +632,7 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
num_virtual_tokens: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -628,6 +651,7 @@ def pre_attn(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -714,6 +738,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand All @@ -727,6 +752,7 @@ def forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
- add new arg lazy_mode
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -842,6 +868,7 @@ def forward(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
None,
)
else:
Expand All @@ -859,6 +886,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -932,6 +960,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -963,6 +992,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
Expand Down Expand Up @@ -1103,6 +1133,7 @@ def prepare_inputs_for_generation(
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
Expand Down

0 comments on commit d26b9ae

Please sign in to comment.