diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 4c2c509..3a8598f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -88,11 +88,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) @@ -561,19 +562,19 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 256, 256, 32, False), (1, 2, 1, 512, 512, 32, True), (1, 2, 1, 512, 512, 32, False), - (1, 2, 1, 1024, 1024, 32, True), # some -Inf and Inf in dbias, Idk why + (1, 2, 1, 1024, 1024, 32, True), (1, 2, 1, 1024, 1024, 32, False), (1, 2, 1, 2048, 2048, 32, True), (1, 2, 1, 2048, 2048, 32, False), - (1, 2, 1, 4096, 4096, 32, True), # some NAN in dbias, Idk why + (1, 2, 1, 4096, 4096, 32, True), (1, 2, 1, 4096, 4096, 32, False), (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 128, 128, 64, False), - (1, 2, 1, 256, 256, 64, True), # some NAN in dbias, Idk why + (1, 2, 1, 256, 256, 64, True), (1, 2, 1, 256, 256, 64, False), (1, 2, 1, 512, 512, 64, True), (1, 2, 1, 512, 512, 64, False), - (1, 2, 1, 1024, 1024, 64, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 64, False), (1, 2, 1, 2048, 2048, 64, True), (1, 2, 1, 2048, 2048, 64, False), @@ -585,26 +586,26 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 256, 256, 96, False), (1, 2, 1, 512, 512, 96, True), (1, 2, 1, 512, 512, 96, False), - (1, 2, 1, 1024, 1024, 96, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 96, False), (1, 2, 1, 2048, 2048, 96, True), (1, 2, 1, 2048, 2048, 96, False), (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - (1, 2, 1, 128, 128, 128, True), # some NAN in dbias, Idk why + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), (1, 2, 1, 512, 512, 128, False), - (1, 2, 1, 1024, 1024, 128, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 128, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 128, False), (1, 2, 1, 2048, 2048, 128, True), (1, 2, 1, 2048, 2048, 128, False), (1, 2, 1, 4096, 4096, 128, True), (1, 2, 1, 4096, 4096, 128, False), - # Not support head_dim > 128 yet in sm 80 + # Not support head_dim > 128 in sm80 yet # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 128, False), # (1, 2, 1, 256, 256, 256, True), diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index b3b0883..0035147 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -88,11 +88,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) @@ -518,28 +519,70 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 3, 512, 128, True), - (1, 2, 1, 1, 512, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -672,27 +715,71 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + # Not support head_dim > 128 in triton yet + # (1, 2, 1, 128, 128, 128, True), + # (1, 2, 1, 128, 128, 128, False), + # (1, 2, 1, 256, 256, 256, True), + # (1, 2, 1, 256, 256, 256, False), + # (1, 2, 1, 512, 512, 256, True), + # (1, 2, 1, 512, 512, 256, False), + # (1, 2, 1, 1024, 1024, 256, True), + # (1, 2, 1, 1024, 1024, 256, False), + # (1, 2, 1, 2048, 2048, 256, True), + # (1, 2, 1, 2048, 2048, 256, False), + # (1, 2, 1, 4096, 4096, 256, True), + # (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -843,27 +930,70 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): # Test configurations for Flex Attention test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -1051,13 +1181,13 @@ def main(): print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍") test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'triton']: - # print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥") - # test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'triton']: + print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥") + test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'flex']: - # print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") - # test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'flex']: + print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") + test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) # Print overall summary diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 8ba8ec0..95c2920 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -110,11 +110,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 9c72926..a146602 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -218,31 +218,29 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutMask{} + typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( sMask.data() + size(sMask), - typename Kernel_traits::SmemLayoutBias{} + typename Kernel_traits::SmemLayoutAtomPS{} ); // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; - auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; - auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; + auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) - Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) - Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -267,10 +265,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -298,10 +296,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("\n"); // } // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -354,13 +352,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN @@ -460,13 +458,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -558,13 +556,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -845,31 +843,29 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutMask{} + typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( sMask.data() + size(sMask), - typename Kernel_traits::SmemLayoutBias{} + typename Kernel_traits::SmemLayoutAtomPS{} ); // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; - auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; - auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; + auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) - Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) - Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) + Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) + Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -891,10 +887,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -907,10 +903,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M, BLK_N) -> (blk_m, blk_n) Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M, BLK_N) -> (blk_m, blk_n) // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); @@ -947,13 +943,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN @@ -1074,13 +1070,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -1190,13 +1186,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 89feeae..d15298e 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -155,85 +155,96 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32); + constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32)); run_flash_splitkv_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - run_flash_fwd, Is_causal>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 176 * 1024) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 128) is 27% slower for seqlen=2k - // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 224 * 1024) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } - } else { + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 160 * 1024) { run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); } template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 192 * 1024) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + // For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. + // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } template @@ -249,12 +260,8 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - - // For A100, we want to run with 64 x 64 (112KB smem). - // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); + if (max_smem_per_block >= 224 * 1024) { + run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); } diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index f7a38d2..3319f71 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -73,6 +73,7 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + static constexpr int kSwizzlePS = 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, @@ -89,18 +90,11 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{} ) ); - using SmemLayoutAtomMask = decltype( + using SmemLayoutAtomPS = decltype( composition( - Swizzle{}, - Layout, - Stride<_8, _1>>{} - ) - ); - using SmemLayoutAtomBias = decltype( - composition( - Swizzle{}, - Layout, - Stride<_8, _1>>{} + Swizzle{}, + Layout, Int>, + Stride, _1>>{} ) ); @@ -127,20 +121,13 @@ struct Flash_fwd_kernel_traits : public Base { ); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - using SmemLayoutMask = decltype( - tile_to_shape( - SmemLayoutAtomMask{}, - Shape, Int>{} - ) - ); - using SmemCopyAtomMask = Copy_Atom, Element>; - using SmemLayoutBias = decltype( + using SmemLayoutPS = decltype( tile_to_shape( - SmemLayoutAtomBias{}, + SmemLayoutAtomPS{}, Shape, Int>{} ) ); - using SmemCopyAtomBias = Copy_Atom, Element>; + using SmemCopyAtomPS = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -162,8 +149,8 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); - static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); + static constexpr int kSmemMaskSize = size(SmemLayoutPS{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutPS{}) * sizeof(Element); // Shared memory size with QKV matrices and mask/bias matrices static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; @@ -196,20 +183,13 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyMask = decltype( - make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>{} - ) - ); // Val layout, 4 vals per read - using GmemTiledCopyBias = decltype( + using GmemTiledCopyMaskBias = decltype( make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom, Element>{}, GmemLayoutAtom{}, - Layout>{} + Layout>{} ) - ); // Val layout, 4 vals per read + ); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy( Copy_Atom, Element>{},