In [10]:
import torch
import torch.nn.functional as F

# Example shapes for tensors
batch_size = 2
num_heads = 4
seq_len = 5
head_dim = 64

# Example input tensors
query_states = torch.randn(batch_size, num_heads, seq_len, head_dim)
key_states = torch.randn(batch_size, num_heads, seq_len, head_dim)
value_states = torch.randn(batch_size, num_heads, seq_len, head_dim)
attention_mask = None

# map to bf16
import torch
query_states = query_states.to(torch.bfloat16).cuda()
key_states = key_states.to(torch.bfloat16).cuda()
value_states = value_states.to(torch.bfloat16).cuda()

# Using scaled_dot_product_attention
attn_output_sdpa = F.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=attention_mask,
    dropout_p=0,  # Example dropout value
    is_causal=False  # Assuming non-causal attention for simplicity
)

# Manual attention implementation
dk = query_states.size(-1)
scores = torch.matmul(query_states, key_states.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32))

# Apply the mask
if attention_mask is not None:
    scores += attention_mask

# Softmax
attention_weights = F.softmax(scores, dim=-1)

# Weighted sum
attn_output_manual = torch.matmul(attention_weights, value_states)

from flash_attn import flash_attn_func
attn_output_flash = flash_attn_func(
    query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2), 0, softmax_scale=None, causal=False
)


In [11]:
attn_output_manual

tensor([[[[ 7.0312e-02, -1.4453e-01,  1.4844e+00,  ..., -1.9727e-01,
            1.4160e-01,  4.1211e-01],
          [ 5.4688e-01,  1.8457e-01,  1.5547e+00,  ...,  4.1016e-02,
           -2.4121e-01,  5.0391e-01],
          [ 8.8281e-01,  7.4219e-02,  1.7500e+00,  ..., -1.3672e-01,
            6.1035e-04,  2.9102e-01],
          [ 5.8594e-01, -3.7695e-01,  1.7891e+00,  ..., -4.0430e-01,
            1.2109e+00, -8.6975e-04],
          [ 1.0312e+00, -4.7070e-01,  1.9922e+00,  ..., -5.3125e-01,
            1.4844e+00, -1.1292e-02]],

         [[-2.2266e-01, -5.4688e-02, -5.2490e-02,  ..., -2.7734e-01,
            1.2256e-01,  6.6406e-01],
          [ 2.9297e-03,  1.6309e-01, -1.4551e-01,  ..., -3.2617e-01,
            6.6406e-02,  4.6875e-01],
          [ 3.2812e-01,  6.4453e-02, -9.3262e-02,  ..., -5.3516e-01,
           -3.1641e-01,  4.9219e-01],
          [ 4.9023e-01,  1.2891e-01, -3.6328e-01,  ..., -2.6758e-01,
           -3.2617e-01,  7.1094e-01],
          [-3.7109e-01, -9.4238e-02

In [15]:
attn_output_sdpa.shape

torch.Size([2, 4, 5, 64])

In [16]:
attn_output_flash.transpose(1,2)-attn_output_sdpa

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  