In [3]:
import torch
def prepare_attention_mask(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype = torch.float32,
        # device: torch.device,
        min_dtype: float = -1.0,
        # cache_position: torch.Tensor,
        batch_size: int =1,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to place the 4D attention mask on.
            min_dtype (`float`):
                The minimum value representable with the dtype `dtype`.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        
        causal_mask = torch.full(
            (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, 
        )
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
            # diagonal=1 excludes the main diagonal, 
            # so bottom left triangle (including main) now is filled with 0,
            # top right (excluding main) is filled with min_dtype
        # causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

        return causal_mask

In [7]:
s = [1,1,1,0,1]
seq_length =len(s)

attention_mask = torch.tensor(s).reshape(1, seq_length)
mask4d = prepare_attention_mask(
    attention_mask=attention_mask,
    sequence_length=seq_length,
    target_length=seq_length,   
)
print(mask4d)


tensor([[[[ 0., -1., -1., -1., -1.],
          [ 0.,  0., -1., -1., -1.],
          [ 0.,  0.,  0., -1., -1.],
          [ 0.,  0.,  0., -1., -1.],
          [ 0.,  0.,  0., -1.,  0.]]]])
