From ce0b6f7cd189bbdfe33d49bdf7c62651abeecfba Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 14 Mar 2026 11:49:57 -0700 Subject: [PATCH] Tighten rtol/atol from 1e-2 to 1e-3 for Helion kernels All four Helion kernels (causal_conv1d, gated_deltanet_chunk_fwd_h, chunk_fwd_o, recompute_w_u) operate in float32 but used overly loose tolerances of rtol=1e-2, atol=1e-2. Tighten to 1e-3 to better catch numerical bugs while still allowing for accumulation chain error. Co-Authored-By: Claude Opus 4.6 (1M context) --- problems/helion/causal_conv1d_py/reference.py | 2 +- problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py | 4 ++-- problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py | 2 +- problems/helion/gated_deltanet_recompute_w_u_py/reference.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/problems/helion/causal_conv1d_py/reference.py b/problems/helion/causal_conv1d_py/reference.py index 268838fc..0d2ae2f5 100644 --- a/problems/helion/causal_conv1d_py/reference.py +++ b/problems/helion/causal_conv1d_py/reference.py @@ -32,4 +32,4 @@ def ref_kernel(data: input_t) -> output_t: return output -check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) +check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py index 30031e0e..9d9b7204 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py @@ -96,8 +96,8 @@ def check_implementation(data, output): exp_h, exp_v = expected got_h, got_v = output - reasons_h = verbose_allclose(got_h.float(), exp_h.float(), rtol=1e-2, atol=1e-2) - reasons_v = verbose_allclose(got_v.float(), exp_v.float(), rtol=1e-2, atol=1e-2) + reasons_h = verbose_allclose(got_h.float(), exp_h.float(), rtol=1e-3, atol=1e-3) + reasons_v = verbose_allclose(got_v.float(), exp_v.float(), rtol=1e-3, atol=1e-3) reasons = [] if reasons_h: diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py index e7775920..54be0f2f 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py @@ -112,4 +112,4 @@ def ref_kernel(data: input_t) -> output_t: return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype) -check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) +check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py index b6fe41ae..bd7c1507 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py @@ -72,8 +72,8 @@ def check_implementation(data, output): exp_w, exp_u = expected got_w, got_u = output - reasons_w = verbose_allclose(got_w, exp_w, rtol=1e-2, atol=1e-2) - reasons_u = verbose_allclose(got_u, exp_u, rtol=1e-2, atol=1e-2) + reasons_w = verbose_allclose(got_w, exp_w, rtol=1e-3, atol=1e-3) + reasons_u = verbose_allclose(got_u, exp_u, rtol=1e-3, atol=1e-3) reasons = [] if reasons_w: