From bb6d0d2314aafc7ffde89ff5585cd68fd6d4d66c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 4 Mar 2026 08:28:08 -0800 Subject: [PATCH] Revert "Add MI355X details (#108)" This reverts commit 9f717a2761f7ee56b5bd6ed4fdc03df21c9863c3. --- problems/amd_202602.yaml | 19 - problems/amd_202602/eval.py | 387 -------------------- problems/amd_202602/mixed-mla/README.md | 200 ---------- problems/amd_202602/mixed-mla/reference.py | 378 ------------------- problems/amd_202602/mixed-mla/submission.py | 186 ---------- problems/amd_202602/mixed-mla/task.py | 38 -- problems/amd_202602/mixed-mla/task.yml | 98 ----- problems/amd_202602/moe-mxfp4/README.md | 198 ---------- problems/amd_202602/moe-mxfp4/reference.py | 323 ---------------- problems/amd_202602/moe-mxfp4/submission.py | 66 ---- problems/amd_202602/moe-mxfp4/task.py | 28 -- problems/amd_202602/moe-mxfp4/task.yml | 130 ------- problems/amd_202602/mxfp4-mm/reference.py | 108 ------ problems/amd_202602/mxfp4-mm/submission.py | 32 -- problems/amd_202602/mxfp4-mm/task.py | 21 -- problems/amd_202602/mxfp4-mm/task.yml | 58 --- problems/amd_202602/utils.py | 147 -------- 17 files changed, 2417 deletions(-) delete mode 100644 problems/amd_202602.yaml delete mode 100644 problems/amd_202602/eval.py delete mode 100644 problems/amd_202602/mixed-mla/README.md delete mode 100644 problems/amd_202602/mixed-mla/reference.py delete mode 100644 problems/amd_202602/mixed-mla/submission.py delete mode 100644 problems/amd_202602/mixed-mla/task.py delete mode 100644 problems/amd_202602/mixed-mla/task.yml delete mode 100644 problems/amd_202602/moe-mxfp4/README.md delete mode 100644 problems/amd_202602/moe-mxfp4/reference.py delete mode 100644 problems/amd_202602/moe-mxfp4/submission.py delete mode 100644 problems/amd_202602/moe-mxfp4/task.py delete mode 100644 problems/amd_202602/moe-mxfp4/task.yml delete mode 100644 problems/amd_202602/mxfp4-mm/reference.py delete mode 100644 problems/amd_202602/mxfp4-mm/submission.py delete mode 100644 problems/amd_202602/mxfp4-mm/task.py delete mode 100644 problems/amd_202602/mxfp4-mm/task.yml delete mode 100644 problems/amd_202602/utils.py diff --git a/problems/amd_202602.yaml b/problems/amd_202602.yaml deleted file mode 100644 index a75894f1..00000000 --- a/problems/amd_202602.yaml +++ /dev/null @@ -1,19 +0,0 @@ -name: AMD Developer Challenge February 2026 -deadline: "2026-03-15 06:00" -description: "AMD Developer Challenge: MXFP4 matrix multiplication, Mixture-of-Experts, and Multi-head Latent Attention optimized for MI355X." -problems: - - directory: amd_202602/mxfp4-mm - name: amd-mxfp4-mm - deadline: "2026-03-15 06:00" - gpus: - - MI355X - - directory: amd_202602/moe-mxfp4 - name: amd-moe-mxfp4 - deadline: "2026-03-15 06:00" - gpus: - - MI355X - - directory: amd_202602/mixed-mla - name: amd-mixed-mla - deadline: "2026-03-15 06:00" - gpus: - - MI355X diff --git a/problems/amd_202602/eval.py b/problems/amd_202602/eval.py deleted file mode 100644 index 2df7ef58..00000000 --- a/problems/amd_202602/eval.py +++ /dev/null @@ -1,387 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional - -import torch.cuda - -from utils import set_seed, clear_l2_cache_large as clear_l2_cache -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, 'w') - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a+b)*(a+b+1)//2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - if val == "true": - val = True - elif val == "false": - val = False - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg)**2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), - worst=float(worst)) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def wrap_check_implementation(data, submission_output): - # Old version returned just a single string, new version - # returns (bool, str); this function ensures compatibility with old - # problem definitions. - result = check_implementation(data, submission_output) - if isinstance(result, tuple): - return result - else: - return not bool(result), result - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - data = generate_input(**test.args) - torch.cuda.synchronize() - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return wrap_check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - # first, one obligatory correctness check - output = custom_kernel(data) - good, message = wrap_check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(start_event.elapsed_time(end_event) * 1e6) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: - break - - return calculate_stats(durations) - - -def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, - max_time_ns: float): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 1000, 50e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - mp_context = multiprocessing.get_context('spawn') - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/amd_202602/mixed-mla/README.md b/problems/amd_202602/mixed-mla/README.md deleted file mode 100644 index 24cae9ce..00000000 --- a/problems/amd_202602/mixed-mla/README.md +++ /dev/null @@ -1,200 +0,0 @@ -# MLA (Multi-head Latent Attention) Decode Kernel - -## Description - -Implement a custom MLA attention decode kernel optimized for MI355X. - -This is the **inner attention kernel** from DeepSeek R1's `forward_absorb` MLA path. -The absorbed query and compressed KV cache are provided directly — you implement the -attention computation with variable-length batching. - -The reference uses **aiter MLA a8w8 decode kernel** (`mla_decode_fwd`, fp8 Q + fp8 KV, -persistent mode). On MI355X, a8w8 is ~2-3x faster than bf16 with negligible accuracy loss. -The reference quantizes Q to fp8 on-the-fly and uses pre-quantized fp8 KV from `kv_data["fp8"]`. - -## DeepSeek R1 Forward-Absorb MLA Config - -| Parameter | Value | Notes | -|---|---|---| -| num_heads | 16 | Query heads (after TP split) | -| num_kv_heads | 1 | Single shared latent KV head | -| kv_lora_rank | 512 | Latent dimension | -| qk_rope_head_dim | 64 | RoPE embedding dimension | -| qk_head_dim | 576 | kv_lora_rank + qk_rope_head_dim (absorbed q/k dim) | -| v_head_dim | 512 | = kv_lora_rank (output dim) | -| sm_scale | 1/sqrt(576) | | -| q dtype | bfloat16 | Input always bf16; reference quantizes to fp8 on-the-fly | -| kv dtype | bf16 / fp8 / mxfp4 | All three provided simultaneously | -| mode | decode | q_seq_len=1, kv_seq_len up to 8k | - -## Reference Kernel - -The reference (`ref_kernel`) is configurable via two globals in `reference.py`: - -| `Q_DTYPE` | `KV_DTYPE` | Aiter kernel dispatched | Description | -|---|---|---|---| -| `"fp8"` (default) | `"fp8"` (default) | `mla_a8w8_qh16_qseqlen1_gqaratio16_ps` | fp8 Q + fp8 KV — fastest | -| `"bf16"` | `"fp8"` | `mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + fp8 KV | -| `"bf16"` | `"bf16"` | `mla_a16w16_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + bf16 KV — highest precision | - -**Note**: `Q_DTYPE="fp8"` + `KV_DTYPE="bf16"` is not a valid combination (no a8w16 kernel exists). - -### Reference Latency (MI355X) - -| Case | a8w8 (us) | a16w16 (us) | a8w8 speedup | -|---|---|---|---| -| bs=4, kv=1k | ~118 | ~162 | 1.4x | -| bs=4, kv=8k | ~113 | ~177 | 1.6x | -| bs=64, kv=8k | ~171 | ~353 | 2.1x | -| bs=256, kv=8k | ~349 | ~814 | 2.3x | - -## KV Buffer Format (forward_absorb) - -The compressed KV buffer has `qk_head_dim=576` dimensions: -- **Full 576 dims** are used as **keys** (for Q@K^T score computation) -- **First 512 dims** (kv_lora_rank) are used as **values** (for output computation) - -## KV Cache Quantization - -| dtype | kv_buffer | kv_scale | Quantization | Bandwidth | -|---|---|---|---|---| -| bf16 | bfloat16 `(total_kv, 1, 576)` | None | No quantization | 1x | -| fp8 | fp8 `(total_kv, 1, 576)` | scalar float32 | Dynamic per-tensor (sglang `scaled_fp8_quant`) | 2x savings | -| mxfp4 | fp4x2 `(total_kv, 1, 288)` | fp8_e8m0 `(total_kv, N_blocks)` | Block-32 MXFP4 (aiter `dynamic_mxfp4_quant`) | 4x savings | - -### FP8 quantization (sglang `scaled_fp8_quant`) - -- **Granularity**: per-tensor -- **Scale**: `kv_scale = max(abs(kv_bf16)) / fp8_max` -- **Quantize**: `kv_fp8 = (kv_bf16 / kv_scale).clamp(...).to(fp8)` -- **Dequantize**: `kv_bf16 ≈ kv_fp8.to(bf16) * kv_scale` -- **kv_scale**: scalar float32 tensor - -### MXFP4 quantization (aiter `dynamic_mxfp4_quant`) - -- **Granularity**: per-block of 32 elements -- **FP4 format**: E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 -- **Scale format**: E8M0 — exponent-only scale stored in `aiter.dtypes.fp8_e8m0` -- **Packing**: 2 FP4 values packed per byte (low nibble = even index, high nibble = odd index) -- **kv_buffer**: `(total_kv, 1, 288)` in `aiter.dtypes.fp4x2` — packed FP4 data -- **kv_scale**: `(total_kv, N_blocks)` in `aiter.dtypes.fp8_e8m0` — per-block E8M0 scale factors -- **Dequantize**: `aiter.utility.fp4_utils.mxfp4_to_f32` + `e8m0_to_f32` for block-wise scaling - -### aiter dtype reference - -| Logical type | aiter dtype | PyTorch native (if available) | Fallback | -|---|---|---|---| -| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | -| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | -| fp8 | `aiter.dtypes.fp8` | `torch.float8_e4m3fnuz` (gfx942) / `torch.float8_e4m3fn` (gfx950+) | `torch.uint8` | - -## Input - -A tuple `(q, kv_data, qo_indptr, kv_indptr, config)`: - -``` -q: (total_q, 16, 576) bfloat16 — absorbed queries -kv_data: dict with three KV cache formats (see below) -qo_indptr: (batch_size + 1,) int32 — query segment pointers -kv_indptr: (batch_size + 1,) int32 — KV segment pointers -config: dict — MLA parameters -``` - -### kv_data dict - -All three KV cache formats are provided simultaneously. Each entry is either a -`Tensor` (bf16) or a `(Tensor, Tensor)` tuple (quantized buffer + scale): - -```python -kv_data = { - "bf16": kv_buffer_bf16, # Tensor (total_kv, 1, 576) bfloat16 - "fp8": (kv_buffer_fp8, kv_scale_fp8), # (fp8 Tensor, scalar float32) - "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), # (fp4x2 Tensor, fp8_e8m0 Tensor) -} -``` - -### config dict - -```python -config = { - "batch_size": int, - "num_heads": 16, - "num_kv_heads": 1, - "qk_head_dim": 576, - "kv_lora_rank": 512, - "qk_rope_head_dim": 64, - "v_head_dim": 512, - "q_seq_len": 1, - "kv_seq_len": int, # varies per test case (1024 or 8192) - "sm_scale": 0.04166..., # 1/sqrt(576) -} -``` - -## Output - -``` -attention_output: (total_q, 16, 512) bfloat16 -``` - -## Optimization Opportunities - -The reference is already a highly optimized aiter a8w8 persistent kernel. To beat it, consider: - -1. **MXFP4 KV cache**: 4x bandwidth savings over bf16, 2x over fp8. Two strategies: - - **Strategy A — Fuse dequantization with attention (keep Q in bf16/fp8):** - Load quantized KV tiles from HBM, dequantize in registers/LDS to bf16, and - immediately compute QK^T and softmax·V — never writing the decompressed KV back - to HBM. This eliminates the extra read/write of the bf16 intermediate buffer, - roughly quartering the memory traffic for mxfp4 compared to the naive - dequant-then-attend approach. - - **Strategy B — Quantize Q to match KV precision (full low-precision compute):** - Dynamically quantize Q from bf16 → mxfp4 (per-block scaling), then compute QK^T - entirely in fp4×fp4 using MFMA instructions on MI355X. The softmax is still done - in fp32 for numerical stability, and V accumulation uses fp4×fp4 → fp32. This - trades a small amount of accuracy for significantly higher throughput on the - matrix units. - -2. **Custom split-K / split-batch scheduling**: the aiter kernel uses 32-way KV splits - with reduce; a different split strategy or tile size may be more efficient for certain - batch/seq_len combinations. - -3. **MQA pattern**: 1 KV head shared across 16 query heads — minimize redundant KV loads - by loading KV once and broadcasting across all query heads in shared memory/LDS. - -4. **Variable-length batching**: indptr-based segmented attention across batch elements. - -5. **Split K/V from buffer**: full 576 dims for keys, first 512 for values — potential - for separate tiling strategies for the score and output stages. - -## Accuracy - -Submissions are checked against the a8w8 reference with `rtol=2e-02, atol=8e-03`. - -Measured accuracy of different approaches vs bf16 torch ground truth: - -| Approach | max abs diff | Notes | -|---|---|---| -| aiter a8w8 (reference) | 2.6e-05 — 8.0e-05 | fp8 quantization + kernel accumulation | -| torch fp8 (scaled_mm) | 2e-06 — 1.5e-05 | Closest to bf16 | -| torch mxfp4 | 2.1e-04 — 8.3e-04 | 4-bit quantization noise | - -All approaches are well within the tolerance. - -## Benchmark Cases - -All three KV formats (bf16, fp8, mxfp4) are provided in every test case. - -| batch_size | q_seq_len | kv_seq_len | -|---|---|---| -| 4 | 1 | 1024 | -| 4 | 1 | 8192 | -| 32 | 1 | 1024 | -| 32 | 1 | 8192 | -| 64 | 1 | 1024 | -| 64 | 1 | 8192 | -| 256 | 1 | 1024 | -| 256 | 1 | 8192 | - -Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/mixed-mla/reference.py b/problems/amd_202602/mixed-mla/reference.py deleted file mode 100644 index e1d91a72..00000000 --- a/problems/amd_202602/mixed-mla/reference.py +++ /dev/null @@ -1,378 +0,0 @@ -""" -Reference implementation for MLA (Multi-head Latent Attention) decode kernel. - -Uses aiter MLA kernels (mla_decode_fwd) as the reference. -DeepSeek R1 forward_absorb MLA: absorbed q (576), compressed kv_buffer (576), -output v_head_dim = kv_lora_rank = 512. - -The input provides: - q: (total_q, num_heads, 576) bfloat16 — absorbed query (num_heads = 128 // tp) - kv_data: dict with KV cache in three formats: - "bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision - "fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized - "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 + fp8_e8m0 — block-32 quantized - The reference quantizes Q to fp8 on-the-fly inside ref_kernel. - -The reference kernel quantizes Q to fp8 on-the-fly and uses fp8 KV (a8w8 kernel), -which is ~2-3x faster than bf16 on MI355X with negligible accuracy loss. - -Decode only — persistent mode with get_mla_metadata_v1. -""" - -import torch -import torch.nn.functional as F -from task import input_t, output_t -from utils import make_match_reference - -from aiter.mla import mla_decode_fwd -from aiter import dtypes as aiter_dtypes -from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 -from aiter.utility.fp4_utils import ( - dynamic_mxfp4_quant, - mxfp4_to_f32, - e8m0_to_f32, -) - -# --------------------------------------------------------------------------- -# DeepSeek R1 latent MQA constants (forward_absorb path) -# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json -# --------------------------------------------------------------------------- -TOTAL_NUM_HEADS = 128 -NUM_KV_HEADS = 1 -KV_LORA_RANK = 512 -QK_ROPE_HEAD_DIM = 64 -QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # 576 -V_HEAD_DIM = KV_LORA_RANK # 512 -SM_SCALE = 1.0 / (QK_HEAD_DIM ** 0.5) - -PAGE_SIZE = 1 -NUM_KV_SPLITS = 32 - -# FP8 dtype (platform-specific via aiter) -FP8_DTYPE = aiter_dtypes.fp8 - -# Query dtype for the reference kernel: "fp8" or "bf16" -Q_DTYPE = "fp8" - -# KV cache dtype for the reference kernel: "fp8" or "bf16" -KV_DTYPE = "fp8" - - -# --------------------------------------------------------------------------- -# FP8 quantization (sglang style: dynamic per-tensor) -# --------------------------------------------------------------------------- - -def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Dynamic per-tensor FP8 quantization (following sglang scaled_fp8_quant). - - Args: - tensor: bf16 tensor to quantize - - Returns: - (fp8_tensor, scale) where scale is a scalar float32 tensor. - Dequantize: fp8_tensor.to(bf16) * scale - """ - finfo = torch.finfo(FP8_DTYPE) - amax = tensor.abs().amax().clamp(min=1e-12) - scale = amax / finfo.max - fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) - return fp8_tensor, scale.to(torch.float32).reshape(1) - - -# --------------------------------------------------------------------------- -# MXFP4 quantization (aiter native: block-32, fp4x2 + fp8_e8m0 dtypes) -# Uses aiter.utility.fp4_utils.dynamic_mxfp4_quant -# --------------------------------------------------------------------------- - -def quantize_mxfp4(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - MXFP4 block-wise quantization using aiter's dynamic_mxfp4_quant. - - Block size = 32. Each block gets an E8M0 scale factor. - Two FP4 E2M1 values are packed per byte. - - Args: - tensor: bf16 tensor of shape [B, M, N] (N must be divisible by 32) - - Returns: - (fp4_data, scale_e8m0) - - fp4_data: shape [B, M, N//2] in aiter_dtypes.fp4x2 - - scale_e8m0: shape [B*M, ceil(N/32)] padded, in aiter_dtypes.fp8_e8m0 - """ - orig_shape = tensor.shape # (B, M, N) - B, M, N = orig_shape - - # dynamic_mxfp4_quant expects 2D: (B*M, N) - tensor_2d = tensor.reshape(B * M, N) - fp4_data_2d, scale_e8m0 = dynamic_mxfp4_quant(tensor_2d) - - # Reshape fp4_data back to 3D: (B, M, N//2) - fp4_data = fp4_data_2d.view(B, M, N // 2) - - return fp4_data, scale_e8m0 - - -def dequantize_mxfp4( - fp4_data: torch.Tensor, - scale_e8m0: torch.Tensor, - orig_shape: tuple, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - Dequantize MXFP4 tensor using aiter utilities. - - Note: dynamic_mxfp4_quant may pad both row and block dimensions in scale_e8m0. - We trim scales to match the actual data dimensions. - - Args: - fp4_data: packed FP4 data, shape [B, M, N//2] in fp4x2 or uint8 - scale_e8m0: E8M0 block scale factors (possibly padded) in fp8_e8m0 - orig_shape: original (B, M, N) for reshaping - dtype: output dtype - - Returns: - Dequantized tensor of shape orig_shape. - """ - B, M, N = orig_shape - num_rows = B * M - block_size = 32 - num_blocks = N // block_size # actual blocks needed (e.g. 576/32 = 18) - - # Unpack FP4 to float32: mxfp4_to_f32 expects (..., N//2) -> (..., N) - fp4_data_2d = fp4_data.reshape(num_rows, N // 2) - float_vals = mxfp4_to_f32(fp4_data_2d) # (num_rows, N) - - # Convert E8M0 scales to float32 and trim padded dimensions - scale_f32 = e8m0_to_f32(scale_e8m0) # (padded_rows, padded_blocks) - scale_f32 = scale_f32[:num_rows, :num_blocks] # (num_rows, num_blocks) - - # Apply block scales - float_vals_blocked = float_vals.view(num_rows, num_blocks, block_size) - scaled = float_vals_blocked * scale_f32.unsqueeze(-1) - - return scaled.view(B, M, N).to(dtype) - - -# --------------------------------------------------------------------------- -# Persistent mode metadata helpers -# --------------------------------------------------------------------------- - -def _make_mla_decode_metadata( - batch_size: int, - max_q_len: int, - nhead: int, - nhead_kv: int, - q_dtype: torch.dtype, - kv_dtype: torch.dtype, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, - num_kv_splits: int = NUM_KV_SPLITS, -): - """Allocate and populate work buffers for persistent mla_decode_fwd.""" - info = get_mla_metadata_info_v1( - batch_size, max_q_len, nhead, q_dtype, kv_dtype, - is_sparse=False, fast_mode=False, - num_kv_splits=num_kv_splits, intra_batch_mode=True, - ) - work = [torch.empty(s, dtype=t, device="cuda") for s, t in info] - (work_metadata, work_indptr, work_info_set, - reduce_indptr, reduce_final_map, reduce_partial_map) = work - - # Populate the metadata buffers - get_mla_metadata_v1( - qo_indptr, kv_indptr, kv_last_page_len, - nhead // nhead_kv, # num_heads_per_head_k - nhead_kv, # num_heads_k - True, # is_causal - work_metadata, work_info_set, work_indptr, - reduce_indptr, reduce_final_map, reduce_partial_map, - page_size=PAGE_SIZE, - kv_granularity=max(PAGE_SIZE, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, - max_split_per_batch=num_kv_splits, - intra_batch_mode=True, - dtype_q=q_dtype, - dtype_kv=kv_dtype, - ) - - return { - "work_meta_data": work_metadata, - "work_indptr": work_indptr, - "work_info_set": work_info_set, - "reduce_indptr": reduce_indptr, - "reduce_final_map": reduce_final_map, - "reduce_partial_map": reduce_partial_map, - } - - -# --------------------------------------------------------------------------- -# Aiter reference kernel (decode only) -# --------------------------------------------------------------------------- - -def _aiter_mla_decode( - q: torch.Tensor, - kv_buffer: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - config: dict, - q_scale: torch.Tensor | None = None, - kv_scale: torch.Tensor | None = None, -) -> torch.Tensor: - """ - MLA decode attention using aiter persistent-mode kernel. - - Supports multiple Q/KV dtype combinations: - - Q_DTYPE="fp8": fp8 Q + fp8 KV (a8w8) — fastest on MI355X - - Q_DTYPE="bf16": bf16 Q + bf16 KV (a16w16) — highest precision - - q: (total_q, num_heads, 576) fp8 or bf16 - kv_buffer: (total_kv, 1, 576) fp8 or bf16 - q_scale: scalar float32 (required for fp8 Q, None for bf16) - kv_scale: scalar float32 (required for fp8 KV, None for bf16) - """ - batch_size = config["batch_size"] - nq = config["num_heads"] - nkv = config["num_kv_heads"] - dq = config["qk_head_dim"] - dv = config["v_head_dim"] - q_seq_len = config["q_seq_len"] - - total_kv_len = int(kv_indptr[-1].item()) - kv_indices = torch.arange(total_kv_len, dtype=torch.int32, device="cuda") - - # Reshape kv_buffer to 4D for aiter: (total_kv, page_size, nhead_kv, dim) - kv_buffer_4d = kv_buffer.view(kv_buffer.shape[0], PAGE_SIZE, nkv, kv_buffer.shape[-1]) - - max_q_len = q_seq_len - kv_last_page_len = (kv_indptr[1:] - kv_indptr[:-1]).to(torch.int32) - - # Build persistent-mode metadata - meta = _make_mla_decode_metadata( - batch_size, max_q_len, nq, nkv, - q.dtype, kv_buffer.dtype, - qo_indptr, kv_indptr, kv_last_page_len, - num_kv_splits=NUM_KV_SPLITS, - ) - - o = torch.empty((q.shape[0], nq, dv), dtype=torch.bfloat16, device="cuda") - mla_decode_fwd( - q.view(-1, nq, dq), - kv_buffer_4d, - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - max_q_len, - page_size=PAGE_SIZE, - nhead_kv=nkv, - sm_scale=SM_SCALE, - logit_cap=0.0, - num_kv_splits=NUM_KV_SPLITS, - q_scale=q_scale, - kv_scale=kv_scale, - intra_batch_mode=True, - **meta, - ) - return o - - -# --------------------------------------------------------------------------- -# generate_input / ref_kernel / check_implementation -# --------------------------------------------------------------------------- - -def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: int) -> input_t: - """ - Generate absorbed q and compressed kv_buffer for MLA decode. - - Args: - tp: tensor parallelism degree (4 or 8). num_heads = TOTAL_NUM_HEADS // tp. - - Returns all three KV cache formats in kv_data dict: - kv_data = { - "bf16": Tensor — (total_kv, 1, 576) bfloat16 - "fp8": (Tensor, Tensor) — kv_buffer fp8 + scalar scale - "mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale - } - """ - assert TOTAL_NUM_HEADS % tp == 0, f"TOTAL_NUM_HEADS ({TOTAL_NUM_HEADS}) must be divisible by tp ({tp})" - num_heads = TOTAL_NUM_HEADS // tp - - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - - total_q = batchsize * qseqlen - total_kv = batchsize * kvseqlen - - # Absorbed query: (total_q, num_heads, 576) bf16 - q = torch.randn( - (total_q, num_heads, QK_HEAD_DIM), - dtype=torch.bfloat16, device="cuda", generator=gen, - ) * 0.02 - - # Compressed KV buffer: (total_kv, 1, 576) bf16 — the source of truth - kv_buffer_bf16 = torch.randn( - (total_kv, NUM_KV_HEADS, QK_HEAD_DIM), - dtype=torch.bfloat16, device="cuda", generator=gen, - ) * 0.02 - - # Quantize KV to fp8 - kv_buffer_fp8, kv_scale_fp8 = quantize_fp8(kv_buffer_bf16) - - # Quantize KV to mxfp4 - kv_buffer_mxfp4, kv_scale_mxfp4 = quantize_mxfp4(kv_buffer_bf16) - - # All three KV formats: bf16 is a Tensor, fp8/mxfp4 are (Tensor, Tensor) tuples - kv_data = { - "bf16": kv_buffer_bf16, - "fp8": (kv_buffer_fp8, kv_scale_fp8), - "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), - } - - qo_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * qseqlen - kv_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * kvseqlen - - config = { - "batch_size": batchsize, - "num_heads": num_heads, - "num_kv_heads": NUM_KV_HEADS, - "qk_head_dim": QK_HEAD_DIM, - "kv_lora_rank": KV_LORA_RANK, - "qk_rope_head_dim": QK_ROPE_HEAD_DIM, - "v_head_dim": V_HEAD_DIM, - "q_seq_len": qseqlen, - "kv_seq_len": kvseqlen, - "sm_scale": SM_SCALE, - } - - return (q, kv_data, qo_indptr, kv_indptr, config) - - -def ref_kernel(data: input_t) -> output_t: - """Reference MLA decode attention. Uses Q_DTYPE and KV_DTYPE to select kernel variant.""" - q, kv_data, qo_indptr, kv_indptr, config = data - - # Resolve Q - if Q_DTYPE == "fp8": - q_input, q_scale = quantize_fp8(q) - else: - q_input, q_scale = q, None - - # Resolve KV - if KV_DTYPE == "fp8": - kv_buffer_fp8, kv_scale = kv_data["fp8"] - kv_input = kv_buffer_fp8 - else: - kv_input, kv_scale = kv_data["bf16"], None - - return _aiter_mla_decode( - q_input, kv_input, qo_indptr, kv_indptr, config, - q_scale=q_scale, kv_scale=kv_scale, - ) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) diff --git a/problems/amd_202602/mixed-mla/submission.py b/problems/amd_202602/mixed-mla/submission.py deleted file mode 100644 index e85525b8..00000000 --- a/problems/amd_202602/mixed-mla/submission.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -MLA (Multi-head Latent Attention) decode kernel — submission template. - -Implement custom_kernel() to beat the aiter a8w8 reference (fp8 Q + fp8 KV). - -DeepSeek R1 forward_absorb MLA config: - total_num_heads = 128 (query heads before TP split) - num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16) - num_kv_heads = 1 (shared latent KV head) - kv_lora_rank = 512 (latent dim) - qk_rope_head_dim = 64 (RoPE dim) - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim) - v_head_dim = 512 (= kv_lora_rank, output dim) - sm_scale = 1/sqrt(576) - -KV buffer format (forward_absorb): - - Full 576 dims used as keys (for Q@K^T score computation) - - First 512 dims (kv_lora_rank) used as values (for output computation) - -Input tuple: - q: (total_q, num_heads, 576) bfloat16 — absorbed query - kv_data: dict with three KV cache formats: - kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 - kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 (total_kv,1,576) + scalar scale - kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 (total_kv,1,288) + fp8_e8m0 scale - qo_indptr: (batch_size + 1,) int32 — query segment pointers - kv_indptr: (batch_size + 1,) int32 — KV segment pointers - config: dict with MLA parameters (includes num_heads computed from tp) - -Output: - attention output: (total_q, num_heads, 512) bfloat16 - -The reference uses aiter's a8w8 persistent MLA kernel (fp8 Q + fp8 KV), -which is ~2-3x faster than bf16. To beat it, consider: - 1. Use mxfp4 KV cache for even lower memory bandwidth - - Fuse dequantization with attention to avoid bf16 materialization - 2. Custom kernel with tighter memory access patterns - 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads - 4. Variable-length batching: indptr-based segmented attention - 5. Split K/V from buffer: full 576 dims for keys, first 512 dims for values -""" - -import torch -import torch.nn.functional as F -from task import input_t, output_t - -from aiter import dtypes as aiter_dtypes -FP8_DTYPE = aiter_dtypes.fp8 - -# QKV dtype for custom_kernel dispatch: "bf16", "fp8", or "mxfp4" -QKV_DTYPE = "fp8" - - -# --------------------------------------------------------------------------- -# Dispatcher: select kernel based on QKV_DTYPE -# --------------------------------------------------------------------------- - -def custom_kernel(data: input_t) -> output_t: - """Dispatch to the appropriate kernel based on QKV_DTYPE.""" - if QKV_DTYPE == "fp8": - return custom_kernel_fp8(data) - elif QKV_DTYPE == "bf16": - return custom_kernel_bf16(data) - else: - raise ValueError(f"Invalid QKV_DTYPE: {QKV_DTYPE}") - -# --------------------------------------------------------------------------- -# FP8 quantization helper (per-tensor, sglang style) -# --------------------------------------------------------------------------- - -def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Dynamic per-tensor FP8 quantization. Returns (fp8_tensor, scale).""" - finfo = torch.finfo(FP8_DTYPE) - amax = tensor.abs().amax().clamp(min=1e-12) - scale = amax / finfo.max - fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) - return fp8_tensor, scale.to(torch.float32).reshape(1) - - -# --------------------------------------------------------------------------- -# Baseline: bf16 Q + bf16 KV — naive torch attention -# --------------------------------------------------------------------------- - -def custom_kernel_bf16(data: input_t) -> output_t: - q, kv_data, qo_indptr, kv_indptr, config = data - - num_heads = config["num_heads"] - kv_lora_rank = config["kv_lora_rank"] - sm_scale = config["sm_scale"] - - # This naive baseline uses bf16 KV directly. - # For better performance, use kv_data["fp8"] or kv_data["mxfp4"] - # which are (kv_buffer, kv_scale) tuples. See docstring for optimization opportunities. - kv_buffer_bf16 = kv_data["bf16"] - - batch_size = qo_indptr.shape[0] - 1 - out_list = [] - - for i in range(batch_size): - q_s, q_e = int(qo_indptr[i].item()), int(qo_indptr[i + 1].item()) - kv_s, kv_e = int(kv_indptr[i].item()), int(kv_indptr[i + 1].item()) - - qi = q[q_s:q_e] # (seq_q, nhead, 576) - kvc = kv_buffer_bf16[kv_s:kv_e, 0] # (seq_kv, 576) squeeze kv_heads dim - - # Key: full 576 dims; Value: first 512 dims (kv_lora_rank) - ki = kvc # (seq_kv, 576) — broadcast over heads - vi = kvc[:, :kv_lora_rank] # (seq_kv, 512) - - # Attention: (nhead, seq_q, 576) @ (576, seq_kv) → (nhead, seq_q, seq_kv) - qi_t = qi.float().permute(1, 0, 2) # (nhead, seq_q, 576) - scores = torch.matmul(qi_t * sm_scale, ki.float().T) # (nhead, seq_q, seq_kv) - - scores = F.softmax(scores, dim=-1) - - # Output: (nhead, seq_q, seq_kv) @ (seq_kv, 512) → (nhead, seq_q, 512) - oi = torch.matmul(scores, vi.float()) # (nhead, seq_q, 512) - oi = oi.permute(1, 0, 2) # (seq_q, nhead, 512) - out_list.append(oi.to(torch.bfloat16)) - - return torch.cat(out_list, dim=0) - - - - -# --------------------------------------------------------------------------- -# FP8 Q + FP8 KV — torch._scaled_mm based attention -# -# Quantize Q to fp8, use fp8 KV from kv_data["fp8"]. -# QK^T and softmax@V both use torch._scaled_mm for fp8×fp8 matmul. -# --------------------------------------------------------------------------- - -def custom_kernel_fp8(data: input_t) -> output_t: - q, kv_data, qo_indptr, kv_indptr, config = data - - num_heads = config["num_heads"] - kv_lora_rank = config["kv_lora_rank"] - qk_head_dim = config["qk_head_dim"] - sm_scale = config["sm_scale"] - - # FP8 KV buffer and scale - kv_buffer_fp8, kv_scale_fp8 = kv_data["fp8"] - kv_fp8_2d = kv_buffer_fp8.view(-1, qk_head_dim) # (total_kv, 576) fp8 - - # Quantize Q to fp8 - q_fp8, q_scale = quantize_fp8(q) # q_fp8: (total_q, 16, 576) fp8 - - batch_size = qo_indptr.shape[0] - 1 - out_list = [] - - scale_one = torch.ones(1, dtype=torch.float32, device="cuda") - - for i in range(batch_size): - q_s, q_e = int(qo_indptr[i].item()), int(qo_indptr[i + 1].item()) - kv_s, kv_e = int(kv_indptr[i].item()), int(kv_indptr[i + 1].item()) - seq_q = q_e - q_s - seq_kv = kv_e - kv_s - - # Q: (seq_q * nhead, 576) fp8, K: (seq_kv, 576) fp8 - qi_fp8 = q_fp8[q_s:q_e].reshape(seq_q * num_heads, qk_head_dim) # (seq_q*16, 576) - ki_fp8 = kv_fp8_2d[kv_s:kv_e] # (seq_kv, 576) - - # QK^T via _scaled_mm: (seq_q*16, 576) @ (seq_kv, 576).T -> (seq_q*16, seq_kv) - # _scaled_mm expects (M,K) @ (N,K).T where b is row-major contiguous - raw_scores = torch._scaled_mm( - qi_fp8, ki_fp8.t(), - scale_a=q_scale, scale_b=kv_scale_fp8, - out_dtype=torch.float32, - ) - # raw_scores: (seq_q*16, seq_kv) - scores = raw_scores.view(seq_q, num_heads, seq_kv).permute(1, 0, 2) # (nhead, seq_q, seq_kv) - scores = scores * sm_scale - scores = F.softmax(scores, dim=-1) - - # V: first 512 dims of KV buffer (bf16 for softmax@V since scores are float) - kv_bf16 = kv_data["bf16"] - vi = kv_bf16[kv_s:kv_e, 0, :kv_lora_rank].float() # (seq_kv, 512) - - # softmax @ V: (nhead, seq_q, seq_kv) @ (seq_kv, 512) -> (nhead, seq_q, 512) - oi = torch.matmul(scores, vi) - oi = oi.permute(1, 0, 2) # (seq_q, nhead, 512) - out_list.append(oi.to(torch.bfloat16)) - - return torch.cat(out_list, dim=0) - - diff --git a/problems/amd_202602/mixed-mla/task.py b/problems/amd_202602/mixed-mla/task.py deleted file mode 100644 index a4c13320..00000000 --- a/problems/amd_202602/mixed-mla/task.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from typing import TypeVar, TypedDict, Union - -# DeepSeek R1 MLA forward_absorb format: -# -# Input: (q, kv_data, qo_indptr, kv_indptr, config) -# q: (total_q, num_heads, qk_head_dim) bfloat16 -# num_heads = 128 // tp (tp=4 → 32, tp=8 → 16) -# kv_data: dict with three KV cache formats: -# "bf16": Tensor (total_kv, 1, 576) bfloat16 -# "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale -# "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale -# qo_indptr: (batch_size + 1,) int32 -# kv_indptr: (batch_size + 1,) int32 -# config: dict with MLA parameters (includes num_heads computed from tp) -# -# where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576 -# -# Output: attention output tensor (total_q, num_heads, v_head_dim) bfloat16 -# where v_head_dim = kv_lora_rank = 512 -# -# The kv_buffer stores the compressed KV representation: -# - Full 576 dims used as keys (for Q@K^T score computation) -# - First 512 dims (kv_lora_rank) used as values (for output computation) - -input_t = TypeVar( - "input_t", - bound=tuple[torch.Tensor, dict, torch.Tensor, torch.Tensor, dict], -) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec(TypedDict): - batchsize: int - qseqlen: int - kvseqlen: int - tp: int - seed: int diff --git a/problems/amd_202602/mixed-mla/task.yml b/problems/amd_202602/mixed-mla/task.yml deleted file mode 100644 index a7d03ad7..00000000 --- a/problems/amd_202602/mixed-mla/task.yml +++ /dev/null @@ -1,98 +0,0 @@ -# name: mla-py - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement a custom MLA (Multi-head Latent Attention) decode kernel optimized for MI355X. - - This is the inner attention kernel from DeepSeek R1's forward_absorb MLA path. - The absorbed query and compressed KV cache are provided directly — you only need to - implement the **attention** computation with variable-length batching (indptr). - - The reference uses aiter a8w8 MLA decode kernel (mla_decode_fwd, fp8 Q + fp8 KV, - persistent mode), which is ~2-3x faster than bf16 on MI355X. - - DeepSeek R1 forward_absorb MLA config: - - total_num_heads = 128 (query heads before TP split) - - num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16) - - num_kv_heads = 1 (shared latent KV head) - - kv_lora_rank = 512 - - qk_rope_head_dim = 64 - - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim) - - v_head_dim = 512 (= kv_lora_rank, output dim) - - sm_scale = 1/sqrt(576) - - dtype: q=bfloat16 - - q_seq_len = 1 or 4, kv_seq_len up to 8k - - KV buffer format (forward_absorb): - - Full 576 dims are used as keys (for Q@K^T score computation) - - First 512 dims (kv_lora_rank) are used as values (for output computation) - - Input tuple: (q, kv_data, qo_indptr, kv_indptr, config) - - q: (total_q, num_heads, 576) bfloat16 — absorbed query - - kv_data: dict with three KV cache formats: - kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 - kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale - kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale - - qo_indptr: (batch_size+1,) int32 — query segment pointers - - kv_indptr: (batch_size+1,) int32 — KV segment pointers - - config: dict with MLA parameters (includes num_heads computed from tp) - - Return: - - attention output: (total_q, num_heads, 512) bfloat16 - - Key optimization opportunities: - 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16) - - Fuse dequantization with attention to skip bf16 materialization - 2. Custom kernel with tighter memory access patterns - 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads - 4. q_seq_len=1 or 4, kv_seq_len up to 8k — memory-bound workload - 5. Variable-length batching: indptr-based segmented attention - 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values - - The ranking criteria is the geometric mean of the benchmark results. - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 1800 -benchmark_timeout: 1800 -ranked_timeout: 1800 - -tests: - # bs=4, tp=8 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220} - - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231} - # bs=32, tp=4 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412} - - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423} - # bs=128, tp=8 - - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816} - - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827} - -benchmarks: - # bs=4, tp=4 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237} - - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251} - # bs=32, tp=8 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415} - - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420} - # bs=32, tp=4 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432} - - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443} - # bs=128, tp=8 - - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816} - - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824} - - -ranking_by: "geom" diff --git a/problems/amd_202602/moe-mxfp4/README.md b/problems/amd_202602/moe-mxfp4/README.md deleted file mode 100644 index ab664f7d..00000000 --- a/problems/amd_202602/moe-mxfp4/README.md +++ /dev/null @@ -1,198 +0,0 @@ -# MXFP4 Mixture-of-Experts (MoE) Fused Kernel - -## Description - -Implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. - -The kernel fuses the complete MoE forward pass into a 2-stage pipeline: -1. **Stage 1**: MXFP4 GEMM (gate+up projection) + SwiGLU activation -2. **Stage 2**: MXFP4 GEMM (down projection) + weighted reduction across top-k experts - -The reference uses **AITER `fused_moe`** with `QuantType.per_1x32` (MXFP4 block scaling, block_size=32). - -## DeepSeek-R1 MoE Architecture - -| Parameter | Value | Notes | -|---|---|---| -| hidden_size | 7168 | Model hidden dimension | -| moe_intermediate_size | 2048 | Per-expert intermediate dimension | -| n_routed_experts | 256 | Routed experts (EP-off) or 32 (EP-on, 8-way split) | -| n_shared_experts | 1 | Always selected with weight=1.0 | -| top_k (routed) | 8 | Routed experts per token | -| total_top_k | 9 | 8 routed + 1 shared | -| MoE layers | 58 | Layers 3–60 | - -## Kernel Flow - -For each token `i` and each assigned expert `j`: - -``` -(1) Quant activations: hidden_states -> MXFP4 (aiter per-1x32 dynamic quantization) - -(2) Stage 1 GEMM + SwiGLU activation: - gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - intermediate = SiLU(gate) * up # SwiGLU activation -> [d_expert] - (W_gate and W_up are fused as gate_up_weight: one a4w4 GEMM + fused activation) - -(3) Stage 2 GEMM: - expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] - -(4) Weighted reduction: - output_i += w_ij * expert_out # accumulate across top_k experts -``` - -All weight GEMMs are **a4w4** (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). -The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. - -## Weight Layout & Pre-shuffling - -Weights are provided in two layouts: - -| Layout | Description | Use case | -|---|---|---| -| **Raw** | Original MXFP4 quantized weights | PyTorch reference / custom kernels | -| **Pre-shuffled** | `shuffle_weight(w, layout=(16,16))` + `e8m0_shuffle(scale)` | AITER CK kernel (tile-coalesced layout) | - -The (16,16) shuffle rearranges weight tiles for coalesced memory access by CK GEMM instructions. -Scale shuffling (`e8m0_shuffle`) reorders E8M0 block scales to match the shuffled weight layout. - -You may use either layout — raw weights if you implement your own tiling, or pre-shuffled weights -for direct use with AITER/CK kernels. - -## MXFP4 Quantization Details - -| Property | Value | -|---|---| -| FP4 format | E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 | -| Scale format | E8M0 — exponent-only (power-of-2 scale) | -| Block size | 32 elements per scale | -| Packing | 2 FP4 values per byte (`fp4x2`): low nibble = even index, high nibble = odd index | -| Padding | Dimensions padded to 256-alignment for CK kernel | - -### aiter dtype reference - -| Logical type | aiter dtype | PyTorch native (if available) | Fallback | -|---|---|---|---| -| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | -| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | - -## Input - -A tuple of tensors and a config dict: - -``` -(hidden_states, - gate_up_weight, down_weight, # fp4x2 raw - gate_up_weight_scale, down_weight_scale, # e8m0 raw - gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled - gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled - topk_weights, topk_ids, - config) -``` - -### Tensor shapes - -| Tensor | Shape | Dtype | Notes | -|---|---|---|---| -| `hidden_states` | `[M, d_hidden]` | bfloat16 | Input activations (M = batch of tokens) | -| `gate_up_weight` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Fused gate+up weights, raw | -| `down_weight` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Down projection weights, raw | -| `gate_up_weight_scale` | `[E, 2*d_expert_pad, d_hidden_pad//32]` | e8m0 | Block scales for gate_up, raw | -| `down_weight_scale` | `[E, d_hidden_pad, d_expert_pad//32]` | e8m0 | Block scales for down, raw | -| `gate_up_weight_shuffled` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Pre-shuffled for CK | -| `down_weight_shuffled` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Pre-shuffled for CK | -| `gate_up_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | -| `down_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | -| `topk_weights` | `[M, total_top_k]` | float32 | Routing weights | -| `topk_ids` | `[M, total_top_k]` | int32 | Expert indices (see below) | - -### topk_ids format - -- First `n_experts_per_token` columns: routed expert IDs `[0, n_routed_experts)` -- Last `n_shared_experts` columns: shared expert IDs `[n_routed_experts, n_routed_experts + n_shared_experts)` -- Shared experts are always selected with weight = 1.0 - -### config dict - -```python -config = { - "d_hidden": int, # hidden dimension (e.g. 7168) - "d_expert": int, # expert intermediate dimension (e.g. 2048 or 256) - "d_hidden_pad": int, # d_hidden padded to 256-alignment - "d_expert_pad": int, # d_expert padded to 256-alignment - "n_routed_experts": int, # number of routed experts - "n_shared_experts": int, # number of shared experts (1) - "n_experts_per_token": int, # routed top-k (8) - "total_top_k": int, # routed + shared (9) - "bs": int, # batch size (number of tokens) -} -``` - -## Output - -``` -output: [M, d_hidden] bfloat16 -``` - -## Reference Performance - -AITER `fused_moe` with MXFP4 (E includes shared expert, top_k = routed + shared): - -| bs | E | d_hidden | d_expert | top_k | time (us) | -|---|---|---|---|---|---| -| 4 | 257 | 7168 | 256 | 9 | 46.9 | -| 64 | 257 | 7168 | 256 | 9 | 187.7 | -| 256 | 257 | 7168 | 256 | 9 | 245.7 | -| 64 | 33 | 7168 | 2048 | 9 | 220.6 | -| 256 | 33 | 7168 | 2048 | 9 | 276.4 | -| 1024 | 33 | 7168 | 2048 | 9 | 572.2 | - -## Optimization Opportunities - -The AITER CK `fused_moe` kernel is already well-optimized. To beat it, consider: - -1. **Custom tiling / scheduling**: The CK kernel uses a fixed tile strategy. For small batch sizes - (bs=4) or highly skewed expert distributions, a custom schedule may reduce idle waves. - -2. **Activation quantization fusion**: The reference quantizes activations separately before the - GEMM. Fusing dynamic MXFP4 quantization into the Stage 1 GEMM prologue saves one global - memory round-trip. - -3. **Inter-stage fusion**: The reference runs Stage 1 and Stage 2 as separate kernel launches. - Fusing both stages (gate_up GEMM → SwiGLU → down GEMM → accumulate) into a single kernel - eliminates the intermediate buffer write/read between stages. - -4. **Expert-parallel wave scheduling**: With 257 experts but only 9 active per token, most - expert slots are empty. A work-stealing or compact-dispatch strategy can minimize wasted - wavefronts. - -5. **Shared expert fusion**: The shared expert is always selected for all tokens. It could be - computed as a dense GEMM (no routing overhead) and fused with the routed expert reduction. - -6. **Split-K for large M**: For bs=1024 with EP-on (E=33, d_expert=2048), the GEMMs are large - enough to benefit from split-K parallelism within each expert. - -## Accuracy - -Submissions are checked against the AITER reference with `rtol=1e-2, atol=1e-2`. - -## Benchmark Cases - -### EP-off (all 257 experts on 1 GPU, d_expert=256) - -| bs | E | d_hidden | d_expert | top_k | -|---|---|---|---|---| -| 4 | 257 | 7168 | 256 | 9 | -| 64 | 257 | 7168 | 256 | 9 | -| 256 | 257 | 7168 | 256 | 9 | - -### EP-on (EP=8, 33 experts per GPU, d_expert=2048) - -| bs | E | d_hidden | d_expert | top_k | -|---|---|---|---|---| -| 64 | 33 | 7168 | 2048 | 9 | -| 256 | 33 | 7168 | 2048 | 9 | -| 1024 | 33 | 7168 | 2048 | 9 | - -Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/moe-mxfp4/reference.py b/problems/amd_202602/moe-mxfp4/reference.py deleted file mode 100644 index 3bdb8fcd..00000000 --- a/problems/amd_202602/moe-mxfp4/reference.py +++ /dev/null @@ -1,323 +0,0 @@ -from utils import make_match_reference -from task import input_t, output_t -import torch -import torch.nn.functional as F -from typing import Dict, Tuple, Optional -import math - -from aiter import ActivationType, QuantType -from aiter.fused_moe import fused_moe -from aiter.utility import fp4_utils -from aiter.ops.shuffle import shuffle_weight - - -# ────────────────────────────────────────────────────────────────────── -# Constants -# ────────────────────────────────────────────────────────────────────── -MXFP4_BLOCK_SIZE = 32 -PAD_ALIGN = 256 - - -def _pad_to(x: int, align: int) -> int: - return (x + align - 1) // align * align - - -# ────────────────────────────────────────────────────────────────────── -# generate_input: produce all tensors needed by ref_kernel -# -# Models DeepSeek-R1 MoE layer shapes: -# - d_hidden = 7168 -# - d_expert = moe_intermediate_size (full=2048, or TP-split) -# - E_total = n_routed_experts + n_shared_experts (257 or 33) -# - top_k_total = nexpertspertoken + nsharedexperts (8+1=9) -# -# ────────────────────────────────────────────────────────────────────── -def generate_input( - dhidden: int, - dexpert: int, - nroutedexperts: int, - nexpertspertoken: int, - nsharedexperts: int, - bs: int, - seed: int, -) -> input_t: - d_hidden = dhidden - d_expert = dexpert - n_routed_experts = nroutedexperts - n_shared_experts = nsharedexperts - routed_top_k = nexpertspertoken - total_top_k = routed_top_k + n_shared_experts # e.g. 8 + 1 = 9 - E_total = n_routed_experts + n_shared_experts # e.g. 256 + 1 = 257 - M = bs # number of tokens - - # Padded dimensions (AITER MXFP4 requires 256-alignment) - d_hidden_pad = _pad_to(d_hidden, PAD_ALIGN) - d_expert_pad = _pad_to(d_expert, PAD_ALIGN) - - config = { - "d_hidden": d_hidden, - "d_expert": d_expert, - "d_hidden_pad": d_hidden_pad, - "d_expert_pad": d_expert_pad, - "n_routed_experts": n_routed_experts, - "n_shared_experts": n_shared_experts, - "n_experts_per_token": routed_top_k, - "total_top_k": total_top_k, - "bs": M, - } - - gen = torch.Generator(device='cuda') - gen.manual_seed(seed) - - # ── hidden_states [M, d_hidden] ── - hidden_states = torch.randn( - (M, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, - ) - - # ── Router: softmax top-k (routed experts only) ── - router_weight = torch.randn( - (n_routed_experts, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, - ) / math.sqrt(d_hidden) - router_logits = F.linear(hidden_states, router_weight) # [M, n_routed_experts] - scores = router_logits.softmax(dim=-1) - routed_weights, routed_ids = torch.topk( - scores, k=routed_top_k, dim=-1, sorted=False - ) - routed_weights = routed_weights.to(torch.float32) - routed_ids = routed_ids.to(torch.int32) - - # ── Append shared expert(s): always selected, weight = 1.0 ── - # Shared experts are indexed as n_routed_experts, n_routed_experts+1, ... - shared_ids = torch.arange( - n_routed_experts, E_total, device='cuda', dtype=torch.int32 - ).unsqueeze(0).expand(M, -1) # [M, n_shared_experts] - shared_weights = torch.ones( - (M, n_shared_experts), device='cuda', dtype=torch.float32 - ) - - topk_ids = torch.cat([routed_ids, shared_ids], dim=-1) # [M, total_top_k] - topk_weights = torch.cat([routed_weights, shared_weights], dim=-1) # [M, total_top_k] - - # ── Expert weights: bf16 -> quantize to MXFP4 ── - # Generate weights for ALL experts (routed + shared) - # gate_up = fused [gate_proj; up_proj] per expert: [2*d_expert_pad, d_hidden_pad] - # down = down_proj per expert: [d_hidden_pad, d_expert_pad] - gate_up_q_list, gate_up_s_list = [], [] - down_q_list, down_s_list = [], [] - - for _ in range(E_total): - # gate_proj + up_proj -> fused [2*d_expert_pad, d_hidden_pad] - gate_bf16 = torch.randn( - (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen - ) / math.sqrt(d_hidden) - up_bf16 = torch.randn( - (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen - ) / math.sqrt(d_hidden) - gate_up_bf16 = torch.cat([gate_bf16, up_bf16], dim=0) - - # down_proj -> [d_hidden_pad, d_expert_pad] - down_bf16 = torch.randn( - (d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen - ) / math.sqrt(d_expert) - - # Quantize to MXFP4 - gu_q, gu_s = fp4_utils.dynamic_mxfp4_quant(gate_up_bf16) - dn_q, dn_s = fp4_utils.dynamic_mxfp4_quant(down_bf16) - - gate_up_q_list.append(gu_q) - gate_up_s_list.append(gu_s) - down_q_list.append(dn_q) - down_s_list.append(dn_s) - - # Stack into [E_total, ...] tensors — raw (before shuffle) - gate_up_weight = torch.stack(gate_up_q_list) # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 - gate_up_weight_scale = torch.stack(gate_up_s_list) # [E_total, 2*d_expert_pad, scale_K] e8m0 - down_weight = torch.stack(down_q_list) # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 - down_weight_scale = torch.stack(down_s_list) # [E_total, d_hidden_pad, scale_K] e8m0 - - # Pre-shuffled weight. You can also shuffle the weights yourself before calling the kernel. - gate_up_weight_shuffled = shuffle_weight(gate_up_weight.clone()) - down_weight_shuffled = shuffle_weight(down_weight.clone()) - gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale.reshape(E_total, -1)) - down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale.reshape(E_total, -1)) - - return ( - hidden_states, # [M, d_hidden] bf16 - gate_up_weight, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - down_weight, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - gate_up_weight_scale, # [E_total, 2*d_expert_pad, scale_K] e8m0 (raw) - down_weight_scale, # [E_total, d_hidden_pad, scale_K] e8m0 (raw) - gate_up_weight_shuffled, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) - down_weight_shuffled, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) - gate_up_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) - down_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) - topk_weights, # [M, total_top_k] float32 - topk_ids, # [M, total_top_k] int32 - config, - ) - - - - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) -# ────────────────────────────────────────────────────────────────────── -def _dequant_mxfp4(weight_fp4, scale_e8m0): - """ - Dequantize MXFP4 weight to float32. - - weight_fp4: [N, K//2] fp4x2 (raw, not shuffled) - scale_e8m0: [padded_N, ceil(K/32)] e8m0 (M-dim padded to 256-align by dynamic_mxfp4_quant) - - Returns: [N, K] float32 - """ - # fp4x2 -> float32 lookup: [N, K] - w_f32 = fp4_utils.mxfp4_to_f32(weight_fp4) # [N, K] - # e8m0 -> float32 power-of-2 scale: [padded_N, scale_K] - s_f32 = fp4_utils.e8m0_to_f32(scale_e8m0) # [padded_N, scale_K] - N, K = w_f32.shape - # Trim scale rows to match weight rows (scale M-dim is padded to 256) - s_f32 = s_f32[:N, :] - # Broadcast scale across block_size=32 columns - s_f32 = s_f32.repeat_interleave(MXFP4_BLOCK_SIZE, dim=-1)[:, :K] # [N, K] - return w_f32 * s_f32 - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) -# will not run. only for reference -# ────────────────────────────────────────────────────────────────────── -def ref_kernel_pytorch(data: input_t) -> output_t: - """ - Pure PyTorch reference: dequantize MXFP4 weights -> bf16 matmul -> SwiGLU -> matmul. - Uses the raw (un-shuffled) weights. - """ - ( - hidden_states, # [M, d_hidden] bf16 - gate_up_weight, # [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 - down_weight, # [E, d_hidden_pad, d_expert_pad//2] fp4x2 - gate_up_weight_scale, # [E, 2*d_expert_pad, scale_K] e8m0 - down_weight_scale, # [E, d_hidden_pad, scale_K] e8m0 - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, # [M, top_k] float32 - topk_ids, # [M, top_k] int32 - config, - ) = data - - d_hidden = config["d_hidden"] - d_expert = config["d_expert"] - d_hidden_pad = config["d_hidden_pad"] - d_expert_pad = config["d_expert_pad"] - M = hidden_states.shape[0] - top_k = topk_ids.shape[1] - E = gate_up_weight.shape[0] - - # Dequantize all expert weights to float32 - # gate_up: [E, 2*d_expert_pad, d_hidden_pad] -> trim to [E, 2*d_expert, d_hidden] - # down: [E, d_hidden_pad, d_expert_pad] -> trim to [E, d_hidden, d_expert] - gate_up_dq = torch.stack([ - _dequant_mxfp4(gate_up_weight[e], gate_up_weight_scale[e]) - for e in range(E) - ]) # [E, 2*d_expert_pad, d_hidden_pad] - gate_up_dq = gate_up_dq[:, :2 * d_expert, :d_hidden].to(torch.bfloat16) - - down_dq = torch.stack([ - _dequant_mxfp4(down_weight[e], down_weight_scale[e]) - for e in range(E) - ]) # [E, d_hidden_pad, d_expert_pad] - down_dq = down_dq[:, :d_hidden, :d_expert].to(torch.bfloat16) - - # Split gate_up -> gate [E, d_expert, d_hidden], up [E, d_expert, d_hidden] - gate_w, up_w = gate_up_dq.chunk(2, dim=1) # each [E, d_expert, d_hidden] - - # Per-token MoE forward - output = torch.zeros((M, d_hidden), dtype=torch.bfloat16, device=hidden_states.device) - - for i in range(M): - x = hidden_states[i] # [d_hidden] - for k in range(top_k): - eid = topk_ids[i, k].item() - w = topk_weights[i, k].item() - - # Stage 1: gate_proj + up_proj + SwiGLU - gate_out = F.silu(x @ gate_w[eid].T) # [d_expert] - up_out = x @ up_w[eid].T # [d_expert] - intermediate = gate_out * up_out # [d_expert] - - # Stage 2: down_proj - # down_dq[eid] is [d_hidden, d_expert], .T is [d_expert, d_hidden] - expert_out = intermediate @ down_dq[eid].T # [d_hidden] - - output[i] += w * expert_out - - return output - - - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel: calls AITER fused_moe with MXFP4 quantized weights -# ────────────────────────────────────────────────────────────────────── -def ref_kernel(data: input_t) -> output_t: - """ - Reference implementation using AITER's fused_moe kernel with MXFP4 quantized weights. - - Input data tuple (E = n_routed_experts + n_shared_experts, total_top_k = routed + shared): - hidden_states: [M, d_hidden] bf16 - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) - gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw, before shuffle) - down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw, before shuffle) - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) - gate_up_weight_scale_shuffled:[padded, flat] e8m0 (pre-shuffled) - down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled) - topk_weights: [M, total_top_k] float32 - topk_ids: [M, total_top_k] int32 - config: dict - - Returns: - output: [M, d_hidden] bf16 - """ - ( - hidden_states, - gate_up_weight, - down_weight, - gate_up_weight_scale, - down_weight_scale, - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, - topk_ids, - config, - ) = data - - hidden_pad = config["d_hidden_pad"] - config["d_hidden"] - intermediate_pad = config["d_expert_pad"] - config["d_expert"] - - output = fused_moe( - hidden_states, - gate_up_weight_shuffled, - down_weight_shuffled, - topk_weights, - topk_ids, - expert_mask=None, - activation=ActivationType.Silu, - quant_type=QuantType.per_1x32, # MXFP4 uses per_1x32 block scaling - doweight_stage1=False, - w1_scale=gate_up_weight_scale_shuffled, - w2_scale=down_weight_scale_shuffled, - a1_scale=None, - a2_scale=None, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ) - - return output - - - -check_implementation = make_match_reference(ref_kernel, rtol=5e-2, atol=5e-2) diff --git a/problems/amd_202602/moe-mxfp4/submission.py b/problems/amd_202602/moe-mxfp4/submission.py deleted file mode 100644 index a771b32c..00000000 --- a/problems/amd_202602/moe-mxfp4/submission.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -from typing import Dict -from task import input_t, output_t - -from aiter import ActivationType, QuantType -from aiter.fused_moe import fused_moe - - -def custom_kernel(data: input_t) -> output_t: - """ - Submission template for DeepSeek-R1 MXFP4 MoE kernel. - - Input data tuple: - hidden_states: [M, d_hidden] bf16 - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw) - down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw) - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) - gate_up_weight_scale_shuffled:[padded, flat] e8m0 (shuffled) - down_weight_scale_shuffled: [padded, flat] e8m0 (shuffled) - topk_weights: [M, total_top_k] float32 - topk_ids: [M, total_top_k] int32 - config: dict - - Returns: - output: [M, d_hidden] bf16 - """ - ( - hidden_states, - gate_up_weight, - down_weight, - gate_up_weight_scale, - down_weight_scale, - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, - topk_ids, - config, - ) = data - - hidden_pad = config["d_hidden_pad"] - config["d_hidden"] - intermediate_pad = config["d_expert_pad"] - config["d_expert"] - - output = fused_moe( - hidden_states, - gate_up_weight_shuffled, - down_weight_shuffled, - topk_weights, - topk_ids, - expert_mask=None, - activation=ActivationType.Silu, - quant_type=QuantType.per_1x32, - doweight_stage1=False, - w1_scale=gate_up_weight_scale_shuffled, - w2_scale=down_weight_scale_shuffled, - a1_scale=None, - a2_scale=None, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ) - - return output diff --git a/problems/amd_202602/moe-mxfp4/task.py b/problems/amd_202602/moe-mxfp4/task.py deleted file mode 100644 index a19edc83..00000000 --- a/problems/amd_202602/moe-mxfp4/task.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import TypeVar, Tuple, Dict -import torch - -input_t = TypeVar("input_t", bound=Tuple[ - torch.Tensor, # hidden_states [M, d_hidden] - torch.Tensor, # gate_up_weight [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - torch.Tensor, # down_weight [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - torch.Tensor, # gate_up_weight_scale [E, 2*d_expert_pad, scale_K] e8m0 (raw) - torch.Tensor, # down_weight_scale [E, d_hidden_pad, scale_K] e8m0 (raw) - torch.Tensor, # gate_up_weight_shuffled [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) - torch.Tensor, # down_weight_shuffled [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) - torch.Tensor, # gate_up_weight_scale_shuffled [padded, flat] e8m0 (shuffled) - torch.Tensor, # down_weight_scale_shuffled [padded, flat] e8m0 (shuffled) - torch.Tensor, # topk_weights [M, total_top_k] - torch.Tensor, # topk_ids [M, total_top_k] - Dict, # config -]) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec: - dhidden: int # hidden dimension (7168 for DeepSeek-R1) - dexpert: int # intermediate dimension per expert (per partition) - nroutedexperts: int # number of local routed experts on this GPU - nexpertspertoken: int # top-k routed experts per token (8 for DeepSeek-R1) - nsharedexperts: int # number of shared experts (1 for DeepSeek-R1), always selected - bs: int # batch size = number of tokens in this batch - seed: int diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml deleted file mode 100644 index 2e5d3608..00000000 --- a/problems/amd_202602/moe-mxfp4/task.yml +++ /dev/null @@ -1,130 +0,0 @@ -# name: 3_moe_mxfp4 - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - You will implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. - - To be explicit, you will be given a tuple of tensors: - ``` - (hidden_states, - gate_up_weight, down_weight, # fp4x2 raw - gate_up_weight_scale, down_weight_scale, # e8m0 raw - gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled - gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled - topk_weights, topk_ids, - config) - ``` - where: - * `hidden_states` is M x d_hidden in bfloat16 (the input activations, M = batch of tokens) - * `gate_up_weight` is [E, 2*d_expert_pad, d_hidden_pad//2] in MXFP4 (fp4x2), raw layout. - Fused gate + up projection weights for each expert. E = number of local experts. - * `down_weight` is [E, d_hidden_pad, d_expert_pad//2] in MXFP4 (fp4x2), raw layout. - Down projection weights for each expert. - * `gate_up_weight_scale` is [E, 2*d_expert_pad, d_hidden_pad//32] in E8M0, raw layout. - Block scales (block_size=32) for gate_up_weight. - * `down_weight_scale` is [E, d_hidden_pad, d_expert_pad//32] in E8M0, raw layout. - Block scales for down_weight. - * `gate_up_weight_shuffled` / `down_weight_shuffled` are the same weights shuffled to - (16,16) tile-coalesced layout for the CK kernel. - * `gate_up_weight_scale_shuffled` / `down_weight_scale_shuffled` are the scales after - e8m0_shuffle, flattened to [padded, flat]. - * `topk_weights` is [M, total_top_k] float32: routing weights (routed experts + shared experts). - * `topk_ids` is [M, total_top_k] int32: expert indices. First nexpertspertoken columns are - routed expert ids (0..n_routed-1), last nsharedexperts columns are shared expert ids - (n_routed..n_routed+n_shared-1). Shared experts are always selected with weight=1.0. - * `config` is a dict with: d_hidden, d_expert, d_hidden_pad, d_expert_pad, - n_routed_experts, n_shared_experts, n_experts_per_token, total_top_k, bs. - - Then, the fused_moe kernel flow is: - (1) Quant activations to MXFP4: aiter per-1x32 dynamic quantization of hidden_states. - (2) Stage 1 GEMM + activation (per token i, per assigned expert j): - - gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - - up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - - intermediate = SiLU(gate) * up # SwiGLU activation, -> [d_expert] - (W_gate and W_up are fused as gate_up_weight, so this is one a4w4 GEMM + fused activation) - (3) Stage 2 GEMM: - - expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] - (4) Weighted reduction: - - output_i += w_ij * expert_out # accumulate across top_k experts - All weight GEMMs are a4w4 (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). - The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. - - DeepSeek-R1 MoE specs: - - hidden_size = 7168, moe_intermediate_size = 2048 - - 256 routed experts + 1 shared expert (total 257), top-8 routed + 1 shared = 9 per token - - 58 MoE layers (layer 3-60) - - The shared expert processes ALL tokens unconditionally (weight=1.0) - - d_hidden_pad and d_expert_pad are the dimensions padded to 256-alignment for the CK kernel. - - **Known issue:** The reference submission (which calls aiter's fused_moe) is non-deterministic - on MI355X — it does not pass correctness checks against itself. This appears to be an aiter - fused_moe kernel bug on gfx950. Submissions will be evaluated on benchmark performance only - until this is resolved. - - The ranking criteria is the geometric mean of the benchmark results. - - ``` - The AITER reference performance is (E includes shared expert, top_k = routed + shared): - bs E d_hidden d_expert top_k time[us] - 16 257 7168 256 9 152.7 - 128 257 7168 256 9 239.0 - 512 257 7168 256 9 336.5 - 16 33 7168 512 9 106.2 - 128 33 7168 512 9 141.1 - 512 33 7168 512 9 225.0 - 512 33 7168 2048 9 380.4 - ``` - - Input: - - hidden_states: [M, d_hidden] bf16 - - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) - - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) - - gate_up_weight_scale: [E, 2*d_expert_pad, d_hidden_pad//32] e8m0 (raw, before shuffle) - - down_weight_scale: [E, d_hidden_pad, d_expert_pad//32] e8m0 (raw, before shuffle) - - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled for CK) - - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled for CK) - - gate_up_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) - - down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) - - topk_weights: [M, total_top_k] float32 - - topk_ids: [M, total_top_k] int32 - - config: dict with dimensions - - Output: - - output: [M, d_hidden] bf16 - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 1800 -benchmark_timeout: 1800 -ranked_timeout: 1800 -ranking_by: "geom" - -tests: - - {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 8, "seed": 9371} - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 32, "seed": 2291} - - {"dhidden": 4096, "dexpert": 1536, "nroutedexperts": 64, "nexpertspertoken": 6, "nsharedexperts": 1, "bs": 128, "seed": 81934} - -benchmarks: - # TP=8 - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 9371} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 2291} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} - # TP=4 - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 2291} - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 81934} - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} - # EP on - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} diff --git a/problems/amd_202602/mxfp4-mm/reference.py b/problems/amd_202602/mxfp4-mm/reference.py deleted file mode 100644 index 3c26d348..00000000 --- a/problems/amd_202602/mxfp4-mm/reference.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. -Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). -""" -import torch -from task import input_t, output_t -from utils import make_match_reference -from aiter import QuantType,dtypes -import aiter -from aiter.ops.shuffle import shuffle_weight -# K must be divisible by 64 (scale group 32 and fp4 pack 2) -SCALE_GROUP_SIZE = 32 - -def generate_input(m: int, n: int, k: int, seed: int):# -> input_t: - """ - Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale. - - Returns: - Tuple of (A, B), both bf16 on cuda. - """ - assert k % 64 == 0, "k must be divisible by 64 (scale group 32 and fp4 pack 2)" - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen) - B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen) - - # quantized mxfp4 B - quant_func = aiter.get_triton_quant(QuantType.per_1x32) - B_q, B_scale_sh = quant_func(B, shuffle=True) - - # shuffle B(weight) to (16,16) tile coalesced - B_shuffle = shuffle_weight(B_q, layout=(16, 16)) - return (A, B, B_q, B_shuffle, B_scale_sh) - -def run_torch_fp4_mm( - x: torch.Tensor, - w: torch.Tensor, - x_scales: torch.Tensor, - w_scales: torch.Tensor, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - PyTorch reference: dequant MXFP4 + E8M0 scale -> f32 -> mm -> dtype. - Same logic as aiter op_tests/test_gemm_a4w4.run_torch. - x: [m, k//2] fp4 packed, w: [n, k//2] fp4 packed - x_scales: [m, k//32] E8M0, w_scales: [n, k//32] E8M0 - Returns: [m, n] in dtype - """ - from aiter.utility import fp4_utils - - m, _ = x.shape - n, _ = w.shape - # fp4 packed -> f32 - x_f32 = fp4_utils.mxfp4_to_f32(x) - w_f32 = fp4_utils.mxfp4_to_f32(w) - # E8M0 scale: [*, k//32] -> repeat 32 along k -> f32 - x_scales = x_scales[:m].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - x_scales_f32 = fp4_utils.e8m0_to_f32(x_scales) - x_f32 = x_f32 * x_scales_f32 - w_scales = w_scales[:n].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - w_scales_f32 = fp4_utils.e8m0_to_f32(w_scales) - w_f32 = w_f32 * w_scales_f32 - return torch.mm(x_f32, w_f32.T).to(dtype)[:m, :n] - - -def ref_kernel(data: input_t) -> output_t: - """ - Reference: MXFP4 per-1x32 quant on A and B; both PyTorch ref and gemm_a4w4 are given. - Returns gemm_a4w4 for check_implementation. - """ - A, B, B_q, B_shuffle, B_scale_sh = data - A = A.contiguous() - B = B.contiguous() - m, k = A.shape - n, _ = B.shape - - # 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16 - # Per-1x32 MXFP4 quant - # quant_func = aiter.get_triton_quant(QuantType.per_1x32) - # quant_func(x, shuffle=False) -> (dtypes.fp4x2, scale); scale layout matches gemm_a4w4 - # A_q, A_scale = quant_func(A, shuffle=False) - # B_q, B_scale = quant_func(B, shuffle=False) - - # gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0 - # quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed - # slice to exact shapes [m,k_scale] / [n,k_scale] (quant may return padded scale) - - # k_scale = k // SCALE_GROUP_SIZE - # A_scale = A_scale[:m, :k_scale].contiguous() - # B_scale = B_scale[:n, :k_scale].contiguous() - # out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16) - - # 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105) - # Per-1x32 MXFP4 quant - quant_func = aiter.get_triton_quant(QuantType.per_1x32) - A_q, A_scale_sh = quant_func(A, shuffle=True) - # to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py - out_gemm = aiter.gemm_a4w4( - A_q, - B_shuffle, - A_scale_sh, - B_scale_sh, - dtype=dtypes.bf16, - bpreshuffle=True, - ) - return out_gemm - -check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) \ No newline at end of file diff --git a/problems/amd_202602/mxfp4-mm/submission.py b/problems/amd_202602/mxfp4-mm/submission.py deleted file mode 100644 index 33b7ebbb..00000000 --- a/problems/amd_202602/mxfp4-mm/submission.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. -Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). -""" -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference: MXFP4 per-1x32 quant on A; B_shuffle, B_scale_sh from generate_input. - gemm_a4w4 with bpreshuffle=True. - """ - import aiter - from aiter import QuantType, dtypes - - A, B, B_q, B_shuffle, B_scale_sh = data - A = A.contiguous() - B = B.contiguous() - m, k = A.shape - n, _ = B.shape - - quant_func = aiter.get_triton_quant(QuantType.per_1x32) - A_q, A_scale_sh = quant_func(A, shuffle=True) - out_gemm = aiter.gemm_a4w4( - A_q, - B_shuffle, - A_scale_sh, - B_scale_sh, - dtype=dtypes.bf16, - bpreshuffle=True, - ) - return out_gemm diff --git a/problems/amd_202602/mxfp4-mm/task.py b/problems/amd_202602/mxfp4-mm/task.py deleted file mode 100644 index a258eac0..00000000 --- a/problems/amd_202602/mxfp4-mm/task.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -quant + FP4 GEMM: bf16 A, B -> MXFP4 1x32 per-block quant -> gemm_a4w4 -> bf16 C. -""" -import torch -from typing import TypeVar, TypedDict - -# Input: (A, B, B_q, B_shuffle, B_scale_sh) from generate_input. -# A [m,k], B [n,k] bf16; B_q quantized MXFP4; B_shuffle = shuffle_weight(B_q,(16,16)); B_scale_sh from quant(B, shuffle=True). -# Output: bf16 C [m, n]. -input_t = TypeVar( - "input_t", - bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], -) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec(TypedDict): - m: int - n: int - k: int - seed: int diff --git a/problems/amd_202602/mxfp4-mm/task.yml b/problems/amd_202602/mxfp4-mm/task.yml deleted file mode 100644 index 1e519c78..00000000 --- a/problems/amd_202602/mxfp4-mm/task.yml +++ /dev/null @@ -1,58 +0,0 @@ -# name: mxfp4-mm - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - You will implement a quantize func and block scaled MXFP4 matrix-matrix multiplication kernel optimized for AMD Instinct MI355X GPU. - To be explicit, you will be given a tuple of tensors: - ``` - (A, B, B_q, B_shuffle, B_scale_sh) - ``` - where: - * `A` is M x K in K-major order in bfloat16 - * `B` is N x K in K-major order in bfloat16 - * `B_q` is N x K/2 in K-major order in MXFP4 - * `B_shuffle` is N x K/2 in shuffled order in MXFP4, shuffled to (16,16) tile coalesced - * `B_scale_sh` is * x K/32 in E8M0, * means it will be padded. - - Then, the kernel flow is bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> BF16 C [m,n]. - To be specific, the invocation flow is: - (1) Quant A to MXFP4: aiter.get_triton_quant(QuantType.per_1x32). - (2) GEMM: aiter.gemm_a4w4. - m, n divisible by 64; k divisible by 64. - - The ranking criteria is the geometric mean of the benchmark results. - Pls note that this is the elimination round, whoever rank top5 are selected into the next round, e2e optimization for deepseek-R1-MXFP4 and GPTOSS-MXFP4 mdoel - ``` - The aiter performance is: - M N K time[us] - 4 2880 512 8.198 - 16 2112 7168 20.873 - 32 4096 512 9.462 - 32 2880 512 9.173 - 64 7168 2048 12.738 - 256 3072 1536 12.219 - ``` -config: - main: "eval.py" - -tests: - - {"m": 8, "n": 2112, "k": 7168, "seed": 124} - - {"m": 16, "n": 3072, "k": 1536, "seed": 6635} - - {"m": 64, "n": 3072, "k": 1536, "seed": 45} - - {"m": 256, "n": 2880, "k": 512, "seed": 78} - -benchmarks: - - {"m": 4, "n": 2880, "k": 512, "seed": 4565} - - {"m": 16, "n": 2112, "k": 7168, "seed": 15} - - {"m": 32, "n": 4096, "k": 512, "seed": 457} - - {"m": 32, "n": 2880, "k": 512, "seed": 54} - - {"m": 64, "n": 7168, "k": 2048, "seed": 687} - - {"m": 256, "n": 3072, "k": 1536, "seed": 7856} diff --git a/problems/amd_202602/utils.py b/problems/amd_202602/utils.py deleted file mode 100644 index 45f120ed..00000000 --- a/problems/amd_202602/utils.py +++ /dev/null @@ -1,147 +0,0 @@ -import random -from typing import Tuple - -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> Tuple[bool, list[str]]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return False, ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received.to(torch.float32) - expected.to(torch.float32)) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return False, mismatch_details - - return True, [f"Maximum error: {torch.max(diff)}"] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int = 5) -> Tuple[bool, list[str]]: - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return False, mismatch_details - - return True, [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08): - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - good, reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return good, "\\n".join(reasons) - - return good, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - -def clear_l2_cache_large(): - dummy = torch.randn((16000, 1024, 1024), device="cuda") - del dummy