Skip to content

glovepost/wmma_ops

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rocWMMA Patch for PyTorch on gfx1151

Optimized WMMA (Wave Matrix Multiply-Accumulate) operations for AMD gfx1151 (RDNA3.5 / Strix Halo) architecture

Based on llama.cpp rocWMMA optimizations (PR #16827) and Sébastien Vince's Deep Dive into Matrix Optimization on AMD GPUs.


Table of Contents


Executive Summary

Successfully implemented and optimized a rocWMMA GEMM kernel for PyTorch targeting gfx1151, achieving 21.6 TFLOPS peak (36% of theoretical 59.4 TFLOPS) with correct results across all test configurations.

This represents a 4× improvement over the initial implementation (5.4 → 21.6 TFLOPS) through systematic optimization, reaching 53% of rocBLAS FP16 performance.

Quick Stats

Metric Value
Peak TFLOPS 21.6 (4096×4096)
Utilization 36% of 59.4 TFLOPS peak
vs rocBLAS FP16 53% (rocBLAS: ~41 TFLOPS)
Improvement 4× over baseline
Correctness ✅ All tests pass (rel_err < 1%)

Performance Results

Final Benchmarks (Adaptive Tile Selection with K-Unrolling)

Configuration WMMA TFLOPS % Peak rocBLAS FP16 % of rocBLAS Status
512×512×512 12.5 21.0% 20.6 60%
1024×1024×1024 14.6 24.6% 37.0 39%
2048×2048×2048 20.0 33.7% 38.5 52%
4096×4096×4096 21.6 36.4% 41.0 53%

rocBLAS achieves 69% of peak (41/59.4 TFLOPS) while our kernel achieves 36% of peak.

Kernel Variant Comparison (4096×4096×2048)

All 12 kernel variants pass correctness tests (< 1% relative error):

Kernel Time (ms) TFLOPS % of Peak Status
matmul_zerocopy 3.33 20.61 34.7% ✅ Best
matmul_adaptive 3.35 20.52 34.5%
matmul_asmOpt 3.35 20.49 34.5%
matmul 3.36 20.46 34.4%
matmul_native 3.45 19.92 33.5%
matmul_kunroll 3.75 18.31 30.8%
matmul_swizzled 3.84 17.90 30.1%
matmul_noPrefetch 3.90 17.60 29.6%
matmul_xor_optimized 3.98 17.26 29.1%
matmul_quad 4.04 17.01 28.6%
matmul_hilbert 5.16 13.33 22.4%
matmul_highOcc 6.68 10.28 17.3%
matmul_coop 3.44 19.98 33.6% ✅ New
PyTorch (reference) 1.92 35.86 60.4%

Peak Theoretical: 59.4 TFLOPS (gfx1151 FP16 WMMA)

Key Findings

  1. matmul_zerocopy performs best (20.61 TFLOPS, 57.5% of PyTorch)

    • Swizzled B matrix with zero-copy stores
    • Best for large matrices
  2. Top 4 kernels are within 1% of each other (~20.5 TFLOPS)

    • zerocopy, adaptive, asmOpt, standard all perform similarly
    • Use matmul_adaptive for automatic selection
  3. XOR swizzle kernels now work correctly

    • matmul_swizzled and matmul_xor_optimized both pass
    • Fixed fragment loading pattern (load ROW, not column)
  4. HighOcc underperforms (10.28 TFLOPS)

    • Lower register pressure doesn't compensate for reduced compute intensity
    • Not recommended for current workloads

Correctness Tests

All correctness tests pass with acceptable FP16 precision:

Test Relative Error Status
512×512×64 < 0.04%
2048×2048×128 < 0.03%
4096×4096×2048 < 0.03%
GEMM α=2.0, β=0.5 < 0.001%
GEMM in-place < 0.001%

Building

Using Docker (Recommended)

The project includes a Docker environment with ROCm 7.10, PyTorch, and all dependencies pre-configured for gfx1151.

1. Build the Docker image:

cd /path/to/wmma_ops
docker compose -f docker/docker-compose.benchmark.yml build

2. Run the build and test suite:

# Create .env file if it doesn't exist (required by docker-compose)
touch .env

# Build and test
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
  bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
           cd /workspace/wmma_ops && ./build_and_test.sh"

3. Interactive development:

# Start an interactive shell in the container
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark bash

# Inside the container:
export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:$LD_LIBRARY_PATH
cd /workspace/wmma_ops
pip install -e . --no-build-isolation
python test_rocwmma_patch.py

Using pip (Host Installation)

If you have ROCm and PyTorch installed on your host system:

cd /path/to/wmma_ops
pip install -e . --no-build-isolation

Build Requirements

  • Toolkit: ROCm 7.9 or 7.10-preview
  • Compiler: HIP compiler (hipcc)
  • Python: PyTorch with ROCm support
  • Target: gfx1151 (RDNA3.5 / Strix Halo)

Usage

Simple Matrix Multiply (C = A × B)

import torch
import wmma_ops

# Create test tensors (FP16 input)
A = torch.randn(4096, 2048, device='cuda', dtype=torch.float16)
B = torch.randn(2048, 4096, device='cuda', dtype=torch.float16)

# Use optimized WMMA matmul (FP32 output)
C = wmma_ops.matmul(A, B)

# Alternative: specify tile configuration
C = wmma_ops.matmul_tiled(A, B, 1)  # 1 = 128×64 tile

# Verify correctness
C_ref = torch.matmul(A, B)
print(f"Max error: {(C - C_ref).abs().max().item()}")

Available Functions

Matrix Multiply Functions

Function Description
wmma_ops.matmul(A, B) Standard optimized kernel (recommended)
wmma_ops.matmul_adaptive(A, B) Auto-selects optimal tile configuration
wmma_ops.matmul_tiled(A, B, config) Explicit tile configuration (0-3)
wmma_ops.matmul_kunroll(A, B) K-unrolled variant (2× fewer syncs)
wmma_ops.matmul_noPrefetch(A, B) Without register prefetch
wmma_ops.matmul_highOcc(A, B) High-occupancy variant
wmma_ops.matmul_quad(A, B) Quad-buffered variant
wmma_ops.matmul_native(A, B) gfx1151-specific with explicit intrinsics
wmma_ops.matmul_zerocopy(A, B) Swizzled B with zero-copy stores (fastest)
wmma_ops.matmul_asmOpt(A, B) Assembly-optimized scheduling hints
wmma_ops.matmul_hilbert(A, B) Hilbert curve tile mapping for L2 locality
wmma_ops.matmul_swizzled(A, B) XOR-swizzled LDS (bank conflict-free)
wmma_ops.matmul_xor_optimized(A, B) Optimized XOR swizzle variant
wmma_ops.matmul_coop(A, B) Cooperative loading (half threads load A, half load B)

BLAS-Style GEMM (C = α × A × B + β × C)

Function Description
wmma_ops.gemm(A, B, alpha=1.0, beta=0.0, C=None) Standard GEMM with fused scaling
wmma_ops.gemm_adaptive(A, B, alpha=1.0, beta=0.0, C=None) Auto-tuned GEMM with scaling
wmma_ops.gemm_inplace(A, B, C, alpha=1.0, beta=0.0) In-place GEMM (modifies C directly)

Flash Attention

Function Description
wmma_ops.flash_attention(Q, K, V, causal=False, scale=-1.0) Flash Attention v2 forward pass

Flash Attention Usage:

import torch
import wmma_ops

# Input: Q, K, V tensors of shape [B, H, N, D]
# B = batch size, H = num heads, N = sequence length, D = head dimension
Q = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
K = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
V = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)

# Non-causal attention
O = wmma_ops.flash_attention(Q, K, V)

# Causal attention (for autoregressive models)
O_causal = wmma_ops.flash_attention(Q, K, V, causal=True)

# Custom scale (default: 1/sqrt(D))
O_scaled = wmma_ops.flash_attention(Q, K, V, scale=0.1)

Current Status:

  • ✅ Correctness: Matches PyTorch SDPA (<0.001 max error)
  • ⚠️ Performance: Simple scalar kernel, ~2-10x slower than PyTorch SDPA
  • 🔜 TODO: WMMA-accelerated tiled implementation for better performance

Usage Example (GEMM with scaling):

import wmma_ops

# C = 2.0 * (A @ B) + 0.5 * C_prev
C = wmma_ops.gemm(A, B, alpha=2.0, beta=0.5, C=C_prev)

# In-place: C = 1.5 * (A @ B) + 0.3 * C (modifies C directly)
wmma_ops.gemm_inplace(A, B, C, alpha=1.5, beta=0.3)

Recommendations

  • For Production Use: Use matmul_adaptive - best overall performance
  • For Specific Sizes:
    • Small (512): Use matmul (Standard)
    • Medium (1024): Use matmul_adaptive (selects K-Unroll)
    • Large (2048+): Use matmul_adaptive (selects Standard 128×64)

Testing

Run Test Suite (Docker)

# Full build + test
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
  bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
           cd /workspace/wmma_ops && ./build_and_test.sh"

# Or run tests only (after building)
docker compose -f docker/docker-compose.benchmark.yml run --rm benchmark \
  bash -c "export LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/torch/lib:\$LD_LIBRARY_PATH && \
           cd /workspace/wmma_ops && python test_rocwmma_patch.py"

Run Test Suite (Host)

python test_rocwmma_patch.py

Run Profiling

python rocprof_wmma.py

Run Benchmarks

# Quick benchmark (no Optuna required)
python autotune.py --quick

# Full Optuna tuning (requires: pip install optuna)
python autotune.py --trials 20

# Tune specific size
python autotune.py --size 4096 4096 2048

Benchmark Methodology

  • Hardware: AMD gfx1151 (Strix Halo / RDNA3.5)
  • Toolkit: ROCm 7.9 or 7.10-preview
  • Warmup: 3 iterations
  • Benchmark: 20 iterations
  • Correctness: Compared against PyTorch FP32 matmul
  • Tolerance: < 1% relative error for correctness pass

Architecture Details

gfx1151 (RDNA3.5 / Strix Halo) Specifications

Component Specification
Wavefront Size 32 threads (Wave32)
SIMD Units per CU 2 × SIMD32 (dual-issue capable)
Compute Units 40 CUs
Peak Clock 2.9 GHz
Peak TFLOPS 59.4 (FP16 WMMA)
LDS 64 KB per CU
VGPR File 192 KB per SIMD (1.5× larger than mobile RDNA3)
Memory LPDDR5X (~256 GB/s)

WMMA Instruction

Parameter Value
Instruction v_wmma_f32_16x16x16_f16
Tile Size 16×16×16 (M×N×K)
Input Type FP16 (half16)
Accumulator Type FP32 (float8)
Wave Size 32 (w32 suffix)
Latency ~32 cycles

WMMA Fragment Layout

For detailed information on fragment layouts, see docs/wmma_fragment_layout_rdna3.md.

Key Points:

  • RDNA3 WMMA requires lane replication: lanes 0-15 and lanes 16-31 must contain identical data for A and B fragments
  • A Matrix: Each lane loads one ROW of A (all 16 K values)
  • B Matrix: Transposed layout in LDS, each lane loads one ROW of transposed B (= one column of original B)
  • C/D Matrix: Row-major format, lanes 0-15 cover even rows, lanes 16-31 cover odd rows

Kernel Variants

1. Standard Kernel (Optimal)

Configuration: 128×64 tile, 4×2 warps, 2×2 register blocking

Features:

  • Double buffering (overlaps loads with compute)
  • GMEM spreading (register prefetch interleaved with MMA)
  • Vectorized half8 global loads
  • Transposed B in LDS for col_major fragment access
  • LDS padding (+8 halfs) to avoid bank conflicts

Performance: 21.9 TFLOPS (36.8% peak) at 4096×4096

2. Adaptive Kernel (Recommended)

Configuration: Auto-selects optimal variant based on matrix dimensions

Selection Logic:

  • Small matrices (< 512): Uses 64×64 tiles
  • Medium matrices (512-2048): Uses 128×64 tiles with K-unroll when beneficial
  • Large matrices (> 2048): Uses 128×64 standard tiles

Performance: 21.6 TFLOPS (36.4% peak) at 4096×4096, best average across sizes

3. K-Unroll Kernel

Configuration: 128×64 tile with 2× K-unrolling

Features:

  • Processes 2× BLOCK_K per iteration
  • Reduces __syncthreads overhead by 50%
  • Best for K dimensions in 768-1536 range

Performance: 17.2 TFLOPS (29% peak) at 4096×4096, but 15.7 TFLOPS at 1024×1024

4. No-Prefetch Kernel

