diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index dd73019..e8954f6 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -17,6 +17,35 @@ import torch +def block_smooth( + attention_mask: torch.Tensor, + key_len: int, + block_size: int, +): + if block_size <= 0: + raise ValueError(f"block_size must be a positive integer, got {block_size}.") + + if block_size > 1: + full_len = (key_len // block_size) * block_size + + if full_len: + block_view = attention_mask[..., :full_len] + block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) + blocks = block_view.view(*block_shape) + block_counts = blocks.sum(dim=-1).to(torch.int64) + block_keep = (block_counts * 2) > block_size + blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) + + if key_len > full_len: + tail_slice = attention_mask[..., full_len:] + tail_len = tail_slice.shape[-1] + tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) + tail_keep = (tail_counts * 2) > tail_len + tail_slice.copy_(tail_keep.expand_as(tail_slice)) + + return attention_mask + + def topk_mask( attention_bias: torch.Tensor, attention_mask: Optional[torch.Tensor], @@ -42,14 +71,11 @@ def topk_mask( attention_mask (Tensor): The attention mask tensor of shape ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). """ - if block_size is not None: - if int(block_size) != block_size or block_size <= 0: - raise ValueError(f"block_size must be a positive integer, got {block_size}.") - block_size = int(block_size) + attention_bias = attention_bias.detach() attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias topk_values, topk_indices = torch.topk( - attention_bias.to(torch.float), + attention_bias, window_size, dim=-1, largest=True, sorted=False ) attention_mask = torch.zeros_like( @@ -58,22 +84,11 @@ def topk_mask( if block_size is not None and block_size > 1: key_len = attention_mask.shape[-1] - full_len = (key_len // block_size) * block_size - - if full_len: - block_view = attention_mask[..., :full_len] - block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) - blocks = block_view.view(*block_shape) - block_counts = blocks.sum(dim=-1).to(torch.int32) - block_keep = (block_counts * 2) > block_size - blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) - - if key_len > full_len: - tail_slice = attention_mask[..., full_len:] - tail_len = tail_slice.shape[-1] - tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) - tail_keep = (tail_counts * 2) > tail_len - tail_slice.copy_(tail_keep.expand_as(tail_slice)) + attention_mask = block_smooth( + attention_mask=attention_mask, + key_len=key_len, + block_size=block_size + ) return attention_mask @@ -101,33 +116,18 @@ def relu_mask( attention_mask (Tensor): The attention mask tensor of shape ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). """ - if block_size is not None: - if int(block_size) != block_size or block_size <= 0: - raise ValueError(f"block_size must be a positive integer, got {block_size}.") - block_size = int(block_size) - + attention_bias = attention_bias.detach() attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias attention_mask = attention_bias > 0 if block_size is not None and block_size > 1: key_len = attention_mask.shape[-1] - full_len = (key_len // block_size) * block_size - - if full_len: - block_view = attention_mask[..., :full_len] - block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) - blocks = block_view.view(*block_shape) - block_counts = blocks.sum(dim=-1).to(torch.int32) - block_keep = (block_counts * 2) > block_size - blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) - - if key_len > full_len: - tail_slice = attention_mask[..., full_len:] - tail_len = tail_slice.shape[-1] - tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) - tail_keep = (tail_counts * 2) > tail_len - tail_slice.copy_(tail_keep.expand_as(tail_slice)) + attention_mask = block_smooth( + attention_mask=attention_mask, + key_len=key_len, + block_size=block_size + ) return attention_mask