Optimized WMMA (Wave Matrix Multiply-Accumulate) operations for AMD gfx1151 (RDNA3.5 / Strix Halo) architecture
Based on llama.cpp rocWMMA optimizations (PR #16827) and Sébastien Vince's Deep Dive into Matrix Optimization on AMD GPUs.
- Executive Summary
- Performance Results
- Building
- Usage
- Testing
- Architecture Details
- Kernel Variants
- Key Optimizations
- Optimization Techniques
- Profiling and Analysis
- Remaining Gap to rocBLAS
- File Structure
- References
Successfully implemented and optimized a rocWMMA GEMM kernel for PyTorch targeting gfx1151, achieving 21.6 TFLOPS peak (36% of theoretical 59.4 TFLOPS) with correct results across all test configurations.
This represents a 4× improvement over the initial implementation (5.4 → 21.6 TFLOPS) through systematic optimization, reaching 53% of rocBLAS FP16 performance.
| Metric | Value |
|---|---|
| Peak TFLOPS | 21.6 (4096×4096) |
| Utilization | 36% of 59.4 TFLOPS peak |
| vs rocBLAS FP16 | 53% (rocBLAS: ~41 TFLOPS) |
| Improvement | 4× over baseline |
| Correctness | ✅ All tests pass (rel_err < 1%) |
| Configuration | WMMA TFLOPS | % Peak | rocBLAS FP16 | % of rocBLAS | Status |
|---|---|---|---|---|---|
| 512×512×512 | 12.5 | 21.0% | 20.6 | 60% | ✅ |
| 1024×1024×1024 | 14.6 | 24.6% | 37.0 | 39% | ✅ |
| 2048×2048×2048 | 20.0 | 33.7% | 38.5 | 52% | ✅ |
| 4096×4096×4096 | 21.6 | 36.4% | 41.0 | 53% | ✅ |
rocBLAS achieves 69% of peak (41/59.4 TFLOPS) while our kernel achieves 36% of peak.
All 12 kernel variants pass correctness tests (< 1% relative error):
| Kernel | Time (ms) | TFLOPS | % of Peak | Status |
|---|---|---|---|---|
| matmul_zerocopy | 3.33 | 20.61 | 34.7% | ✅ Best |
| matmul_adaptive | 3.35 | 20.52 | 34.5% | ✅ |
| matmul_asmOpt | 3.35 | 20.49 | 34.5% | ✅ |
| matmul | 3.36 | 20.46 | 34.4% | ✅ |
| matmul_native | 3.45 | 19.92 | 33.5% | ✅ |
| matmul_kunroll | 3.75 | 18.31 | 30.8% | ✅ |
| matmul_swizzled | 3.84 | 17.90 | 30.1% | ✅ |
| matmul_noPrefetch | 3.90 | 17.60 | 29.6% | ✅ |
| matmul_xor_optimized | 3.98 | 17.26 | 29.1% | ✅ |
| matmul_quad | 4.04 | 17.01 | 28.6% | ✅ |
| matmul_hilbert | 5.16 | 13.33 | 22.4% | ✅ |
| matmul_highOcc | 6.68 | 10.28 | 17.3% | ✅ |
| matmul_coop | 3.44 | 19.98 | 33.6% | ✅ New |
| PyTorch (reference) | 1.92 | 35.86 | 60.4% | — |
Peak Theoretical: 59.4 TFLOPS (gfx1151 FP16 WMMA)
-
matmul_zerocopy performs best (20.61 TFLOPS, 57.5% of PyTorch)
- Swizzled B matrix with zero-copy stores
- Best for large matrices
-
Top 4 kernels are within 1% of each other (~20.5 TFLOPS)
- zerocopy, adaptive, asmOpt, standard all perform similarly
- Use
matmul_adaptivefor automatic selection
-
XOR swizzle kernels now work correctly
matmul_swizzledandmatmul_xor_optimizedboth pass- Fixed fragment loading pattern (load ROW, not column)
-
HighOcc underperforms (10.28 TFLOPS)
- Lower register pressure doesn't compensate for reduced compute intensity
- Not recommended for current workloads
All correctness tests pass with acceptable FP16 precision:
| Test | Relative Error | Status |
|---|---|---|
| 512×512×64 | < 0.04% | ✅ |
| 2048×2048×128 | < 0.03% | ✅ |
| 4096×4096×2048 | < 0.03% | ✅ |
| GEMM α=2.0, β=0.5 | < 0.001% | ✅ |
| GEMM in-place | < 0.001% | ✅ |
The project includes a Docker environment with ROCm 7.10, PyTorch, and all dependencies pre-configured for gfx1151.
1. Build the Docker image:
cd /path/to/wmma_ops
docker compose -f docker/docker-compose.benchmark.yml build2. Run the build and test suite:
# Create .env file if it doesn't exist (required by docker-compose)
touch .env
# Build and test
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
cd /workspace/wmma_ops && ./build_and_test.sh"3. Interactive development:
# Start an interactive shell in the container
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark bash
# Inside the container:
export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:$LD_LIBRARY_PATH
cd /workspace/wmma_ops
pip install -e . --no-build-isolation
python test_rocwmma_patch.pyIf you have ROCm and PyTorch installed on your host system:
cd /path/to/wmma_ops
pip install -e . --no-build-isolation- Toolkit: ROCm 7.9 or 7.10-preview
- Compiler: HIP compiler (
hipcc) - Python: PyTorch with ROCm support
- Target: gfx1151 (RDNA3.5 / Strix Halo)
import torch
import wmma_ops
# Create test tensors (FP16 input)
A = torch.randn(4096, 2048, device='cuda', dtype=torch.float16)
B = torch.randn(2048, 4096, device='cuda', dtype=torch.float16)
# Use optimized WMMA matmul (FP32 output)
C = wmma_ops.matmul(A, B)
# Alternative: specify tile configuration
C = wmma_ops.matmul_tiled(A, B, 1) # 1 = 128×64 tile
# Verify correctness
C_ref = torch.matmul(A, B)
print(f"Max error: {(C - C_ref).abs().max().item()}")| Function | Description |
|---|---|
wmma_ops.matmul(A, B) |
Standard optimized kernel (recommended) |
wmma_ops.matmul_adaptive(A, B) |
Auto-selects optimal tile configuration |
wmma_ops.matmul_tiled(A, B, config) |
Explicit tile configuration (0-3) |
wmma_ops.matmul_kunroll(A, B) |
K-unrolled variant (2× fewer syncs) |
wmma_ops.matmul_noPrefetch(A, B) |
Without register prefetch |
wmma_ops.matmul_highOcc(A, B) |
High-occupancy variant |
wmma_ops.matmul_quad(A, B) |
Quad-buffered variant |
wmma_ops.matmul_native(A, B) |
gfx1151-specific with explicit intrinsics |
wmma_ops.matmul_zerocopy(A, B) |
Swizzled B with zero-copy stores (fastest) |
wmma_ops.matmul_asmOpt(A, B) |
Assembly-optimized scheduling hints |
wmma_ops.matmul_hilbert(A, B) |
Hilbert curve tile mapping for L2 locality |
wmma_ops.matmul_swizzled(A, B) |
XOR-swizzled LDS (bank conflict-free) |
wmma_ops.matmul_xor_optimized(A, B) |
Optimized XOR swizzle variant |
wmma_ops.matmul_coop(A, B) |
Cooperative loading (half threads load A, half load B) |
| Function | Description |
|---|---|
wmma_ops.gemm(A, B, alpha=1.0, beta=0.0, C=None) |
Standard GEMM with fused scaling |
wmma_ops.gemm_adaptive(A, B, alpha=1.0, beta=0.0, C=None) |
Auto-tuned GEMM with scaling |
wmma_ops.gemm_inplace(A, B, C, alpha=1.0, beta=0.0) |
In-place GEMM (modifies C directly) |
| Function | Description |
|---|---|
wmma_ops.flash_attention(Q, K, V, causal=False, scale=-1.0) |
Flash Attention v2 forward pass |
Flash Attention Usage:
import torch
import wmma_ops
# Input: Q, K, V tensors of shape [B, H, N, D]
# B = batch size, H = num heads, N = sequence length, D = head dimension
Q = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
K = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
V = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
# Non-causal attention
O = wmma_ops.flash_attention(Q, K, V)
# Causal attention (for autoregressive models)
O_causal = wmma_ops.flash_attention(Q, K, V, causal=True)
# Custom scale (default: 1/sqrt(D))
O_scaled = wmma_ops.flash_attention(Q, K, V, scale=0.1)Current Status:
- ✅ Correctness: Matches PyTorch SDPA (<0.001 max error)
⚠️ Performance: Simple scalar kernel, ~2-10x slower than PyTorch SDPA- 🔜 TODO: WMMA-accelerated tiled implementation for better performance
Usage Example (GEMM with scaling):
import wmma_ops
# C = 2.0 * (A @ B) + 0.5 * C_prev
C = wmma_ops.gemm(A, B, alpha=2.0, beta=0.5, C=C_prev)
# In-place: C = 1.5 * (A @ B) + 0.3 * C (modifies C directly)
wmma_ops.gemm_inplace(A, B, C, alpha=1.5, beta=0.3)- For Production Use: Use
matmul_adaptive- best overall performance - For Specific Sizes:
- Small (512): Use
matmul(Standard) - Medium (1024): Use
matmul_adaptive(selects K-Unroll) - Large (2048+): Use
matmul_adaptive(selects Standard 128×64)
- Small (512): Use
# Full build + test
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
cd /workspace/wmma_ops && ./build_and_test.sh"
# Or run tests only (after building)
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
cd /workspace/wmma_ops && python test_rocwmma_patch.py"python test_rocwmma_patch.pypython rocprof_wmma.py# Quick benchmark (no Optuna required)
python autotune.py --quick
# Full Optuna tuning (requires: pip install optuna)
python autotune.py --trials 20
# Tune specific size
python autotune.py --size 4096 4096 2048- Hardware: AMD gfx1151 (Strix Halo / RDNA3.5)
- Toolkit: ROCm 7.9 or 7.10-preview
- Warmup: 3 iterations
- Benchmark: 20 iterations
- Correctness: Compared against PyTorch FP32 matmul
- Tolerance: < 1% relative error for correctness pass
| Component | Specification |
|---|---|
| Wavefront Size | 32 threads (Wave32) |
| SIMD Units per CU | 2 × SIMD32 (dual-issue capable) |
| Compute Units | 40 CUs |
| Peak Clock | 2.9 GHz |
| Peak TFLOPS | 59.4 (FP16 WMMA) |
| LDS | 64 KB per CU |
| VGPR File | 192 KB per SIMD (1.5× larger than mobile RDNA3) |
| Memory | LPDDR5X (~256 GB/s) |
| Parameter | Value |
|---|---|
| Instruction | v_wmma_f32_16x16x16_f16 |
| Tile Size | 16×16×16 (M×N×K) |
| Input Type | FP16 (half16) |
| Accumulator Type | FP32 (float8) |
| Wave Size | 32 (w32 suffix) |
| Latency | ~32 cycles |
For detailed information on fragment layouts, see docs/wmma_fragment_layout_rdna3.md.
Key Points:
- RDNA3 WMMA requires lane replication: lanes 0-15 and lanes 16-31 must contain identical data for A and B fragments
- A Matrix: Each lane loads one ROW of A (all 16 K values)
- B Matrix: Transposed layout in LDS, each lane loads one ROW of transposed B (= one column of original B)
- C/D Matrix: Row-major format, lanes 0-15 cover even rows, lanes 16-31 cover odd rows
Configuration: 128×64 tile, 4×2 warps, 2×2 register blocking
Features:
- Double buffering (overlaps loads with compute)
- GMEM spreading (register prefetch interleaved with MMA)
- Vectorized
half8global loads - Transposed B in LDS for
col_majorfragment access - LDS padding (+8 halfs) to avoid bank conflicts
Performance: 21.9 TFLOPS (36.8% peak) at 4096×4096
Configuration: Auto-selects optimal variant based on matrix dimensions
Selection Logic:
- Small matrices (< 512): Uses 64×64 tiles
- Medium matrices (512-2048): Uses 128×64 tiles with K-unroll when beneficial
- Large matrices (> 2048): Uses 128×64 standard tiles
Performance: 21.6 TFLOPS (36.4% peak) at 4096×4096, best average across sizes
Configuration: 128×64 tile with 2× K-unrolling
Features:
- Processes 2× BLOCK_K per iteration
- Reduces
__syncthreadsoverhead by 50% - Best for K dimensions in 768-1536 range
Performance: 17.2 TFLOPS (29% peak) at 4096×4096, but 15.7 TFLOPS at 1024×1024
Simplification: Removes register prefetch phase
Trade-off: Lower register pressure (~64 VGPRs) but less latency hiding
Performance: 19.2 TFLOPS (32% peak) at 4096×4096
Configuration: 64×32 tile, 4×1 warps, 2×1 register blocking
Goal: Maximize waves/CU by reducing VGPRs to ~50
Result: Worse performance (9.8 TFLOPS) — latency hiding more important than occupancy for this compute-bound workload
Configuration: 128×64 tile, explicit inline assembly fences
Features:
lds_fence(),vmem_fence(),full_fence()via inline asm- Interleaved prefetch pattern (global loads between WMMA ops)
__builtin_prefetchfor software prefetchamdgpu_waves_per_eu(4, 8)occupancy hint
Result: ~same as adaptive (~20 TFLOPS) — async copy hardware not exposed in HIP
| Optimization | Impact | Description |
|---|---|---|
| 2×2 Register Blocking | +80% | 4 WMMA tiles per warp (32×32 output) |
| Double Buffering | +20% | Ping-pong LDS buffers |
| GMEM Spreading | +15% | Prefetch into registers, interleaved with MMA |
Vectorized half8 Loads |
+25% | 128-bit global loads |
| 128×64 Tile Shape | +25% | Optimal A-matrix reuse |
| Transposed B in LDS | Required | Matches col_major fragment layout |
| LDS Padding (+8 halfs) | +15-20% | Eliminates bank conflicts (stride 24 vs 16) |
| Pointer Increment | +2% | Reduces VALU pressure in main loop |
| amdgpu_waves_per_eu | +2% | Compiler hint for occupancy targeting |
| Epilogue Fusion (α/β) | Saves 1 pass | Fused C = αAB + βC avoids separate scaling kernel |
- Explicit Wave32: Force/confirm wave32 compilation for gfx11 targets and assert at runtime using
__AMDGCN_WAVEFRONT_SIZE == 32. - Compiler Hints: Use
__restrict__on A/B/C pointers and__builtin_assume_aligned(ptr, 16)on vectorized paths to encourageglobal_load_b128generation. - Wide LDS Reads: Convert LDS fragment loads from 16 scalar half loads into two
ds_read_b128+ pack. This reduces LGKM overhead and bank conflict probability.
| Optimization | Result | Reason |
|---|---|---|
| BLOCK_K=32 | Slower | Added loop overhead outweighed benefits |
| Odd LDS Stride (17, 33) | Slower | Forces scalar LDS access |
| Pre-transpose B on Host | Slower | Host overhead > kernel savings |
| Triple Buffering | Broken | Correctness issues with rotation |
| Inline Assembly Scheduling | No Change | Compiler scheduling already optimal |
| High-Occupancy Variant | Slower | Latency hiding > occupancy for this workload |
The current implementation uses LDS padding to avoid bank conflicts:
#define LDS_PAD 8
constexpr int A_STRIDE = BLOCK_K + LDS_PAD; // 16 + 8 = 24 halfs = 48 bytesProblems with padding:
- Wastes LDS memory: 8 extra halfs per row = 16 bytes wasted per row
- For BLOCK_M=128 rows: 128 × 16 = 2KB wasted per buffer
- With double buffering: 4KB wasted total
- Doesn't guarantee conflict-free access: Padding only helps if all accesses are sequential
- Breaks vectorized access alignment: Stride of 24 halfs means rows aren't 128-bit aligned
XOR swizzle transforms memory indices so that bank conflicts are mathematically impossible:
Original index: (row, col)
Swizzled index: (row, col XOR f(row))
Where f(row) is chosen such that threads accessing different rows but same logical column will hit different banks.
For BLOCK_K=16, KPACK=8:
- K_GROUPS = 16 / 8 = 2
- Swizzle:
k_group_swizzled = k_group ^ (row & 1)
Memory Savings:
- Padding approach: 18,432 bytes (with 2 buffers)
- XOR Swizzle: 12,288 bytes
- Savings: 6,144 bytes (33%)
Status: ✅ Implemented and correct in wmma_xor_swizzle.hpp, but slower than padding (~15-20% slower).
| Approach | TFLOPS | LDS Usage | Bank Conflicts |
|---|---|---|---|
| Padding (stride=24) | 20-21 | 18.4 KB | Low (stride breaks alignment) |
| XOR Swizzle (stride=16) | 17-18 | 12.3 KB | None (mathematically eliminated) |
Why XOR Swizzle is Slower:
- B matrix transpose stores: Each scalar store requires computing
Swizzle::to_physical()(division, modulo, XOR) - Flat 1D array indexing: More VALU overhead than 2D array with padding
- RDNA3 LDS bank conflict penalty: May not be severe enough to justify swizzle computation overhead
- Compiler optimization: 2D arrays with padding are easier for the compiler to optimize
Recommendation: Use padding approach for gfx1151. The 33% LDS savings from XOR swizzle doesn't compensate for the ~15-20% performance loss.
When data is stored swizzled in LDS, fragment loading must account for the swizzle transformation. The following fixes are required:
Fix 1: Fragment Loading with XOR Swizzle Inversion
When loading fragments from swizzled LDS, you must invert the swizzle to get the correct data layout for WMMA:
// INCORRECT: Direct access ignores swizzle
int frag_col = lane % 16;
for (int r = 0; r < 16; r++) {
a0[r] = A_lds[curr][SwzA::to_flat(warp_m_base + r, frag_col)];
}
// CORRECT: Invert XOR swizzle
int frag_col_orig = lane % 16; // Original column needed by WMMA
for (int r = 0; r < 16; r++) {
int row = warp_m_base + r;
// Invert XOR swizzle: find which swizzled column contains original column frag_col_orig
int k_group_orig = frag_col_orig / 8;
int k_local = frag_col_orig % 8;
int k_group_swz = k_group_orig ^ (row & SwzA::K_GROUPS_MASK);
int frag_col_swz = k_group_swz * 8 + k_local;
a0[r] = *reinterpret_cast<const _Float16*>(&A_lds[curr][SwzA::to_flat(row, frag_col_swz)]);
}For B matrix (similar fix with transposed layout):
int frag_row_orig = lane % 16; // Original row in transposed B layout
for (int kk = 0; kk < 16; kk++) {
int n = warp_n_base + frag_row_orig;
// Invert XOR swizzle for B
int k_group_orig = kk / 8;
int k_local = kk % 8;
int k_group_swz = k_group_orig ^ (n & SwzB::K_GROUPS_MASK);
int k_swz = k_group_swz * 8 + k_local;
b0[kk] = *reinterpret_cast<const _Float16*>(&B_lds[curr][SwzB::to_flat(n, k_swz)]);
}Fix 2: Correct Epilogue Store Pattern
WMMA fragment layout stores elements in a specific pattern. Each element c_frag[i] stores to row i*2 + (lane/16), column lane%16:
// INCORRECT: Wrong fragment layout assumption
int frag_row = lane % 16;
int frag_col_off = (lane / 16) * 8;
for (int e = 0; e < 8; e++) {
int local_c = frag_col_off + e;
C[gr0 * N + gc0] = c00[e]; // WRONG!
}
// CORRECT: Proper WMMA fragment layout
int frag_col = lane % 16; // Column is fixed per lane
int frag_row_offset = lane / 16; // 0 for lanes 0-15, 1 for lanes 16-31
for (int i = 0; i < 8; i++) {
int frag_row = i * 2 + frag_row_offset; // Rows: 0,2,4,...,14 or 1,3,5,...,15
int gr0 = block_m + warp_m_base + frag_row;
int gc0 = block_n + warp_n_base + frag_col;
if (gr0 < M && gc0 < N) C[gr0 * N + gc0] = c00[i];
// For 2×2 register blocking, handle all 4 tiles:
int gc1 = gc0 + 16; // Tile [0][1]
if (gr0 < M && gc1 < N) C[gr0 * N + gc1] = c01[i];
int gr1 = gr0 + 16; // Tile [1][0]
if (gr1 < M && gc0 < N) C[gr1 * N + gc0] = c10[i];
if (gr1 < M && gc1 < N) C[gr1 * N + gc1] = c11[i]; // Tile [1][1]
}Complete Helper Function Example:
template<typename SwzA>
__device__ __forceinline__ void load_a_frag_swizzled(
half16& a_frag,
const __half* lds_base,
int warp_m_base,
int frag_col_orig
) {
#pragma unroll
for (int r = 0; r < 16; r++) {
int row = warp_m_base + r;
// Invert XOR swizzle
int k_group_orig = frag_col_orig / 8;
int k_local = frag_col_orig % 8;
int k_group_swz = k_group_orig ^ (row & SwzA::K_GROUPS_MASK);
int frag_col_swz = k_group_swz * 8 + k_local;
a_frag[r] = *reinterpret_cast<const _Float16*>(&lds_base[SwzA::to_flat(row, frag_col_swz)]);
}
}Alternative Approach: If the swizzle unswizzling is too complex or has performance overhead, consider:
- Store data swizzled (for bank conflict avoidance during global→LDS load)
- Unswizzle during LDS→Fragment load into a temporary buffer
- Load fragments from unswizzled buffer
This adds an extra LDS copy step but simplifies the fragment loading code.
Testing Recommendations:
- Start with correctness: Test with small matrices (128×128) before performance
- Compare against reference: Use non-swizzled kernel as reference
- Verify swizzle math: Test XOR swizzle inversion logic separately
- Check fragment layout: Verify fragment loading matches expected WMMA layout (see
docs/wmma_fragment_layout_rdna3.md) - Profile bank conflicts: Use rocprof to verify XOR swizzle actually reduces conflicts
Simple row-major launch can cause L2 thrashing for large matrices. Column-major or chunked tile processing improves L2 cache locality.
Expected Impact: 5-15% improvement for large matrices (4096×4096+)
Status: Implemented in wmma_optimizations.hpp but not yet integrated into main kernels.
Split-K assigns partial K-dimension slices to different work-groups, improving utilization for skinny matrices (small M or N, large K).
Example: M=16, N=4096, K=4096
- Without Split-K: Only 64 tiles → 60% utilization
- With Split-K factor 4: 256 tiles → 90% utilization
Status: Implemented in wmma_optimizations.hpp but not yet benchmarked.
Current Register Usage: Estimated ~91-92 VGPRs per thread (needs verification via roc-obj-utils or rocprof)
To verify actual VGPR usage:
# Extract from compiled kernel
roc-obj-utils --disassemble kernel.hsaco | grep -A 20 "COMPUTE_PGM_RSRC"
# Look for .vgprsnum valueCompiler Hints:
__launch_bounds__(256, 2)
__attribute__((amdgpu_waves_per_eu(4, 8)))
// Note: amdgpu_num_vgpr does NOT work with templates in HIP
// Only amdgpu_waves_per_eu is supported with templatesImportant: The amdgpu_num_vgpr attribute is not supported with template kernels in HIP. Use amdgpu_waves_per_eu to hint occupancy instead.
Current kernel does Global→LDS double buffering but not LDS→Register double buffering. The problem is LDS loads and WMMA compute are serialized.
Potential Improvement: 5-10% by overlapping LDS loads with computation.
Trade-off: Doubles register usage for fragments (from ~128 to ~256 VGPRs), which may reduce occupancy.
| K Dimension | Regime | Limiting Factor |
|---|---|---|
| K < 256 | Memory-bound | LPDDR5X bandwidth |
| K ≥ 512 | Compute-bound | WMMA throughput |
| Bottleneck | Contribution | Notes |
|---|---|---|
__syncthreads overhead |
~20% | 128 barriers for K=2048 |
| Occupancy (5 waves/CU) | ~15% | 91 VGPRs limits to 5 waves |
| LDS bank conflicts (B scatter) | ~10% | Transpose pattern causes conflicts |
| WMMA pipeline bubbles | ~10% | Data dependencies within wave |
| Metric | Value |
|---|---|
| Peak Compute | 59.4 TFLOPS |
| Peak Memory BW | 256 GB/s |
| Ridge Point | ~106 ops/byte |
| Our Intensity (K=2048) | ~682 ops/byte |
| Regime | Compute-bound ✅ |
| Component | Count | Status |
|---|---|---|
| WMMA Instructions | 4 | ✅ Correct (2×2 blocking) |
Global Loads (global_load_b128) |
4 | ✅ Vectorized |
LDS Stores (ds_store_b128) |
8 | ✅ Vectorized |
Dual-Issue (v_dual_*) |
6-21 | ✅ Active |
Barriers (s_waitcnt) |
10 |
s_waitcnt lgkmcnt(0) ; Wait for LDS
v_wmma_f32_16x16x16_f16 v[1:8], v[57:64], v[49:56], v[1:8] ; MMA 0
v_wmma_f32_16x16x16_f16 v[9:16], v[57:64], v[41:48], v[9:16] ; MMA 1
v_wmma_f32_16x16x16_f16 v[17:24], v[33:40], v[49:56], v[17:24] ; MMA 2
v_wmma_f32_16x16x16_f16 v[25:32], v[33:40], v[41:48], v[25:32] ; MMA 3Key Finding: Compiler groups WMMAs together intentionally. Attempts to interleave with inline assembly did not improve performance — the hardware scheduler handles dual-issue at runtime.
Profile LDS bank conflicts:
rocprof --stats -o profile.csv -i metrics.txt ./your_kernel
# metrics.txt should include:
# LDSBankConflict
# LDSInstructions
# LDSBankConflictCycles
# View results
cat profile.csv | grep -E "LDSBankConflict|LDSInstructions"Interpretation:
LDSBankConflict/LDSInstructions= conflict rate (target: < 5%)- High conflict rate indicates need for swizzling/padding
Our kernel achieves 53% of rocBLAS performance (~21.6 vs ~41 TFLOPS). The gap is due to:
| Factor | Description |
|---|---|
| Hand-tuned Assembly | rocBLAS uses offline-optimized ISA with perfect scheduling |
| Adaptive Tile Selection | rocBLAS selects optimal tile per matrix dimension |
| Async Copy Hardware | Uses hardware async LDS loads (not exposed in HIP) |
| Register Allocation | Compiler-level register allocation vs manual tuning |
| Multi-kernel Fusion | rocBLAS fuses alpha/beta scaling |
We attempted several gfx1151-specific optimizations without portability constraints:
| Optimization | Result | Notes |
|---|---|---|
| Async global-to-LDS | ❌ Not available | __builtin_amdgcn_global_load_lds not exposed in ROCm 7.x for gfx1151 |
| Hardware prefetch | ❌ Not available | s_prefetch_data instruction not supported on RDNA3.5 |
| Explicit waitcnt | ⚖️ No improvement | lds_fence(), vmem_fence() via inline asm perform same as compiler-managed |
| Interleaved prefetch | ⚖️ No improvement | WMMA ops (128 cycles) complete before global loads (400-800 cycles) |
| Software prefetch | ⚖️ Minimal impact | __builtin_prefetch adds ~1% improvement |
Based on extensive research of AMD ROCm docs, LLVM AMDGPU backend, and GPUOpen resources (as of late 2025):
- No
__hip_ds_copy_asyncor similar hardware intrinsics exist for gfx1151 llvm.amdgcn.load.to.ldsexists but is synchronous (lowers toglobal_load_lds+s_waitcnt)- Async behavior must be achieved through:
- HIP runtime APIs (
hipMemcpyAsyncwith streams) - host-device only - Manual overlap with
s_waitcnt vmcnt(x)to allow in-flight loads - already implemented - Queue-level async (separate compute/copy queues) - not applicable for LDS
- HIP runtime APIs (
The only ISA pattern available for global-to-LDS:
buffer_load_dword v1, v0, s[sgpr0:sgpr3], ... ; Global load
s_waitcnt vmcnt(0) ; Wait for load
ds_write_b32 v2, v1 ; Write to LDS
s_waitcnt lgkmcnt(0) ; Wait for LDS visibilityGFX12 has s_wait_dscnt but still no async. ASYNC LDS and tensor ops are not covered by the memory model implemented by the AMDGPU backend; waits aren't inserted automatically and must be emitted explicitly.
- AMD exposing true async LDS intrinsics - requires hardware/firmware changes, not just software
- Write in AMDGCN assembly directly with perfect scheduling (impractical, loses compiler optimizations)
- Use rocBLAS/hipBLASLt for production (recommended - these use internal AMD optimizations)
The theoretical maximum without vendor-level assembly optimization is approximately 55-60% of peak (~33-36 TFLOPS). Current optimizations in progress (XOR swizzle, L2 rasterization) could potentially reach 40-45% of peak (~24-27 TFLOPS).
| Version | TFLOPS | Change | Key Optimization |
|---|---|---|---|
| Baseline | 5.4 | — | Basic WMMA implementation |
| + Multi-column blocks | 6.2 | +15% | BLOCK_N=64 for B reuse |
| + 2×2 Register Blocking | 9.7 | +56% | 4 accumulators per warp |
| + Vectorized Loads | 11.5 | +19% | half8 global loads |
| + Double Buffering | 13.2 | +15% | Ping-pong LDS |
| + 128×64 Tile Shape | 16.5 | +25% | Increased A reuse |
| + GMEM Spreading | 17.4 | +5% | Register prefetch |
| + LDS Padding | 20.9 | +20% | Bank conflict elimination |
- Tile shape matters more than expected: 128×64 (tall) significantly outperforms 64×64 (square) due to A-matrix reuse
- Prefetching > Occupancy: For compute-bound workloads, hiding latency via prefetch beats maximizing waves/CU
- Vectorized access is critical: LPDDR5X heavily penalizes scalar loads
- Compiler scheduling is smart: Inline assembly attempts didn't improve on LLVM's scheduling
- LDS transpose is required:
col_majorB fragments need transposed data - Odd strides hurt more than help: Bank conflict avoidance via odd strides forces scalar access
Several optimizations from the development notes were attempted:
| Optimization | Result | Notes |
|---|---|---|
| Vectorized C Writes | ❌ 0.72x slower | Extra 32KB LDS + sync overhead hurt more than coalescing helped |
| Cooperative Loading | ⚖️ 1.01-1.06x | Marginal gains at small K, slight regression at large K |
| BLOCK_K=32 (K-unroll) | ❌ 0.89-0.97x | Increased LDS/register pressure hurt occupancy |
| XOR Swizzle LDS | ❌ 0.85x slower | Swizzle computation overhead > bank conflict savings |
Conclusion: The standard kernel is already well-optimized with double-buffered LDS, interleaved prefetch, and LDS padding. Further gains likely require split-K parallelism or assembly-level tuning.
wmma_ops/
├── README.md # This file
├── setup.py # Build configuration
├── wmma_gemm.hip # Main kernel implementation & pybind
├── wmma_kernels_optimized.hpp # Optimized kernel variants (kunroll, quad, hilbert, etc.)
├── wmma_tile_mapping.hpp # Hilbert curve tile mapping for L2 locality
├── wmma_xor_swizzle.hpp # XOR swizzle, rasterization, Split-K
├── wmma_tile_selection.hpp # Adaptive tile configuration
├── wmma_device_helpers.hpp # Fragment loading helpers
├── rocwmma_patch/
│ └── rocwmma_gfx1151.hpp # Custom rocWMMA patch header
├── docs/
│ └── WMMA_DEVELOPMENT_NOTES.md # Consolidated development documentation
├── examples/ # Reference implementations from other projects
├── autotune.py # Optuna-based auto-tuner
├── test_rocwmma_patch.py # Test suite
├── benchmark_summary.py # Benchmark utilities
└── build_in_docker.sh # Docker build script
-
rocWMMA Documentation (ROCm)
-
Deep Dive into Matrix Optimization on AMD GPUs (Sébastien Vince)
- Blog Post
- Achievement: 49 TFLOPS on FP32 GEMM (60% faster than rocBLAS)
-
AMD RDNA™ 3.5 ISA Reference Guide
-
LLVM AMDGPU Usage
- LLVM Documentation
- Code object metadata, register usage, ISA details
-
rocBLAS Documentation (ROCm)
- llama.cpp PR #16827 - Original optimizations
- rocWMMA Library - AMD's rocWMMA library
- rocWMMA Samples - Reference kernels and usage patterns
- AMD Matrix Instruction Calculator - WMMA layout verification
Based on rocWMMA library (MIT License) and llama.cpp optimizations.