Simplification: Removes register prefetch phase

Trade-off: Lower register pressure (~64 VGPRs) but less latency hiding

Performance: 19.2 TFLOPS (32% peak) at 4096×4096

5. High-Occupancy Kernel

Configuration: 64×32 tile, 4×1 warps, 2×1 register blocking

Goal: Maximize waves/CU by reducing VGPRs to ~50

Result: Worse performance (9.8 TFLOPS) — latency hiding more important than occupancy for this compute-bound workload

6. Native Kernel (gfx1151-specific)

Configuration: 128×64 tile, explicit inline assembly fences

Features:

  • lds_fence(), vmem_fence(), full_fence() via inline asm
  • Interleaved prefetch pattern (global loads between WMMA ops)
  • __builtin_prefetch for software prefetch
  • amdgpu_waves_per_eu(4, 8) occupancy hint

Result: ~same as adaptive (~20 TFLOPS) — async copy hardware not exposed in HIP


Key Optimizations

✅ Implemented & Verified

Optimization Impact Description
2×2 Register Blocking +80% 4 WMMA tiles per warp (32×32 output)
Double Buffering +20% Ping-pong LDS buffers
GMEM Spreading +15% Prefetch into registers, interleaved with MMA
Vectorized half8 Loads +25% 128-bit global loads
128×64 Tile Shape +25% Optimal A-matrix reuse
Transposed B in LDS Required Matches col_major fragment layout
LDS Padding (+8 halfs) +15-20% Eliminates bank conflicts (stride 24 vs 16)
Pointer Increment +2% Reduces VALU pressure in main loop
amdgpu_waves_per_eu +2% Compiler hint for occupancy targeting
Epilogue Fusion (α/β) Saves 1 pass Fused C = αAB + βC avoids separate scaling kernel

Low-Risk Perf Tweaks

  • Explicit Wave32: Force/confirm wave32 compilation for gfx11 targets and assert at runtime using __AMDGCN_WAVEFRONT_SIZE == 32.
  • Compiler Hints: Use __restrict__ on A/B/C pointers and __builtin_assume_aligned(ptr, 16) on vectorized paths to encourage global_load_b128 generation.
  • Wide LDS Reads: Convert LDS fragment loads from 16 scalar half loads into two ds_read_b128 + pack. This reduces LGKM overhead and bank conflict probability.

❌ Tested But Not Beneficial

Optimization Result Reason
BLOCK_K=32 Slower Added loop overhead outweighed benefits
Odd LDS Stride (17, 33) Slower Forces scalar LDS access
Pre-transpose B on Host Slower Host overhead > kernel savings
Triple Buffering Broken Correctness issues with rotation
Inline Assembly Scheduling No Change Compiler scheduling already optimal
High-Occupancy Variant Slower Latency hiding > occupancy for this workload

Optimization Techniques

LDS Bank Conflict Elimination

Current Approach: Padding

The current implementation uses LDS padding to avoid bank conflicts:

#define LDS_PAD 8
constexpr int A_STRIDE = BLOCK_K + LDS_PAD;  // 16 + 8 = 24 halfs = 48 bytes

Problems with padding:

  1. Wastes LDS memory: 8 extra halfs per row = 16 bytes wasted per row
    • For BLOCK_M=128 rows: 128 × 16 = 2KB wasted per buffer
    • With double buffering: 4KB wasted total
  2. Doesn't guarantee conflict-free access: Padding only helps if all accesses are sequential
  3. Breaks vectorized access alignment: Stride of 24 halfs means rows aren't 128-bit aligned

Alternative: XOR-Based LDS Swizzle

XOR swizzle transforms memory indices so that bank conflicts are mathematically impossible:

Original index: (row, col)
Swizzled index: (row, col XOR f(row))

Where f(row) is chosen such that threads accessing different rows but same logical column will hit different banks.

For BLOCK_K=16, KPACK=8:

  • K_GROUPS = 16 / 8 = 2
  • Swizzle: k_group_swizzled = k_group ^ (row & 1)

Memory Savings:

  • Padding approach: 18,432 bytes (with 2 buffers)
  • XOR Swizzle: 12,288 bytes
  • Savings: 6,144 bytes (33%)

Status: ✅ Implemented and correct in wmma_xor_swizzle.hpp, but slower than padding (~15-20% slower).

Performance Analysis: XOR Swizzle vs Padding

Approach TFLOPS LDS Usage Bank Conflicts
Padding (stride=24) 20-21 18.4 KB Low (stride breaks alignment)
XOR Swizzle (stride=16) 17-18 12.3 KB None (mathematically eliminated)

Why XOR Swizzle is Slower:

  1. B matrix transpose stores: Each scalar store requires computing Swizzle::to_physical() (division, modulo, XOR)
  2. Flat 1D array indexing: More VALU overhead than 2D array with padding
  3. RDNA3 LDS bank conflict penalty: May not be severe enough to justify swizzle computation overhead
  4. Compiler optimization: 2D arrays with padding are easier for the compiler to optimize

Recommendation: Use padding approach for gfx1151. The 33% LDS savings from XOR swizzle doesn't compensate for the ~15-20% performance loss.

Critical Implementation Fixes for XOR Swizzle

When data is stored swizzled in LDS, fragment loading must account for the swizzle transformation. The following fixes are required:

Fix 1: Fragment Loading with XOR Swizzle Inversion

When loading fragments from swizzled LDS, you must invert the swizzle to get the correct data layout for WMMA:

// INCORRECT: Direct access ignores swizzle
int frag_col = lane % 16;
for (int r = 0; r < 16; r++) {
    a0[r] = A_lds[curr][SwzA::to_flat(warp_m_base + r, frag_col)];
}

// CORRECT: Invert XOR swizzle
int frag_col_orig = lane % 16;  // Original column needed by WMMA

for (int r = 0; r < 16; r++) {
    int row = warp_m_base + r;
    
    // Invert XOR swizzle: find which swizzled column contains original column frag_col_orig
    int k_group_orig = frag_col_orig / 8;
    int k_local = frag_col_orig % 8;
    int k_group_swz = k_group_orig ^ (row & SwzA::K_GROUPS_MASK);
    int frag_col_swz = k_group_swz * 8 + k_local;
    
    a0[r] = *reinterpret_cast<const _Float16*>(&A_lds[curr][SwzA::to_flat(row, frag_col_swz)]);
}

For B matrix (similar fix with transposed layout):

int frag_row_orig = lane % 16;  // Original row in transposed B layout

for (int kk = 0; kk < 16; kk++) {
    int n = warp_n_base + frag_row_orig;
    
    // Invert XOR swizzle for B
    int k_group_orig = kk / 8;
    int k_local = kk % 8;
    int k_group_swz = k_group_orig ^ (n & SwzB::K_GROUPS_MASK);
    int k_swz = k_group_swz * 8 + k_local;
    
    b0[kk] = *reinterpret_cast<const _Float16*>(&B_lds[curr][SwzB::to_flat(n, k_swz)]);
}

Fix 2: Correct Epilogue Store Pattern

WMMA fragment layout stores elements in a specific pattern. Each element c_frag[i] stores to row i*2 + (lane/16), column lane%16:

// INCORRECT: Wrong fragment layout assumption
int frag_row = lane % 16;
int frag_col_off = (lane / 16) * 8;
for (int e = 0; e < 8; e++) {
    int local_c = frag_col_off + e;
    C[gr0 * N + gc0] = c00[e];  // WRONG!
}

// CORRECT: Proper WMMA fragment layout
int frag_col = lane % 16;           // Column is fixed per lane
int frag_row_offset = lane / 16;    // 0 for lanes 0-15, 1 for lanes 16-31

for (int i = 0; i < 8; i++) {
    int frag_row = i * 2 + frag_row_offset;  // Rows: 0,2,4,...,14 or 1,3,5,...,15
    
    int gr0 = block_m + warp_m_base + frag_row;
    int gc0 = block_n + warp_n_base + frag_col;
    
    if (gr0 < M && gc0 < N) C[gr0 * N + gc0] = c00[i];
    
    // For 2×2 register blocking, handle all 4 tiles:
    int gc1 = gc0 + 16;  // Tile [0][1]
    if (gr0 < M && gc1 < N) C[gr0 * N + gc1] = c01[i];
    
    int gr1 = gr0 + 16;  // Tile [1][0]
    if (gr1 < M && gc0 < N) C[gr1 * N + gc0] = c10[i];
    
    if (gr1 < M && gc1 < N) C[gr1 * N + gc1] = c11[i];  // Tile [1][1]
}

