Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 67 additions & 39 deletions problems/helion/causal_conv1d_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,77 @@
import helion.language as hl


# Per-shape configs: map (B, D, S, W) to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
# Benchmark shapes
(1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
}


# Optional: add advanced_controls_file to your Config for extra performance (see docs).
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
@helion.kernel(
static_shapes=True,
config=helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1),
)
def conv1d_kernel(
x_pad: torch.Tensor, # (B, D, L) zero-padded input
w: torch.Tensor, # (D, W) filter coefficients
b: torch.Tensor, # (D,) additive offset
) -> torch.Tensor:
B = x_pad.size(0)
D = x_pad.size(1)
L = x_pad.size(2)
W = hl.specialize(w.size(1))
N = L - W + 1

y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device)

for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]):
bi = rb.begin
acc1 = hl.zeros([rd, rs], dtype=torch.float32)
acc2 = hl.zeros([rd, rs], dtype=torch.float32)
acc3 = hl.zeros([rd, rs], dtype=torch.float32)
for j in range(W):
c1 = w[rd, j].to(torch.float32)
x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc1 = acc1 + x1 * c1[:, None]
c2 = w[rd, j].to(torch.float32)
x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc2 = acc2 + x2 * c2[:, None]
c3 = w[rd, j].to(torch.float32)
x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc3 = acc3 + x3 * c3[:, None]
acc = (acc1 + acc2 + acc3) / 3.0
acc = acc + b[rd].to(torch.float32)[:, None]
y[rb, rd, rs] = acc[None, :, :].to(y.dtype)

return y
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def kernel(
x_pad: torch.Tensor, # (B, D, L) zero-padded input
w: torch.Tensor, # (D, W) filter coefficients
b: torch.Tensor, # (D,) additive offset
) -> torch.Tensor:
B = x_pad.size(0)
D = x_pad.size(1)
L = x_pad.size(2)
W = hl.specialize(w.size(1))
N = L - W + 1

y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device)

for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]):
bi = rb.begin
acc1 = hl.zeros([rd, rs], dtype=torch.float32)
acc2 = hl.zeros([rd, rs], dtype=torch.float32)
acc3 = hl.zeros([rd, rs], dtype=torch.float32)
for j in range(W):
c1 = w[rd, j].to(torch.float32)
x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc1 = acc1 + x1 * c1[:, None]
c2 = w[rd, j].to(torch.float32)
x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc2 = acc2 + x2 * c2[:, None]
c3 = w[rd, j].to(torch.float32)
x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc3 = acc3 + x3 * c3[:, None]
acc = (acc1 + acc2 + acc3) / 3.0
acc = acc + b[rd].to(torch.float32)[:, None]
y[rb, rd, rs] = acc[None, :, :].to(y.dtype)

return y

return kernel


_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}


def custom_kernel(data: input_t) -> output_t:
x, weight, bias = data
B, D, S = x.shape
W = weight.shape[1]
pad_zeros = torch.zeros(x.shape[0], x.shape[1], W - 1, dtype=x.dtype, device=x.device)
kernel = _KERNELS[(B, D, S, W)]
pad_zeros = torch.zeros(B, D, W - 1, dtype=x.dtype, device=x.device)
padded = torch.cat([pad_zeros, x], dim=2)
return conv1d_kernel(padded, weight, bias)
return kernel(padded, weight, bias)
104 changes: 61 additions & 43 deletions problems/helion/fp8_quant_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,68 @@
import helion.language as hl
from pathlib import Path

COFIG_DICT={
"block_sizes": [1],
"num_warps": 1,
"num_stages": 1,

# Per-shape configs: map (num_tokens, hidden_dim, group_size) to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
# Benchmark shapes
# (1, 4096, 128) already covered above
(16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
}

ACF_FILE = "booster_pack/fp8_group_quant_0.acf"
if Path(ACF_FILE).exists():
print(f"Using ACF file: {ACF_FILE}")
COFIG_DICT["advanced_controls_file"] = ACF_FILE

# Optional: add advanced_controls_file to your Config for extra performance (see docs).
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
@helion.kernel(
static_shapes=True,
config=helion.Config(**COFIG_DICT),
)
def normalize_to_range(
data: torch.Tensor, # [N, G] input rows
scales_out: torch.Tensor, # [N] output normalization factors
) -> torch.Tensor:
nrows = data.size(0)
ncols = hl.specialize(data.size(1))
MAX_VAL = 448.0

qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device)

for rr in hl.tile(nrows):
row = data[rr, :].to(torch.float32)

abs1 = torch.abs(row)
amax1 = torch.amax(abs1, -1)
abs2 = torch.abs(row)
amax2 = torch.amax(abs2, -1)
abs3 = torch.abs(row)
amax3 = torch.amax(abs3, -1)
amax = (amax1 + amax2 + amax3) / 3.0
amax = torch.clamp(amax, min=1e-10)
scale = amax / MAX_VAL

q1 = row / scale[:, None]
q2 = row / scale[:, None]
q3 = row / scale[:, None]
qout[rr, :] = (q1 + q2 + q3) / 3.0
scales_out[rr] = scale

return qout
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def kernel(
data: torch.Tensor, # [N, G] input rows
scales_out: torch.Tensor, # [N] output normalization factors
) -> torch.Tensor:
nrows = data.size(0)
ncols = hl.specialize(data.size(1))
MAX_VAL = 448.0

qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device)

for rr in hl.tile(nrows):
row = data[rr, :].to(torch.float32)

abs1 = torch.abs(row)
amax1 = torch.amax(abs1, -1)
abs2 = torch.abs(row)
amax2 = torch.amax(abs2, -1)
abs3 = torch.abs(row)
amax3 = torch.amax(abs3, -1)
amax = (amax1 + amax2 + amax3) / 3.0
amax = torch.clamp(amax, min=1e-10)
scale = amax / MAX_VAL

q1 = row / scale[:, None]
q2 = row / scale[:, None]
q3 = row / scale[:, None]
qout[rr, :] = (q1 + q2 + q3) / 3.0
scales_out[rr] = scale

return qout

return kernel


_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}


def custom_kernel(data: input_t) -> output_t:
Expand All @@ -60,10 +76,12 @@ def custom_kernel(data: input_t) -> output_t:
gsz = H // G
N = T * G

kernel = _KERNELS[(T, H, gsz)]

flat_in = x.reshape(N, gsz)
flat_s = x_s.reshape(N)

flat_q = normalize_to_range(flat_in, flat_s)
flat_q = kernel(flat_in, flat_s)

x_q[...] = flat_q.reshape(T, H)
x_s[...] = flat_s.reshape(T, G)
Expand Down
124 changes: 76 additions & 48 deletions problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,93 @@
import helion.language as hl


# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
# Benchmark shapes
(1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
}


# Optional: add advanced_controls_file to your Config for extra performance (see docs).
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_h_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
@helion.kernel(
static_shapes=True,
dot_precision="ieee",
config=helion.Config(block_sizes=[], num_warps=1, num_stages=1),
)
def chunk_state_pass(
k: torch.Tensor, # [B, T, H, K]
w: torch.Tensor, # [B, T, H, K]
u: torch.Tensor, # [B, T, H, V]
g: torch.Tensor, # [B, T, H]
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K = k.shape
V = u.shape[-1]
C = 64
K = hl.specialize(K)
V = hl.specialize(V)
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, dot_precision="ieee", config=config)
def kernel(
k: torch.Tensor, # [B, T, H, K]
w: torch.Tensor, # [B, T, H, K]
u: torch.Tensor, # [B, T, H, V]
g: torch.Tensor, # [B, T, H]
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K = k.shape
V = u.shape[-1]
C = 64
K = hl.specialize(K)
V = hl.specialize(V)

NT = (T + C - 1) // C
h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device)
v_out = torch.empty_like(u)
NT = (T + C - 1) // C
h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device)
v_out = torch.empty_like(u)

BH = B * H
BH = B * H

for flat, tv in hl.tile([BH, V], block_size=[1, 8]):
b_idx = flat.begin // H
h_idx = flat.begin % H
state = hl.zeros([K, tv], dtype=torch.float32)
for flat, tv in hl.tile([BH, V], block_size=[1, 8]):
b_idx = flat.begin // H
h_idx = flat.begin % H
state = hl.zeros([K, tv], dtype=torch.float32)

for tc in hl.tile(T, block_size=C):
chunk_idx = tc.begin // C
t_end = min(tc.begin + C, T) - 1
for tc in hl.tile(T, block_size=C):
chunk_idx = tc.begin // C
t_end = min(tc.begin + C, T) - 1

h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype)
h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype)

proj1 = hl.dot(
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
)
proj2 = hl.dot(
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
)
proj = (proj1 + proj2) * 0.5
diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj
v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype)
proj1 = hl.dot(
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
)
proj2 = hl.dot(
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
)
proj = (proj1 + proj2) * 0.5
diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj
v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype)

g_end = g[b_idx, t_end, h_idx].to(torch.float32)
g_t = g[b_idx, tc, h_idx].to(torch.float32)
valid = tc.index < T
alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0)
k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None]
g_end = g[b_idx, t_end, h_idx].to(torch.float32)
g_t = g[b_idx, tc, h_idx].to(torch.float32)
valid = tc.index < T
alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0)
k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None]

state = state * torch.exp(g_end)
upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
state = state + (upd1 + upd2) * 0.5
state = state * torch.exp(g_end)
upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
state = state + (upd1 + upd2) * 0.5

return h_out, v_out
return h_out, v_out

return kernel


_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}


def custom_kernel(data: input_t) -> output_t:
k, w, u, g = data
return chunk_state_pass(k, w, u, g)
B, T, H, K = k.shape
V = u.shape[-1]
kernel = _KERNELS[(B, T, H, K, V)]
return kernel(k, w, u, g)
Loading