# `F.multi_head_attention_forward`

Source (pytorch/torch/nn/functional.py):
- https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L6244-L6695

In [None]:
def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Optional[Tensor],
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
    average_attn_weights: bool = True,
    is_causal: bool = False,
) -> tuple[Tensor, Optional[Tensor]]:
    tens_ops = (
        query,
        key,
        value,
        in_proj_weight,
        in_proj_bias,
        bias_k,
        bias_v,
        out_proj_weight,
        out_proj_bias,
    )
    if has_torch_function(tens_ops):
        return handle_torch_function(
            multi_head_attention_forward,
            tens_ops,
            query,
            key,
            value,
            embed_dim_to_check,
            num_heads,
            in_proj_weight,
            in_proj_bias,
            bias_k,
            bias_v,
            add_zero_attn,
            dropout_p,
            out_proj_weight,
            out_proj_bias,
            training=training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            is_causal=is_causal,
            use_separate_proj_weight=use_separate_proj_weight,
            q_proj_weight=q_proj_weight,
            k_proj_weight=k_proj_weight,
            v_proj_weight=v_proj_weight,
            static_k=static_k,
            static_v=static_v,
            average_attn_weights=average_attn_weights,
        )

    is_batched = _mha_shape_check(
        query, key, value, key_padding_mask, attn_mask, num_heads
    )

    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
    # is batched, run the computation and before returning squeeze the
    # batch dimension so that the output doesn't carry this temporary batch dimension.
    if not is_batched:
        # unsqueeze if the input is unbatched
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape

    key_padding_mask = _canonical_mask(
        mask=key_padding_mask,
        mask_name="key_padding_mask",
        other_type=_none_or_dtype(attn_mask),
        other_name="attn_mask",
        target_type=query.dtype,
    )

    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
            "You may use the Transformer module method "
            "`generate_square_subsequent_mask` to create this mask."
        )

    if is_causal and key_padding_mask is None and not need_weights:
        # when we have a kpm or need weights, we need attn_mask
        # Otherwise, we use the is_causal hint go as is_causal
        # indicator to SDPA.
        attn_mask = None
    else:
        attn_mask = _canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=None,
            other_name="",
            target_type=query.dtype,
            check_other=False,
        )

        if key_padding_mask is not None:
            # We have the attn_mask, and use that to merge kpm into it.
            # Turn off use of is_causal hint, as the merged mask is no
            # longer causal.
            is_causal = False

    if embed_dim != embed_dim_to_check:
        raise AssertionError(
            f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
        )
    if isinstance(embed_dim, torch.Tensor):
        # embed_dim can be a tensor when JIT tracing
        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
    else:
        head_dim = embed_dim // num_heads
    if head_dim * num_heads != embed_dim:
        raise AssertionError(
            f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
        )
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        if key.shape[:2] != value.shape[:2]:
            raise AssertionError(
                f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
            )
    else:
        if key.shape != value.shape:
            raise AssertionError(
                f"key shape {key.shape} does not match value shape {value.shape}"
            )

    #
    # compute in-projection
    #
    if not use_separate_proj_weight:
        if in_proj_weight is None:
            raise AssertionError(
                "use_separate_proj_weight is False but in_proj_weight is None"
            )
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        if q_proj_weight is None:
            raise AssertionError(
                "use_separate_proj_weight is True but q_proj_weight is None"
            )
        if k_proj_weight is None:
            raise AssertionError(
                "use_separate_proj_weight is True but k_proj_weight is None"
            )
        if v_proj_weight is None:
            raise AssertionError(
                "use_separate_proj_weight is True but v_proj_weight is None"
            )
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        q, k, v = _in_projection(
            query,
            key,
            value,
            q_proj_weight,
            k_proj_weight,
            v_proj_weight,
            b_q,
            b_k,
            b_v,
        )

    # prep attention mask

    if attn_mask is not None:
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                )
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                )
        else:
            raise RuntimeError(
                f"attn_mask's dimension {attn_mask.dim()} is not supported"
            )

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        if static_k is not None:
            raise AssertionError("bias cannot be added to static key.")
        if static_v is not None:
            raise AssertionError("bias cannot be added to static value.")
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            # pyrefly: ignore [bad-argument-type]
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            # pyrefly: ignore [bad-argument-type]
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        if bias_k is not None:
            raise AssertionError("bias_k is set but bias_v is None")
        if bias_v is not None:
            raise AssertionError("bias_v is set but bias_k is None")

    #
    # reshape q, k, v for multihead attention and make them batch first
    #
    # pyrefly: ignore [no-matching-overload]
    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if static_k is None:
        # pyrefly: ignore [no-matching-overload]
        k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        if static_k.size(0) != bsz * num_heads:
            raise AssertionError(
                f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
            )
        if static_k.size(2) != head_dim:
            raise AssertionError(
                f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
            )
        k = static_k
    if static_v is None:
        # pyrefly: ignore [no-matching-overload]
        v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        if static_v.size(0) != bsz * num_heads:
            raise AssertionError(
                f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
            )
        if static_v.size(2) != head_dim:
            raise AssertionError(
                f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
            )
        v = static_v

    # add zero attention along batch dimension (now first)
    if add_zero_attn:
        zero_attn_shape = (bsz * num_heads, 1, head_dim)
        k = torch.cat(
            # pyrefly: ignore [no-matching-overload]
            [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
            dim=1,
        )
        v = torch.cat(
            # pyrefly: ignore [no-matching-overload]
            [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
            dim=1,
        )
        if attn_mask is not None:
            # pyrefly: ignore [bad-argument-type]
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            # pyrefly: ignore [bad-argument-type]
            key_padding_mask = pad(key_padding_mask, (0, 1))

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _check_key_padding_mask(key_padding_mask, src_len, bsz)

        key_padding_mask = (
            key_padding_mask.view(bsz, 1, 1, src_len)
            .expand(-1, num_heads, -1, -1)
            .reshape(bsz * num_heads, 1, src_len)
        )
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = attn_mask + key_padding_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #

    if need_weights:
        _B, _Nt, E = q.shape
        q_scaled = q * math.sqrt(1.0 / float(E))

        if is_causal and attn_mask is None:
            raise AssertionError("FIXME: is_causal not implemented for need_weights")

        if attn_mask is not None:
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, k.transpose(-2, -1)
            )
        else:
            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        attn_output_weights = softmax(attn_output_weights, dim=-1)
        if dropout_p > 0.0:
            attn_output_weights = dropout(attn_output_weights, p=dropout_p)

        attn_output = torch.bmm(attn_output_weights, v)

        attn_output = (
            # pyrefly: ignore [no-matching-overload]
            attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
        )
        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

        # optionally average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        if average_attn_weights:
            attn_output_weights = attn_output_weights.mean(dim=1)

        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
            attn_output_weights = attn_output_weights.squeeze(0)
        return attn_output, attn_output_weights
    else:
        # attn_mask can be either (L,S) or (N*num_heads, L, S)
        # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
        # in order to match the input for SDPA of (N, num_heads, L, S)
        if attn_mask is not None:
            if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
                attn_mask = attn_mask.unsqueeze(0)
            else:
                attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)

        q = q.view(bsz, num_heads, tgt_len, head_dim)
        # pyrefly: ignore [no-matching-overload]
        k = k.view(bsz, num_heads, src_len, head_dim)
        # pyrefly: ignore [no-matching-overload]
        v = v.view(bsz, num_heads, src_len, head_dim)

        attn_output = scaled_dot_product_attention(
            q, k, v, attn_mask, dropout_p, is_causal
        )
        attn_output = (
            # pyrefly: ignore [no-matching-overload]
            attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
        )

        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
        return attn_output, None

## Excerpt: three masking controls

In [None]:
    key_padding_mask = _canonical_mask(
        mask=key_padding_mask,
    )
    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
        )
    if is_causal and key_padding_mask is None and not need_weights:
        attn_mask = None
    else:
        attn_mask = _canonical_mask(
            mask=attn_mask,
        )
        if key_padding_mask is not None:
            is_causal = False
    if key_padding_mask is not None:
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = attn_mask + key_padding_mask


## Excerpt: math implementation or CUDA-optimized SDPA kernels


In [None]:
    if need_weights:
        # q_scaled = q / sqrt(E)
        q_scaled = q * math.sqrt(1.0 / float(E))
        if attn_mask is not None:
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, k.transpose(-2, -1)
            )
        else:
            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        # A = softmax((QK^T + mask) / sqrt(E))
        attn_output_weights = softmax(attn_output_weights, dim=-1)
        if dropout_p > 0.0:
            attn_output_weights = dropout(attn_output_weights, p=dropout_p)
        # O = A V
        attn_output = torch.bmm(attn_output_weights, v)
        return attn_output, attn_output_weights
    else:
        attn_output = scaled_dot_product_attention(
            q, k, v, attn_mask, dropout_p, is_causal
        )
        return attn_output, None

## Supplement: SDPA

Source (pytorch/aten/src/ATen/native/transformers
/attention.cpp):
- https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L718C1-L848

### Pseudocode
```
function scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa):
    
    # Step 1: Select backend
    backend ‚Üê choose_best_backend(device, inputs)
    
    # Step 2: Dispatch to backend
    switch backend:
        case CUDNN:
            return cudnn_sdpa(query, key, value, attn_mask, ...)
        case FLASH:
            return flash_sdpa(query, key, value, ...)
        case EFFICIENT:
            return efficient_sdpa(query, key, value, attn_mask, ...)
        case MATH:
            return math_sdpa(query, key, value, attn_mask, ...)  // fallback
```