Complete Helper Function Example:

template<typename SwzA>
__device__ __forceinline__ void load_a_frag_swizzled(
    half16& a_frag,
    const __half* lds_base,
    int warp_m_base,
    int frag_col_orig
) {
    #pragma unroll
    for (int r = 0; r < 16; r++) {
        int row = warp_m_base + r;
        
        // Invert XOR swizzle
        int k_group_orig = frag_col_orig / 8;
        int k_local = frag_col_orig % 8;
        int k_group_swz = k_group_orig ^ (row & SwzA::K_GROUPS_MASK);
        int frag_col_swz = k_group_swz * 8 + k_local;
        
        a_frag[r] = *reinterpret_cast<const _Float16*>(&lds_base[SwzA::to_flat(row, frag_col_swz)]);
    }
}

Alternative Approach: If the swizzle unswizzling is too complex or has performance overhead, consider:

  1. Store data swizzled (for bank conflict avoidance during global→LDS load)
  2. Unswizzle during LDS→Fragment load into a temporary buffer
  3. Load fragments from unswizzled buffer

This adds an extra LDS copy step but simplifies the fragment loading code.

Testing Recommendations:

  1. Start with correctness: Test with small matrices (128×128) before performance
  2. Compare against reference: Use non-swizzled kernel as reference
  3. Verify swizzle math: Test XOR swizzle inversion logic separately
  4. Check fragment layout: Verify fragment loading matches expected WMMA layout (see docs/wmma_fragment_layout_rdna3.md)
  5. Profile bank conflicts: Use rocprof to verify XOR swizzle actually reduces conflicts

L2 Cache Tile Rasterization

Simple row-major launch can cause L2 thrashing for large matrices. Column-major or chunked tile processing improves L2 cache locality.

Expected Impact: 5-15% improvement for large matrices (4096×4096+)

Status: Implemented in wmma_optimizations.hpp but not yet integrated into main kernels.

Split-K for Skinny Matrices

Split-K assigns partial K-dimension slices to different work-groups, improving utilization for skinny matrices (small M or N, large K).

Example: M=16, N=4096, K=4096

  • Without Split-K: Only 64 tiles → 60% utilization
  • With Split-K factor 4: 256 tiles → 90% utilization

Status: Implemented in wmma_optimizations.hpp but not yet benchmarked.

Register Pressure Management

Current Register Usage: Estimated ~91-92 VGPRs per thread (needs verification via roc-obj-utils or rocprof)

To verify actual VGPR usage:

# Extract from compiled kernel
roc-obj-utils --disassemble kernel.hsaco | grep -A 20 "COMPUTE_PGM_RSRC"
# Look for .vgprsnum value

Compiler Hints:

__launch_bounds__(256, 2)
__attribute__((amdgpu_waves_per_eu(4, 8)))
// Note: amdgpu_num_vgpr does NOT work with templates in HIP
// Only amdgpu_waves_per_eu is supported with templates

Important: The amdgpu_num_vgpr attribute is not supported with template kernels in HIP. Use amdgpu_waves_per_eu to hint occupancy instead.

SMEM-to-Register Double Buffering

Current kernel does Global→LDS double buffering but not LDS→Register double buffering. The problem is LDS loads and WMMA compute are serialized.

Potential Improvement: 5-10% by overlapping LDS loads with computation.

Trade-off: Doubles register usage for fragments (from ~128 to ~256 VGPRs), which may reduce occupancy.


Profiling and Analysis

Performance Characteristics

K Dimension Regime Limiting Factor
K < 256 Memory-bound LPDDR5X bandwidth
K ≥ 512 Compute-bound WMMA throughput

Bottleneck Analysis (Compute-Bound Regime)

Bottleneck Contribution Notes
__syncthreads overhead ~20% 128 barriers for K=2048
Occupancy (5 waves/CU) ~15% 91 VGPRs limits to 5 waves
LDS bank conflicts (B scatter) ~10% Transpose pattern causes conflicts
WMMA pipeline bubbles ~10% Data dependencies within wave

Roofline Analysis

Metric Value
Peak Compute 59.4 TFLOPS
Peak Memory BW 256 GB/s
Ridge Point ~106 ops/byte
Our Intensity (K=2048) ~682 ops/byte
Regime Compute-bound ✅

