Skip to content

Enhance mla numeric check mechanism#137

Merged
msaroufim merged 1 commit intogpu-mode:mainfrom
danielhua23:danie/mla
Mar 18, 2026
Merged

Enhance mla numeric check mechanism#137
msaroufim merged 1 commit intogpu-mode:mainfrom
danielhua23:danie/mla

Conversation

@danielhua23
Copy link
Copy Markdown
Contributor

No description provided.

Copilot AI review requested due to automatic review settings March 18, 2026 13:50
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts correctness checking for the AMD 202602 Mixed-MLA problem by adding an MLA-specific “mismatch ratio” allowance in the shared match_reference helper, and by loosening the Mixed-MLA reference tolerances. It also refactors the Mixed-MLA submission template to use an aiter persistent MLA decode path and expands the inline documentation.

Changes:

  • Add _is_mla_case detection and an MLA-only mismatch-ratio “pass with warning” path in match_reference.
  • Update Mixed-MLA submission template to call aiter mla_decode_fwd with persistent-mode metadata helpers.
  • Loosen Mixed-MLA check_implementation tolerances (rtol/atol) in reference.py.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
problems/amd_202602/utils.py Adds MLA-case detection and a mismatch-ratio relaxation path in match_reference.
problems/amd_202602/mixed-mla/submission.py Refactors the template to an aiter-based decode reference path and adds quant/dequant helpers and metadata plumbing.
problems/amd_202602/mixed-mla/reference.py Loosens check_implementation rtol/atol thresholds.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 24 to +26
import torch.nn.functional as F
from task import input_t, output_t
from utils import make_match_reference
Comment on lines +268 to +274
page_size=PAGE_SIZE,
nhead_kv=nkv,
sm_scale=SM_SCALE,
logit_cap=0.0,
num_kv_splits=NUM_KV_SPLITS,
q_scale=q_scale,
kv_scale=kv_scale,
Comment on lines +152 to +160
# Only for MLA: aligned with aiter
if (not good) and _is_mla_case(data) and output.shape == expected.shape:
mismatch_mask = ~torch.isclose(output, expected, rtol=rtol, atol=atol)
mismatch_ratio = (mismatch_mask.sum() / output.numel()).item()
if mismatch_ratio <= tol_err_ratio:
return True, (
f"warning: mismatch_ratio={mismatch_ratio:.6f} "
f"(<= tol_err_ratio={tol_err_ratio}) with rtol={rtol}, atol={atol}"
)


check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=2e-02)
check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-01)
@msaroufim msaroufim merged commit 7382df0 into gpu-mode:main Mar 18, 2026
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants