## Notebook to debug structure of QuantMultiheadAttention in order to get it working

### everything has been copy pasted from brevitas.nn.quant_mha

In [1]:
import math
from typing import Optional, Tuple, Union
import warnings

from packaging import version
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.functional as F
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_

from brevitas import torch_version
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn.utils import check_tensors_same_ptr
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int32Bias
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
from brevitas.quant_tensor import QuantTensor


  return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [69]:
#######copy pasted from quant_mha.py 

from logging import getLogger

class QuantMultiheadAttention(Module):
    """"
    Args:
        embed_dim: Total dimension of the model.
        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
            Default: ``False``.
        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    """

    def __init__(
            self,
            embed_dim,
            num_heads,
            dropout=0.,
            bias=True,
            add_bias_kv=False,
            add_zero_attn=False,
            kdim=None,
            vdim=None,
            packed_in_proj=True,
            in_proj_input_quant=Int8ActPerTensorFloat,
            in_proj_weight_quant=Int8WeightPerTensorFloat,
            in_proj_bias_quant=Int32Bias,
            softmax_input_quant=None,
            attn_output_weights_quant=Uint8ActPerTensorFloat,
            q_scaled_quant=Int8ActPerTensorFloat,
            k_transposed_quant=Int8ActPerTensorFloat,
            v_quant=Int8ActPerTensorFloat,
            out_proj_input_quant=Int8ActPerTensorFloat,
            out_proj_weight_quant=Int8WeightPerTensorFloat,
            out_proj_bias_quant=Int32Bias,
            out_proj_output_quant=None,
            batch_first=False,
            return_quant_tensor=False,
            device=None,
            dtype=None,
            **kwargs) -> None:
        super(QuantMultiheadAttention, self).__init__()
        
        self.log = getLogger(self.__class__.__name__)
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        def filter_kwargs(prefix):
            return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

        if self._qkv_same_embed_dim and packed_in_proj:
            self.in_proj = QuantLinear(
                out_features=3 * embed_dim,
                in_features=embed_dim,
                bias=bias,
                input_quant=in_proj_input_quant,
                weight_quant=in_proj_weight_quant,
                bias_quant=in_proj_bias_quant,
                device=device,
                dtype=dtype,
                **filter_kwargs('in_proj_'))
            self.q_proj = self.k_proj = self.v_proj = None
        else:
            self.q_proj = QuantLinear(
                out_features=embed_dim,
                in_features=embed_dim,
                bias=bias,
                input_quant=in_proj_input_quant,
                weight_quant=in_proj_weight_quant,
                bias_quant=in_proj_bias_quant,
                device=device,
                dtype=dtype,
                **filter_kwargs('in_proj_'))
            self.k_proj = QuantLinear(
                out_features=embed_dim,
                in_features=self.kdim,
                bias=bias,
                input_quant=in_proj_input_quant,
                weight_quant=in_proj_weight_quant,
                bias_quant=in_proj_bias_quant,
                device=device,
                dtype=dtype,
                **filter_kwargs('in_proj_'))
            self.v_proj = QuantLinear(
                out_features=embed_dim,
                in_features=self.vdim,
                bias=bias,
                input_quant=in_proj_input_quant,
                weight_quant=in_proj_weight_quant,
                bias_quant=in_proj_bias_quant,
                device=device,
                dtype=dtype,
                **filter_kwargs('in_proj_'))
            self.in_proj = None

        # Keep compatibility with this regression between 1.6.0 and 1.8.2, where bias is always enabled
        # https://github.com/pytorch/pytorch/issues/52257
        out_proj_bias = bias or (version.parse('1.8.2') >= torch_version >= version.parse('1.6.0'))

        self.out_proj = QuantLinear(
            embed_dim,
            embed_dim,
            bias=out_proj_bias,
            input_quant=out_proj_input_quant,
            weight_quant=out_proj_weight_quant,
            bias_quant=out_proj_bias_quant,
            output_quant=out_proj_output_quant,
            return_quant_tensor=return_quant_tensor,
            device=device,
            dtype=dtype,
            **filter_kwargs('out_proj_'))

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), device=device, dtype=dtype))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), device=device, dtype=dtype))
        else:
            self.bias_k = self.bias_v = None

        self.softmax_input_quant = QuantIdentity(
            act_quant=softmax_input_quant, **filter_kwargs('softmax_input_'))
        self.attn_output_weights_quant = QuantIdentity(
            act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_'))
        self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_'))
        self.k_transposed_quant = QuantIdentity(
            act_quant=k_transposed_quant, **filter_kwargs('k_transposed_'))
        self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_'))

        self.add_zero_attn = add_zero_attn
        self._reset_parameters()

    def _reset_parameters(self):
        if self.in_proj is not None:
            xavier_uniform_(self.in_proj.weight)
            if self.in_proj.bias is not None:
                constant_(self.in_proj.bias, 0.)
        else:
            xavier_uniform_(self.q_proj.weight)
            xavier_uniform_(self.k_proj.weight)
            xavier_uniform_(self.v_proj.weight)
            if self.q_proj.bias is not None:
                constant_(self.q_proj.bias, 0.)
            if self.k_proj.bias is not None:
                constant_(self.k_proj.bias, 0.)
            if self.v_proj.bias is not None:
                constant_(self.v_proj.bias, 0.)

        if self.out_proj.bias is not None:
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def mha_shape_check(
            self,
            query: Union[Tensor, QuantTensor],
            key: Union[Tensor, QuantTensor],
            value: Union[Tensor, QuantTensor],
            key_padding_mask: Optional[Tensor],
            attn_mask: Optional[Tensor],
            num_heads: int):
        # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
        # and returns if the input is batched or not.
        # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.

        # Shape check.
        if query.dim() == 3:
            # Batched Inputs
            is_batched = True
            assert key.dim() == 3 and value.dim() == 3, \
                ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
                 f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
            if key_padding_mask is not None:
                assert key_padding_mask.dim() == 2, \
                    ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
                     f" but found {key_padding_mask.dim()}-D tensor instead")
            if attn_mask is not None:
                assert attn_mask.dim() in (2, 3), \
                    ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
                     f" but found {attn_mask.dim()}-D tensor instead")
        elif query.dim() == 2:
            # Unbatched Inputs
            is_batched = False
            assert key.dim() == 2 and value.dim() == 2, \
                ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
                 f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")

            if key_padding_mask is not None:
                assert key_padding_mask.dim() == 1, \
                    ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
                     f" but found {key_padding_mask.dim()}-D tensor instead")

            if attn_mask is not None:
                assert attn_mask.dim() in (2, 3), \
                    ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
                     f" but found {attn_mask.dim()}-D tensor instead")
                if attn_mask.dim() == 3:
                    expected_shape = (num_heads, query.shape[0], key.shape[0])
                    assert attn_mask.shape == expected_shape, \
                        (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
        else:
            raise AssertionError(
                f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
            )

        return is_batched

    def multi_head_attention(
            self,
            query: Union[Tensor, QuantTensor],
            key: Union[Tensor, QuantTensor],
            value: Union[Tensor, QuantTensor],
            embed_dim_to_check: int,
            num_heads: int,
            bias_k: Optional[Tensor],
            bias_v: Optional[Tensor],
            add_zero_attn: bool,
            dropout_p: float,
            training: bool = True,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            use_separate_proj_weight: bool = False,
            static_k: Optional[Tensor] = None,
            static_v: Optional[Tensor] = None,
            average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        Args:
            query, key, value: map a query and a set of key-value pairs to an output.
                See "Attention Is All You Need" for more details.
            embed_dim_to_check: total dimension of the model.
            num_heads: parallel attention heads.
            in_proj_weight, in_proj_bias: input projection weight and bias.
            bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
            add_zero_attn: add a new batch of zeros to the key and
                           value sequences at dim=1.
            dropout_p: probability of an element to be zeroed.
            out_proj_weight, out_proj_bias: the output projection weight and bias.
            training: apply dropout if is ``True``.
            key_padding_mask: if provided, specified padding elements in the key will
                be ignored by the attention. This is an binary mask. When the value is True,
                the corresponding value on the attention layer will be filled with -inf.
            need_weights: output attn_output_weights.
            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
            use_separate_proj_weight: the function accept the proj. weights for query, key,
                and value in different forms. If false, in_proj_weight will be used, which is
                a combination of q_proj_weight, k_proj_weight, v_proj_weight.
            q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
            static_k, static_v: static key and value used for attention operators.
            average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
                Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
                when ``need_weights=True.``. Default: True


        Shape:
            Inputs:
            - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
              the embedding dimension.
            - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
              the embedding dimension.
            - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
              the embedding dimension.
            - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
              If a FloatTensor is provided, it will be directly added to the value.
              If a BoolTensor is provided, the positions with the
              value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
              3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
              S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
              positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
              while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
              are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
              is provided, it will be added to the attention weight.
            - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
              N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
            - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
              N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.

            Outputs:
            - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension.
            - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
              attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
              :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
              :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
              head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
        """
        "! comments starting with ! are from me"
        "is_batched is always True if q,k,v are three dimensional, false otherwise"
        is_batched = self.mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
        log.critical(f'info test logging')
        log.critical(f'critical test logging')
        log.warning(f'warn test logging')
        # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
        # is batched, run the computation and before returning squeeze the
        # batch dimension so that the output doesn't carry this temporary batch dimension.
        if not is_batched:
            # unsqueeze if the input is unbatched
            query = query.unsqueeze(1)
            key = key.unsqueeze(1)
            value = value.unsqueeze(1)
            if key_padding_mask is not None:
                key_padding_mask = key_padding_mask.unsqueeze(0)
        
        self.log.critical(f'sizes after batch checking: {query.size()}, {key.size()}, {value.size()}')

        # set up shape vars
        "! bsz is probably batch size, why is that the case if input was transposed in forward pass"
        tgt_len, bsz, embed_dim = query.shape
        src_len, _, _ = key.shape
        if key_padding_mask is not None:
            self.log.critical(f'key_padding_mask is: {key_padding_mask}')
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported")
        assert embed_dim == embed_dim_to_check, \
            f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
        #! calculate head shape of embedding dimension
        if isinstance(embed_dim, torch.Tensor):
            # embed_dim can be a tensor when JIT tracing
            head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
        else:
            head_dim = embed_dim // num_heads
        assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
        if use_separate_proj_weight:
            # allow MHA to have different embedding dimensions when separate projection weights are used
            """! 
            key not transposed yet -> same shape as query hopefully
            after attention layer, tensor should have the feature length of 

            """
            assert key.shape[:2] == value.shape[:2], \
                f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
        else:
            assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

        #
        # compute in-projection
        #
        def chunk(x, num=3, dim=-1):
            _len, _bsz, _dim = x.shape
            "! make 4d tensor of 3d tensor, same shape but add num parameter between context and embedding size. why?"
            x = x.reshape(_len, _bsz, num, dim)
            return x[:, :, 0, :], x[:, :, 1, :], x[:, :, 2, :]
        

        if self.in_proj is not None:
            self.log.critical(f'self.in_proj is not None')
            if check_tensors_same_ptr([key, query, value]):
                self.log.critical(f'q,k,v are the same tensor')
                # Mark dimensions through named tensors.
                if not torch._C._get_tracing_state():
                    if isinstance(query, QuantTensor):
                        query.value.rename_('L', 'N', 'E')
                    else:
                        query.rename_('L', 'N', 'E')
                # self-attention
                # q, k, v = self.in_proj(query).chunk(3, dim=-1)
                q, k, v = chunk(self.in_proj(query), num=3, dim=-1)
                self.log.critical(f'q,k,v have been chunked q:{q.size()}, k: {k.size()}, v: {v.size()}')

            else:
                raise RuntimeError(
                    "Packed in_proj is supported only for self-attention with k is v is q. Set packed_in_proj=False."
                )
        else:
            self.log.critical(f'self.in_proj is None')
            assert self.q_proj is not None, "use_separate_proj_weight is True but q_proj is None"
            assert self.k_proj is not None, "use_separate_proj_weight is True but k_proj is None"
            assert self.v_proj is not None, "use_separate_proj_weight is True but v_proj is None"
            # Mark dimensions through named tensors.
            if not torch._C._get_tracing_state():
                for t in [query, key, value]:
                    if isinstance(t, QuantTensor):
                        t.value.rename_('L', 'N', 'E')
                    else:
                        t.rename_('L', 'N', 'E')
            q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)

        # Remove names to avoid errors downstream
        if not torch._C._get_tracing_state():
            for t in [q, k, v]:
                t.rename_(None)

        # prep attention mask
        log.critical(f'attn_mask is: {type(attn_mask)}')
        if attn_mask is not None:
            if attn_mask.dtype == torch.uint8:
                warnings.warn(
                    "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
                )
                attn_mask = attn_mask.to(torch.bool)
            else:
                assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                    f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
            # ensure attn_mask's dim is 3
            if attn_mask.dim() == 2:
                correct_2d_size = (tgt_len, src_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                    )
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                correct_3d_size = (bsz * num_heads, tgt_len, src_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                    )
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

        # add bias along batch dimension (currently second)
        if bias_k is not None and bias_v is not None:
            assert static_k is None, "bias cannot be added to static key."
            assert static_v is None, "bias cannot be added to static value."
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = F.pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = F.pad(key_padding_mask, (0, 1))
        else:
            assert bias_k is None
            assert bias_v is None

        #
        # reshape q, k, v for multihead attention and make em batch first
        #
        "! batch size and feature length dimension are swapped, why?"
        "! check if batch_first has been performed during forward pass and q,k,v have been transposed if batch_first = False"
        q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
        if static_k is None:
            k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(1,2)#(0, 1)
        else:
            # TODO finish disentangling control flow so we don't do in-projections when statics are passed
            assert static_k.size(0) == bsz * num_heads, \
                f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
            assert static_k.size(2) == head_dim, \
                f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
            k = static_k
        if static_v is None:
            "! unsure why v should be transposed"
            v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim)#.transpose(0, 1)
        else:
            # TODO finish disentangling control flow so we don't do in-projections when statics are passed
            assert static_v.size(0) == bsz * num_heads, \
                f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
            assert static_v.size(2) == head_dim, \
                f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
            v = static_v

        
        # add zero attention along batch dimension (now first)
        log.critical(f'add_zero_attn is: {type(add_zero_attn)}')
        if add_zero_attn:
            zero_attn_shape = (bsz * num_heads, 1, head_dim)
            k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=0)#dim=1)
            v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=0)#dim=1)
            if attn_mask is not None:
                attn_mask = F.pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = F.pad(key_padding_mask, (0, 1))

        # update source sequence length after adjustments
        src_len = k.size(1)

        # merge key padding and attention masks
        if key_padding_mask is not None:
            assert key_padding_mask.shape == (bsz, src_len), \
                f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
            key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
                expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
            if attn_mask is None:
                attn_mask = key_padding_mask
            elif attn_mask.dtype == torch.bool:
                attn_mask = attn_mask.logical_or(key_padding_mask)
            else:
                attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

        # convert mask to float
        if attn_mask is not None and attn_mask.dtype == torch.bool:
            new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
            new_attn_mask.masked_fill_(attn_mask, float("-inf"))
            attn_mask = new_attn_mask
        log.critical(f'attn_mask: {attn_mask}. size:{attn_mask.size() if attn_mask is not None else "None"}')
        # adjust dropout probability
        if not training:
            dropout_p = 0.0

        #
        # (deep breath) calculate attention and out projection
        #

        B, Nt, E = q.shape
        "! only q is divided, not matrix multiplication of q and k"
        "! should square root be quantized in some way?"
        q_scaled = q / math.sqrt(E)
        k_transposed = k.transpose(-2, -1)
        #log.critical(f'before transposed: {}')
        # Quantize q_scaled and k_transposed
        q_scaled = self.q_scaled_quant(q_scaled)
        k_transposed = self.k_transposed_quant(k_transposed).transpose(-2, -1)
        self.log.critical(f'q_scaled: {q_scaled.size()}, k_transposed: {k_transposed.size()}')
        if attn_mask is not None:
            attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k_transposed)
        #!--------------------------------------------
        else:
            attn_output_weights = torch.bmm(q_scaled, k_transposed)

        # Quantize the input to softmax, if any
        attn_output_weights = self.softmax_input_quant(attn_output_weights)

        attn_output_weights = F.softmax(attn_output_weights, dim=-1)
        if dropout_p > 0.0:
            attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)

        # Quantize attn_output_weights and value
        attn_output_weights = self.attn_output_weights_quant(attn_output_weights)
        v = self.v_quant(v)

        attn_output = torch.bmm(attn_output_weights, v)
        # preserve the 3D input compared to the float version to be able to do row wise scaling
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        # Set dim names for PTQ algorithms that requires it
        if not torch._C._get_tracing_state():
            attn_output.rename_('L', 'N', 'E')
        attn_output = self.out_proj(attn_output)
        # Remove names to avoid errors un unsupported downstream ops
        if not torch._C._get_tracing_state():
            if isinstance(attn_output, QuantTensor):
                attn_output.value.rename_(None)
            else:
                attn_output.rename_(None)

        if need_weights:
            # optionally average attention weights over heads
            attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
            if average_attn_weights:
                attn_output_weights = attn_output_weights.sum(dim=1) / num_heads

            if not is_batched:
                # squeeze the output if input was unbatched
                attn_output = attn_output.squeeze(1)
                attn_output_weights = attn_output_weights.squeeze(0)
            return attn_output, attn_output_weights
        else:
            if not is_batched:
                # squeeze the output if input was unbatched
                attn_output = attn_output.squeeze(1)
            return attn_output, None

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
    Args:
        query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
            or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
            :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
            Queries are compared against key-value pairs to produce the output.
            See "Attention Is All You Need" for more details.
        key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
            or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
            :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
            See "Attention Is All You Need" for more details.
        value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
            ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
            sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
            See "Attention Is All You Need" for more details.
        key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
            to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
            Binary and byte masks are supported.
            For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
            the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
        need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
            Default: ``True``.
        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
            Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
            corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
            corresponding position is not allowed to attend. For a float mask, the mask values will be added to
            the attention weight.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
            heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
            effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)

    Outputs:
        - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
          :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
          where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
          embedding dimension ``embed_dim``.
        - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
          returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
          head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.

        .. note::
            `batch_first` argument is ignored for unbatched inputs.
        """
        is_batched = query.dim() == 3
        if any([hasattr(t, 'is_nested') and t.is_nested for t in (query, key, value)]):
            raise RuntimeError("Nested inputs not supported for quantization.")
        log.critical(f'before dim=0: {query.size()}')
        if self.batch_first and is_batched:
            log.critical(f'dim=0 is batch size')
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
        log.critical(f'{query.size()}')
        attn_output, attn_output_weights = self.multi_head_attention(
            query=query,
            key=key,
            value=value,
            embed_dim_to_check=self.embed_dim,
            num_heads=self.num_heads,
            bias_k=self.bias_k,
            bias_v=self.bias_v,
            add_zero_attn=self.add_zero_attn,
            dropout_p=self.dropout,
            training=self.training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            average_attn_weights=average_attn_weights)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

    def _load_from_state_dict(
            self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
            error_msgs):

        def set_bias(proj_name, value):
            bias_name = f'{prefix}{proj_name}_proj.bias'
            state_dict[bias_name] = value

        def set_weight(proj_name, value):
            key = f'{prefix}{proj_name}_proj.weight'
            state_dict[key] = value

        for name, value in list(state_dict.items()):
            if prefix + 'in_proj_weight' in name:
                if self.in_proj is not None:
                    set_weight('in', value)
                # We might have set packed_in_proj=False, which is absent in the original float implementation
                else:
                    if not value.size(0) % 3 == 0:
                        raise RuntimeError("in_proj dim 0 doesn't divide evenly into 3 tensors.")
                    q_proj, k_proj, v_proj = torch.chunk(value, 3, dim=0)
                    set_weight('q', q_proj)
                    set_weight('k', k_proj)
                    set_weight('v', v_proj)
                del state_dict[name]
            elif prefix + 'q_proj_weight' in name:
                set_weight('q', value)
                del state_dict[name]
            elif prefix + 'k_proj_weight' in name:
                set_weight('k', value)
                del state_dict[name]
            elif prefix + 'v_proj_weight' in name:
                set_weight('v', value)
                del state_dict[name]
            elif prefix + 'in_proj_bias' in name:
                if self.in_proj is not None:
                    set_bias('in', value)
                else:
                    q_proj, k_proj, v_proj = torch.chunk(value, 3, dim=0)
                    set_bias('q', q_proj)
                    set_bias('k', k_proj)
                    set_bias('v', v_proj)
                del state_dict[name]
        super(QuantMultiheadAttention, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

        


In [70]:
#size: 3,16,64
#number of heads: 2
#-> 64 / 2 = 32
#simulate values from embedding, skip positional encoding
wte = torch.nn.Embedding(16,64)
tokens = torch.randint(16, (3,16))
embeddings = wte(tokens)
embeddings.size()
attn_batchfirst = QuantMultiheadAttention(num_heads=2, embed_dim=64,batch_first=True)
attn_no_batchfirst = QuantMultiheadAttention(num_heads=2, embed_dim=64,batch_first=False)

In [61]:
embeddings.transpose(0,1).size()

torch.Size([16, 3, 64])

size of q and k is [32,3,32] \
matrix of 32x32 makes sense as embedding dimension is split in half by the number of heads \
but the batch size (3) is the second dimension instead of the first \
the q and k matrices also have the size of the embedding dimension instead of the context \
if q_scaled only needs to perform scaling, why do they have to be transposed?

In [72]:
q,k,v = [embeddings for _ in range(3)]
#quantmha probably assumes shape to always be (seq, batch, features)
#from torch mha: batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).
#in forward pass, if batch_first=True and dim=3, then input is transposed to have dim=0 be context and dim=1 to have batch
attn_batchfirst(q,k,v)

before dim=0: torch.Size([3, 16, 64])
dim=0 is batch size
torch.Size([16, 3, 64])
info test logging
critical test logging
warn test logging
sizes after batch checking: torch.Size([16, 3, 64]), torch.Size([16, 3, 64]), torch.Size([16, 3, 64])
self.in_proj is not None
q,k,v are the same tensor
q,k,v have been chunked q:torch.Size([16, 3, 64]), k: torch.Size([16, 3, 64]), v: torch.Size([16, 3, 64])
attn_mask is: <class 'NoneType'>
add_zero_attn is: <class 'bool'>
attn_mask: None. size:None
q_scaled: torch.Size([6, 16, 32]), k_transposed: torch.Size([16, 32, 6])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [6, 32] but got: [16, 32].

In [32]:
#testing in_proj and chunking
#for some reason, chunking constructs a 4d tensor and then retrieves three new tensors
#in_proj reshapes the tensor at the embedding dimension, new shape: embd_dim * 3
in_projected = attn_no_batchfirst.in_proj(embeddings)
print(in_projected.size())
def chunk(x, num=3, dim=-1):
    _len, _bsz, _dim = x.shape
    "! make 4d tensor of 3d tensor, same shape but add num parameter between context and embedding size. why?"
    x = x.reshape(_len, _bsz, num, dim)
    return x[:, :, 0, :], x[:, :, 1, :], x[:, :, 2, :]
chunks = chunk(in_projected)[0]
#batches are removed
q,k,v = chunks
print(q.size(), k.size(), v.size())
chunks.equal(embeddings)


torch.Size([3, 16, 192])
torch.Size([16, 64]) torch.Size([16, 64]) torch.Size([16, 64])


False

In [None]:
"""
    forward pass of mha is delegated
    q,k,v are expected
    
"""
attn_output, attn_output_weights = self.multi_head_attention(
    query=query,
    key=key,
    value=value,
    embed_dim_to_check=self.embed_dim,
    num_heads=self.num_heads,
    bias_k=self.bias_k,
    bias_v=self.bias_v,
    add_zero_attn=self.add_zero_attn,
    dropout_p=self.dropout,
    training=self.training,
    key_padding_mask=key_padding_mask,
    need_weights=need_weights,
    attn_mask=attn_mask,
    average_attn_weights=average_attn_weights)

## test if multiheadattention of pytorch works 

In [49]:
torch_mha = torch.nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)

In [52]:
#works, who would have guessed
embeddings = embeddings.rename(None)
attn, weights = torch_mha(embeddings,embeddings,embeddings)
attn.size()

torch.Size([3, 16, 64])

In [None]:
#torch mha source code without comments (https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention)
class MultiheadAttention(Module):

    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        if embed_dim <= 0 or num_heads <= 0:
            raise ValueError(
                f"embed_dim and num_heads must be greater than 0,"
                f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
            )
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if not self._qkv_same_embed_dim:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super().__setstate__(state)

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True,
            is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
        why_not_fast_path = ''
        if ((attn_mask is not None and torch.is_floating_point(attn_mask))
           or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
            why_not_fast_path = "floating-point masks are not supported for fast path."

        is_batched = query.dim() == 3

        key_padding_mask = F._canonical_mask(
            mask=key_padding_mask,
            mask_name="key_padding_mask",
            other_type=F._none_or_dtype(attn_mask),
            other_name="attn_mask",
            target_type=query.dtype
        )

        attn_mask = F._canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=None,
            other_name="",
            target_type=query.dtype,
            check_other=False,
        )


        if not is_batched:
            why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
        elif query is not key or key is not value:
            # When lifting this restriction, don't forget to either
            # enforce that the dtypes all match or test cases where
            # they don't!
            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
        elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
        elif self.in_proj_weight is None:
            why_not_fast_path = "in_proj_weight was None"
        elif query.dtype != self.in_proj_weight.dtype:
            # this case will fail anyway, but at least they'll get a useful error message.
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
        elif self.training:
            why_not_fast_path = "training is enabled"
        elif (self.num_heads % 2) != 0:
            why_not_fast_path = "self.num_heads is not even"
        elif not self.batch_first:
            why_not_fast_path = "batch_first was not True"
        elif self.bias_k is not None:
            why_not_fast_path = "self.bias_k was not None"
        elif self.bias_v is not None:
            why_not_fast_path = "self.bias_v was not None"
        elif self.add_zero_attn:
            why_not_fast_path = "add_zero_attn was enabled"
        elif not self._qkv_same_embed_dim:
            why_not_fast_path = "_qkv_same_embed_dim was not True"
        elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
            why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
                                 is not supported with NestedTensor input"
        elif torch.is_autocast_enabled():
            why_not_fast_path = "autocast is enabled"

        if not why_not_fast_path:
            tensor_args = (
                query,
                key,
                value,
                self.in_proj_weight,
                self.in_proj_bias,
                self.out_proj.weight,
                self.out_proj.bias,
            )
            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_fast_path = "some Tensor argument has_torch_function"
            elif _is_make_fx_tracing():
                why_not_fast_path = "we are running make_fx tracing"
            elif not all(_check_arg_device(x) for x in tensor_args):
                why_not_fast_path = ("some Tensor argument's device is neither one of "
                                     f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
            elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
                why_not_fast_path = ("grad is enabled and at least one of query or the "
                                     "input/output projection weights or biases requires_grad")
            if not why_not_fast_path:
                merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)

                if self.in_proj_bias is not None and self.in_proj_weight is not None:
                    return torch._native_multi_head_attention(
                        query,
                        key,
                        value,
                        self.embed_dim,
                        self.num_heads,
                        self.in_proj_weight,
                        self.in_proj_bias,
                        self.out_proj.weight,
                        self.out_proj.bias,
                        merged_mask,
                        need_weights,
                        average_attn_weights,
                        mask_type)

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = (x.transpose(1, 0) for x in (query, key))
                    value = key
            else:
                query, key, value = (x.transpose(1, 0) for x in (query, key, value))

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal)
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights


    def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
                    query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
        r"""
        Determine mask type and combine masks if necessary. If only one mask is provided, that mask
        and the corresponding mask type will be returned. If both masks are provided, they will be both
        expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
        and mask type 2 will be returned
        Args:
            attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
            key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
            query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
        Returns:
            merged_mask: merged mask
            mask_type: merged mask type (0, 1, or 2)
        """
        mask_type: Optional[int] = None
        merged_mask: Optional[Tensor] = None

        if key_padding_mask is not None:
            mask_type = 1
            merged_mask = key_padding_mask

        if attn_mask is not None:
            # In this branch query can't be a nested tensor, so it has a shape
            batch_size, seq_len, _ = query.shape
            mask_type = 2

            # Always expands attn_mask to 4D
            if attn_mask.dim() == 3:
                attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
            else:  # attn_mask.dim() == 2:
                attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
            merged_mask = attn_mask_expanded

            if key_padding_mask is not None:
                key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
                merged_mask = attn_mask_expanded + key_padding_mask_expanded

        # no attn_mask and no key_padding_mask, returns None, None
        return merged_mask, mask_type