In [1]:
import ctypes as ct
import numpy as np
import os

# Load the shared library (adjust path if needed)
lib_path = os.path.join(os.getcwd(), "libflash_attention.so")
lib = ct.CDLL(lib_path)

# Define types
# __half* â†’ use POINTER(c_uint16) (binary-compatible with float16)
HalfPtr = ct.POINTER(ct.c_uint16)
FloatPtr = ct.POINTER(ct.c_float)

# Function signatures (extern "C" in wrapper.cuh)
# void run_tensor_flash_attention_host_half(
#     const __half* hQ, const __half* hK, const __half* hV, __half* hO,
#     int B, int H, int L, int D, int tile, cudaStream_t stream = 0, float* elapsed_ms = nullptr);
lib.run_tensor_flash_attention_host_half.argtypes = [
    HalfPtr, HalfPtr, HalfPtr, HalfPtr,
    ct.c_int, ct.c_int, ct.c_int, ct.c_int, ct.c_int,
    ct.c_void_p,  # stream (nullptr)
    ct.POINTER(ct.c_float)  # elapsed_ms (nullable)
]
lib.run_tensor_flash_attention_host_half.restype = None

# Prepare inputs
B, H, L, D, tile = 32, 16, 512, 128, 64
size = B * H * L * D

# Create numpy float16 buffers
Q = (np.random.randn(size).astype(np.float16))
K = (np.random.randn(size).astype(np.float16))
V = (np.random.randn(size).astype(np.float16))
O = np.zeros(size, dtype=np.float16)

# Get ctypes pointers (uint16 underlying storage)
Q_p = Q.ctypes.data_as(HalfPtr)
K_p = K.ctypes.data_as(HalfPtr)
V_p = V.ctypes.data_as(HalfPtr)
O_p = O.ctypes.data_as(HalfPtr)

# elapsed time capture (optional)
elapsed = ct.c_float(0.0)
elapsed_p = ct.pointer(elapsed)

# Call the function (stream = None/0)
lib.run_tensor_flash_attention_host_half(
    Q_p, K_p, V_p, O_p,
    B, H, L, D, tile,
    None,
    elapsed_p
)

print("Elapsed ms:", elapsed.value)
print("Output sample:", O[:10])

Elapsed ms: 7.799744129180908
Output sample: [-0.0359   -0.06085  -0.002062 -0.0999   -0.1879   -0.015526  0.1666
  0.06696   0.03867  -0.1951  ]


In [2]:
import torch

# Ensure PyTorch uses GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Reshape the flat buffers into (B, H, L, D)
Q_t = torch.from_numpy(Q.reshape(B, H, L, D).astype(np.float16)).to(device)
K_t = torch.from_numpy(K.reshape(B, H, L, D).astype(np.float16)).to(device)
V_t = torch.from_numpy(V.reshape(B, H, L, D).astype(np.float16)).to(device)

# Compute scaled dot-product attention: softmax((Q @ K^T) / sqrt(D)) @ V
scale = 1.0 / np.sqrt(D)
# (B,H,L,D) x (B,H,D,L) -> (B,H,L,L)
scores = torch.matmul(Q_t.to(torch.float32), K_t.transpose(-1, -2).to(torch.float32)) * scale
attn = torch.softmax(scores, dim=-1)
O_ref = torch.matmul(attn, V_t.to(torch.float32))  # (B,H,L,D) in float32
O_ref = O_ref.to(torch.float16)

# Flatten to match O's layout
O_ref_np = O_ref.detach().cpu().numpy().reshape(-1)

# Compare against kernel output O (numpy float16 flat)
O_np = O  # already numpy float16 flat

# Compute mean absolute error
mae = np.mean(np.abs(O_ref_np.astype(np.float32) - O_np.astype(np.float32)))

Using device: cuda


In [3]:
import numpy as np

# Ensure we have reference and output arrays
assert 'O_ref_np' in globals() and 'O' in globals(), "Reference or output not found"
O_kernel = O.astype(np.float32)
O_ref32 = O_ref_np.astype(np.float32)

# Elementwise differences
diff = O_ref32 - O_kernel
abs_diff = np.abs(diff)

# Metrics
mae = float(np.mean(abs_diff))
# Mean relative error (avoid divide-by-zero: mask zeros in reference)
nonzero_mask = np.abs(O_ref32) > 0
if np.any(nonzero_mask):
    mre = float(np.mean(np.abs(diff[nonzero_mask]) / np.abs(O_ref32[nonzero_mask])))
else:
    mre = float('nan')


In [4]:
import torch
import numpy as np

# Ensure tensors on CUDA and in FP16
assert 'Q_t' in globals() and 'K_t' in globals() and 'V_t' in globals(), "Run the PyTorch prep cell first"

# Enable PyTorch Flash SDP when available
try:
    torch.backends.cuda.matmul.allow_tf32 = True  # optional for perf
    torch.backends.cuda.enable_flash_sdp(True)
except Exception:
    pass

