In [1]:
from tests.test_attention import flash_backward_results, _attention_and_lse, _make_attn_inputs
from cs336_systems.flashattention import FlashAttentionTriton, AttentionPytorch
from cs336_systems.flashattention_triton_autotune import FlashAttentionTritonAutotune
from cs336_systems.flashattention_triton_backward import FlashAttentionTritonBackward
from cs336_systems.flashattention_triton_optimized import FlashAttentionTritonOptimized

# sanity check
# impl_0 = AttentionPytorch.apply
# q_0, k_0, v_0, do_0 = _make_attn_inputs(device='cuda')
# o_0 = impl_0(q_0, k_0, v_0, True)
# o_0.backward(do_0)

# impl_1 = FlashAttentionTritonOptimized.apply
# q_1, k_1, v_1, do_1 = _make_attn_inputs(device='cuda')
# o_1 = impl_1(q_1, k_1, v_1, True)
# o_1.backward(do_1)
#
# import torch
# torch.testing.assert_close(q_0.grad, q_1.grad, rtol=1e-2, atol=1e-2)
# torch.testing.assert_close(k_0.grad, k_1.grad, rtol=1e-2, atol=1e-2)

In [None]:
class FlashAttention2:
    def __init__(self, impl, B_q=16, B_k=16):
        self.impl = impl
        self.B_q = B_q
        self.B_k = B_k
    
    def __call__(self, Q, K, V, is_causal=True):
        return self.impl.apply(Q, K, V, is_causal, self.B_q, self.B_k)

# # Usage:
# flash_small = FlashAttentionPytorch(B_q=8, B_k=8)
# o = flash_small(q, k, v, False)

In [None]:
import torch
import triton
def test_timing_flash_forward_backward(test, impl, n_heads, d_head, sequence_length, dtype=torch.bfloat16, device='cuda', B_q=16, B_k=16):
    q, k, v = torch.randn(
        3, n_heads, sequence_length, d_head, device=device, dtype=dtype, requires_grad=True
    )
    
    flash = torch.compile(FlashAttention2(impl, B_q, B_k))
    # sanity check; it would fail without compiling if precision in triton is not implemented right
    # flash = FlashAttention2(impl, B_q, B_k)
    
    def flash_forward():
        o = flash(q, k, v, True)

    def flash_forward_backward():
        o = flash(q, k, v, True)
        loss = o.sum()
        loss.backward()

    if test == "forward":
        results = triton.testing.do_bench(flash_forward, rep=1000, warmup=1000)
    elif test == "forward_backward":
        results = triton.testing.do_bench(flash_forward_backward, rep=1000, warmup=1000)
    else:
        raise ValueError("Wrong selection.")
    print(results)

In [None]:
test_timing_flash_forward_backward("forward_backward", FlashAttentionTriton, 16, 128, 16384, dtype=torch.bfloat16)

## The leaderboard

In [None]:
import torch
import triton
from cs336_systems.flashattention import FlashAttentionTriton, AttentionPytorch
from cs336_systems.flashattention_triton_autotune import FlashAttentionTritonAutotune
from cs336_systems.flashattention_triton_backward import FlashAttentionTritonBackward
from cs336_systems.flashattention_triton_optimized import FlashAttentionTritonOptimized

def test_timing_flash_forward_backward(impl):
    n_heads = 16
    d_head = 64
    sequence_length = 16384
    q, k, v = torch.randn(
        3, n_heads, sequence_length, d_head, device="cuda", dtype=torch.bfloat16, requires_grad=True
    )
    
    flash = torch.compile(impl.apply)

    def flash_forward_backward():
        o = flash(q, k, v, True)
        loss = o.sum()
        loss.backward()

    results = triton.testing.do_bench(flash_forward_backward, rep=10000, warmup=1000)
    # print(results)
    return results

for impl in [
    AttentionPytorch, FlashAttentionTriton, FlashAttentionTritonAutotune,
    FlashAttentionTritonBackward, FlashAttentionTritonOptimized
]:
    time = test_timing_flash_forward_backward(impl)
    print(f"{impl.__name__}: {time:.3f}")

# output
# AttentionPytorch: 32.250
# FlashAttentionTriton: 39.116
# FlashAttentionTritonAutotune: 27.936
# FlashAttentionTritonBackward: 43.512
# FlashAttentionTritonOptimized: 9.752