In [13]:
import ctypes as ct
import numpy as np
import os

lib_path = os.path.join(os.getcwd(), "libflash_attention.so")
lib = ct.CDLL(lib_path)

HalfPtr = ct.POINTER(ct.c_uint16)
FloatPtr = ct.POINTER(ct.c_float)

lib.run_tensor_flash_attention_host_half.argtypes = [
    HalfPtr, HalfPtr, HalfPtr, HalfPtr,
    ct.c_int, ct.c_int, ct.c_int, ct.c_int, ct.c_int,
    ct.c_void_p,
    ct.POINTER(ct.c_float)
]
lib.run_tensor_flash_attention_host_half.restype = None

B, H, L, D, tile = 32, 16, 512, 128, 64
size = B * H * L * D

Q = (np.random.randn(size).astype(np.float16))
Q = Q + 0.01 * (np.random.randn(size).astype(np.float16))
K = (np.random.randn(size).astype(np.float16))
K = K + 0.01 * (np.random.randn(size).astype(np.float16))
V = (np.random.randn(size).astype(np.float16))
V = V + 0.01 * (np.random.randn(size).astype(np.float16))
O = np.zeros(size, dtype=np.float16)

Q_p = Q.ctypes.data_as(HalfPtr)
K_p = K.ctypes.data_as(HalfPtr)
V_p = V.ctypes.data_as(HalfPtr)
O_p = O.ctypes.data_as(HalfPtr)

elapsed = ct.c_float(0.0)
elapsed_p = ct.pointer(elapsed)

lib.run_tensor_flash_attention_host_half(
    Q_p, K_p, V_p, O_p,
    B, H, L, D, tile,
    None,
    elapsed_p
)

print("Elapsed ms:", elapsed.value)

Elapsed ms: 8.304415702819824


In [14]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Q_t = torch.from_numpy(Q.reshape(B, H, L, D).astype(np.float16)).to(device)
K_t = torch.from_numpy(K.reshape(B, H, L, D).astype(np.float16)).to(device)
V_t = torch.from_numpy(V.reshape(B, H, L, D).astype(np.float16)).to(device)

try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.enable_flash_sdp(True)
except Exception:
    pass

O_ref = torch.nn.functional.scaled_dot_product_attention(
    Q_t, K_t, V_t, attn_mask=None, dropout_p=0.0, is_causal=False
)

O_ref_np = O_ref.detach().cpu().numpy().reshape(-1)
O_np = O

mae = np.mean(np.abs(O_ref_np.astype(np.float32) - O_np.astype(np.float32)))

Using device: cuda


In [15]:
import numpy as np

assert 'O_ref_np' in globals() and 'O' in globals(), "Reference or output not found"
O_kernel = O.astype(np.float32)
O_ref32 = O_ref_np.astype(np.float32)

diff = O_ref32 - O_kernel
abs_diff = np.abs(diff)

mae = float(np.mean(abs_diff))
nonzero_mask = np.abs(O_ref32) > 0
if np.any(nonzero_mask):
    mre = float(np.mean(np.abs(diff[nonzero_mask]) / np.abs(O_ref32[nonzero_mask])))
else:
    mre = float('nan')

In [16]:
import torch
import numpy as np

assert 'Q_t' in globals() and 'K_t' in globals() and 'V_t' in globals(), "Run the PyTorch prep cell first"

try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.enable_flash_sdp(True)
except Exception:
    pass

Q16 = Q_t.to(torch.float16)
K16 = K_t.to(torch.float16)
V16 = V_t.to(torch.float16)

O_flash16 = torch.nn.functional.scaled_dot_product_attention(
    Q16, K16, V16, attn_mask=None, dropout_p=0.0, is_causal=False
)

O_flash16_np = O_flash16.detach().cpu().numpy().reshape(-1).astype(np.float16)

O_kernel16 = O.astype(np.float16)
diff16 = O_flash16_np.astype(np.float16).astype(np.float32) - O_kernel16.astype(np.float16).astype(np.float32)
abs_diff16 = np.abs(diff16)

mae16 = float(np.mean(abs_diff16))
ref32 = O_flash16_np.astype(np.float32)
nonzero_mask16 = np.abs(ref32) > 0
mre16 = float(np.mean(np.abs(diff16[nonzero_mask16]) / np.abs(ref32[nonzero_mask16]))) if np.any(nonzero_mask16) else float('nan')

In [17]:
import numpy as np

def flash_attention_numpy(Q_bhl_d, K_bhl_d, V_bhl_d, tile: int, return_fp16: bool = True):
    assert Q_bhl_d.ndim == 4 and K_bhl_d.ndim == 4 and V_bhl_d.ndim == 4
    B, H, L, D = Q_bhl_d.shape
    assert K_bhl_d.shape == (B, H, L, D)
    assert V_bhl_d.shape == (B, H, L, D)
    assert tile > 0 and tile <= L

    Q = Q_bhl_d.astype(np.float32, copy=False)
    K = K_bhl_d.astype(np.float32, copy=False)
    V = V_bhl_d.astype(np.float32, copy=False)

    scale = 1.0 / np.sqrt(float(D))

    O = np.zeros((B, H, L, D), dtype=np.float32)

    m_i = np.full((B, H, L), -np.inf, dtype=np.float32)
    l_i = np.zeros((B, H, L), dtype=np.float32)
    o_i = np.zeros((B, H, L, D), dtype=np.float32)

    for start in range(0, L, tile):
        end = min(start + tile, L)
        K_tile = K[:, :, start:end, :]
        V_tile = V[:, :, start:end, :]

        S = np.einsum('bhld,bhTd->bh lT', Q, K_tile, optimize=True) * scale
        m_ij = np.max(S, axis=-1)
        m_new = np.maximum(m_i, m_ij)
        S_shift = S - m_new[..., None]
        P = np.exp(S_shift)
        alpha = np.exp(m_i - m_new)
        l_new = alpha * l_i + np.sum(P, axis=-1)
        PV = np.einsum('bh lT,bhTd->bh ld', P, V_tile, optimize=True)
        o_new = alpha[..., None] * o_i + PV
        m_i = m_new
        l_i = l_new
        o_i = o_new

    O = o_i / l_i[..., None]

    if return_fp16:
        return O.astype(np.float16)
    return O

try:
    assert 'Q' in globals() and 'K' in globals() and 'V' in globals()
    B_, H_, L_, D_ = B, H, L, D
    Q4 = Q.reshape(B_, H_, L_, D_)
    K4 = K.reshape(B_, H_, L_, D_)
    V4 = V.reshape(B_, H_, L_, D_)

    tile_size = 64
    O_np_flash16 = flash_attention_numpy(Q4, K4, V4, tile=tile_size, return_fp16=True)
    O_np_flash32 = O_np_flash16.astype(np.float32)

    O_kernel16 = O.astype(np.float16)
    O_kernel32 = O_kernel16.astype(np.float32)

    diff_k = O_np_flash32.reshape(-1) - O_kernel32.reshape(-1)
    abs_diff_k = np.abs(diff_k)
    mae_k = float(np.mean(abs_diff_k))
    ref32_k = O_kernel32.reshape(-1)
    nz_k = np.abs(ref32_k) > 0
    mre_k = float(np.mean(np.abs(diff_k[nz_k]) / np.abs(ref32_k[nz_k]))) if np.any(nz_k) else float('nan')

    print('NumPy vs Kernel:')
    print(f'  MAE:   {mae_k}')
    print(f'  MRE:   {mre_k}')

    if 'O_flash16_np' in globals():
        O_torch32 = O_flash16_np.astype(np.float32)
        diff_t = O_np_flash32.reshape(-1) - O_torch32.reshape(-1)
        abs_diff_t = np.abs(diff_t)
        mae_t = float(np.mean(abs_diff_t))
        ref32_t = O_torch32.reshape(-1)
        nz_t = np.abs(ref32_t) > 0
        mre_t = float(np.mean(np.abs(diff_t[nz_t]) / np.abs(ref32_t[nz_t]))) if np.any(nz_t) else float('nan')

        print('NumPy vs PyTorch Flash SDP:')
        print(f'  MAE:   {mae_t}')
        print(f'  MRE:   {mre_t}')
    else:
        print('PyTorch Flash SDP output not found; run the PyTorch Flash cell first to compare.')
except Exception as e:
    print('Skip comparisons until inputs are prepared:', e)

NumPy vs Kernel:
  MAE:   5.7072058552876115e-05
  MRE:   0.006194089539349079
NumPy vs PyTorch Flash SDP:
  MAE:   1.0230122825305443e-05
  MRE:   0.0016731568612158298
NumPy vs PyTorch Flash SDP:
  MAE:   1.0230122825305443e-05
  MRE:   0.0016731568612158298


In [None]:


import numpy as np

summary = {}

if 'O_np_flash32' in globals():
    O_kernel32 = O.astype(np.float32)
    diff_kn = O_kernel32.reshape(-1) - O_np_flash32.reshape(-1)
    ad_kn = np.abs(diff_kn)
    summary['kernel_vs_numpy'] = {
        'MAE': float(np.mean(ad_kn)),
        'MRE': float(np.mean(np.abs(diff_kn[np.abs(O_kernel32.reshape(-1))>0]) / np.abs(O_kernel32.reshape(-1)[np.abs(O_kernel32.reshape(-1))>0]))) if np.any(np.abs(O_kernel32.reshape(-1))>0) else float('nan'),
    }

if 'O_flash16_np' in globals():
    O_torch32 = O_flash16_np.astype(np.float32)
    diff_kt = O_kernel32.reshape(-1) - O_torch32.reshape(-1)
    ad_kt = np.abs(diff_kt)
    summary['kernel_vs_torch'] = {
        'MAE': float(np.mean(ad_kt)),
        'MRE': float(np.mean(np.abs(diff_kt[np.abs(O_torch32.reshape(-1))>0]) / np.abs(O_torch32.reshape(-1)[np.abs(O_torch32.reshape(-1))>0]))) if np.any(np.abs(O_torch32.reshape(-1))>0) else float('nan'),
    }

if 'O_np_flash32' in globals() and 'O_flash16_np' in globals():
    O_torch32 = O_flash16_np.astype(np.float32)
    diff_nt = O_np_flash32.reshape(-1) - O_torch32.reshape(-1)
    ad_nt = np.abs(diff_nt)
    summary['numpy_vs_torch'] = {
        'MAE': float(np.mean(ad_nt)),
        'MRE': float(np.mean(np.abs(diff_nt[np.abs(O_torch32.reshape(-1))>0]) / np.abs(O_torch32.reshape(-1)[np.abs(O_torch32.reshape(-1))>0]))) if np.any(np.abs(O_torch32.reshape(-1))>0) else float('nan'),
    }
    
for k, v in summary.items():
    print(k)
    for mk, mv in v.items():
        print(f"  {mk}: {mv}")

kernel_vs_numpy
  MAE: 5.7072058552876115e-05
  MRE: 0.006194089539349079
kernel_vs_torch
  MAE: 6.050314550520852e-05
  MRE: 0.007585969753563404
numpy_vs_torch
  MAE: 1.0230122825305443e-05
  MRE: 0.0016731568612158298