ISA Analysis

Generated Assembly Inspection

Component Count Status
WMMA Instructions 4 ✅ Correct (2×2 blocking)
Global Loads (global_load_b128) 4 ✅ Vectorized
LDS Stores (ds_store_b128) 8 ✅ Vectorized
Dual-Issue (v_dual_*) 6-21 ✅ Active
Barriers (s_waitcnt) 10 ⚠️ Necessary overhead

WMMA Instruction Pattern

s_waitcnt lgkmcnt(0)                                          ; Wait for LDS
v_wmma_f32_16x16x16_f16 v[1:8], v[57:64], v[49:56], v[1:8]    ; MMA 0
v_wmma_f32_16x16x16_f16 v[9:16], v[57:64], v[41:48], v[9:16]  ; MMA 1
v_wmma_f32_16x16x16_f16 v[17:24], v[33:40], v[49:56], v[17:24] ; MMA 2
v_wmma_f32_16x16x16_f16 v[25:32], v[33:40], v[41:48], v[25:32] ; MMA 3

Key Finding: Compiler groups WMMAs together intentionally. Attempts to interleave with inline assembly did not improve performance — the hardware scheduler handles dual-issue at runtime.

Profiling Commands

Profile LDS bank conflicts:

rocprof --stats -o profile.csv -i metrics.txt ./your_kernel

# metrics.txt should include:
# LDSBankConflict
# LDSInstructions
# LDSBankConflictCycles

# View results
cat profile.csv | grep -E "LDSBankConflict|LDSInstructions"

Interpretation:

  • LDSBankConflict / LDSInstructions = conflict rate (target: < 5%)
  • High conflict rate indicates need for swizzling/padding

Remaining Gap to rocBLAS

Our kernel achieves 53% of rocBLAS performance (~21.6 vs ~41 TFLOPS). The gap is due to:

Factor Description
Hand-tuned Assembly rocBLAS uses offline-optimized ISA with perfect scheduling
Adaptive Tile Selection rocBLAS selects optimal tile per matrix dimension
Async Copy Hardware Uses hardware async LDS loads (not exposed in HIP)
Register Allocation Compiler-level register allocation vs manual tuning
Multi-kernel Fusion rocBLAS fuses alpha/beta scaling

Architecture-Specific Optimizations Explored

We attempted several gfx1151-specific optimizations without portability constraints:

Optimization Result Notes
Async global-to-LDS ❌ Not available __builtin_amdgcn_global_load_lds not exposed in ROCm 7.x for gfx1151
Hardware prefetch ❌ Not available s_prefetch_data instruction not supported on RDNA3.5
Explicit waitcnt ⚖️ No improvement lds_fence(), vmem_fence() via inline asm perform same as compiler-managed
Interleaved prefetch ⚖️ No improvement WMMA ops (128 cycles) complete before global loads (400-800 cycles)
Software prefetch ⚖️ Minimal impact __builtin_prefetch adds ~1% improvement

Definitive Finding: No Async LDS Intrinsics for gfx1151

Based on extensive research of AMD ROCm docs, LLVM AMDGPU backend, and GPUOpen resources (as of late 2025):

  • No __hip_ds_copy_async or similar hardware intrinsics exist for gfx1151
  • llvm.amdgcn.load.to.lds exists but is synchronous (lowers to global_load_lds + s_waitcnt)
  • Async behavior must be achieved through:
    • HIP runtime APIs (hipMemcpyAsync with streams) - host-device only
    • Manual overlap with s_waitcnt vmcnt(x) to allow in-flight loads - already implemented
    • Queue-level async (separate compute/copy queues) - not applicable for LDS

The only ISA pattern available for global-to-LDS:

buffer_load_dword v1, v0, s[sgpr0:sgpr3], ...  ; Global load
s_waitcnt vmcnt(0)                              ; Wait for load
ds_write_b32 v2, v1                             ; Write to LDS
s_waitcnt lgkmcnt(0)                            ; Wait for LDS visibility

GFX12 has s_wait_dscnt but still no async. ASYNC LDS and tensor ops are not covered by the memory model implemented by the AMDGPU backend; waits aren't inserted automatically and must be emitted explicitly.

What Would Close the Gap

  1. AMD exposing true async LDS intrinsics - requires hardware/firmware changes, not just software
  2. Write in AMDGCN assembly directly with perfect scheduling (impractical, loses compiler optimizations)
  3. Use rocBLAS/hipBLASLt for production (recommended - these use internal AMD optimizations)

Expected Final Performance

The theoretical maximum without vendor-level assembly optimization is approximately 55-60% of peak (~33-36 TFLOPS). Current optimizations in progress (XOR swizzle, L2 rasterization) could potentially reach 40-45% of peak (~24-27 TFLOPS).


Optimization Journey

Evolution of Performance

Version TFLOPS Change Key Optimization
Baseline 5.4 Basic WMMA implementation
+ Multi-column blocks 6.2 +15% BLOCK_N=64 for B reuse
+ 2×2 Register Blocking 9.7 +56% 4 accumulators per warp
+ Vectorized Loads 11.5 +19% half8 global loads
+ Double Buffering 13.2 +15% Ping-pong LDS
+ 128×64 Tile Shape 16.5 +25% Increased A reuse
+ GMEM Spreading 17.4 +5% Register prefetch
+ LDS Padding 20.9 +20% Bank conflict elimination

Lessons Learned

  1. Tile shape matters more than expected: 128×64 (tall) significantly outperforms 64×64 (square) due to A-matrix reuse
  2. Prefetching > Occupancy: For compute-bound workloads, hiding latency via prefetch beats maximizing waves/CU
  3. Vectorized access is critical: LPDDR5X heavily penalizes scalar loads
  4. Compiler scheduling is smart: Inline assembly attempts didn't improve on LLVM's scheduling
  5. LDS transpose is required: col_major B fragments need transposed data
  6. Odd strides hurt more than help: Bank conflict avoidance via odd strides forces scalar access

Optimization Attempts (December 2025)

Several optimizations from the development notes were attempted:

Optimization Result Notes
Vectorized C Writes ❌ 0.72x slower Extra 32KB LDS + sync overhead hurt more than coalescing helped
Cooperative Loading ⚖️ 1.01-1.06x Marginal gains at small K, slight regression at large K
BLOCK_K=32 (K-unroll) ❌ 0.89-0.97x Increased LDS/register pressure hurt occupancy
XOR Swizzle LDS ❌ 0.85x slower Swizzle computation overhead > bank conflict savings

Conclusion: The standard kernel is already well-optimized with double-buffered LDS, interleaved prefetch, and LDS padding. Further gains likely require split-K parallelism or assembly-level tuning.


File Structure

wmma_ops/
├── README.md                        # This file
├── setup.py                         # Build configuration
├── wmma_gemm.hip                    # Main kernel implementation & pybind
├── wmma_kernels_optimized.hpp       # Optimized kernel variants (kunroll, quad, hilbert, etc.)
├── wmma_tile_mapping.hpp            # Hilbert curve tile mapping for L2 locality
├── wmma_xor_swizzle.hpp             # XOR swizzle, rasterization, Split-K
├── wmma_tile_selection.hpp          # Adaptive tile configuration
├── wmma_device_helpers.hpp          # Fragment loading helpers
├── rocwmma_patch/
│   └── rocwmma_gfx1151.hpp          # Custom rocWMMA patch header
├── docs/
│   └── WMMA_DEVELOPMENT_NOTES.md    # Consolidated development documentation
├── examples/                        # Reference implementations from other projects
├── autotune.py                      # Optuna-based auto-tuner
├── test_rocwmma_patch.py            # Test suite
├── benchmark_summary.py             # Benchmark utilities
└── build_in_docker.sh               # Docker build script

References

Primary Resources

  1. rocWMMA Documentation (ROCm)

  2. Deep Dive into Matrix Optimization on AMD GPUs (Sébastien Vince)

    • Blog Post
    • Achievement: 49 TFLOPS on FP32 GEMM (60% faster than rocBLAS)
  3. AMD RDNA™ 3.5 ISA Reference Guide

  4. LLVM AMDGPU Usage

  5. rocBLAS Documentation (ROCm)

Implementation References


License

Based on rocWMMA library (MIT License) and llama.cpp optimizations.

About

Optimized WMMA (Wave Matrix Multiply-Accumulate) operations for AMD gfx1151 (RDNA3.5 / Strix Halo) architecture

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors