From 15fe8a59edf28dcbad655d2598421036641b723e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 3 Mar 2026 18:13:36 -0800 Subject: [PATCH 1/4] Add AMD February 2026 competition: mxfp4-mm, moe-mxfp4, mixed-mla 3 problems targeting MI355X from AMD-AIM/reference-kernels@20260209. Also fixes eval.py regex to support underscored keys and booleans. --- problems/amd_202602.yaml | 19 + problems/amd_202602/eval.py | 387 ++++++++++++++++++++ problems/amd_202602/mixed-mla/README.md | 200 ++++++++++ problems/amd_202602/mixed-mla/reference.py | 372 +++++++++++++++++++ problems/amd_202602/mixed-mla/submission.py | 185 ++++++++++ problems/amd_202602/mixed-mla/task.py | 36 ++ problems/amd_202602/mixed-mla/task.yml | 95 +++++ problems/amd_202602/moe-mxfp4/README.md | 198 ++++++++++ problems/amd_202602/moe-mxfp4/reference.py | 323 ++++++++++++++++ problems/amd_202602/moe-mxfp4/submission.py | 66 ++++ problems/amd_202602/moe-mxfp4/task.py | 28 ++ problems/amd_202602/moe-mxfp4/task.yml | 122 ++++++ problems/amd_202602/mxfp4-mm/reference.py | 108 ++++++ problems/amd_202602/mxfp4-mm/submission.py | 32 ++ problems/amd_202602/mxfp4-mm/task.py | 21 ++ problems/amd_202602/mxfp4-mm/task.yml | 62 ++++ problems/amd_202602/utils.py | 147 ++++++++ 17 files changed, 2401 insertions(+) create mode 100644 problems/amd_202602.yaml create mode 100644 problems/amd_202602/eval.py create mode 100644 problems/amd_202602/mixed-mla/README.md create mode 100644 problems/amd_202602/mixed-mla/reference.py create mode 100644 problems/amd_202602/mixed-mla/submission.py create mode 100644 problems/amd_202602/mixed-mla/task.py create mode 100644 problems/amd_202602/mixed-mla/task.yml create mode 100644 problems/amd_202602/moe-mxfp4/README.md create mode 100644 problems/amd_202602/moe-mxfp4/reference.py create mode 100644 problems/amd_202602/moe-mxfp4/submission.py create mode 100644 problems/amd_202602/moe-mxfp4/task.py create mode 100644 problems/amd_202602/moe-mxfp4/task.yml create mode 100644 problems/amd_202602/mxfp4-mm/reference.py create mode 100644 problems/amd_202602/mxfp4-mm/submission.py create mode 100644 problems/amd_202602/mxfp4-mm/task.py create mode 100644 problems/amd_202602/mxfp4-mm/task.yml create mode 100644 problems/amd_202602/utils.py diff --git a/problems/amd_202602.yaml b/problems/amd_202602.yaml new file mode 100644 index 00000000..a75894f1 --- /dev/null +++ b/problems/amd_202602.yaml @@ -0,0 +1,19 @@ +name: AMD Developer Challenge February 2026 +deadline: "2026-03-15 06:00" +description: "AMD Developer Challenge: MXFP4 matrix multiplication, Mixture-of-Experts, and Multi-head Latent Attention optimized for MI355X." +problems: + - directory: amd_202602/mxfp4-mm + name: amd-mxfp4-mm + deadline: "2026-03-15 06:00" + gpus: + - MI355X + - directory: amd_202602/moe-mxfp4 + name: amd-moe-mxfp4 + deadline: "2026-03-15 06:00" + gpus: + - MI355X + - directory: amd_202602/mixed-mla + name: amd-mixed-mla + deadline: "2026-03-15 06:00" + gpus: + - MI355X diff --git a/problems/amd_202602/eval.py b/problems/amd_202602/eval.py new file mode 100644 index 00000000..2df7ef58 --- /dev/null +++ b/problems/amd_202602/eval.py @@ -0,0 +1,387 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional + +import torch.cuda + +from utils import set_seed, clear_l2_cache_large as clear_l2_cache +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, 'w') + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a+b)*(a+b+1)//2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + if val == "true": + val = True + elif val == "false": + val = False + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg)**2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), + worst=float(worst)) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def wrap_check_implementation(data, submission_output): + # Old version returned just a single string, new version + # returns (bool, str); this function ensures compatibility with old + # problem definitions. + result = check_implementation(data, submission_output) + if isinstance(result, tuple): + return result + else: + return not bool(result), result + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + data = generate_input(**test.args) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + output = custom_kernel(data) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + clear_l2_cache() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(start_event.elapsed_time(end_event) * 1e6) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: + break + + return calculate_stats(durations) + + +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 1000, 50e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/amd_202602/mixed-mla/README.md b/problems/amd_202602/mixed-mla/README.md new file mode 100644 index 00000000..24cae9ce --- /dev/null +++ b/problems/amd_202602/mixed-mla/README.md @@ -0,0 +1,200 @@ +# MLA (Multi-head Latent Attention) Decode Kernel + +## Description + +Implement a custom MLA attention decode kernel optimized for MI355X. + +This is the **inner attention kernel** from DeepSeek R1's `forward_absorb` MLA path. +The absorbed query and compressed KV cache are provided directly — you implement the +attention computation with variable-length batching. + +The reference uses **aiter MLA a8w8 decode kernel** (`mla_decode_fwd`, fp8 Q + fp8 KV, +persistent mode). On MI355X, a8w8 is ~2-3x faster than bf16 with negligible accuracy loss. +The reference quantizes Q to fp8 on-the-fly and uses pre-quantized fp8 KV from `kv_data["fp8"]`. + +## DeepSeek R1 Forward-Absorb MLA Config + +| Parameter | Value | Notes | +|---|---|---| +| num_heads | 16 | Query heads (after TP split) | +| num_kv_heads | 1 | Single shared latent KV head | +| kv_lora_rank | 512 | Latent dimension | +| qk_rope_head_dim | 64 | RoPE embedding dimension | +| qk_head_dim | 576 | kv_lora_rank + qk_rope_head_dim (absorbed q/k dim) | +| v_head_dim | 512 | = kv_lora_rank (output dim) | +| sm_scale | 1/sqrt(576) | | +| q dtype | bfloat16 | Input always bf16; reference quantizes to fp8 on-the-fly | +| kv dtype | bf16 / fp8 / mxfp4 | All three provided simultaneously | +| mode | decode | q_seq_len=1, kv_seq_len up to 8k | + +## Reference Kernel + +The reference (`ref_kernel`) is configurable via two globals in `reference.py`: + +| `Q_DTYPE` | `KV_DTYPE` | Aiter kernel dispatched | Description | +|---|---|---|---| +| `"fp8"` (default) | `"fp8"` (default) | `mla_a8w8_qh16_qseqlen1_gqaratio16_ps` | fp8 Q + fp8 KV — fastest | +| `"bf16"` | `"fp8"` | `mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + fp8 KV | +| `"bf16"` | `"bf16"` | `mla_a16w16_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + bf16 KV — highest precision | + +**Note**: `Q_DTYPE="fp8"` + `KV_DTYPE="bf16"` is not a valid combination (no a8w16 kernel exists). + +### Reference Latency (MI355X) + +| Case | a8w8 (us) | a16w16 (us) | a8w8 speedup | +|---|---|---|---| +| bs=4, kv=1k | ~118 | ~162 | 1.4x | +| bs=4, kv=8k | ~113 | ~177 | 1.6x | +| bs=64, kv=8k | ~171 | ~353 | 2.1x | +| bs=256, kv=8k | ~349 | ~814 | 2.3x | + +## KV Buffer Format (forward_absorb) + +The compressed KV buffer has `qk_head_dim=576` dimensions: +- **Full 576 dims** are used as **keys** (for Q@K^T score computation) +- **First 512 dims** (kv_lora_rank) are used as **values** (for output computation) + +## KV Cache Quantization + +| dtype | kv_buffer | kv_scale | Quantization | Bandwidth | +|---|---|---|---|---| +| bf16 | bfloat16 `(total_kv, 1, 576)` | None | No quantization | 1x | +| fp8 | fp8 `(total_kv, 1, 576)` | scalar float32 | Dynamic per-tensor (sglang `scaled_fp8_quant`) | 2x savings | +| mxfp4 | fp4x2 `(total_kv, 1, 288)` | fp8_e8m0 `(total_kv, N_blocks)` | Block-32 MXFP4 (aiter `dynamic_mxfp4_quant`) | 4x savings | + +### FP8 quantization (sglang `scaled_fp8_quant`) + +- **Granularity**: per-tensor +- **Scale**: `kv_scale = max(abs(kv_bf16)) / fp8_max` +- **Quantize**: `kv_fp8 = (kv_bf16 / kv_scale).clamp(...).to(fp8)` +- **Dequantize**: `kv_bf16 ≈ kv_fp8.to(bf16) * kv_scale` +- **kv_scale**: scalar float32 tensor + +### MXFP4 quantization (aiter `dynamic_mxfp4_quant`) + +- **Granularity**: per-block of 32 elements +- **FP4 format**: E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 +- **Scale format**: E8M0 — exponent-only scale stored in `aiter.dtypes.fp8_e8m0` +- **Packing**: 2 FP4 values packed per byte (low nibble = even index, high nibble = odd index) +- **kv_buffer**: `(total_kv, 1, 288)` in `aiter.dtypes.fp4x2` — packed FP4 data +- **kv_scale**: `(total_kv, N_blocks)` in `aiter.dtypes.fp8_e8m0` — per-block E8M0 scale factors +- **Dequantize**: `aiter.utility.fp4_utils.mxfp4_to_f32` + `e8m0_to_f32` for block-wise scaling + +### aiter dtype reference + +| Logical type | aiter dtype | PyTorch native (if available) | Fallback | +|---|---|---|---| +| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | +| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | +| fp8 | `aiter.dtypes.fp8` | `torch.float8_e4m3fnuz` (gfx942) / `torch.float8_e4m3fn` (gfx950+) | `torch.uint8` | + +## Input + +A tuple `(q, kv_data, qo_indptr, kv_indptr, config)`: + +``` +q: (total_q, 16, 576) bfloat16 — absorbed queries +kv_data: dict with three KV cache formats (see below) +qo_indptr: (batch_size + 1,) int32 — query segment pointers +kv_indptr: (batch_size + 1,) int32 — KV segment pointers +config: dict — MLA parameters +``` + +### kv_data dict + +All three KV cache formats are provided simultaneously. Each entry is either a +`Tensor` (bf16) or a `(Tensor, Tensor)` tuple (quantized buffer + scale): + +```python +kv_data = { + "bf16": kv_buffer_bf16, # Tensor (total_kv, 1, 576) bfloat16 + "fp8": (kv_buffer_fp8, kv_scale_fp8), # (fp8 Tensor, scalar float32) + "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), # (fp4x2 Tensor, fp8_e8m0 Tensor) +} +``` + +### config dict + +```python +config = { + "batch_size": int, + "num_heads": 16, + "num_kv_heads": 1, + "qk_head_dim": 576, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64, + "v_head_dim": 512, + "q_seq_len": 1, + "kv_seq_len": int, # varies per test case (1024 or 8192) + "sm_scale": 0.04166..., # 1/sqrt(576) +} +``` + +## Output + +``` +attention_output: (total_q, 16, 512) bfloat16 +``` + +## Optimization Opportunities + +The reference is already a highly optimized aiter a8w8 persistent kernel. To beat it, consider: + +1. **MXFP4 KV cache**: 4x bandwidth savings over bf16, 2x over fp8. Two strategies: + + **Strategy A — Fuse dequantization with attention (keep Q in bf16/fp8):** + Load quantized KV tiles from HBM, dequantize in registers/LDS to bf16, and + immediately compute QK^T and softmax·V — never writing the decompressed KV back + to HBM. This eliminates the extra read/write of the bf16 intermediate buffer, + roughly quartering the memory traffic for mxfp4 compared to the naive + dequant-then-attend approach. + + **Strategy B — Quantize Q to match KV precision (full low-precision compute):** + Dynamically quantize Q from bf16 → mxfp4 (per-block scaling), then compute QK^T + entirely in fp4×fp4 using MFMA instructions on MI355X. The softmax is still done + in fp32 for numerical stability, and V accumulation uses fp4×fp4 → fp32. This + trades a small amount of accuracy for significantly higher throughput on the + matrix units. + +2. **Custom split-K / split-batch scheduling**: the aiter kernel uses 32-way KV splits + with reduce; a different split strategy or tile size may be more efficient for certain + batch/seq_len combinations. + +3. **MQA pattern**: 1 KV head shared across 16 query heads — minimize redundant KV loads + by loading KV once and broadcasting across all query heads in shared memory/LDS. + +4. **Variable-length batching**: indptr-based segmented attention across batch elements. + +5. **Split K/V from buffer**: full 576 dims for keys, first 512 for values — potential + for separate tiling strategies for the score and output stages. + +## Accuracy + +Submissions are checked against the a8w8 reference with `rtol=2e-02, atol=8e-03`. + +Measured accuracy of different approaches vs bf16 torch ground truth: + +| Approach | max abs diff | Notes | +|---|---|---| +| aiter a8w8 (reference) | 2.6e-05 — 8.0e-05 | fp8 quantization + kernel accumulation | +| torch fp8 (scaled_mm) | 2e-06 — 1.5e-05 | Closest to bf16 | +| torch mxfp4 | 2.1e-04 — 8.3e-04 | 4-bit quantization noise | + +All approaches are well within the tolerance. + +## Benchmark Cases + +All three KV formats (bf16, fp8, mxfp4) are provided in every test case. + +| batch_size | q_seq_len | kv_seq_len | +|---|---|---| +| 4 | 1 | 1024 | +| 4 | 1 | 8192 | +| 32 | 1 | 1024 | +| 32 | 1 | 8192 | +| 64 | 1 | 1024 | +| 64 | 1 | 8192 | +| 256 | 1 | 1024 | +| 256 | 1 | 8192 | + +Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/mixed-mla/reference.py b/problems/amd_202602/mixed-mla/reference.py new file mode 100644 index 00000000..b0dee591 --- /dev/null +++ b/problems/amd_202602/mixed-mla/reference.py @@ -0,0 +1,372 @@ +""" +Reference implementation for MLA (Multi-head Latent Attention) decode kernel. + +Uses aiter MLA kernels (mla_decode_fwd) as the reference. +DeepSeek R1 forward_absorb MLA: absorbed q (576), compressed kv_buffer (576), +output v_head_dim = kv_lora_rank = 512. + +The input provides: + q: (total_q, 16, 576) bfloat16 — absorbed query + kv_data: dict with KV cache in three formats: + "bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision + "fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized + "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 + fp8_e8m0 — block-32 quantized + The reference quantizes Q to fp8 on-the-fly inside ref_kernel. + +The reference kernel quantizes Q to fp8 on-the-fly and uses fp8 KV (a8w8 kernel), +which is ~2-3x faster than bf16 on MI355X with negligible accuracy loss. + +Decode only — persistent mode with get_mla_metadata_v1. +""" + +import torch +import torch.nn.functional as F +from task import input_t, output_t +from utils import make_match_reference + +from aiter.mla import mla_decode_fwd +from aiter import dtypes as aiter_dtypes +from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 +from aiter.utility.fp4_utils import ( + dynamic_mxfp4_quant, + mxfp4_to_f32, + e8m0_to_f32, +) + +# --------------------------------------------------------------------------- +# DeepSeek R1 latent MQA constants (forward_absorb path) +# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json +# --------------------------------------------------------------------------- +NUM_HEADS = 16 +NUM_KV_HEADS = 1 +KV_LORA_RANK = 512 +QK_ROPE_HEAD_DIM = 64 +QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # 576 +V_HEAD_DIM = KV_LORA_RANK # 512 +SM_SCALE = 1.0 / (QK_HEAD_DIM ** 0.5) + +PAGE_SIZE = 1 +NUM_KV_SPLITS = 32 + +# FP8 dtype (platform-specific via aiter) +FP8_DTYPE = aiter_dtypes.fp8 + +# Query dtype for the reference kernel: "fp8" or "bf16" +Q_DTYPE = "fp8" + +# KV cache dtype for the reference kernel: "fp8" or "bf16" +KV_DTYPE = "fp8" + + +# --------------------------------------------------------------------------- +# FP8 quantization (sglang style: dynamic per-tensor) +# --------------------------------------------------------------------------- + +def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dynamic per-tensor FP8 quantization (following sglang scaled_fp8_quant). + + Args: + tensor: bf16 tensor to quantize + + Returns: + (fp8_tensor, scale) where scale is a scalar float32 tensor. + Dequantize: fp8_tensor.to(bf16) * scale + """ + finfo = torch.finfo(FP8_DTYPE) + amax = tensor.abs().amax().clamp(min=1e-12) + scale = amax / finfo.max + fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) + return fp8_tensor, scale.to(torch.float32).reshape(1) + + +# --------------------------------------------------------------------------- +# MXFP4 quantization (aiter native: block-32, fp4x2 + fp8_e8m0 dtypes) +# Uses aiter.utility.fp4_utils.dynamic_mxfp4_quant +# --------------------------------------------------------------------------- + +def quantize_mxfp4(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + MXFP4 block-wise quantization using aiter's dynamic_mxfp4_quant. + + Block size = 32. Each block gets an E8M0 scale factor. + Two FP4 E2M1 values are packed per byte. + + Args: + tensor: bf16 tensor of shape [B, M, N] (N must be divisible by 32) + + Returns: + (fp4_data, scale_e8m0) + - fp4_data: shape [B, M, N//2] in aiter_dtypes.fp4x2 + - scale_e8m0: shape [B*M, ceil(N/32)] padded, in aiter_dtypes.fp8_e8m0 + """ + orig_shape = tensor.shape # (B, M, N) + B, M, N = orig_shape + + # dynamic_mxfp4_quant expects 2D: (B*M, N) + tensor_2d = tensor.reshape(B * M, N) + fp4_data_2d, scale_e8m0 = dynamic_mxfp4_quant(tensor_2d) + + # Reshape fp4_data back to 3D: (B, M, N//2) + fp4_data = fp4_data_2d.view(B, M, N // 2) + + return fp4_data, scale_e8m0 + + +def dequantize_mxfp4( + fp4_data: torch.Tensor, + scale_e8m0: torch.Tensor, + orig_shape: tuple, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """ + Dequantize MXFP4 tensor using aiter utilities. + + Note: dynamic_mxfp4_quant may pad both row and block dimensions in scale_e8m0. + We trim scales to match the actual data dimensions. + + Args: + fp4_data: packed FP4 data, shape [B, M, N//2] in fp4x2 or uint8 + scale_e8m0: E8M0 block scale factors (possibly padded) in fp8_e8m0 + orig_shape: original (B, M, N) for reshaping + dtype: output dtype + + Returns: + Dequantized tensor of shape orig_shape. + """ + B, M, N = orig_shape + num_rows = B * M + block_size = 32 + num_blocks = N // block_size # actual blocks needed (e.g. 576/32 = 18) + + # Unpack FP4 to float32: mxfp4_to_f32 expects (..., N//2) -> (..., N) + fp4_data_2d = fp4_data.reshape(num_rows, N // 2) + float_vals = mxfp4_to_f32(fp4_data_2d) # (num_rows, N) + + # Convert E8M0 scales to float32 and trim padded dimensions + scale_f32 = e8m0_to_f32(scale_e8m0) # (padded_rows, padded_blocks) + scale_f32 = scale_f32[:num_rows, :num_blocks] # (num_rows, num_blocks) + + # Apply block scales + float_vals_blocked = float_vals.view(num_rows, num_blocks, block_size) + scaled = float_vals_blocked * scale_f32.unsqueeze(-1) + + return scaled.view(B, M, N).to(dtype) + + +# --------------------------------------------------------------------------- +# Persistent mode metadata helpers +# --------------------------------------------------------------------------- + +def _make_mla_decode_metadata( + batch_size: int, + max_q_len: int, + nhead: int, + nhead_kv: int, + q_dtype: torch.dtype, + kv_dtype: torch.dtype, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + num_kv_splits: int = NUM_KV_SPLITS, +): + """Allocate and populate work buffers for persistent mla_decode_fwd.""" + info = get_mla_metadata_info_v1( + batch_size, max_q_len, nhead, q_dtype, kv_dtype, + is_sparse=False, fast_mode=False, + num_kv_splits=num_kv_splits, intra_batch_mode=True, + ) + work = [torch.empty(s, dtype=t, device="cuda") for s, t in info] + (work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map) = work + + # Populate the metadata buffers + get_mla_metadata_v1( + qo_indptr, kv_indptr, kv_last_page_len, + nhead // nhead_kv, # num_heads_per_head_k + nhead_kv, # num_heads_k + True, # is_causal + work_metadata, work_info_set, work_indptr, + reduce_indptr, reduce_final_map, reduce_partial_map, + page_size=PAGE_SIZE, + kv_granularity=max(PAGE_SIZE, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intra_batch_mode=True, + dtype_q=q_dtype, + dtype_kv=kv_dtype, + ) + + return { + "work_meta_data": work_metadata, + "work_indptr": work_indptr, + "work_info_set": work_info_set, + "reduce_indptr": reduce_indptr, + "reduce_final_map": reduce_final_map, + "reduce_partial_map": reduce_partial_map, + } + + +# --------------------------------------------------------------------------- +# Aiter reference kernel (decode only) +# --------------------------------------------------------------------------- + +def _aiter_mla_decode( + q: torch.Tensor, + kv_buffer: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + config: dict, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, +) -> torch.Tensor: + """ + MLA decode attention using aiter persistent-mode kernel. + + Supports multiple Q/KV dtype combinations: + - Q_DTYPE="fp8": fp8 Q + fp8 KV (a8w8) — fastest on MI355X + - Q_DTYPE="bf16": bf16 Q + bf16 KV (a16w16) — highest precision + + q: (total_q, num_heads, 576) fp8 or bf16 + kv_buffer: (total_kv, 1, 576) fp8 or bf16 + q_scale: scalar float32 (required for fp8 Q, None for bf16) + kv_scale: scalar float32 (required for fp8 KV, None for bf16) + """ + batch_size = config["batch_size"] + nq = config["num_heads"] + nkv = config["num_kv_heads"] + dq = config["qk_head_dim"] + dv = config["v_head_dim"] + q_seq_len = config["q_seq_len"] + + total_kv_len = int(kv_indptr[-1].item()) + kv_indices = torch.arange(total_kv_len, dtype=torch.int32, device="cuda") + + # Reshape kv_buffer to 4D for aiter: (total_kv, page_size, nhead_kv, dim) + kv_buffer_4d = kv_buffer.view(kv_buffer.shape[0], PAGE_SIZE, nkv, kv_buffer.shape[-1]) + + max_q_len = q_seq_len + kv_last_page_len = (kv_indptr[1:] - kv_indptr[:-1]).to(torch.int32) + + # Build persistent-mode metadata + meta = _make_mla_decode_metadata( + batch_size, max_q_len, nq, nkv, + q.dtype, kv_buffer.dtype, + qo_indptr, kv_indptr, kv_last_page_len, + num_kv_splits=NUM_KV_SPLITS, + ) + + o = torch.empty((q.shape[0], nq, dv), dtype=torch.bfloat16, device="cuda") + mla_decode_fwd( + q.view(-1, nq, dq), + kv_buffer_4d, + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + max_q_len, + page_size=PAGE_SIZE, + nhead_kv=nkv, + sm_scale=SM_SCALE, + logit_cap=0.0, + num_kv_splits=NUM_KV_SPLITS, + q_scale=q_scale, + kv_scale=kv_scale, + intra_batch_mode=True, + **meta, + ) + return o + + +# --------------------------------------------------------------------------- +# generate_input / ref_kernel / check_implementation +# --------------------------------------------------------------------------- + +def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> input_t: + """ + Generate absorbed q and compressed kv_buffer for MLA decode. + + Returns all three KV cache formats in kv_data dict: + kv_data = { + "bf16": Tensor — (total_kv, 1, 576) bfloat16 + "fp8": (Tensor, Tensor) — kv_buffer fp8 + scalar scale + "mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale + } + """ + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + + total_q = batchsize * qseqlen + total_kv = batchsize * kvseqlen + + # Absorbed query: (total_q, num_heads, 576) bf16 + q = torch.randn( + (total_q, NUM_HEADS, QK_HEAD_DIM), + dtype=torch.bfloat16, device="cuda", generator=gen, + ) * 0.02 + + # Compressed KV buffer: (total_kv, 1, 576) bf16 — the source of truth + kv_buffer_bf16 = torch.randn( + (total_kv, NUM_KV_HEADS, QK_HEAD_DIM), + dtype=torch.bfloat16, device="cuda", generator=gen, + ) * 0.02 + + # Quantize KV to fp8 + kv_buffer_fp8, kv_scale_fp8 = quantize_fp8(kv_buffer_bf16) + + # Quantize KV to mxfp4 + kv_buffer_mxfp4, kv_scale_mxfp4 = quantize_mxfp4(kv_buffer_bf16) + + # All three KV formats: bf16 is a Tensor, fp8/mxfp4 are (Tensor, Tensor) tuples + kv_data = { + "bf16": kv_buffer_bf16, + "fp8": (kv_buffer_fp8, kv_scale_fp8), + "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), + } + + qo_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * qseqlen + kv_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * kvseqlen + + config = { + "batch_size": batchsize, + "num_heads": NUM_HEADS, + "num_kv_heads": NUM_KV_HEADS, + "qk_head_dim": QK_HEAD_DIM, + "kv_lora_rank": KV_LORA_RANK, + "qk_rope_head_dim": QK_ROPE_HEAD_DIM, + "v_head_dim": V_HEAD_DIM, + "q_seq_len": qseqlen, + "kv_seq_len": kvseqlen, + "sm_scale": SM_SCALE, + } + + return (q, kv_data, qo_indptr, kv_indptr, config) + + +def ref_kernel(data: input_t) -> output_t: + """Reference MLA decode attention. Uses Q_DTYPE and KV_DTYPE to select kernel variant.""" + q, kv_data, qo_indptr, kv_indptr, config = data + + # Resolve Q + if Q_DTYPE == "fp8": + q_input, q_scale = quantize_fp8(q) + else: + q_input, q_scale = q, None + + # Resolve KV + if KV_DTYPE == "fp8": + kv_buffer_fp8, kv_scale = kv_data["fp8"] + kv_input = kv_buffer_fp8 + else: + kv_input, kv_scale = kv_data["bf16"], None + + return _aiter_mla_decode( + q_input, kv_input, qo_indptr, kv_indptr, config, + q_scale=q_scale, kv_scale=kv_scale, + ) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) diff --git a/problems/amd_202602/mixed-mla/submission.py b/problems/amd_202602/mixed-mla/submission.py new file mode 100644 index 00000000..b6ff1309 --- /dev/null +++ b/problems/amd_202602/mixed-mla/submission.py @@ -0,0 +1,185 @@ +""" +MLA (Multi-head Latent Attention) decode kernel — submission template. + +Implement custom_kernel() to beat the aiter a8w8 reference (fp8 Q + fp8 KV). + +DeepSeek R1 forward_absorb MLA config: + num_heads = 16 (query heads, after TP split) + num_kv_heads = 1 (shared latent KV head) + kv_lora_rank = 512 (latent dim) + qk_rope_head_dim = 64 (RoPE dim) + qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim) + v_head_dim = 512 (= kv_lora_rank, output dim) + sm_scale = 1/sqrt(576) + +KV buffer format (forward_absorb): + - Full 576 dims used as keys (for Q@K^T score computation) + - First 512 dims (kv_lora_rank) used as values (for output computation) + +Input tuple: + q: (total_q, 16, 576) bfloat16 — absorbed query + kv_data: dict with three KV cache formats: + kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 + kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 (total_kv,1,576) + scalar scale + kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 (total_kv,1,288) + fp8_e8m0 scale + qo_indptr: (batch_size + 1,) int32 — query segment pointers + kv_indptr: (batch_size + 1,) int32 — KV segment pointers + config: dict with MLA parameters + +Output: + attention output: (total_q, 16, 512) bfloat16 + +The reference uses aiter's a8w8 persistent MLA kernel (fp8 Q + fp8 KV), +which is ~2-3x faster than bf16. To beat it, consider: + 1. Use mxfp4 KV cache for even lower memory bandwidth + - Fuse dequantization with attention to avoid bf16 materialization + 2. Custom kernel with tighter memory access patterns + 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads + 4. Variable-length batching: indptr-based segmented attention + 5. Split K/V from buffer: full 576 dims for keys, first 512 dims for values +""" + +import torch +import torch.nn.functional as F +from task import input_t, output_t + +from aiter import dtypes as aiter_dtypes +FP8_DTYPE = aiter_dtypes.fp8 + +# QKV dtype for custom_kernel dispatch: "bf16", "fp8", or "mxfp4" +QKV_DTYPE = "fp8" + + +# --------------------------------------------------------------------------- +# Dispatcher: select kernel based on QKV_DTYPE +# --------------------------------------------------------------------------- + +def custom_kernel(data: input_t) -> output_t: + """Dispatch to the appropriate kernel based on QKV_DTYPE.""" + if QKV_DTYPE == "fp8": + return custom_kernel_fp8(data) + elif QKV_DTYPE == "bf16": + return custom_kernel_bf16(data) + else: + raise ValueError(f"Invalid QKV_DTYPE: {QKV_DTYPE}") + +# --------------------------------------------------------------------------- +# FP8 quantization helper (per-tensor, sglang style) +# --------------------------------------------------------------------------- + +def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Dynamic per-tensor FP8 quantization. Returns (fp8_tensor, scale).""" + finfo = torch.finfo(FP8_DTYPE) + amax = tensor.abs().amax().clamp(min=1e-12) + scale = amax / finfo.max + fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) + return fp8_tensor, scale.to(torch.float32).reshape(1) + + +# --------------------------------------------------------------------------- +# Baseline: bf16 Q + bf16 KV — naive torch attention +# --------------------------------------------------------------------------- + +def custom_kernel_bf16(data: input_t) -> output_t: + q, kv_data, qo_indptr, kv_indptr, config = data + + num_heads = config["num_heads"] + kv_lora_rank = config["kv_lora_rank"] + sm_scale = config["sm_scale"] + + # This naive baseline uses bf16 KV directly. + # For better performance, use kv_data["fp8"] or kv_data["mxfp4"] + # which are (kv_buffer, kv_scale) tuples. See docstring for optimization opportunities. + kv_buffer_bf16 = kv_data["bf16"] + + batch_size = qo_indptr.shape[0] - 1 + out_list = [] + + for i in range(batch_size): + q_s, q_e = int(qo_indptr[i].item()), int(qo_indptr[i + 1].item()) + kv_s, kv_e = int(kv_indptr[i].item()), int(kv_indptr[i + 1].item()) + + qi = q[q_s:q_e] # (seq_q, nhead, 576) + kvc = kv_buffer_bf16[kv_s:kv_e, 0] # (seq_kv, 576) squeeze kv_heads dim + + # Key: full 576 dims; Value: first 512 dims (kv_lora_rank) + ki = kvc # (seq_kv, 576) — broadcast over heads + vi = kvc[:, :kv_lora_rank] # (seq_kv, 512) + + # Attention: (nhead, seq_q, 576) @ (576, seq_kv) → (nhead, seq_q, seq_kv) + qi_t = qi.float().permute(1, 0, 2) # (nhead, seq_q, 576) + scores = torch.matmul(qi_t * sm_scale, ki.float().T) # (nhead, seq_q, seq_kv) + + scores = F.softmax(scores, dim=-1) + + # Output: (nhead, seq_q, seq_kv) @ (seq_kv, 512) → (nhead, seq_q, 512) + oi = torch.matmul(scores, vi.float()) # (nhead, seq_q, 512) + oi = oi.permute(1, 0, 2) # (seq_q, nhead, 512) + out_list.append(oi.to(torch.bfloat16)) + + return torch.cat(out_list, dim=0) + + + + +# --------------------------------------------------------------------------- +# FP8 Q + FP8 KV — torch._scaled_mm based attention +# +# Quantize Q to fp8, use fp8 KV from kv_data["fp8"]. +# QK^T and softmax@V both use torch._scaled_mm for fp8×fp8 matmul. +# --------------------------------------------------------------------------- + +def custom_kernel_fp8(data: input_t) -> output_t: + q, kv_data, qo_indptr, kv_indptr, config = data + + num_heads = config["num_heads"] + kv_lora_rank = config["kv_lora_rank"] + qk_head_dim = config["qk_head_dim"] + sm_scale = config["sm_scale"] + + # FP8 KV buffer and scale + kv_buffer_fp8, kv_scale_fp8 = kv_data["fp8"] + kv_fp8_2d = kv_buffer_fp8.view(-1, qk_head_dim) # (total_kv, 576) fp8 + + # Quantize Q to fp8 + q_fp8, q_scale = quantize_fp8(q) # q_fp8: (total_q, 16, 576) fp8 + + batch_size = qo_indptr.shape[0] - 1 + out_list = [] + + scale_one = torch.ones(1, dtype=torch.float32, device="cuda") + + for i in range(batch_size): + q_s, q_e = int(qo_indptr[i].item()), int(qo_indptr[i + 1].item()) + kv_s, kv_e = int(kv_indptr[i].item()), int(kv_indptr[i + 1].item()) + seq_q = q_e - q_s + seq_kv = kv_e - kv_s + + # Q: (seq_q * nhead, 576) fp8, K: (seq_kv, 576) fp8 + qi_fp8 = q_fp8[q_s:q_e].reshape(seq_q * num_heads, qk_head_dim) # (seq_q*16, 576) + ki_fp8 = kv_fp8_2d[kv_s:kv_e] # (seq_kv, 576) + + # QK^T via _scaled_mm: (seq_q*16, 576) @ (seq_kv, 576).T -> (seq_q*16, seq_kv) + # _scaled_mm expects (M,K) @ (N,K).T where b is row-major contiguous + raw_scores = torch._scaled_mm( + qi_fp8, ki_fp8.t(), + scale_a=q_scale, scale_b=kv_scale_fp8, + out_dtype=torch.float32, + ) + # raw_scores: (seq_q*16, seq_kv) + scores = raw_scores.view(seq_q, num_heads, seq_kv).permute(1, 0, 2) # (nhead, seq_q, seq_kv) + scores = scores * sm_scale + scores = F.softmax(scores, dim=-1) + + # V: first 512 dims of KV buffer (bf16 for softmax@V since scores are float) + kv_bf16 = kv_data["bf16"] + vi = kv_bf16[kv_s:kv_e, 0, :kv_lora_rank].float() # (seq_kv, 512) + + # softmax @ V: (nhead, seq_q, seq_kv) @ (seq_kv, 512) -> (nhead, seq_q, 512) + oi = torch.matmul(scores, vi) + oi = oi.permute(1, 0, 2) # (seq_q, nhead, 512) + out_list.append(oi.to(torch.bfloat16)) + + return torch.cat(out_list, dim=0) + + diff --git a/problems/amd_202602/mixed-mla/task.py b/problems/amd_202602/mixed-mla/task.py new file mode 100644 index 00000000..7aff7b6a --- /dev/null +++ b/problems/amd_202602/mixed-mla/task.py @@ -0,0 +1,36 @@ +import torch +from typing import TypeVar, TypedDict, Union + +# DeepSeek R1 MLA forward_absorb format: +# +# Input: (q, kv_data, qo_indptr, kv_indptr, config) +# q: (total_q, num_heads, qk_head_dim) bfloat16 +# kv_data: dict with three KV cache formats: +# "bf16": Tensor (total_kv, 1, 576) bfloat16 +# "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale +# "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale +# qo_indptr: (batch_size + 1,) int32 +# kv_indptr: (batch_size + 1,) int32 +# config: dict with MLA parameters +# +# where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576 +# +# Output: attention output tensor (total_q, num_heads, v_head_dim) bfloat16 +# where v_head_dim = kv_lora_rank = 512 +# +# The kv_buffer stores the compressed KV representation: +# - Full 576 dims used as keys (for Q@K^T score computation) +# - First 512 dims (kv_lora_rank) used as values (for output computation) + +input_t = TypeVar( + "input_t", + bound=tuple[torch.Tensor, dict, torch.Tensor, torch.Tensor, dict], +) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec(TypedDict): + batchsize: int + qseqlen: int + kvseqlen: int + seed: int diff --git a/problems/amd_202602/mixed-mla/task.yml b/problems/amd_202602/mixed-mla/task.yml new file mode 100644 index 00000000..c0a5d5a6 --- /dev/null +++ b/problems/amd_202602/mixed-mla/task.yml @@ -0,0 +1,95 @@ +# name: mla-py + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement a custom MLA (Multi-head Latent Attention) decode kernel optimized for MI355X. + + This is the inner attention kernel from DeepSeek R1's forward_absorb MLA path. + The absorbed query and compressed KV cache are provided directly — you only need to + implement the **attention** computation with variable-length batching (indptr). + + The reference uses aiter a8w8 MLA decode kernel (mla_decode_fwd, fp8 Q + fp8 KV, + persistent mode), which is ~2-3x faster than bf16 on MI355X. + + DeepSeek R1 forward_absorb MLA config: + - num_heads = 16 (query heads, after TP split) + - num_kv_heads = 1 (shared latent KV head) + - kv_lora_rank = 512 + - qk_rope_head_dim = 64 + - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim) + - v_head_dim = 512 (= kv_lora_rank, output dim) + - sm_scale = 1/sqrt(576) + - dtype: q=bfloat16 + - decode only (q_seq_len=1, kv_seq_len up to 8k) + + KV buffer format (forward_absorb): + - Full 576 dims are used as keys (for Q@K^T score computation) + - First 512 dims (kv_lora_rank) are used as values (for output computation) + + Input tuple: (q, kv_data, qo_indptr, kv_indptr, config) + - q: (total_q, 16, 576) bfloat16 — absorbed query + - kv_data: dict with three KV cache formats: + kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 + kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale + kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale + - qo_indptr: (batch_size+1,) int32 — query segment pointers + - kv_indptr: (batch_size+1,) int32 — KV segment pointers + - config: dict with MLA parameters + + Return: + - attention output: (total_q, 16, 512) bfloat16 + + Key optimization opportunities: + 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16) + - Fuse dequantization with attention to skip bf16 materialization + 2. Custom kernel with tighter memory access patterns + 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads + 4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload + 5. Variable-length batching: indptr-based segmented attention + 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values + + The ranking criteria is the geometric mean of the benchmark results. + +config: + main: "eval.py" + +templates: + Python: "submission.py" + +test_timeout: 900 +benchmark_timeout: 900 +ranked_timeout: 1200 + +tests: + # bs=4 + - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220} + # bs=32 + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} + # bs=64 + - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} + # bs=256 + - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} + +benchmarks: + # bs=4 + - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217} + - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220} + # bs=32 + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415} + # bs=64 + - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357} + - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} + # bs=256 + - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823} + - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} + +ranking_by: "geom" diff --git a/problems/amd_202602/moe-mxfp4/README.md b/problems/amd_202602/moe-mxfp4/README.md new file mode 100644 index 00000000..ab664f7d --- /dev/null +++ b/problems/amd_202602/moe-mxfp4/README.md @@ -0,0 +1,198 @@ +# MXFP4 Mixture-of-Experts (MoE) Fused Kernel + +## Description + +Implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. + +The kernel fuses the complete MoE forward pass into a 2-stage pipeline: +1. **Stage 1**: MXFP4 GEMM (gate+up projection) + SwiGLU activation +2. **Stage 2**: MXFP4 GEMM (down projection) + weighted reduction across top-k experts + +The reference uses **AITER `fused_moe`** with `QuantType.per_1x32` (MXFP4 block scaling, block_size=32). + +## DeepSeek-R1 MoE Architecture + +| Parameter | Value | Notes | +|---|---|---| +| hidden_size | 7168 | Model hidden dimension | +| moe_intermediate_size | 2048 | Per-expert intermediate dimension | +| n_routed_experts | 256 | Routed experts (EP-off) or 32 (EP-on, 8-way split) | +| n_shared_experts | 1 | Always selected with weight=1.0 | +| top_k (routed) | 8 | Routed experts per token | +| total_top_k | 9 | 8 routed + 1 shared | +| MoE layers | 58 | Layers 3–60 | + +## Kernel Flow + +For each token `i` and each assigned expert `j`: + +``` +(1) Quant activations: hidden_states -> MXFP4 (aiter per-1x32 dynamic quantization) + +(2) Stage 1 GEMM + SwiGLU activation: + gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] + up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] + intermediate = SiLU(gate) * up # SwiGLU activation -> [d_expert] + (W_gate and W_up are fused as gate_up_weight: one a4w4 GEMM + fused activation) + +(3) Stage 2 GEMM: + expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] + +(4) Weighted reduction: + output_i += w_ij * expert_out # accumulate across top_k experts +``` + +All weight GEMMs are **a4w4** (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). +The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. + +## Weight Layout & Pre-shuffling + +Weights are provided in two layouts: + +| Layout | Description | Use case | +|---|---|---| +| **Raw** | Original MXFP4 quantized weights | PyTorch reference / custom kernels | +| **Pre-shuffled** | `shuffle_weight(w, layout=(16,16))` + `e8m0_shuffle(scale)` | AITER CK kernel (tile-coalesced layout) | + +The (16,16) shuffle rearranges weight tiles for coalesced memory access by CK GEMM instructions. +Scale shuffling (`e8m0_shuffle`) reorders E8M0 block scales to match the shuffled weight layout. + +You may use either layout — raw weights if you implement your own tiling, or pre-shuffled weights +for direct use with AITER/CK kernels. + +## MXFP4 Quantization Details + +| Property | Value | +|---|---| +| FP4 format | E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 | +| Scale format | E8M0 — exponent-only (power-of-2 scale) | +| Block size | 32 elements per scale | +| Packing | 2 FP4 values per byte (`fp4x2`): low nibble = even index, high nibble = odd index | +| Padding | Dimensions padded to 256-alignment for CK kernel | + +### aiter dtype reference + +| Logical type | aiter dtype | PyTorch native (if available) | Fallback | +|---|---|---|---| +| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | +| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | + +## Input + +A tuple of tensors and a config dict: + +``` +(hidden_states, + gate_up_weight, down_weight, # fp4x2 raw + gate_up_weight_scale, down_weight_scale, # e8m0 raw + gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled + gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled + topk_weights, topk_ids, + config) +``` + +### Tensor shapes + +| Tensor | Shape | Dtype | Notes | +|---|---|---|---| +| `hidden_states` | `[M, d_hidden]` | bfloat16 | Input activations (M = batch of tokens) | +| `gate_up_weight` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Fused gate+up weights, raw | +| `down_weight` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Down projection weights, raw | +| `gate_up_weight_scale` | `[E, 2*d_expert_pad, d_hidden_pad//32]` | e8m0 | Block scales for gate_up, raw | +| `down_weight_scale` | `[E, d_hidden_pad, d_expert_pad//32]` | e8m0 | Block scales for down, raw | +| `gate_up_weight_shuffled` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Pre-shuffled for CK | +| `down_weight_shuffled` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Pre-shuffled for CK | +| `gate_up_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | +| `down_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | +| `topk_weights` | `[M, total_top_k]` | float32 | Routing weights | +| `topk_ids` | `[M, total_top_k]` | int32 | Expert indices (see below) | + +### topk_ids format + +- First `n_experts_per_token` columns: routed expert IDs `[0, n_routed_experts)` +- Last `n_shared_experts` columns: shared expert IDs `[n_routed_experts, n_routed_experts + n_shared_experts)` +- Shared experts are always selected with weight = 1.0 + +### config dict + +```python +config = { + "d_hidden": int, # hidden dimension (e.g. 7168) + "d_expert": int, # expert intermediate dimension (e.g. 2048 or 256) + "d_hidden_pad": int, # d_hidden padded to 256-alignment + "d_expert_pad": int, # d_expert padded to 256-alignment + "n_routed_experts": int, # number of routed experts + "n_shared_experts": int, # number of shared experts (1) + "n_experts_per_token": int, # routed top-k (8) + "total_top_k": int, # routed + shared (9) + "bs": int, # batch size (number of tokens) +} +``` + +## Output + +``` +output: [M, d_hidden] bfloat16 +``` + +## Reference Performance + +AITER `fused_moe` with MXFP4 (E includes shared expert, top_k = routed + shared): + +| bs | E | d_hidden | d_expert | top_k | time (us) | +|---|---|---|---|---|---| +| 4 | 257 | 7168 | 256 | 9 | 46.9 | +| 64 | 257 | 7168 | 256 | 9 | 187.7 | +| 256 | 257 | 7168 | 256 | 9 | 245.7 | +| 64 | 33 | 7168 | 2048 | 9 | 220.6 | +| 256 | 33 | 7168 | 2048 | 9 | 276.4 | +| 1024 | 33 | 7168 | 2048 | 9 | 572.2 | + +## Optimization Opportunities + +The AITER CK `fused_moe` kernel is already well-optimized. To beat it, consider: + +1. **Custom tiling / scheduling**: The CK kernel uses a fixed tile strategy. For small batch sizes + (bs=4) or highly skewed expert distributions, a custom schedule may reduce idle waves. + +2. **Activation quantization fusion**: The reference quantizes activations separately before the + GEMM. Fusing dynamic MXFP4 quantization into the Stage 1 GEMM prologue saves one global + memory round-trip. + +3. **Inter-stage fusion**: The reference runs Stage 1 and Stage 2 as separate kernel launches. + Fusing both stages (gate_up GEMM → SwiGLU → down GEMM → accumulate) into a single kernel + eliminates the intermediate buffer write/read between stages. + +4. **Expert-parallel wave scheduling**: With 257 experts but only 9 active per token, most + expert slots are empty. A work-stealing or compact-dispatch strategy can minimize wasted + wavefronts. + +5. **Shared expert fusion**: The shared expert is always selected for all tokens. It could be + computed as a dense GEMM (no routing overhead) and fused with the routed expert reduction. + +6. **Split-K for large M**: For bs=1024 with EP-on (E=33, d_expert=2048), the GEMMs are large + enough to benefit from split-K parallelism within each expert. + +## Accuracy + +Submissions are checked against the AITER reference with `rtol=1e-2, atol=1e-2`. + +## Benchmark Cases + +### EP-off (all 257 experts on 1 GPU, d_expert=256) + +| bs | E | d_hidden | d_expert | top_k | +|---|---|---|---|---| +| 4 | 257 | 7168 | 256 | 9 | +| 64 | 257 | 7168 | 256 | 9 | +| 256 | 257 | 7168 | 256 | 9 | + +### EP-on (EP=8, 33 experts per GPU, d_expert=2048) + +| bs | E | d_hidden | d_expert | top_k | +|---|---|---|---|---| +| 64 | 33 | 7168 | 2048 | 9 | +| 256 | 33 | 7168 | 2048 | 9 | +| 1024 | 33 | 7168 | 2048 | 9 | + +Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/moe-mxfp4/reference.py b/problems/amd_202602/moe-mxfp4/reference.py new file mode 100644 index 00000000..3bdb8fcd --- /dev/null +++ b/problems/amd_202602/moe-mxfp4/reference.py @@ -0,0 +1,323 @@ +from utils import make_match_reference +from task import input_t, output_t +import torch +import torch.nn.functional as F +from typing import Dict, Tuple, Optional +import math + +from aiter import ActivationType, QuantType +from aiter.fused_moe import fused_moe +from aiter.utility import fp4_utils +from aiter.ops.shuffle import shuffle_weight + + +# ────────────────────────────────────────────────────────────────────── +# Constants +# ────────────────────────────────────────────────────────────────────── +MXFP4_BLOCK_SIZE = 32 +PAD_ALIGN = 256 + + +def _pad_to(x: int, align: int) -> int: + return (x + align - 1) // align * align + + +# ────────────────────────────────────────────────────────────────────── +# generate_input: produce all tensors needed by ref_kernel +# +# Models DeepSeek-R1 MoE layer shapes: +# - d_hidden = 7168 +# - d_expert = moe_intermediate_size (full=2048, or TP-split) +# - E_total = n_routed_experts + n_shared_experts (257 or 33) +# - top_k_total = nexpertspertoken + nsharedexperts (8+1=9) +# +# ────────────────────────────────────────────────────────────────────── +def generate_input( + dhidden: int, + dexpert: int, + nroutedexperts: int, + nexpertspertoken: int, + nsharedexperts: int, + bs: int, + seed: int, +) -> input_t: + d_hidden = dhidden + d_expert = dexpert + n_routed_experts = nroutedexperts + n_shared_experts = nsharedexperts + routed_top_k = nexpertspertoken + total_top_k = routed_top_k + n_shared_experts # e.g. 8 + 1 = 9 + E_total = n_routed_experts + n_shared_experts # e.g. 256 + 1 = 257 + M = bs # number of tokens + + # Padded dimensions (AITER MXFP4 requires 256-alignment) + d_hidden_pad = _pad_to(d_hidden, PAD_ALIGN) + d_expert_pad = _pad_to(d_expert, PAD_ALIGN) + + config = { + "d_hidden": d_hidden, + "d_expert": d_expert, + "d_hidden_pad": d_hidden_pad, + "d_expert_pad": d_expert_pad, + "n_routed_experts": n_routed_experts, + "n_shared_experts": n_shared_experts, + "n_experts_per_token": routed_top_k, + "total_top_k": total_top_k, + "bs": M, + } + + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + + # ── hidden_states [M, d_hidden] ── + hidden_states = torch.randn( + (M, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, + ) + + # ── Router: softmax top-k (routed experts only) ── + router_weight = torch.randn( + (n_routed_experts, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, + ) / math.sqrt(d_hidden) + router_logits = F.linear(hidden_states, router_weight) # [M, n_routed_experts] + scores = router_logits.softmax(dim=-1) + routed_weights, routed_ids = torch.topk( + scores, k=routed_top_k, dim=-1, sorted=False + ) + routed_weights = routed_weights.to(torch.float32) + routed_ids = routed_ids.to(torch.int32) + + # ── Append shared expert(s): always selected, weight = 1.0 ── + # Shared experts are indexed as n_routed_experts, n_routed_experts+1, ... + shared_ids = torch.arange( + n_routed_experts, E_total, device='cuda', dtype=torch.int32 + ).unsqueeze(0).expand(M, -1) # [M, n_shared_experts] + shared_weights = torch.ones( + (M, n_shared_experts), device='cuda', dtype=torch.float32 + ) + + topk_ids = torch.cat([routed_ids, shared_ids], dim=-1) # [M, total_top_k] + topk_weights = torch.cat([routed_weights, shared_weights], dim=-1) # [M, total_top_k] + + # ── Expert weights: bf16 -> quantize to MXFP4 ── + # Generate weights for ALL experts (routed + shared) + # gate_up = fused [gate_proj; up_proj] per expert: [2*d_expert_pad, d_hidden_pad] + # down = down_proj per expert: [d_hidden_pad, d_expert_pad] + gate_up_q_list, gate_up_s_list = [], [] + down_q_list, down_s_list = [], [] + + for _ in range(E_total): + # gate_proj + up_proj -> fused [2*d_expert_pad, d_hidden_pad] + gate_bf16 = torch.randn( + (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen + ) / math.sqrt(d_hidden) + up_bf16 = torch.randn( + (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen + ) / math.sqrt(d_hidden) + gate_up_bf16 = torch.cat([gate_bf16, up_bf16], dim=0) + + # down_proj -> [d_hidden_pad, d_expert_pad] + down_bf16 = torch.randn( + (d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen + ) / math.sqrt(d_expert) + + # Quantize to MXFP4 + gu_q, gu_s = fp4_utils.dynamic_mxfp4_quant(gate_up_bf16) + dn_q, dn_s = fp4_utils.dynamic_mxfp4_quant(down_bf16) + + gate_up_q_list.append(gu_q) + gate_up_s_list.append(gu_s) + down_q_list.append(dn_q) + down_s_list.append(dn_s) + + # Stack into [E_total, ...] tensors — raw (before shuffle) + gate_up_weight = torch.stack(gate_up_q_list) # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 + gate_up_weight_scale = torch.stack(gate_up_s_list) # [E_total, 2*d_expert_pad, scale_K] e8m0 + down_weight = torch.stack(down_q_list) # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 + down_weight_scale = torch.stack(down_s_list) # [E_total, d_hidden_pad, scale_K] e8m0 + + # Pre-shuffled weight. You can also shuffle the weights yourself before calling the kernel. + gate_up_weight_shuffled = shuffle_weight(gate_up_weight.clone()) + down_weight_shuffled = shuffle_weight(down_weight.clone()) + gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale.reshape(E_total, -1)) + down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale.reshape(E_total, -1)) + + return ( + hidden_states, # [M, d_hidden] bf16 + gate_up_weight, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) + down_weight, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) + gate_up_weight_scale, # [E_total, 2*d_expert_pad, scale_K] e8m0 (raw) + down_weight_scale, # [E_total, d_hidden_pad, scale_K] e8m0 (raw) + gate_up_weight_shuffled, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) + down_weight_shuffled, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) + gate_up_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) + down_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) + topk_weights, # [M, total_top_k] float32 + topk_ids, # [M, total_top_k] int32 + config, + ) + + + + +# ────────────────────────────────────────────────────────────────────── +# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) +# ────────────────────────────────────────────────────────────────────── +def _dequant_mxfp4(weight_fp4, scale_e8m0): + """ + Dequantize MXFP4 weight to float32. + + weight_fp4: [N, K//2] fp4x2 (raw, not shuffled) + scale_e8m0: [padded_N, ceil(K/32)] e8m0 (M-dim padded to 256-align by dynamic_mxfp4_quant) + + Returns: [N, K] float32 + """ + # fp4x2 -> float32 lookup: [N, K] + w_f32 = fp4_utils.mxfp4_to_f32(weight_fp4) # [N, K] + # e8m0 -> float32 power-of-2 scale: [padded_N, scale_K] + s_f32 = fp4_utils.e8m0_to_f32(scale_e8m0) # [padded_N, scale_K] + N, K = w_f32.shape + # Trim scale rows to match weight rows (scale M-dim is padded to 256) + s_f32 = s_f32[:N, :] + # Broadcast scale across block_size=32 columns + s_f32 = s_f32.repeat_interleave(MXFP4_BLOCK_SIZE, dim=-1)[:, :K] # [N, K] + return w_f32 * s_f32 + +# ────────────────────────────────────────────────────────────────────── +# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) +# will not run. only for reference +# ────────────────────────────────────────────────────────────────────── +def ref_kernel_pytorch(data: input_t) -> output_t: + """ + Pure PyTorch reference: dequantize MXFP4 weights -> bf16 matmul -> SwiGLU -> matmul. + Uses the raw (un-shuffled) weights. + """ + ( + hidden_states, # [M, d_hidden] bf16 + gate_up_weight, # [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 + down_weight, # [E, d_hidden_pad, d_expert_pad//2] fp4x2 + gate_up_weight_scale, # [E, 2*d_expert_pad, scale_K] e8m0 + down_weight_scale, # [E, d_hidden_pad, scale_K] e8m0 + gate_up_weight_shuffled, + down_weight_shuffled, + gate_up_weight_scale_shuffled, + down_weight_scale_shuffled, + topk_weights, # [M, top_k] float32 + topk_ids, # [M, top_k] int32 + config, + ) = data + + d_hidden = config["d_hidden"] + d_expert = config["d_expert"] + d_hidden_pad = config["d_hidden_pad"] + d_expert_pad = config["d_expert_pad"] + M = hidden_states.shape[0] + top_k = topk_ids.shape[1] + E = gate_up_weight.shape[0] + + # Dequantize all expert weights to float32 + # gate_up: [E, 2*d_expert_pad, d_hidden_pad] -> trim to [E, 2*d_expert, d_hidden] + # down: [E, d_hidden_pad, d_expert_pad] -> trim to [E, d_hidden, d_expert] + gate_up_dq = torch.stack([ + _dequant_mxfp4(gate_up_weight[e], gate_up_weight_scale[e]) + for e in range(E) + ]) # [E, 2*d_expert_pad, d_hidden_pad] + gate_up_dq = gate_up_dq[:, :2 * d_expert, :d_hidden].to(torch.bfloat16) + + down_dq = torch.stack([ + _dequant_mxfp4(down_weight[e], down_weight_scale[e]) + for e in range(E) + ]) # [E, d_hidden_pad, d_expert_pad] + down_dq = down_dq[:, :d_hidden, :d_expert].to(torch.bfloat16) + + # Split gate_up -> gate [E, d_expert, d_hidden], up [E, d_expert, d_hidden] + gate_w, up_w = gate_up_dq.chunk(2, dim=1) # each [E, d_expert, d_hidden] + + # Per-token MoE forward + output = torch.zeros((M, d_hidden), dtype=torch.bfloat16, device=hidden_states.device) + + for i in range(M): + x = hidden_states[i] # [d_hidden] + for k in range(top_k): + eid = topk_ids[i, k].item() + w = topk_weights[i, k].item() + + # Stage 1: gate_proj + up_proj + SwiGLU + gate_out = F.silu(x @ gate_w[eid].T) # [d_expert] + up_out = x @ up_w[eid].T # [d_expert] + intermediate = gate_out * up_out # [d_expert] + + # Stage 2: down_proj + # down_dq[eid] is [d_hidden, d_expert], .T is [d_expert, d_hidden] + expert_out = intermediate @ down_dq[eid].T # [d_hidden] + + output[i] += w * expert_out + + return output + + + +# ────────────────────────────────────────────────────────────────────── +# ref_kernel: calls AITER fused_moe with MXFP4 quantized weights +# ────────────────────────────────────────────────────────────────────── +def ref_kernel(data: input_t) -> output_t: + """ + Reference implementation using AITER's fused_moe kernel with MXFP4 quantized weights. + + Input data tuple (E = n_routed_experts + n_shared_experts, total_top_k = routed + shared): + hidden_states: [M, d_hidden] bf16 + gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) + down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) + gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw, before shuffle) + down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw, before shuffle) + gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) + down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) + gate_up_weight_scale_shuffled:[padded, flat] e8m0 (pre-shuffled) + down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled) + topk_weights: [M, total_top_k] float32 + topk_ids: [M, total_top_k] int32 + config: dict + + Returns: + output: [M, d_hidden] bf16 + """ + ( + hidden_states, + gate_up_weight, + down_weight, + gate_up_weight_scale, + down_weight_scale, + gate_up_weight_shuffled, + down_weight_shuffled, + gate_up_weight_scale_shuffled, + down_weight_scale_shuffled, + topk_weights, + topk_ids, + config, + ) = data + + hidden_pad = config["d_hidden_pad"] - config["d_hidden"] + intermediate_pad = config["d_expert_pad"] - config["d_expert"] + + output = fused_moe( + hidden_states, + gate_up_weight_shuffled, + down_weight_shuffled, + topk_weights, + topk_ids, + expert_mask=None, + activation=ActivationType.Silu, + quant_type=QuantType.per_1x32, # MXFP4 uses per_1x32 block scaling + doweight_stage1=False, + w1_scale=gate_up_weight_scale_shuffled, + w2_scale=down_weight_scale_shuffled, + a1_scale=None, + a2_scale=None, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + ) + + return output + + + +check_implementation = make_match_reference(ref_kernel, rtol=5e-2, atol=5e-2) diff --git a/problems/amd_202602/moe-mxfp4/submission.py b/problems/amd_202602/moe-mxfp4/submission.py new file mode 100644 index 00000000..a771b32c --- /dev/null +++ b/problems/amd_202602/moe-mxfp4/submission.py @@ -0,0 +1,66 @@ +import torch +from typing import Dict +from task import input_t, output_t + +from aiter import ActivationType, QuantType +from aiter.fused_moe import fused_moe + + +def custom_kernel(data: input_t) -> output_t: + """ + Submission template for DeepSeek-R1 MXFP4 MoE kernel. + + Input data tuple: + hidden_states: [M, d_hidden] bf16 + gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) + down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) + gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw) + down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw) + gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) + down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) + gate_up_weight_scale_shuffled:[padded, flat] e8m0 (shuffled) + down_weight_scale_shuffled: [padded, flat] e8m0 (shuffled) + topk_weights: [M, total_top_k] float32 + topk_ids: [M, total_top_k] int32 + config: dict + + Returns: + output: [M, d_hidden] bf16 + """ + ( + hidden_states, + gate_up_weight, + down_weight, + gate_up_weight_scale, + down_weight_scale, + gate_up_weight_shuffled, + down_weight_shuffled, + gate_up_weight_scale_shuffled, + down_weight_scale_shuffled, + topk_weights, + topk_ids, + config, + ) = data + + hidden_pad = config["d_hidden_pad"] - config["d_hidden"] + intermediate_pad = config["d_expert_pad"] - config["d_expert"] + + output = fused_moe( + hidden_states, + gate_up_weight_shuffled, + down_weight_shuffled, + topk_weights, + topk_ids, + expert_mask=None, + activation=ActivationType.Silu, + quant_type=QuantType.per_1x32, + doweight_stage1=False, + w1_scale=gate_up_weight_scale_shuffled, + w2_scale=down_weight_scale_shuffled, + a1_scale=None, + a2_scale=None, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + ) + + return output diff --git a/problems/amd_202602/moe-mxfp4/task.py b/problems/amd_202602/moe-mxfp4/task.py new file mode 100644 index 00000000..a19edc83 --- /dev/null +++ b/problems/amd_202602/moe-mxfp4/task.py @@ -0,0 +1,28 @@ +from typing import TypeVar, Tuple, Dict +import torch + +input_t = TypeVar("input_t", bound=Tuple[ + torch.Tensor, # hidden_states [M, d_hidden] + torch.Tensor, # gate_up_weight [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) + torch.Tensor, # down_weight [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) + torch.Tensor, # gate_up_weight_scale [E, 2*d_expert_pad, scale_K] e8m0 (raw) + torch.Tensor, # down_weight_scale [E, d_hidden_pad, scale_K] e8m0 (raw) + torch.Tensor, # gate_up_weight_shuffled [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) + torch.Tensor, # down_weight_shuffled [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) + torch.Tensor, # gate_up_weight_scale_shuffled [padded, flat] e8m0 (shuffled) + torch.Tensor, # down_weight_scale_shuffled [padded, flat] e8m0 (shuffled) + torch.Tensor, # topk_weights [M, total_top_k] + torch.Tensor, # topk_ids [M, total_top_k] + Dict, # config +]) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec: + dhidden: int # hidden dimension (7168 for DeepSeek-R1) + dexpert: int # intermediate dimension per expert (per partition) + nroutedexperts: int # number of local routed experts on this GPU + nexpertspertoken: int # top-k routed experts per token (8 for DeepSeek-R1) + nsharedexperts: int # number of shared experts (1 for DeepSeek-R1), always selected + bs: int # batch size = number of tokens in this batch + seed: int diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml new file mode 100644 index 00000000..12d14e57 --- /dev/null +++ b/problems/amd_202602/moe-mxfp4/task.yml @@ -0,0 +1,122 @@ +# name: 3_moe_mxfp4 + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + You will implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. + + To be explicit, you will be given a tuple of tensors: + ``` + (hidden_states, + gate_up_weight, down_weight, # fp4x2 raw + gate_up_weight_scale, down_weight_scale, # e8m0 raw + gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled + gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled + topk_weights, topk_ids, + config) + ``` + where: + * `hidden_states` is M x d_hidden in bfloat16 (the input activations, M = batch of tokens) + * `gate_up_weight` is [E, 2*d_expert_pad, d_hidden_pad//2] in MXFP4 (fp4x2), raw layout. + Fused gate + up projection weights for each expert. E = number of local experts. + * `down_weight` is [E, d_hidden_pad, d_expert_pad//2] in MXFP4 (fp4x2), raw layout. + Down projection weights for each expert. + * `gate_up_weight_scale` is [E, 2*d_expert_pad, d_hidden_pad//32] in E8M0, raw layout. + Block scales (block_size=32) for gate_up_weight. + * `down_weight_scale` is [E, d_hidden_pad, d_expert_pad//32] in E8M0, raw layout. + Block scales for down_weight. + * `gate_up_weight_shuffled` / `down_weight_shuffled` are the same weights shuffled to + (16,16) tile-coalesced layout for the CK kernel. + * `gate_up_weight_scale_shuffled` / `down_weight_scale_shuffled` are the scales after + e8m0_shuffle, flattened to [padded, flat]. + * `topk_weights` is [M, total_top_k] float32: routing weights (routed experts + shared experts). + * `topk_ids` is [M, total_top_k] int32: expert indices. First nexpertspertoken columns are + routed expert ids (0..n_routed-1), last nsharedexperts columns are shared expert ids + (n_routed..n_routed+n_shared-1). Shared experts are always selected with weight=1.0. + * `config` is a dict with: d_hidden, d_expert, d_hidden_pad, d_expert_pad, + n_routed_experts, n_shared_experts, n_experts_per_token, total_top_k, bs. + + Then, the fused_moe kernel flow is: + (1) Quant activations to MXFP4: aiter per-1x32 dynamic quantization of hidden_states. + (2) Stage 1 GEMM + activation (per token i, per assigned expert j): + - gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] + - up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] + - intermediate = SiLU(gate) * up # SwiGLU activation, -> [d_expert] + (W_gate and W_up are fused as gate_up_weight, so this is one a4w4 GEMM + fused activation) + (3) Stage 2 GEMM: + - expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] + (4) Weighted reduction: + - output_i += w_ij * expert_out # accumulate across top_k experts + All weight GEMMs are a4w4 (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). + The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. + + DeepSeek-R1 MoE specs: + - hidden_size = 7168, moe_intermediate_size = 2048 + - 256 routed experts + 1 shared expert (total 257), top-8 routed + 1 shared = 9 per token + - 58 MoE layers (layer 3-60) + - The shared expert processes ALL tokens unconditionally (weight=1.0) + + d_hidden_pad and d_expert_pad are the dimensions padded to 256-alignment for the CK kernel. + + The ranking criteria is the geometric mean of the benchmark results. + + ``` + The AITER reference performance is (E includes shared expert, top_k = routed + shared): + bs E d_hidden d_expert top_k time[us] + 4 257 7168 256 9 46.9 + 64 257 7168 256 9 187.7 + 256 257 7168 256 9 245.7 + 64 33 7168 2048 9 220.6 + 256 33 7168 2048 9 276.4 + 1024 33 7168 2048 9 572.2 + ``` + + Input: + - hidden_states: [M, d_hidden] bf16 + - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) + - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) + - gate_up_weight_scale: [E, 2*d_expert_pad, d_hidden_pad//32] e8m0 (raw, before shuffle) + - down_weight_scale: [E, d_hidden_pad, d_expert_pad//32] e8m0 (raw, before shuffle) + - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled for CK) + - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled for CK) + - gate_up_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) + - down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) + - topk_weights: [M, total_top_k] float32 + - topk_ids: [M, total_top_k] int32 + - config: dict with dimensions + + Output: + - output: [M, d_hidden] bf16 + +config: + main: "eval.py" + +templates: + Python: "submission.py" + +test_timeout: 540 +benchmark_timeout: 540 +ranked_timeout: 840 +ranking_by: "geom" + +tests: + - {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 16, "nexpertspertoken": 4, "nsharedexperts": 1, "bs": 8, "seed": 9371} + - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 32, "seed": 2291} + - {"dhidden": 4096, "dexpert": 1536, "nroutedexperts": 64, "nexpertspertoken": 6, "nsharedexperts": 1, "bs": 128, "seed": 81934} + +benchmarks: + # EP off (all 257 experts on 1 GPU): E=257, top_k=9 (8 routed + 1 shared) + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 4, "seed": 9371} + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291} + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934} + # EP on (EP=8, 33 experts per GPU): E=33, top_k=9 (8 routed + 1 shared) + - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291} + - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934} + - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 1024, "seed": 81934} diff --git a/problems/amd_202602/mxfp4-mm/reference.py b/problems/amd_202602/mxfp4-mm/reference.py new file mode 100644 index 00000000..3c26d348 --- /dev/null +++ b/problems/amd_202602/mxfp4-mm/reference.py @@ -0,0 +1,108 @@ +""" +FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. +Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). +""" +import torch +from task import input_t, output_t +from utils import make_match_reference +from aiter import QuantType,dtypes +import aiter +from aiter.ops.shuffle import shuffle_weight +# K must be divisible by 64 (scale group 32 and fp4 pack 2) +SCALE_GROUP_SIZE = 32 + +def generate_input(m: int, n: int, k: int, seed: int):# -> input_t: + """ + Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale. + + Returns: + Tuple of (A, B), both bf16 on cuda. + """ + assert k % 64 == 0, "k must be divisible by 64 (scale group 32 and fp4 pack 2)" + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen) + B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen) + + # quantized mxfp4 B + quant_func = aiter.get_triton_quant(QuantType.per_1x32) + B_q, B_scale_sh = quant_func(B, shuffle=True) + + # shuffle B(weight) to (16,16) tile coalesced + B_shuffle = shuffle_weight(B_q, layout=(16, 16)) + return (A, B, B_q, B_shuffle, B_scale_sh) + +def run_torch_fp4_mm( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """ + PyTorch reference: dequant MXFP4 + E8M0 scale -> f32 -> mm -> dtype. + Same logic as aiter op_tests/test_gemm_a4w4.run_torch. + x: [m, k//2] fp4 packed, w: [n, k//2] fp4 packed + x_scales: [m, k//32] E8M0, w_scales: [n, k//32] E8M0 + Returns: [m, n] in dtype + """ + from aiter.utility import fp4_utils + + m, _ = x.shape + n, _ = w.shape + # fp4 packed -> f32 + x_f32 = fp4_utils.mxfp4_to_f32(x) + w_f32 = fp4_utils.mxfp4_to_f32(w) + # E8M0 scale: [*, k//32] -> repeat 32 along k -> f32 + x_scales = x_scales[:m].repeat_interleave(SCALE_GROUP_SIZE, dim=1) + x_scales_f32 = fp4_utils.e8m0_to_f32(x_scales) + x_f32 = x_f32 * x_scales_f32 + w_scales = w_scales[:n].repeat_interleave(SCALE_GROUP_SIZE, dim=1) + w_scales_f32 = fp4_utils.e8m0_to_f32(w_scales) + w_f32 = w_f32 * w_scales_f32 + return torch.mm(x_f32, w_f32.T).to(dtype)[:m, :n] + + +def ref_kernel(data: input_t) -> output_t: + """ + Reference: MXFP4 per-1x32 quant on A and B; both PyTorch ref and gemm_a4w4 are given. + Returns gemm_a4w4 for check_implementation. + """ + A, B, B_q, B_shuffle, B_scale_sh = data + A = A.contiguous() + B = B.contiguous() + m, k = A.shape + n, _ = B.shape + + # 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16 + # Per-1x32 MXFP4 quant + # quant_func = aiter.get_triton_quant(QuantType.per_1x32) + # quant_func(x, shuffle=False) -> (dtypes.fp4x2, scale); scale layout matches gemm_a4w4 + # A_q, A_scale = quant_func(A, shuffle=False) + # B_q, B_scale = quant_func(B, shuffle=False) + + # gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0 + # quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed + # slice to exact shapes [m,k_scale] / [n,k_scale] (quant may return padded scale) + + # k_scale = k // SCALE_GROUP_SIZE + # A_scale = A_scale[:m, :k_scale].contiguous() + # B_scale = B_scale[:n, :k_scale].contiguous() + # out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16) + + # 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105) + # Per-1x32 MXFP4 quant + quant_func = aiter.get_triton_quant(QuantType.per_1x32) + A_q, A_scale_sh = quant_func(A, shuffle=True) + # to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py + out_gemm = aiter.gemm_a4w4( + A_q, + B_shuffle, + A_scale_sh, + B_scale_sh, + dtype=dtypes.bf16, + bpreshuffle=True, + ) + return out_gemm + +check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) \ No newline at end of file diff --git a/problems/amd_202602/mxfp4-mm/submission.py b/problems/amd_202602/mxfp4-mm/submission.py new file mode 100644 index 00000000..33b7ebbb --- /dev/null +++ b/problems/amd_202602/mxfp4-mm/submission.py @@ -0,0 +1,32 @@ +""" +FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. +Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). +""" +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference: MXFP4 per-1x32 quant on A; B_shuffle, B_scale_sh from generate_input. + gemm_a4w4 with bpreshuffle=True. + """ + import aiter + from aiter import QuantType, dtypes + + A, B, B_q, B_shuffle, B_scale_sh = data + A = A.contiguous() + B = B.contiguous() + m, k = A.shape + n, _ = B.shape + + quant_func = aiter.get_triton_quant(QuantType.per_1x32) + A_q, A_scale_sh = quant_func(A, shuffle=True) + out_gemm = aiter.gemm_a4w4( + A_q, + B_shuffle, + A_scale_sh, + B_scale_sh, + dtype=dtypes.bf16, + bpreshuffle=True, + ) + return out_gemm diff --git a/problems/amd_202602/mxfp4-mm/task.py b/problems/amd_202602/mxfp4-mm/task.py new file mode 100644 index 00000000..a258eac0 --- /dev/null +++ b/problems/amd_202602/mxfp4-mm/task.py @@ -0,0 +1,21 @@ +""" +quant + FP4 GEMM: bf16 A, B -> MXFP4 1x32 per-block quant -> gemm_a4w4 -> bf16 C. +""" +import torch +from typing import TypeVar, TypedDict + +# Input: (A, B, B_q, B_shuffle, B_scale_sh) from generate_input. +# A [m,k], B [n,k] bf16; B_q quantized MXFP4; B_shuffle = shuffle_weight(B_q,(16,16)); B_scale_sh from quant(B, shuffle=True). +# Output: bf16 C [m, n]. +input_t = TypeVar( + "input_t", + bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], +) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec(TypedDict): + m: int + n: int + k: int + seed: int diff --git a/problems/amd_202602/mxfp4-mm/task.yml b/problems/amd_202602/mxfp4-mm/task.yml new file mode 100644 index 00000000..a8906f82 --- /dev/null +++ b/problems/amd_202602/mxfp4-mm/task.yml @@ -0,0 +1,62 @@ +# name: mxfp4-mm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + You will implement a quantize func and block scaled MXFP4 matrix-matrix multiplication kernel optimized for AMD Instinct MI355X GPU. + To be explicit, you will be given a tuple of tensors: + ``` + (A, B, B_q, B_shuffle, B_scale_sh) + ``` + where: + * `A` is M x K in K-major order in bfloat16 + * `B` is N x K in K-major order in bfloat16 + * `B_q` is N x K/2 in K-major order in MXFP4 + * `B_shuffle` is N x K/2 in shuffled order in MXFP4, shuffled to (16,16) tile coalesced + * `B_scale_sh` is * x K/32 in E8M0, * means it will be padded. + + Then, the kernel flow is bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> BF16 C [m,n]. + To be specific, the invocation flow is: + (1) Quant A to MXFP4: aiter.get_triton_quant(QuantType.per_1x32). + (2) GEMM: aiter.gemm_a4w4. + m, n divisible by 64; k divisible by 64. + + The ranking criteria is the geometric mean of the benchmark results. + Pls note that this is the elimination round, whoever rank top5 are selected into the next round, e2e optimization for deepseek-R1-MXFP4 and GPTOSS-MXFP4 mdoel + ``` + The aiter performance is: + M N K time[us] + 4 2880 512 8.198 + 16 2112 7168 20.873 + 32 4096 512 9.462 + 64 7168 2048 12.738 + 64 2880 512 9.873 + 128 2112 7168 27.284 + 256 3072 1536 12.219 + 256 7168 2048 13.506 + ``` +config: + main: "eval.py" + +tests: + - {"m": 8, "n": 2112, "k": 7168, "seed": 124} + - {"m": 16, "n": 3072, "k": 1536, "seed": 6635} + - {"m": 64, "n": 3072, "k": 1536, "seed": 45} + - {"m": 256, "n": 2880, "k": 512, "seed": 78} + +benchmarks: + - {"m": 4, "n": 2880, "k": 512, "seed": 4565} + - {"m": 16, "n": 2112, "k": 7168, "seed": 15} + - {"m": 32, "n": 4096, "k": 512, "seed": 457} + - {"m": 64, "n": 7168, "k": 2048, "seed": 687} + - {"m": 64, "n": 2880, "k": 512, "seed": 54} + - {"m": 128, "n": 2112, "k": 7168, "seed": 24} + - {"m": 256, "n": 3072, "k": 1536, "seed": 7856} + - {"m": 256, "n": 7168, "k": 2048, "seed": 223} \ No newline at end of file diff --git a/problems/amd_202602/utils.py b/problems/amd_202602/utils.py new file mode 100644 index 00000000..45f120ed --- /dev/null +++ b/problems/amd_202602/utils.py @@ -0,0 +1,147 @@ +import random +from typing import Tuple + +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return False, ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received.to(torch.float32) - expected.to(torch.float32)) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return False, mismatch_details + + return True, [f"Maximum error: {torch.max(diff)}"] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int = 5) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return False, mismatch_details + + return True, [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08): + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + good, reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return good, "\\n".join(reasons) + + return good, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + +def clear_l2_cache_large(): + dummy = torch.randn((16000, 1024, 1024), device="cuda") + del dummy From 3cca4453019fd3a384e79300fc6cfb15dcda83e8 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 3 Mar 2026 21:00:21 -0800 Subject: [PATCH 2/4] Document known moe-mxfp4 non-determinism issue on MI355X aiter's fused_moe kernel produces different results across calls with identical inputs on gfx950, causing the reference submission to fail correctness checks against itself. --- problems/amd_202602/moe-mxfp4/task.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml index 12d14e57..c7304282 100644 --- a/problems/amd_202602/moe-mxfp4/task.yml +++ b/problems/amd_202602/moe-mxfp4/task.yml @@ -65,6 +65,11 @@ description: | d_hidden_pad and d_expert_pad are the dimensions padded to 256-alignment for the CK kernel. + **Known issue:** The reference submission (which calls aiter's fused_moe) is non-deterministic + on MI355X — it does not pass correctness checks against itself. This appears to be an aiter + fused_moe kernel bug on gfx950. Submissions will be evaluated on benchmark performance only + until this is resolved. + The ranking criteria is the geometric mean of the benchmark results. ``` From d452068cad17d0646b64ed9b86e51371c24523d7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 4 Mar 2026 07:29:39 -0800 Subject: [PATCH 3/4] Update AMD problems with latest changes from AMD-AIM - mixed-mla: Add tp (tensor parallel) parameter, variable num_heads, qseqlen=4 prefill cases, updated test/benchmark shapes - moe-mxfp4: Updated benchmark shapes with TP=4/TP=8 variants, different batch sizes - mxfp4-mm: Added m=32 benchmark, adjusted shape set --- problems/amd_202602/mixed-mla/reference.py | 16 ++++-- problems/amd_202602/mixed-mla/submission.py | 11 ++-- problems/amd_202602/mixed-mla/task.py | 4 +- problems/amd_202602/mixed-mla/task.yml | 57 +++++++++++---------- problems/amd_202602/moe-mxfp4/task.yml | 33 ++++++------ problems/amd_202602/mxfp4-mm/task.yml | 8 +-- 6 files changed, 70 insertions(+), 59 deletions(-) diff --git a/problems/amd_202602/mixed-mla/reference.py b/problems/amd_202602/mixed-mla/reference.py index b0dee591..e1d91a72 100644 --- a/problems/amd_202602/mixed-mla/reference.py +++ b/problems/amd_202602/mixed-mla/reference.py @@ -6,7 +6,7 @@ output v_head_dim = kv_lora_rank = 512. The input provides: - q: (total_q, 16, 576) bfloat16 — absorbed query + q: (total_q, num_heads, 576) bfloat16 — absorbed query (num_heads = 128 // tp) kv_data: dict with KV cache in three formats: "bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision "fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized @@ -37,7 +37,7 @@ # DeepSeek R1 latent MQA constants (forward_absorb path) # https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json # --------------------------------------------------------------------------- -NUM_HEADS = 16 +TOTAL_NUM_HEADS = 128 NUM_KV_HEADS = 1 KV_LORA_RANK = 512 QK_ROPE_HEAD_DIM = 64 @@ -285,10 +285,13 @@ def _aiter_mla_decode( # generate_input / ref_kernel / check_implementation # --------------------------------------------------------------------------- -def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> input_t: +def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: int) -> input_t: """ Generate absorbed q and compressed kv_buffer for MLA decode. + Args: + tp: tensor parallelism degree (4 or 8). num_heads = TOTAL_NUM_HEADS // tp. + Returns all three KV cache formats in kv_data dict: kv_data = { "bf16": Tensor — (total_kv, 1, 576) bfloat16 @@ -296,6 +299,9 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in "mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale } """ + assert TOTAL_NUM_HEADS % tp == 0, f"TOTAL_NUM_HEADS ({TOTAL_NUM_HEADS}) must be divisible by tp ({tp})" + num_heads = TOTAL_NUM_HEADS // tp + gen = torch.Generator(device="cuda") gen.manual_seed(seed) @@ -304,7 +310,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in # Absorbed query: (total_q, num_heads, 576) bf16 q = torch.randn( - (total_q, NUM_HEADS, QK_HEAD_DIM), + (total_q, num_heads, QK_HEAD_DIM), dtype=torch.bfloat16, device="cuda", generator=gen, ) * 0.02 @@ -332,7 +338,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in config = { "batch_size": batchsize, - "num_heads": NUM_HEADS, + "num_heads": num_heads, "num_kv_heads": NUM_KV_HEADS, "qk_head_dim": QK_HEAD_DIM, "kv_lora_rank": KV_LORA_RANK, diff --git a/problems/amd_202602/mixed-mla/submission.py b/problems/amd_202602/mixed-mla/submission.py index b6ff1309..e85525b8 100644 --- a/problems/amd_202602/mixed-mla/submission.py +++ b/problems/amd_202602/mixed-mla/submission.py @@ -4,7 +4,8 @@ Implement custom_kernel() to beat the aiter a8w8 reference (fp8 Q + fp8 KV). DeepSeek R1 forward_absorb MLA config: - num_heads = 16 (query heads, after TP split) + total_num_heads = 128 (query heads before TP split) + num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16) num_kv_heads = 1 (shared latent KV head) kv_lora_rank = 512 (latent dim) qk_rope_head_dim = 64 (RoPE dim) @@ -17,24 +18,24 @@ - First 512 dims (kv_lora_rank) used as values (for output computation) Input tuple: - q: (total_q, 16, 576) bfloat16 — absorbed query + q: (total_q, num_heads, 576) bfloat16 — absorbed query kv_data: dict with three KV cache formats: kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 (total_kv,1,576) + scalar scale kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 (total_kv,1,288) + fp8_e8m0 scale qo_indptr: (batch_size + 1,) int32 — query segment pointers kv_indptr: (batch_size + 1,) int32 — KV segment pointers - config: dict with MLA parameters + config: dict with MLA parameters (includes num_heads computed from tp) Output: - attention output: (total_q, 16, 512) bfloat16 + attention output: (total_q, num_heads, 512) bfloat16 The reference uses aiter's a8w8 persistent MLA kernel (fp8 Q + fp8 KV), which is ~2-3x faster than bf16. To beat it, consider: 1. Use mxfp4 KV cache for even lower memory bandwidth - Fuse dequantization with attention to avoid bf16 materialization 2. Custom kernel with tighter memory access patterns - 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads + 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads 4. Variable-length batching: indptr-based segmented attention 5. Split K/V from buffer: full 576 dims for keys, first 512 dims for values """ diff --git a/problems/amd_202602/mixed-mla/task.py b/problems/amd_202602/mixed-mla/task.py index 7aff7b6a..a4c13320 100644 --- a/problems/amd_202602/mixed-mla/task.py +++ b/problems/amd_202602/mixed-mla/task.py @@ -5,13 +5,14 @@ # # Input: (q, kv_data, qo_indptr, kv_indptr, config) # q: (total_q, num_heads, qk_head_dim) bfloat16 +# num_heads = 128 // tp (tp=4 → 32, tp=8 → 16) # kv_data: dict with three KV cache formats: # "bf16": Tensor (total_kv, 1, 576) bfloat16 # "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale # "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale # qo_indptr: (batch_size + 1,) int32 # kv_indptr: (batch_size + 1,) int32 -# config: dict with MLA parameters +# config: dict with MLA parameters (includes num_heads computed from tp) # # where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576 # @@ -33,4 +34,5 @@ class TestSpec(TypedDict): batchsize: int qseqlen: int kvseqlen: int + tp: int seed: int diff --git a/problems/amd_202602/mixed-mla/task.yml b/problems/amd_202602/mixed-mla/task.yml index c0a5d5a6..08053062 100644 --- a/problems/amd_202602/mixed-mla/task.yml +++ b/problems/amd_202602/mixed-mla/task.yml @@ -20,7 +20,8 @@ description: | persistent mode), which is ~2-3x faster than bf16 on MI355X. DeepSeek R1 forward_absorb MLA config: - - num_heads = 16 (query heads, after TP split) + - total_num_heads = 128 (query heads before TP split) + - num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16) - num_kv_heads = 1 (shared latent KV head) - kv_lora_rank = 512 - qk_rope_head_dim = 64 @@ -28,31 +29,31 @@ description: | - v_head_dim = 512 (= kv_lora_rank, output dim) - sm_scale = 1/sqrt(576) - dtype: q=bfloat16 - - decode only (q_seq_len=1, kv_seq_len up to 8k) + - q_seq_len = 1 or 4, kv_seq_len up to 8k KV buffer format (forward_absorb): - Full 576 dims are used as keys (for Q@K^T score computation) - First 512 dims (kv_lora_rank) are used as values (for output computation) Input tuple: (q, kv_data, qo_indptr, kv_indptr, config) - - q: (total_q, 16, 576) bfloat16 — absorbed query + - q: (total_q, num_heads, 576) bfloat16 — absorbed query - kv_data: dict with three KV cache formats: kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale - qo_indptr: (batch_size+1,) int32 — query segment pointers - kv_indptr: (batch_size+1,) int32 — KV segment pointers - - config: dict with MLA parameters + - config: dict with MLA parameters (includes num_heads computed from tp) Return: - - attention output: (total_q, 16, 512) bfloat16 + - attention output: (total_q, num_heads, 512) bfloat16 Key optimization opportunities: 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16) - Fuse dequantization with attention to skip bf16 materialization 2. Custom kernel with tighter memory access patterns - 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads - 4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload + 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads + 4. q_seq_len=1 or 4, kv_seq_len up to 8k — memory-bound workload 5. Variable-length batching: indptr-based segmented attention 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values @@ -69,27 +70,29 @@ benchmark_timeout: 900 ranked_timeout: 1200 tests: - # bs=4 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220} - # bs=32 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} - # bs=64 - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} - # bs=256 - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} + # bs=4, tp=8 + - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220} + - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231} + # bs=32, tp=4 + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412} + - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423} + # bs=128, tp=8 + - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816} + - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827} benchmarks: - # bs=4 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217} - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220} - # bs=32 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415} - # bs=64 - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357} - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} - # bs=256 - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823} - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} + # bs=4, tp=4 + - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237} + - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251} + # bs=32, tp=8 + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415} + - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420} + # bs=32, tp=4 + - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432} + - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443} + # bs=128, tp=8 + - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816} + - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824} + ranking_by: "geom" diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml index c7304282..23a0e4b7 100644 --- a/problems/amd_202602/moe-mxfp4/task.yml +++ b/problems/amd_202602/moe-mxfp4/task.yml @@ -75,12 +75,13 @@ description: | ``` The AITER reference performance is (E includes shared expert, top_k = routed + shared): bs E d_hidden d_expert top_k time[us] - 4 257 7168 256 9 46.9 - 64 257 7168 256 9 187.7 - 256 257 7168 256 9 245.7 - 64 33 7168 2048 9 220.6 - 256 33 7168 2048 9 276.4 - 1024 33 7168 2048 9 572.2 + 16 257 7168 256 9 152.7 + 128 257 7168 256 9 239.0 + 512 257 7168 256 9 336.5 + 16 33 7168 512 9 106.2 + 128 33 7168 512 9 141.1 + 512 33 7168 512 9 225.0 + 512 33 7168 2048 9 380.4 ``` Input: @@ -112,16 +113,18 @@ ranked_timeout: 840 ranking_by: "geom" tests: - - {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 16, "nexpertspertoken": 4, "nsharedexperts": 1, "bs": 8, "seed": 9371} + - {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 8, "seed": 9371} - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 32, "seed": 2291} - {"dhidden": 4096, "dexpert": 1536, "nroutedexperts": 64, "nexpertspertoken": 6, "nsharedexperts": 1, "bs": 128, "seed": 81934} benchmarks: - # EP off (all 257 experts on 1 GPU): E=257, top_k=9 (8 routed + 1 shared) - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 4, "seed": 9371} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934} - # EP on (EP=8, 33 experts per GPU): E=33, top_k=9 (8 routed + 1 shared) - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291} - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934} - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 1024, "seed": 81934} + # TP=8 + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 9371} + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 2291} + - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} + # TP=4 + - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 2291} + - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 81934} + - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} + # EP on + - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} diff --git a/problems/amd_202602/mxfp4-mm/task.yml b/problems/amd_202602/mxfp4-mm/task.yml index a8906f82..1e519c78 100644 --- a/problems/amd_202602/mxfp4-mm/task.yml +++ b/problems/amd_202602/mxfp4-mm/task.yml @@ -36,11 +36,9 @@ description: | 4 2880 512 8.198 16 2112 7168 20.873 32 4096 512 9.462 + 32 2880 512 9.173 64 7168 2048 12.738 - 64 2880 512 9.873 - 128 2112 7168 27.284 256 3072 1536 12.219 - 256 7168 2048 13.506 ``` config: main: "eval.py" @@ -55,8 +53,6 @@ benchmarks: - {"m": 4, "n": 2880, "k": 512, "seed": 4565} - {"m": 16, "n": 2112, "k": 7168, "seed": 15} - {"m": 32, "n": 4096, "k": 512, "seed": 457} + - {"m": 32, "n": 2880, "k": 512, "seed": 54} - {"m": 64, "n": 7168, "k": 2048, "seed": 687} - - {"m": 64, "n": 2880, "k": 512, "seed": 54} - - {"m": 128, "n": 2112, "k": 7168, "seed": 24} - {"m": 256, "n": 3072, "k": 1536, "seed": 7856} - - {"m": 256, "n": 7168, "k": 2048, "seed": 223} \ No newline at end of file From 9e712af42bd7288c298066f5496302c8b37531e9 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 4 Mar 2026 08:13:15 -0800 Subject: [PATCH 4/4] Increase AMD problem timeouts to 30 minutes aiter JIT compilation on first run can take 10+ minutes on MI355X, causing test timeouts. Bump all timeouts to 1800s (30 min). --- problems/amd_202602/mixed-mla/task.yml | 6 +++--- problems/amd_202602/moe-mxfp4/task.yml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/problems/amd_202602/mixed-mla/task.yml b/problems/amd_202602/mixed-mla/task.yml index 08053062..a7d03ad7 100644 --- a/problems/amd_202602/mixed-mla/task.yml +++ b/problems/amd_202602/mixed-mla/task.yml @@ -65,9 +65,9 @@ config: templates: Python: "submission.py" -test_timeout: 900 -benchmark_timeout: 900 -ranked_timeout: 1200 +test_timeout: 1800 +benchmark_timeout: 1800 +ranked_timeout: 1800 tests: # bs=4, tp=8 diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml index 23a0e4b7..2e5d3608 100644 --- a/problems/amd_202602/moe-mxfp4/task.yml +++ b/problems/amd_202602/moe-mxfp4/task.yml @@ -107,9 +107,9 @@ config: templates: Python: "submission.py" -test_timeout: 540 -benchmark_timeout: 540 -ranked_timeout: 840 +test_timeout: 1800 +benchmark_timeout: 1800 +ranked_timeout: 1800 ranking_by: "geom" tests: