diff --git a/problems/helion/causal_conv1d_py/submission.py b/problems/helion/causal_conv1d_py/submission.py index 32037a94..92716763 100644 --- a/problems/helion/causal_conv1d_py/submission.py +++ b/problems/helion/causal_conv1d_py/submission.py @@ -5,49 +5,77 @@ import helion.language as hl +# Per-shape configs: map (B, D, S, W) to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # Test shapes + (1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + # Benchmark shapes + (1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config +} + + +# Optional: add advanced_controls_file to your Config for extra performance (see docs). +# Autotune with autotune_search_acf to find the best ACF, then hardcode it: +# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf") + + # NOTE: This is an intentionally inefficient baseline implementation. -@helion.kernel( - static_shapes=True, - config=helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), -) -def conv1d_kernel( - x_pad: torch.Tensor, # (B, D, L) zero-padded input - w: torch.Tensor, # (D, W) filter coefficients - b: torch.Tensor, # (D,) additive offset -) -> torch.Tensor: - B = x_pad.size(0) - D = x_pad.size(1) - L = x_pad.size(2) - W = hl.specialize(w.size(1)) - N = L - W + 1 - - y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device) - - for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]): - bi = rb.begin - acc1 = hl.zeros([rd, rs], dtype=torch.float32) - acc2 = hl.zeros([rd, rs], dtype=torch.float32) - acc3 = hl.zeros([rd, rs], dtype=torch.float32) - for j in range(W): - c1 = w[rd, j].to(torch.float32) - x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc1 = acc1 + x1 * c1[:, None] - c2 = w[rd, j].to(torch.float32) - x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc2 = acc2 + x2 * c2[:, None] - c3 = w[rd, j].to(torch.float32) - x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc3 = acc3 + x3 * c3[:, None] - acc = (acc1 + acc2 + acc3) / 3.0 - acc = acc + b[rd].to(torch.float32)[:, None] - y[rb, rd, rs] = acc[None, :, :].to(y.dtype) - - return y +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, config=config) + def kernel( + x_pad: torch.Tensor, # (B, D, L) zero-padded input + w: torch.Tensor, # (D, W) filter coefficients + b: torch.Tensor, # (D,) additive offset + ) -> torch.Tensor: + B = x_pad.size(0) + D = x_pad.size(1) + L = x_pad.size(2) + W = hl.specialize(w.size(1)) + N = L - W + 1 + + y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device) + + for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]): + bi = rb.begin + acc1 = hl.zeros([rd, rs], dtype=torch.float32) + acc2 = hl.zeros([rd, rs], dtype=torch.float32) + acc3 = hl.zeros([rd, rs], dtype=torch.float32) + for j in range(W): + c1 = w[rd, j].to(torch.float32) + x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc1 = acc1 + x1 * c1[:, None] + c2 = w[rd, j].to(torch.float32) + x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc2 = acc2 + x2 * c2[:, None] + c3 = w[rd, j].to(torch.float32) + x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc3 = acc3 + x3 * c3[:, None] + acc = (acc1 + acc2 + acc3) / 3.0 + acc = acc + b[rd].to(torch.float32)[:, None] + y[rb, rd, rs] = acc[None, :, :].to(y.dtype) + + return y + + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: x, weight, bias = data + B, D, S = x.shape W = weight.shape[1] - pad_zeros = torch.zeros(x.shape[0], x.shape[1], W - 1, dtype=x.dtype, device=x.device) + kernel = _KERNELS[(B, D, S, W)] + pad_zeros = torch.zeros(B, D, W - 1, dtype=x.dtype, device=x.device) padded = torch.cat([pad_zeros, x], dim=2) - return conv1d_kernel(padded, weight, bias) + return kernel(padded, weight, bias) diff --git a/problems/helion/fp8_quant_py/submission.py b/problems/helion/fp8_quant_py/submission.py index e3108d51..4b562fa9 100644 --- a/problems/helion/fp8_quant_py/submission.py +++ b/problems/helion/fp8_quant_py/submission.py @@ -5,52 +5,68 @@ import helion.language as hl from pathlib import Path -COFIG_DICT={ - "block_sizes": [1], - "num_warps": 1, - "num_stages": 1, + +# Per-shape configs: map (num_tokens, hidden_dim, group_size) to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # Test shapes + (1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + # Benchmark shapes + # (1, 4096, 128) already covered above + (16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config } -ACF_FILE = "booster_pack/fp8_group_quant_0.acf" -if Path(ACF_FILE).exists(): - print(f"Using ACF file: {ACF_FILE}") - COFIG_DICT["advanced_controls_file"] = ACF_FILE + +# Optional: add advanced_controls_file to your Config for extra performance (see docs). +# Autotune with autotune_search_acf to find the best ACF, then hardcode it: +# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf") + # NOTE: This is an intentionally inefficient baseline implementation. -@helion.kernel( - static_shapes=True, - config=helion.Config(**COFIG_DICT), -) -def normalize_to_range( - data: torch.Tensor, # [N, G] input rows - scales_out: torch.Tensor, # [N] output normalization factors -) -> torch.Tensor: - nrows = data.size(0) - ncols = hl.specialize(data.size(1)) - MAX_VAL = 448.0 - - qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device) - - for rr in hl.tile(nrows): - row = data[rr, :].to(torch.float32) - - abs1 = torch.abs(row) - amax1 = torch.amax(abs1, -1) - abs2 = torch.abs(row) - amax2 = torch.amax(abs2, -1) - abs3 = torch.abs(row) - amax3 = torch.amax(abs3, -1) - amax = (amax1 + amax2 + amax3) / 3.0 - amax = torch.clamp(amax, min=1e-10) - scale = amax / MAX_VAL - - q1 = row / scale[:, None] - q2 = row / scale[:, None] - q3 = row / scale[:, None] - qout[rr, :] = (q1 + q2 + q3) / 3.0 - scales_out[rr] = scale - - return qout +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, config=config) + def kernel( + data: torch.Tensor, # [N, G] input rows + scales_out: torch.Tensor, # [N] output normalization factors + ) -> torch.Tensor: + nrows = data.size(0) + ncols = hl.specialize(data.size(1)) + MAX_VAL = 448.0 + + qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device) + + for rr in hl.tile(nrows): + row = data[rr, :].to(torch.float32) + + abs1 = torch.abs(row) + amax1 = torch.amax(abs1, -1) + abs2 = torch.abs(row) + amax2 = torch.amax(abs2, -1) + abs3 = torch.abs(row) + amax3 = torch.amax(abs3, -1) + amax = (amax1 + amax2 + amax3) / 3.0 + amax = torch.clamp(amax, min=1e-10) + scale = amax / MAX_VAL + + q1 = row / scale[:, None] + q2 = row / scale[:, None] + q3 = row / scale[:, None] + qout[rr, :] = (q1 + q2 + q3) / 3.0 + scales_out[rr] = scale + + return qout + + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: @@ -60,10 +76,12 @@ def custom_kernel(data: input_t) -> output_t: gsz = H // G N = T * G + kernel = _KERNELS[(T, H, gsz)] + flat_in = x.reshape(N, gsz) flat_s = x_s.reshape(N) - flat_q = normalize_to_range(flat_in, flat_s) + flat_q = kernel(flat_in, flat_s) x_q[...] = flat_q.reshape(T, H) x_s[...] = flat_s.reshape(T, G) diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py index 528a61cc..04e0ecfc 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py @@ -5,65 +5,93 @@ import helion.language as hl +# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # Test shapes + (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + # Benchmark shapes + (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config +} + + +# Optional: add advanced_controls_file to your Config for extra performance (see docs). +# Autotune with autotune_search_acf to find the best ACF, then hardcode it: +# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_h_0.acf") + + # NOTE: This is an intentionally inefficient baseline implementation. -@helion.kernel( - static_shapes=True, - dot_precision="ieee", - config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), -) -def chunk_state_pass( - k: torch.Tensor, # [B, T, H, K] - w: torch.Tensor, # [B, T, H, K] - u: torch.Tensor, # [B, T, H, V] - g: torch.Tensor, # [B, T, H] -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K = k.shape - V = u.shape[-1] - C = 64 - K = hl.specialize(K) - V = hl.specialize(V) +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) + def kernel( + k: torch.Tensor, # [B, T, H, K] + w: torch.Tensor, # [B, T, H, K] + u: torch.Tensor, # [B, T, H, V] + g: torch.Tensor, # [B, T, H] + ) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K = k.shape + V = u.shape[-1] + C = 64 + K = hl.specialize(K) + V = hl.specialize(V) - NT = (T + C - 1) // C - h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device) - v_out = torch.empty_like(u) + NT = (T + C - 1) // C + h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device) + v_out = torch.empty_like(u) - BH = B * H + BH = B * H - for flat, tv in hl.tile([BH, V], block_size=[1, 8]): - b_idx = flat.begin // H - h_idx = flat.begin % H - state = hl.zeros([K, tv], dtype=torch.float32) + for flat, tv in hl.tile([BH, V], block_size=[1, 8]): + b_idx = flat.begin // H + h_idx = flat.begin % H + state = hl.zeros([K, tv], dtype=torch.float32) - for tc in hl.tile(T, block_size=C): - chunk_idx = tc.begin // C - t_end = min(tc.begin + C, T) - 1 + for tc in hl.tile(T, block_size=C): + chunk_idx = tc.begin // C + t_end = min(tc.begin + C, T) - 1 - h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype) + h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype) - proj1 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj2 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj = (proj1 + proj2) * 0.5 - diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj - v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype) + proj1 = hl.dot( + w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 + ) + proj2 = hl.dot( + w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 + ) + proj = (proj1 + proj2) * 0.5 + diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj + v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype) - g_end = g[b_idx, t_end, h_idx].to(torch.float32) - g_t = g[b_idx, tc, h_idx].to(torch.float32) - valid = tc.index < T - alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0) - k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None] + g_end = g[b_idx, t_end, h_idx].to(torch.float32) + g_t = g[b_idx, tc, h_idx].to(torch.float32) + valid = tc.index < T + alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0) + k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None] - state = state * torch.exp(g_end) - upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - state = state + (upd1 + upd2) * 0.5 + state = state * torch.exp(g_end) + upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) + upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) + state = state + (upd1 + upd2) * 0.5 - return h_out, v_out + return h_out, v_out + + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: k, w, u, g = data - return chunk_state_pass(k, w, u, g) + B, T, H, K = k.shape + V = u.shape[-1] + kernel = _KERNELS[(B, T, H, K, V)] + return kernel(k, w, u, g) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py index 8e2a2f53..0743521d 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py @@ -5,58 +5,86 @@ import helion.language as hl +# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # Test shapes + (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + # Benchmark shapes + (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config +} + + +# Optional: add advanced_controls_file to your Config for extra performance (see docs). +# Autotune with autotune_search_acf to find the best ACF, then hardcode it: +# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_o_0.acf") + + # NOTE: This is an intentionally inefficient baseline implementation. -@helion.kernel( - static_shapes=True, - dot_precision="ieee", - config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), -) -def gated_chunk_attn( - q: torch.Tensor, # [B, T, H, K] - k: torch.Tensor, # [B, T, H, K] - v: torch.Tensor, # [B, T, H, V] - h: torch.Tensor, # [B, NT, H, K, V] - g: torch.Tensor, # [B, T, H] - scale: float, -) -> torch.Tensor: - B, T, H, K = q.shape - V = v.shape[-1] - C = 64 - K = hl.specialize(K) - V = hl.specialize(V) +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) + def kernel( + q: torch.Tensor, # [B, T, H, K] + k: torch.Tensor, # [B, T, H, K] + v: torch.Tensor, # [B, T, H, V] + h: torch.Tensor, # [B, NT, H, K, V] + g: torch.Tensor, # [B, T, H] + scale: float, + ) -> torch.Tensor: + B, T, H, K = q.shape + V = v.shape[-1] + C = 64 + K = hl.specialize(K) + V = hl.specialize(V) - out = torch.empty_like(v) + out = torch.empty_like(v) - BH = B * H - for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H - c_idx = tile_t.begin // C + BH = B * H + for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): + b_idx = flat_bh.begin // H + h_idx = flat_bh.begin % H + c_idx = tile_t.begin // C - g_vals = g[b_idx, tile_t, h_idx] - q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] - k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] + g_vals = g[b_idx, tile_t, h_idx] + q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] + k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] - sim1 = hl.dot(q_s, k_s.T) - sim2 = hl.dot(q_s, k_s.T) - sim = (sim1 + sim2) * 0.5 - idx = hl.arange(tile_t.block_size) - mask = idx[:, None] >= idx[None, :] - sim = torch.where(mask, sim, 0.0) - local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) - local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) - local_out = (local1 + local2) * 0.5 + sim1 = hl.dot(q_s, k_s.T) + sim2 = hl.dot(q_s, k_s.T) + sim = (sim1 + sim2) * 0.5 + idx = hl.arange(tile_t.block_size) + mask = idx[:, None] >= idx[None, :] + sim = torch.where(mask, sim, 0.0) + local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) + local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) + local_out = (local1 + local2) * 0.5 - glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) - glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) - global_out = (glob1 + glob2) * 0.5 + glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) + glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) + global_out = (glob1 + glob2) * 0.5 - out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) + out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) - return out + return out + + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: q, k, v_new, h, g = data - scale = q.shape[-1] ** -0.5 - return gated_chunk_attn(q, k, v_new, h, g, scale) + B, T, H, K = q.shape + V = v_new.shape[-1] + scale = K ** -0.5 + kernel = _KERNELS[(B, T, H, K, V)] + return kernel(q, k, v_new, h, g, scale) diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py index 918f519a..07fb0691 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py @@ -5,68 +5,96 @@ import helion.language as hl +# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # Test shapes + (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + # Benchmark shapes + (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config +} + + +# Optional: add advanced_controls_file to your Config for extra performance (see docs). +# Autotune with autotune_search_acf to find the best ACF, then hardcode it: +# helion.Config(..., advanced_controls_file="/opt/booster_pack/recompute_w_u_fwd_0.acf") + + # NOTE: This is an intentionally inefficient baseline implementation. -@helion.kernel( - static_shapes=True, - dot_precision="ieee", - config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), -) -def project_kv( - k: torch.Tensor, # [B, T, H, K] - v: torch.Tensor, # [B, T, H, V] - beta: torch.Tensor, # [B, T, H] - A: torch.Tensor, # [B, T, H, BT] - g: torch.Tensor, # [B, T, H] -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K = k.shape - V = v.shape[-1] - C = hl.specialize(A.shape[-1]) - K = hl.specialize(K) - V = hl.specialize(V) +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) + def kernel( + k: torch.Tensor, # [B, T, H, K] + v: torch.Tensor, # [B, T, H, V] + beta: torch.Tensor, # [B, T, H] + A: torch.Tensor, # [B, T, H, BT] + g: torch.Tensor, # [B, T, H] + ) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K = k.shape + V = v.shape[-1] + C = hl.specialize(A.shape[-1]) + K = hl.specialize(K) + V = hl.specialize(V) + + w_out = torch.empty_like(k) + u_out = torch.empty_like(v) - w_out = torch.empty_like(k) - u_out = torch.empty_like(v) + BH = B * H + for flat_bh, rt in hl.tile([BH, T], block_size=[1, C]): + b_idx = flat_bh.begin // H + h_idx = flat_bh.begin % H - BH = B * H - for flat_bh, rt in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H + w_acc1 = hl.zeros([rt, K], dtype=torch.float32) + u_acc1 = hl.zeros([rt, V], dtype=torch.float32) + w_acc2 = hl.zeros([rt, K], dtype=torch.float32) + u_acc2 = hl.zeros([rt, V], dtype=torch.float32) - w_acc1 = hl.zeros([rt, K], dtype=torch.float32) - u_acc1 = hl.zeros([rt, V], dtype=torch.float32) - w_acc2 = hl.zeros([rt, K], dtype=torch.float32) - u_acc2 = hl.zeros([rt, V], dtype=torch.float32) + for ci in range(C): + t_ci = rt.begin + ci + a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) + coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) + decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - for ci in range(C): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) + k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) + v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) + w_acc1 = w_acc1 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] + u_acc1 = u_acc1 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - w_acc1 = w_acc1 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc1 = u_acc1 + a_col[:, None] * (v_ci * coeff_ci)[None, :] + for ci in range(C - 1, -1, -1): + t_ci = rt.begin + ci + a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) + coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) + decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - for ci in range(C - 1, -1, -1): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) + k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) + v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) + w_acc2 = w_acc2 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] + u_acc2 = u_acc2 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - w_acc2 = w_acc2 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc2 = u_acc2 + a_col[:, None] * (v_ci * coeff_ci)[None, :] + w_out[b_idx, rt, h_idx, :] = ((w_acc1 + w_acc2) * 0.5).to(k.dtype) + u_out[b_idx, rt, h_idx, :] = ((u_acc1 + u_acc2) * 0.5).to(v.dtype) - w_out[b_idx, rt, h_idx, :] = ((w_acc1 + w_acc2) * 0.5).to(k.dtype) - u_out[b_idx, rt, h_idx, :] = ((u_acc1 + u_acc2) * 0.5).to(v.dtype) + return w_out, u_out - return w_out, u_out + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: k, v, beta, A, g = data - return project_kv(k, v, beta, A, g) + B, T, H, K = k.shape + V = v.shape[-1] + kernel = _KERNELS[(B, T, H, K, V)] + return kernel(k, v, beta, A, g) diff --git a/problems/helion/template.py b/problems/helion/template.py index 4aec6a6c..37d04820 100644 --- a/problems/helion/template.py +++ b/problems/helion/template.py @@ -1,5 +1,31 @@ from task import input_t, output_t +import torch +import helion +import helion.language as hl + + +# Per-shape configs: map input shape tuples to optimized helion.Config objects. +# Autotune locally for each shape, then paste the best config here. +# Include all test and benchmark shapes from task.yml. +SHAPE_CONFIGS: dict[tuple, helion.Config] = { + # (shape_dim_1, shape_dim_2, ...): helion.Config(...), # TODO: replace with your config +} + + +def _make_kernel(config: helion.Config): + @helion.kernel(static_shapes=True, config=config) + def kernel(...) -> ...: + # Your Helion kernel implementation + ... + + return kernel + + +_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} def custom_kernel(data: input_t) -> output_t: + # Extract shape key from input tensors to select the right kernel + # shape_key = (...) + # kernel = _KERNELS[shape_key] pass