In [30]:
import torch
import inspect
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer
import types
from typing import Optional, Union, Unpack, Any

# torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_mem_efficient_sdp(False)

model_name = "Qwen/Qwen3-1.7B"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="cuda",   # or "auto"
    attn_implementation="flash_attention_2",
    trust_remote_code=True
)

print(model)


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [32]:
def test_inference():
    prompt = "Explain FlashAttention in one sentence."
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=32
        )
    
    print(tokenizer.decode(out[0], skip_special_tokens=True))

In [33]:
test_inference()

Explain FlashAttention in one sentence. FlashAttention is a memory-efficient attention mechanism that enables large-scale training of transformer models by using a flash-like algorithm to process attention queries in a more efficient way,


In [34]:
from transformers.models.qwen3.modeling_qwen3 import ALL_ATTENTION_FUNCTIONS

In [40]:
org_attn_forward = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]

In [47]:
def patched_flash_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:

    print("\n===== FlashAttention Backend Debug =====")

    print("module:", type(module))
    print("dropout:", dropout)
    print("scaling:", scaling)
    print("sliding_window:", sliding_window)
    print("softcap:", softcap)

    print("\n--- Tensor inputs ---")
    print("query:", query.shape, query.dtype, query.device)
    print("key:  ", key.shape, key.dtype, key.device)
    print("value:", value.shape, value.dtype, value.device)

    if attention_mask is None:
        print("attention_mask: None")
    else:
        print(
            "attention_mask:",
            attention_mask.shape,
            attention_mask.dtype,
            attention_mask.device,
        )

    print("\n--- kwargs ---")
    if not kwargs:
        print("kwargs: (empty)")
    else:
        for k, v in kwargs.items():
            if torch.is_tensor(v):
                print(f"{k}: Tensor(shape={v.shape}, dtype={v.dtype}, device={v.device})")
            else:
                print(f"{k}: {v} (type={type(v)})")

    print("===== End Debug =====\n")
    attn_output, _ = org_attn_forward(module, query, key, value, attention_mask, dropout, scaling, sliding_window, softcap, **kwargs)
    return attn_output, None

ALL_ATTENTION_FUNCTIONS['my_flash'] = patched_flash_attention_forward
model.config._attn_implementation = 'my_flash'

In [48]:
test_inference()


===== FlashAttention Backend Debug =====
module: <class 'transformers.models.qwen3.modeling_qwen3.Qwen3Attention'>
dropout: 0.0
scaling: 0.08838834764831845
sliding_window: None
softcap: None

--- Tensor inputs ---
query: torch.Size([1, 16, 8, 128]) torch.float16 cuda:0
key:   torch.Size([1, 8, 8, 128]) torch.float16 cuda:0
value: torch.Size([1, 8, 8, 128]) torch.float16 cuda:0
attention_mask: None

--- kwargs ---
position_ids: Tensor(shape=torch.Size([1, 8]), dtype=torch.int64, device=cuda:0)
use_cache: True (type=<class 'bool'>)
===== End Debug =====


===== FlashAttention Backend Debug =====
module: <class 'transformers.models.qwen3.modeling_qwen3.Qwen3Attention'>
dropout: 0.0
scaling: 0.08838834764831845
sliding_window: None
softcap: None

--- Tensor inputs ---
query: torch.Size([1, 16, 8, 128]) torch.float16 cuda:0
key:   torch.Size([1, 8, 8, 128]) torch.float16 cuda:0
value: torch.Size([1, 8, 8, 128]) torch.float16 cuda:0
attention_mask: None

--- kwargs ---
position_ids: Tensor

In [None]:
flash_attention_2: <function flash_attention_forward at 0x78aab3d299e0>
def flash_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
        logger.warning_once(
            "`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
            " Please set your attention to `eager` if you want any of these features."
        )

    # This is before the transpose
    seq_len = query.shape[2]

    if any(dim == 0 for dim in query.shape):
        raise ValueError(
            "Tensor query has shape  with a zero dimension.\n"
            "FlashAttention does not support inputs with dim=0.\n"
            "Please check your input shapes or use SDPA instead."
        )
    # FA2 uses non-transposed inputs
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    # In PEFT, usually we cast the layer norms in float32 for training stability reasons
    # therefore the input hidden states gets silently casted in float32. Hence, we need
    # cast them back in the correct dtype just to be sure everything works as expected.
    # This might slowdown training & inference so it is recommended to not cast the LayerNorms
    # in fp32. (usually our RMSNorm modules handle it correctly)
    target_dtype = None
    if query.dtype == torch.float32:
        if torch.is_autocast_enabled():
            target_dtype = torch.get_autocast_gpu_dtype()
        # Handle the case where the model is quantized
        elif hasattr(module.config, "_pre_quantization_dtype"):
            target_dtype = module.config._pre_quantization_dtype
        else:
            target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

    # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
    is_causal = kwargs.pop("is_causal", None)
    if is_causal is None:
        is_causal = module.is_causal

    attn_output = _flash_attention_forward(
        query,
        key,
        value,
        attention_mask,
        query_length=seq_len,
        is_causal=is_causal,
        dropout=dropout,
        softmax_scale=scaling,
        sliding_window=sliding_window,
        softcap=softcap,
        use_top_left_mask=_use_top_left_mask,
        target_dtype=target_dtype,
        attn_implementation=module.config._attn_implementation,
        layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
        **kwargs,
    )

    return attn_output, None

sdpa: <function sdpa_attention_forward at 0x78aab3c22160>
def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
        logger.warning_once(
            "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
            " Please set your attention to `eager` if you want any of these features."
        )
    sdpa_kwargs = {}
    if hasattr(module, "num_key_value_groups"):
        if not use_gqa_in_sdpa(attention_mask, key):
            key = repeat_kv(key, module.num_key_value_groups)
            value = repeat_kv(value, module.num_key_value_groups)
        else:
            sdpa_kwargs = {"enable_gqa": True}

    if attention_mask is not None and attention_mask.ndim == 4:
        attention_mask = attention_mask[:, :, :, : key.shape[-2]]

    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
    # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
    if is_causal is None:
        # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
        # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
        is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)

    # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
    # We convert it to a bool for the SDPA kernel that only accepts bools.
    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator，
    # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
    # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
    if _is_torch_npu_available:
        if attention_mask is not None and attention_mask.dtype != torch.bool:
            # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
            attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
        **sdpa_kwargs,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None
