From 11cb5d918c0428197163e507fbebd6e26f06b729 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 10 Mar 2026 19:16:24 -0700 Subject: [PATCH] Add reference Helion kernel implementations --- .../helion/causal_conv1d_py/submission.py | 55 +++++++++++-- problems/helion/fp8_quant_py/submission.py | 62 ++++++++++---- .../reference.py | 4 +- .../submission.py | 80 +++++++++++++------ .../submission.py | 78 +++++++++++------- .../submission.py | 75 +++++++++++++---- 6 files changed, 264 insertions(+), 90 deletions(-) diff --git a/problems/helion/causal_conv1d_py/submission.py b/problems/helion/causal_conv1d_py/submission.py index ba89f5ad..32037a94 100644 --- a/problems/helion/causal_conv1d_py/submission.py +++ b/problems/helion/causal_conv1d_py/submission.py @@ -1,14 +1,53 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl -def custom_kernel(data: input_t) -> output_t: - import torch - import torch.nn.functional as F +# NOTE: This is an intentionally inefficient baseline implementation. +@helion.kernel( + static_shapes=True, + config=helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), +) +def conv1d_kernel( + x_pad: torch.Tensor, # (B, D, L) zero-padded input + w: torch.Tensor, # (D, W) filter coefficients + b: torch.Tensor, # (D,) additive offset +) -> torch.Tensor: + B = x_pad.size(0) + D = x_pad.size(1) + L = x_pad.size(2) + W = hl.specialize(w.size(1)) + N = L - W + 1 + + y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device) + + for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]): + bi = rb.begin + acc1 = hl.zeros([rd, rs], dtype=torch.float32) + acc2 = hl.zeros([rd, rs], dtype=torch.float32) + acc3 = hl.zeros([rd, rs], dtype=torch.float32) + for j in range(W): + c1 = w[rd, j].to(torch.float32) + x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc1 = acc1 + x1 * c1[:, None] + c2 = w[rd, j].to(torch.float32) + x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc2 = acc2 + x2 * c2[:, None] + c3 = w[rd, j].to(torch.float32) + x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc3 = acc3 + x3 * c3[:, None] + acc = (acc1 + acc2 + acc3) / 3.0 + acc = acc + b[rd].to(torch.float32)[:, None] + y[rb, rd, rs] = acc[None, :, :].to(y.dtype) + + return y + + +def custom_kernel(data: input_t) -> output_t: x, weight, bias = data W = weight.shape[1] - D = x.shape[1] - - x_padded = F.pad(x, (W - 1, 0)) - output = F.conv1d(x_padded, weight.unsqueeze(1), bias=bias, groups=D) - return output + pad_zeros = torch.zeros(x.shape[0], x.shape[1], W - 1, dtype=x.dtype, device=x.device) + padded = torch.cat([pad_zeros, x], dim=2) + return conv1d_kernel(padded, weight, bias) diff --git a/problems/helion/fp8_quant_py/submission.py b/problems/helion/fp8_quant_py/submission.py index 39cf1d08..f8f9888b 100644 --- a/problems/helion/fp8_quant_py/submission.py +++ b/problems/helion/fp8_quant_py/submission.py @@ -1,25 +1,59 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl -FP8_MAX = 448.0 -FP8_MIN = -448.0 -FP8_EPS = 1e-10 + +# NOTE: This is an intentionally inefficient baseline implementation. +@helion.kernel( + static_shapes=True, + config=helion.Config(block_sizes=[1], num_warps=1, num_stages=1), +) +def normalize_to_range( + data: torch.Tensor, # [N, G] input rows + scales_out: torch.Tensor, # [N] output normalization factors +) -> torch.Tensor: + nrows = data.size(0) + ncols = hl.specialize(data.size(1)) + MAX_VAL = 448.0 + + qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device) + + for rr in hl.tile(nrows): + row = data[rr, :].to(torch.float32) + + abs1 = torch.abs(row) + amax1 = torch.amax(abs1, -1) + abs2 = torch.abs(row) + amax2 = torch.amax(abs2, -1) + abs3 = torch.abs(row) + amax3 = torch.amax(abs3, -1) + amax = (amax1 + amax2 + amax3) / 3.0 + amax = torch.clamp(amax, min=1e-10) + scale = amax / MAX_VAL + + q1 = row / scale[:, None] + q2 = row / scale[:, None] + q3 = row / scale[:, None] + qout[rr, :] = (q1 + q2 + q3) / 3.0 + scales_out[rr] = scale + + return qout def custom_kernel(data: input_t) -> output_t: x, x_q, x_s = data - num_tokens, hidden_dim = x.shape - num_groups = x_s.shape[1] - group_size = hidden_dim // num_groups + T, H = x.shape + G = x_s.shape[1] + gsz = H // G + N = T * G - x_f32 = x.float() - x_grouped = x_f32.reshape(num_tokens, num_groups, group_size) + flat_in = x.reshape(N, gsz) + flat_s = x_s.reshape(N) - absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS) - scale = absmax / FP8_MAX - quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX) - quantized = quantized.reshape(num_tokens, hidden_dim) + flat_q = normalize_to_range(flat_in, flat_s) - x_q[...] = quantized - x_s[...] = scale + x_q[...] = flat_q.reshape(T, H) + x_s[...] = flat_s.reshape(T, G) return x_q, x_s diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py index 8863335d..8b3668c2 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py @@ -61,8 +61,8 @@ def check_implementation(data, output): exp_h, exp_v = expected got_h, got_v = output - reasons_h = verbose_allclose(got_h, exp_h, rtol=1e-2, atol=1e-2) - reasons_v = verbose_allclose(got_v, exp_v, rtol=1e-2, atol=1e-2) + reasons_h = verbose_allclose(got_h, exp_h, rtol=2e-2, atol=2e-2) + reasons_v = verbose_allclose(got_v, exp_v, rtol=2e-2, atol=2e-2) reasons = [] if reasons_h: diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py index e85c0260..528a61cc 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py @@ -1,39 +1,69 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl -def custom_kernel(data: input_t) -> output_t: - import torch - k, w, u, g = data +# NOTE: This is an intentionally inefficient baseline implementation. +@helion.kernel( + static_shapes=True, + dot_precision="ieee", + config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), +) +def chunk_state_pass( + k: torch.Tensor, # [B, T, H, K] + w: torch.Tensor, # [B, T, H, K] + u: torch.Tensor, # [B, T, H, V] + g: torch.Tensor, # [B, T, H] +) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K = k.shape V = u.shape[-1] - BT = 64 - NT = T // BT + C = 64 + K = hl.specialize(K) + V = hl.specialize(V) + + NT = (T + C - 1) // C + h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device) + v_out = torch.empty_like(u) + + BH = B * H - h = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=k.device) - v_new = torch.empty_like(u) + for flat, tv in hl.tile([BH, V], block_size=[1, 8]): + b_idx = flat.begin // H + h_idx = flat.begin % H + state = hl.zeros([K, tv], dtype=torch.float32) - for b in range(B): - for hh in range(H): - b_h = torch.zeros(K, V, dtype=torch.float32, device=k.device) + for tc in hl.tile(T, block_size=C): + chunk_idx = tc.begin // C + t_end = min(tc.begin + C, T) - 1 - for c in range(NT): - cs = c * BT - ce = cs + BT + h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype) - h[b, c, hh] = b_h + proj1 = hl.dot( + w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 + ) + proj2 = hl.dot( + w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 + ) + proj = (proj1 + proj2) * 0.5 + diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj + v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype) - b_w = w[b, cs:ce, hh].float() - b_u = u[b, cs:ce, hh].float() - b_v = b_u - torch.matmul(b_w, b_h) - v_new[b, cs:ce, hh] = b_v + g_end = g[b_idx, t_end, h_idx].to(torch.float32) + g_t = g[b_idx, tc, h_idx].to(torch.float32) + valid = tc.index < T + alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0) + k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None] - b_g = g[b, cs:ce, hh].float() - b_g_last = b_g[-1] - b_v_gated = b_v * torch.exp(b_g_last - b_g)[:, None] + state = state * torch.exp(g_end) + upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) + upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) + state = state + (upd1 + upd2) * 0.5 - b_h = b_h * torch.exp(b_g_last) - b_k = k[b, cs:ce, hh].float() - b_h = b_h + torch.matmul(b_k.T, b_v_gated) + return h_out, v_out - return h, v_new + +def custom_kernel(data: input_t) -> output_t: + k, w, u, g = data + return chunk_state_pass(k, w, u, g) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py index 0b5f02cd..8e2a2f53 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py @@ -1,38 +1,62 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl + + +# NOTE: This is an intentionally inefficient baseline implementation. +@helion.kernel( + static_shapes=True, + dot_precision="ieee", + config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), +) +def gated_chunk_attn( + q: torch.Tensor, # [B, T, H, K] + k: torch.Tensor, # [B, T, H, K] + v: torch.Tensor, # [B, T, H, V] + h: torch.Tensor, # [B, NT, H, K, V] + g: torch.Tensor, # [B, T, H] + scale: float, +) -> torch.Tensor: + B, T, H, K = q.shape + V = v.shape[-1] + C = 64 + K = hl.specialize(K) + V = hl.specialize(V) -def custom_kernel(data: input_t) -> output_t: - import torch + out = torch.empty_like(v) - q, k, v_new, h, g = data - B, T, H, K = q.shape - V = v_new.shape[-1] - BT = 64 - scale = K ** -0.5 + BH = B * H + for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): + b_idx = flat_bh.begin // H + h_idx = flat_bh.begin % H + c_idx = tile_t.begin // C - o = torch.empty_like(v_new) - causal = torch.tril(torch.ones(BT, BT, device=q.device, dtype=torch.bool)) + g_vals = g[b_idx, tile_t, h_idx] + q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] + k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] - for cs in range(0, T, BT): - ce = cs + BT - c_idx = cs // BT + sim1 = hl.dot(q_s, k_s.T) + sim2 = hl.dot(q_s, k_s.T) + sim = (sim1 + sim2) * 0.5 + idx = hl.arange(tile_t.block_size) + mask = idx[:, None] >= idx[None, :] + sim = torch.where(mask, sim, 0.0) + local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) + local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) + local_out = (local1 + local2) * 0.5 - b_q = q[:, cs:ce, :, :].permute(0, 2, 1, 3).float() - b_k = k[:, cs:ce, :, :].permute(0, 2, 1, 3).float() - b_v = v_new[:, cs:ce, :, :].permute(0, 2, 1, 3).float() - b_h = h[:, c_idx, :, :, :].float() - b_g = g[:, cs:ce, :].permute(0, 2, 1).float() + glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) + glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) + global_out = (glob1 + glob2) * 0.5 - inter = torch.matmul(b_q, b_h) - inter = inter * torch.exp(b_g).unsqueeze(-1) + out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) - attn = torch.matmul(b_q, b_k.transpose(-1, -2)) - g_diff = b_g.unsqueeze(-1) - b_g.unsqueeze(-2) - attn = attn * torch.exp(g_diff) - attn = attn.masked_fill(~causal, 0.0) - intra = torch.matmul(attn, b_v) + return out - b_o = (inter + intra) * scale - o[:, cs:ce, :, :] = b_o.permute(0, 2, 1, 3) - return o +def custom_kernel(data: input_t) -> output_t: + q, k, v_new, h, g = data + scale = q.shape[-1] ** -0.5 + return gated_chunk_attn(q, k, v_new, h, g, scale) diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py index ec50c3cf..918f519a 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py @@ -1,25 +1,72 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl -def custom_kernel(data: input_t) -> output_t: - import torch - k, v, beta, A, g = data +# NOTE: This is an intentionally inefficient baseline implementation. +@helion.kernel( + static_shapes=True, + dot_precision="ieee", + config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), +) +def project_kv( + k: torch.Tensor, # [B, T, H, K] + v: torch.Tensor, # [B, T, H, V] + beta: torch.Tensor, # [B, T, H] + A: torch.Tensor, # [B, T, H, BT] + g: torch.Tensor, # [B, T, H] +) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K = k.shape V = v.shape[-1] - BT = A.shape[-1] + C = hl.specialize(A.shape[-1]) + K = hl.specialize(K) + V = hl.specialize(V) + + w_out = torch.empty_like(k) + u_out = torch.empty_like(v) + + BH = B * H + for flat_bh, rt in hl.tile([BH, T], block_size=[1, C]): + b_idx = flat_bh.begin // H + h_idx = flat_bh.begin % H + + w_acc1 = hl.zeros([rt, K], dtype=torch.float32) + u_acc1 = hl.zeros([rt, V], dtype=torch.float32) + w_acc2 = hl.zeros([rt, K], dtype=torch.float32) + u_acc2 = hl.zeros([rt, V], dtype=torch.float32) + + for ci in range(C): + t_ci = rt.begin + ci + a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) + coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) + decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - w = torch.empty_like(k) - u = torch.empty_like(v) + k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) + v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - for cs in range(0, T, BT): - ce = cs + BT - A_bh = A[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + w_acc1 = w_acc1 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] + u_acc1 = u_acc1 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - vb = (v[:, cs:ce, :, :] * beta[:, cs:ce, :, None]).permute(0, 2, 1, 3).float() - u[:, cs:ce, :, :] = torch.matmul(A_bh, vb).permute(0, 2, 1, 3) + for ci in range(C - 1, -1, -1): + t_ci = rt.begin + ci + a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) + coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) + decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - kb = (k[:, cs:ce, :, :] * beta[:, cs:ce, :, None] * torch.exp(g[:, cs:ce, :, None])).permute(0, 2, 1, 3).float() - w[:, cs:ce, :, :] = torch.matmul(A_bh, kb).permute(0, 2, 1, 3) + k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) + v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - return w, u + w_acc2 = w_acc2 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] + u_acc2 = u_acc2 + a_col[:, None] * (v_ci * coeff_ci)[None, :] + + w_out[b_idx, rt, h_idx, :] = ((w_acc1 + w_acc2) * 0.5).to(k.dtype) + u_out[b_idx, rt, h_idx, :] = ((u_acc1 + u_acc2) * 0.5).to(v.dtype) + + return w_out, u_out + + +def custom_kernel(data: input_t) -> output_t: + k, v, beta, A, g = data + return project_kv(k, v, beta, A, g)