# Use scaled_dot_product_attention which routes to FlashAttention on supported GPUs
# Input shapes: (B, H, L, D)
Q16 = Q_t.to(torch.float16)
K16 = K_t.to(torch.float16)
V16 = V_t.to(torch.float16)

# PyTorch API expects (B,H,L,D) and will compute attention along L
O_flash16 = torch.nn.functional.scaled_dot_product_attention(
    Q16, K16, V16, attn_mask=None, dropout_p=0.0, is_causal=False
)

# Flatten and move to CPU
O_flash16_np = O_flash16.detach().cpu().numpy().reshape(-1).astype(np.float16)

# Compare FP16 directly (kernel output O is float16)
O_kernel16 = O.astype(np.float16)
diff16 = O_flash16_np.astype(np.float16).astype(np.float32) - O_kernel16.astype(np.float16).astype(np.float32)
abs_diff16 = np.abs(diff16)

mae16 = float(np.mean(abs_diff16))
# Mean relative error in FP16 domain (compute in float32 to avoid underflow)
ref32 = O_flash16_np.astype(np.float32)
nonzero_mask16 = np.abs(ref32) > 0
mre16 = float(np.mean(np.abs(diff16[nonzero_mask16]) / np.abs(ref32[nonzero_mask16]))) if np.any(nonzero_mask16) else float('nan')


In [5]:
import numpy as np

# NumPy FlashAttention (blockwise, stable softmax) implementation
# Contract:
# - Inputs: Q, K, V shaped (B, H, L, D), dtype float16/float32
# - tile: integer tile length along sequence L
# - Returns: O shaped (B, H, L, D), dtype float32 (optionally cast to float16)
# - Behavior: computes softmax(QK^T / sqrt(D)) @ V in a numerically stable, streaming manner

def flash_attention_numpy(Q_bhl_d, K_bhl_d, V_bhl_d, tile: int, return_fp16: bool = True):
    assert Q_bhl_d.ndim == 4 and K_bhl_d.ndim == 4 and V_bhl_d.ndim == 4, "Expect (B,H,L,D)"
    B, H, L, D = Q_bhl_d.shape
    assert K_bhl_d.shape == (B, H, L, D)
    assert V_bhl_d.shape == (B, H, L, D)
    assert tile > 0 and tile <= L

    # Work in float32 for stability; we'll cast at the end if requested
    Q = Q_bhl_d.astype(np.float32, copy=False)
    K = K_bhl_d.astype(np.float32, copy=False)
    V = V_bhl_d.astype(np.float32, copy=False)

    scale = 1.0 / np.sqrt(float(D))

    # Output accumulator (float32)
    O = np.zeros((B, H, L, D), dtype=np.float32)

    # We compute attention per query position i, streaming over key/value tiles.
    # Maintain per-(b,h,i) running max m_i and normalization factor l_i, plus value accumulator o_i.
    # Reference: FlashAttention algorithm (Dao et al.).

    # Initialize m_i (max logits), l_i (sum of exp shifted), and o_i (accumulated values)
    m_i = np.full((B, H, L), -np.inf, dtype=np.float32)
    l_i = np.zeros((B, H, L), dtype=np.float32)
    o_i = np.zeros((B, H, L, D), dtype=np.float32)

    # Iterate over K/V tiles along sequence
    for start in range(0, L, tile):
        end = min(start + tile, L)
        K_tile = K[:, :, start:end, :]   # (B,H,T,D)
        V_tile = V[:, :, start:end, :]   # (B,H,T,D)

        # Compute scores for this tile: (B,H,L,D) @ (B,H,D,T) -> (B,H,L,T)
        # We'll batch-matmul via einsum to avoid huge temp memory
        # s_ij = (Q_i dot K_j) * scale
        S = np.einsum('bhld,bhTd->bh lT', Q, K_tile, optimize=True) * scale  # shapes: l=D axis name; T=end-start

        # Current tile-wise max per query index i
        m_ij = np.max(S, axis=-1)  # (B,H,L)

        # Combine with running max
        m_new = np.maximum(m_i, m_ij)  # (B,H,L)

        # Compute exp of shifted scores wrt new max: exp(S - m_new[..., None])
        S_shift = S - m_new[..., None]
        P = np.exp(S_shift)  # (B,H,L,T)

        # l_new = exp(m_i - m_new) * l_i + sum_j exp(S - m_new)
        alpha = np.exp(m_i - m_new)  # (B,H,L)
        l_new = alpha * l_i + np.sum(P, axis=-1)  # (B,H,L)

        # o_new = (alpha * o_i) + sum_j P_ij * V_j
        # sum over tile positions j: (B,H,L,T) with (B,H,T,D) -> (B,H,L,D)
        PV = np.einsum('bh lT,bhTd->bh ld', P, V_tile, optimize=True)
        o_new = alpha[..., None] * o_i + PV  # (B,H,L,D)

        # Update running states
        m_i = m_new
        l_i = l_new
        o_i = o_new

    # Final output: O_i = o_i / l_i
    O = o_i / l_i[..., None]

    if return_fp16:
        return O.astype(np.float16)
    return O

# If Q,K,V,O,B,H,L,D are available from previous cells, run comparisons
try:
    assert 'Q' in globals() and 'K' in globals() and 'V' in globals(), 'Run the first cell to define Q,K,V.'
    B_, H_, L_, D_ = B, H, L, D
    Q4 = Q.reshape(B_, H_, L_, D_)
    K4 = K.reshape(B_, H_, L_, D_)
    V4 = V.reshape(B_, H_, L_, D_)

    # Choose a tile size (match kernel tile if known)
    tile_size = 64

    # NumPy FlashAttention output
    O_np_flash16 = flash_attention_numpy(Q4, K4, V4, tile=tile_size, return_fp16=True)
    O_np_flash32 = O_np_flash16.astype(np.float32)

    # Kernel output and PyTorch reference prepared earlier
    O_kernel16 = O.astype(np.float16)
    O_kernel32 = O_kernel16.astype(np.float32)

    # Compare NumPy FlashAttention vs kernel (FP16 domain measured in fp32)
    diff_k = O_np_flash32.reshape(-1) - O_kernel32.reshape(-1)
    abs_diff_k = np.abs(diff_k)
    mae_k = float(np.mean(abs_diff_k))
    ref32_k = O_kernel32.reshape(-1)
    nz_k = np.abs(ref32_k) > 0
    mre_k = float(np.mean(np.abs(diff_k[nz_k]) / np.abs(ref32_k[nz_k]))) if np.any(nz_k) else float('nan')


    # Compare NumPy FlashAttention vs PyTorch Flash SDP if available
    if 'O_flash16_np' in globals():
        O_torch32 = O_flash16_np.astype(np.float32)
        diff_t = O_np_flash32.reshape(-1) - O_torch32.reshape(-1)
        abs_diff_t = np.abs(diff_t)
        mae_t = float(np.mean(abs_diff_t))
        ref32_t = O_torch32.reshape(-1)
        nz_t = np.abs(ref32_t) > 0
        mre_t = float(np.mean(np.abs(diff_t[nz_t]) / np.abs(ref32_t[nz_t]))) if np.any(nz_t) else float('nan')

    else:
        print('PyTorch Flash SDP output not found; run the PyTorch Flash cell first to compare.')
except Exception as e:
    print('Skip comparisons until inputs are prepared:', e)

In [None]:
# Summary: Precision comparison
# - Kernel vs NumPy FlashAttention
# - Kernel vs PyTorch Flash SDP
# - NumPy vs PyTorch Flash SDP

import numpy as np

summary = {}

# Kernel vs NumPy
if 'O_np_flash32' in globals():
    O_kernel32 = O.astype(np.float32)
    diff_kn = O_kernel32.reshape(-1) - O_np_flash32.reshape(-1)
    ad_kn = np.abs(diff_kn)
    summary['kernel_vs_numpy'] = {
        'MAE': float(np.mean(ad_kn)),
        'MRE': float(np.mean(np.abs(diff_kn[np.abs(O_kernel32.reshape(-1))>0]) / np.abs(O_kernel32.reshape(-1)[np.abs(O_kernel32.reshape(-1))>0]))) if np.any(np.abs(O_kernel32.reshape(-1))>0) else float('nan'),
    }

# Kernel vs PyTorch
if 'O_flash16_np' in globals():
    O_torch32 = O_flash16_np.astype(np.float32)
    diff_kt = O_kernel32.reshape(-1) - O_torch32.reshape(-1)
    ad_kt = np.abs(diff_kt)
    summary['kernel_vs_torch'] = {
        'MAE': float(np.mean(ad_kt)),
        'MRE': float(np.mean(np.abs(diff_kt[np.abs(O_torch32.reshape(-1))>0]) / np.abs(O_torch32.reshape(-1)[np.abs(O_torch32.reshape(-1))>0]))) if np.any(np.abs(O_torch32.reshape(-1))>0) else float('nan'),
    }

# NumPy vs PyTorch
if 'O_np_flash32' in globals() and 'O_flash16_np' in globals():
    O_torch32 = O_flash16_np.astype(np.float32)
    diff_nt = O_np_flash32.reshape(-1) - O_torch32.reshape(-1)
    ad_nt = np.abs(diff_nt)
    summary['numpy_vs_torch'] = {
        'MAE': float(np.mean(ad_nt)),
        'MRE': float(np.mean(np.abs(diff_nt[np.abs(O_torch32.reshape(-1))>0]) / np.abs(O_torch32.reshape(-1)[np.abs(O_torch32.reshape(-1))>0]))) if np.any(np.abs(O_torch32.reshape(-1))>0) else float('nan'),
    }
    
for k, v in summary.items():
    print(k)
    for mk, mv in v.items():
        print(f"  {mk}: {mv}")

kernel_vs_numpy
  MAE: 6.86005296302028e-05
  MRE: 0.008922836743295193
kernel_vs_torch
  MAE: 7.20318712410517e-05
  MRE: 0.009019171819090843
numpy_vs_torch
  MAE: 1.023419918055879e-05
  MRE: 0.0016797683201730251
