From 24b5bcee1ac5ffcfea9db2f05430967aa71f0194 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 14 Mar 2026 10:45:44 -0700 Subject: [PATCH 1/2] Fix NaN in gated_deltanet_chunk_fwd_o reference and submission The reference kernel computed exp(g_i - g_j) before applying the causal mask. When g values are very negative (cumulative sums of negative increments), the upper-triangle differences g_i - g_j overflow exp() to inf, and inf * 0 (causal mask) produces NaN. Fix: zero out g_diff in the upper triangle before calling exp(), so we never compute exp(large_positive). Apply the same fix in the submission kernel which had a similar issue with exp(-g) overflowing. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../gated_deltanet_chunk_fwd_o_py/reference.py | 8 +++++--- .../gated_deltanet_chunk_fwd_o_py/submission.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py index 2ce7062e..467d9afb 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py @@ -103,9 +103,11 @@ def ref_kernel(data: input_t) -> output_t: v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1) - qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_c.unsqueeze(-1) - g_c.unsqueeze(-2)) - causal = torch.tril(torch.ones(C, C, device=q.device)) - o = (o_inter + (qk * causal) @ v_c) * scale + causal = torch.tril(torch.ones(C, C, dtype=torch.bool, device=q.device)) + g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) + g_diff = torch.where(causal, g_diff, torch.zeros_like(g_diff)) + qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_diff) * causal + o = (o_inter + qk @ v_c) * scale return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py index 81e8fcc4..eb4de947 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py @@ -54,15 +54,20 @@ def kernel( c_idx = tile_t.begin // C g_vals = g[b_idx, tile_t, h_idx] - q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] - k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] + q_tile = q[b_idx, tile_t, h_idx, :] + k_tile = k[b_idx, tile_t, h_idx, :] + v_tile = v[b_idx, tile_t, h_idx, :] - sim = hl.dot(q_s, k_s.T) + # intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask + qk = hl.dot(q_tile, k_tile.T) idx = hl.arange(tile_t.block_size) - mask = idx[:, None] >= idx[None, :] - sim = torch.where(mask, sim, 0.0) - local_out = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) + g_diff = g_vals[:, None] - g_vals[None, :] + causal_mask = idx[:, None] >= idx[None, :] + sim = torch.where(causal_mask, qk * torch.exp(g_diff), 0.0) + local_out = hl.dot(sim.to(v.dtype), v_tile) + # inter-chunk: (q @ h) * exp(g) + q_s = q_tile * torch.exp(g_vals)[:, None] global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) From a611712a554aada5e1cc5620981c1ff88a671d34 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 14 Mar 2026 10:49:42 -0700 Subject: [PATCH 2/2] Fix NaN in _chunk_scaled_dot_kkt_fwd_eager across all 3 gated deltanet kernels Zero out g_diff outside the strict lower triangle before calling exp(), preventing inf * 0 = NaN when upper-triangle g differences overflow. Co-Authored-By: Claude Opus 4.6 (1M context) --- problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py | 3 ++- problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py | 3 ++- problems/helion/gated_deltanet_recompute_w_u_py/reference.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py index 0f8ea9cc..30031e0e 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py @@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) kkt = k_c @ k_c.transpose(-1, -2) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) + g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) + g_diff = g_diff * strict_lower A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py index 467d9afb..e7775920 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py @@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) kkt = k_c @ k_c.transpose(-1, -2) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) + g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) + g_diff = g_diff * strict_lower A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py index 44b31387..b6fe41ae 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py @@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) kkt = k_c @ k_c.transpose(-1, -2) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) + g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) + g_diff = g_diff * strict_lower A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)