diff --git a/problems/helion.yaml b/problems/helion.yaml new file mode 100644 index 00000000..32f9b0e8 --- /dev/null +++ b/problems/helion.yaml @@ -0,0 +1,29 @@ +name: Helion Kernel Challenge +deadline: "2026-03-14" +description: "GPU kernel challenges inspired by Helion kernel ideas — convolution, quantization, and gated deltanet operators from production LLM architectures." +problems: + - directory: helion/causal_conv1d_py + name: causal_conv1d + deadline: "2026-03-14 00:00" + gpus: + - NVIDIA + - directory: helion/fp8_quant_py + name: fp8_quant + deadline: "2026-03-14 00:00" + gpus: + - NVIDIA + - directory: helion/gated_deltanet_chunk_fwd_h_py + name: gated_deltanet_chunk_fwd_h + deadline: "2026-03-14 00:00" + gpus: + - NVIDIA + - directory: helion/gated_deltanet_chunk_fwd_o_py + name: gated_deltanet_chunk_fwd_o + deadline: "2026-03-14 00:00" + gpus: + - NVIDIA + - directory: helion/gated_deltanet_recompute_w_u_py + name: gated_deltanet_recompute_w_u + deadline: "2026-03-14 00:00" + gpus: + - NVIDIA diff --git a/problems/helion/causal_conv1d_py/reference.py b/problems/helion/causal_conv1d_py/reference.py new file mode 100644 index 00000000..e132fbf5 --- /dev/null +++ b/problems/helion/causal_conv1d_py/reference.py @@ -0,0 +1,35 @@ +import torch +import torch.nn.functional as F +from task import input_t, output_t +from utils import make_match_reference, DeterministicContext + + +def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t: + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + x = torch.randn(B, D, S, dtype=torch.float32, device="cuda", generator=gen).contiguous() + weight = torch.randn(D, W, dtype=torch.float32, device="cuda", generator=gen).contiguous() + bias = torch.randn(D, dtype=torch.float32, device="cuda", generator=gen).contiguous() + return x, weight, bias + + +def ref_kernel(data: input_t) -> output_t: + with DeterministicContext(): + x, weight, bias = data + B, D, S = x.shape + W = weight.shape[1] + + # Causal (left) padding + x_padded = F.pad(x, (W - 1, 0)) + + # Depthwise conv1d (groups=D) + output = F.conv1d( + x_padded, + weight.unsqueeze(1), # [D, 1, W] + bias=bias, + groups=D, + ) + return output + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-4, atol=1e-4) diff --git a/problems/helion/causal_conv1d_py/submission.py b/problems/helion/causal_conv1d_py/submission.py new file mode 100644 index 00000000..ba89f5ad --- /dev/null +++ b/problems/helion/causal_conv1d_py/submission.py @@ -0,0 +1,14 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + import torch + import torch.nn.functional as F + + x, weight, bias = data + W = weight.shape[1] + D = x.shape[1] + + x_padded = F.pad(x, (W - 1, 0)) + output = F.conv1d(x_padded, weight.unsqueeze(1), bias=bias, groups=D) + return output diff --git a/problems/helion/causal_conv1d_py/task.py b/problems/helion/causal_conv1d_py/task.py new file mode 100644 index 00000000..00a02fe6 --- /dev/null +++ b/problems/helion/causal_conv1d_py/task.py @@ -0,0 +1,12 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) + +class TestSpec(TypedDict): + B: int + D: int + S: int + W: int + seed: int diff --git a/problems/helion/causal_conv1d_py/task.yml b/problems/helion/causal_conv1d_py/task.yml new file mode 100644 index 00000000..8ef81809 --- /dev/null +++ b/problems/helion/causal_conv1d_py/task.yml @@ -0,0 +1,51 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement a causal depthwise 1D convolution kernel. + + This is a core component of Mamba/Mamba-2 architectures. Each channel is + convolved independently (depthwise) with causal (left) zero-padding so that + output[t] depends only on input[t-W+1:t+1]. + + For each batch b, channel d, and time t: + out[b, d, t] = bias[d] + sum_{k=0}^{W-1} weight[d, k] * x[b, d, t - W + 1 + k] + where out-of-bounds values are treated as zero. + + Input: tuple(x, weight, bias) where: + - x: torch.Tensor of shape [B, D, S] (float32) + - weight: torch.Tensor of shape [D, W] (float32) + - bias: torch.Tensor of shape [D] (float32) + + Output: torch.Tensor of shape [B, D, S] (float32) + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"B": 1, "D": 64, "S": 64, "W": 4, "seed": 4242} + - {"B": 2, "D": 128, "S": 128, "W": 4, "seed": 5236} + - {"B": 1, "D": 256, "S": 256, "W": 3, "seed": 1001} + - {"B": 1, "D": 128, "S": 64, "W": 8, "seed": 5531} + - {"B": 4, "D": 64, "S": 128, "W": 4, "seed": 9173} + +benchmarks: + - {"B": 1, "D": 768, "S": 512, "W": 4, "seed": 31232} + - {"B": 1, "D": 768, "S": 2048, "W": 4, "seed": 4052} + - {"B": 1, "D": 1536, "S": 2048, "W": 4, "seed": 2146} + - {"B": 1, "D": 2560, "S": 2048, "W": 4, "seed": 3129} + - {"B": 1, "D": 2560, "S": 4096, "W": 4, "seed": 54352} + +test_timeout: 180 +benchmark_timeout: 180 +ranked_timeout: 420 +ranking_by: "geom" diff --git a/problems/helion/eval.py b/problems/helion/eval.py new file mode 100644 index 00000000..f99aba46 --- /dev/null +++ b/problems/helion/eval.py @@ -0,0 +1,378 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional + +import torch.cuda + +from utils import set_seed, clear_l2_cache +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, 'w') + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a+b)*(a+b+1)//2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + if val == "true": + val = True + elif val == "false": + val = False + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg)**2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), + worst=float(worst)) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + data = generate_input(**test.args) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + output = custom_kernel(data) + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: + break + + return calculate_stats(durations) + + +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/helion/fp8_quant_py/reference.py b/problems/helion/fp8_quant_py/reference.py new file mode 100644 index 00000000..b8792248 --- /dev/null +++ b/problems/helion/fp8_quant_py/reference.py @@ -0,0 +1,59 @@ +import torch +from task import input_t, output_t +from utils import verbose_allclose + +FP8_MAX = 448.0 +FP8_MIN = -448.0 +FP8_EPS = 1e-10 + + +def generate_input(num_tokens: int, hidden_dim: int, group_size: int, seed: int) -> input_t: + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + x = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda", generator=gen).contiguous() + x_q = torch.empty(num_tokens, hidden_dim, dtype=torch.float32, device="cuda").contiguous() + x_s = torch.empty(num_tokens, hidden_dim // group_size, dtype=torch.float32, device="cuda").contiguous() + return x, x_q, x_s + + +def ref_kernel(data: input_t) -> output_t: + x, x_q, x_s = data + num_tokens, hidden_dim = x.shape + num_groups = x_s.shape[1] + group_size = hidden_dim // num_groups + + x_f32 = x.float() + x_grouped = x_f32.reshape(num_tokens, num_groups, group_size) + + # Per-group absmax + absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS) + + # Scale = absmax / fp8_max + scale = absmax / FP8_MAX + + # Quantize + quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX) + quantized = quantized.reshape(num_tokens, hidden_dim) + + x_q[...] = quantized + x_s[...] = scale + return x_q, x_s + + +def check_implementation(data, output): + expected = ref_kernel(data) + expected_q, expected_s = expected + received_q, received_s = output + + reasons_q = verbose_allclose(received_q, expected_q, rtol=1e-3, atol=1e-3) + reasons_s = verbose_allclose(received_s, expected_s, rtol=1e-4, atol=1e-6) + + reasons = [] + if reasons_q: + reasons.append("quantized values mismatch: " + " ".join(reasons_q)) + if reasons_s: + reasons.append("scales mismatch: " + " ".join(reasons_s)) + + if reasons: + return False, " | ".join(reasons) + return True, "" diff --git a/problems/helion/fp8_quant_py/submission.py b/problems/helion/fp8_quant_py/submission.py new file mode 100644 index 00000000..39cf1d08 --- /dev/null +++ b/problems/helion/fp8_quant_py/submission.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +FP8_MAX = 448.0 +FP8_MIN = -448.0 +FP8_EPS = 1e-10 + + +def custom_kernel(data: input_t) -> output_t: + x, x_q, x_s = data + num_tokens, hidden_dim = x.shape + num_groups = x_s.shape[1] + group_size = hidden_dim // num_groups + + x_f32 = x.float() + x_grouped = x_f32.reshape(num_tokens, num_groups, group_size) + + absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS) + scale = absmax / FP8_MAX + quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX) + quantized = quantized.reshape(num_tokens, hidden_dim) + + x_q[...] = quantized + x_s[...] = scale + return x_q, x_s diff --git a/problems/helion/fp8_quant_py/task.py b/problems/helion/fp8_quant_py/task.py new file mode 100644 index 00000000..8fb6c1f0 --- /dev/null +++ b/problems/helion/fp8_quant_py/task.py @@ -0,0 +1,11 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) + +class TestSpec(TypedDict): + num_tokens: int + hidden_dim: int + group_size: int + seed: int diff --git a/problems/helion/fp8_quant_py/task.yml b/problems/helion/fp8_quant_py/task.yml new file mode 100644 index 00000000..d8288ead --- /dev/null +++ b/problems/helion/fp8_quant_py/task.yml @@ -0,0 +1,58 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement a per-token-group FP8 E4M3 quantization kernel. + + This is THE standard activation quantization method in production LLM inference + (DeepSeek-V3, Llama 3, Qwen3). It dynamically quantizes activations to FP8 + format with per-group scale factors for W8A8 quantized inference. + + For each group of `group_size` contiguous elements: + 1. absmax = max(|x_group|) + 2. scale = max(absmax, eps) / 448.0 + 3. x_q = clamp(x / scale, -448.0, 448.0) + + Where 448.0 is the max representable value in FP8 E4M3 format. + + NOTE: Output is float32 clamped to FP8 range (for broad GPU compatibility). + + Input: tuple(x, x_q, x_s) where: + - x: torch.Tensor of shape [num_tokens, hidden_dim] (float32) + - x_q: pre-allocated output [num_tokens, hidden_dim] (float32) + - x_s: pre-allocated scales [num_tokens, hidden_dim // group_size] (float32) + + Output: tuple(x_q, x_s) where: + - x_q: quantized values [num_tokens, hidden_dim] (float32, clamped to FP8 range) + - x_s: per-group scale factors [num_tokens, hidden_dim // group_size] (float32) + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"num_tokens": 1, "hidden_dim": 256, "group_size": 64, "seed": 4242} + - {"num_tokens": 4, "hidden_dim": 512, "group_size": 128, "seed": 5236} + - {"num_tokens": 16, "hidden_dim": 1024, "group_size": 64, "seed": 1001} + - {"num_tokens": 1, "hidden_dim": 4096, "group_size": 128, "seed": 5531} + - {"num_tokens": 8, "hidden_dim": 4096, "group_size": 128, "seed": 9173} + +benchmarks: + - {"num_tokens": 1, "hidden_dim": 4096, "group_size": 128, "seed": 31232} + - {"num_tokens": 16, "hidden_dim": 4096, "group_size": 128, "seed": 4052} + - {"num_tokens": 256, "hidden_dim": 4096, "group_size": 128, "seed": 2146} + - {"num_tokens": 256, "hidden_dim": 8192, "group_size": 128, "seed": 3129} + - {"num_tokens": 4096, "hidden_dim": 7168, "group_size": 128, "seed": 54352} + +test_timeout: 180 +benchmark_timeout: 180 +ranked_timeout: 420 +ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py new file mode 100644 index 00000000..ecee1896 --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py @@ -0,0 +1,79 @@ +import torch +from task import input_t, output_t +from utils import verbose_allclose + +CHUNK_SIZE = 64 + + +def generate_input(B: int, T: int, H: int, K: int, V: int, use_initial_state: bool, seed: int) -> input_t: + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + k = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous() + w = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous() + u = torch.randn(B, T, H, V, dtype=torch.float32, device="cuda", generator=gen).contiguous() + # Use negative values for g to keep exp(g) bounded in (0, 1] and prevent overflow + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda", generator=gen)).contiguous() + if use_initial_state: + initial_state = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda", generator=gen).contiguous() + else: + initial_state = torch.zeros(B, H, K, V, dtype=torch.float32, device="cuda").contiguous() + return k, w, u, g, initial_state + + +def ref_kernel(data: input_t) -> output_t: + k, w, u, g, initial_state = data + B, T, H, K = k.shape + V = u.shape[-1] + BT = CHUNK_SIZE + NT = T // BT + + h = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=k.device) + v_new = torch.empty_like(u) + + for b in range(B): + for hh in range(H): + b_h = initial_state[b, hh].float().clone() # [K, V] + + for c in range(NT): + cs = c * BT + ce = cs + BT + + # Store current state + h[b, c, hh] = b_h + + # v_new = u - w @ h_state + b_w = w[b, cs:ce, hh].float() # [BT, K] + b_u = u[b, cs:ce, hh].float() # [BT, V] + b_v = b_u - torch.matmul(b_w, b_h) # [BT, V] + v_new[b, cs:ce, hh] = b_v + + # Gating + b_g = g[b, cs:ce, hh].float() # [BT] + b_g_last = b_g[-1] + b_v_gated = b_v * torch.exp(b_g_last - b_g)[:, None] + + # Decay and update + b_h = b_h * torch.exp(b_g_last) + b_k = k[b, cs:ce, hh].float() # [BT, K] + b_h = b_h + torch.matmul(b_k.T, b_v_gated) + + return h, v_new + + +def check_implementation(data, output): + expected = ref_kernel(data) + exp_h, exp_v = expected + got_h, got_v = output + + reasons_h = verbose_allclose(got_h, exp_h, rtol=1e-2, atol=1e-2) + reasons_v = verbose_allclose(got_v, exp_v, rtol=1e-2, atol=1e-2) + + reasons = [] + if reasons_h: + reasons.append("h mismatch: " + " ".join(reasons_h)) + if reasons_v: + reasons.append("v_new mismatch: " + " ".join(reasons_v)) + + if reasons: + return False, " | ".join(reasons) + return True, "" diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py new file mode 100644 index 00000000..38fa590a --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py @@ -0,0 +1,39 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + import torch + + k, w, u, g, initial_state = data + B, T, H, K = k.shape + V = u.shape[-1] + BT = 64 + NT = T // BT + + h = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=k.device) + v_new = torch.empty_like(u) + + for b in range(B): + for hh in range(H): + b_h = initial_state[b, hh].float().clone() + + for c in range(NT): + cs = c * BT + ce = cs + BT + + h[b, c, hh] = b_h + + b_w = w[b, cs:ce, hh].float() + b_u = u[b, cs:ce, hh].float() + b_v = b_u - torch.matmul(b_w, b_h) + v_new[b, cs:ce, hh] = b_v + + b_g = g[b, cs:ce, hh].float() + b_g_last = b_g[-1] + b_v_gated = b_v * torch.exp(b_g_last - b_g)[:, None] + + b_h = b_h * torch.exp(b_g_last) + b_k = k[b, cs:ce, hh].float() + b_h = b_h + torch.matmul(b_k.T, b_v_gated) + + return h, v_new diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py new file mode 100644 index 00000000..435bb18c --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py @@ -0,0 +1,14 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) + +class TestSpec(TypedDict): + B: int + T: int + H: int + K: int + V: int + use_initial_state: bool + seed: int diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml new file mode 100644 index 00000000..5567bfae --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml @@ -0,0 +1,69 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement the chunk_fwd_h (inter-chunk state recurrence) kernel for Gated DeltaNet. + + This kernel maintains a hidden state h of shape [K, V] across chunks and computes + v_new (corrected values) for each chunk. It is the sequential bottleneck in the + chunkwise parallel forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). + + The sequence is divided into chunks of BT=64 timesteps. Processing is sequential + across chunks but parallel across (B, H) and within each chunk: + + For each (b, h) pair, starting with h_state = initial_state[b, h] (zeros or provided): + For each chunk c = 0, 1, ..., NT-1: + 1. Store: h_out[b, c, h] = h_state + 2. Compute: v_new = u - w @ h_state + 3. Gate: v_gated[t] = v_new[t] * exp(g[last_t] - g[t]) + 4. Decay: h_state = h_state * exp(g[last_t]) + 5. Update: h_state = h_state + k^T @ v_gated + + Input: tuple(k, w, u, g, initial_state) where: + - k: torch.Tensor of shape [B, T, H, K] (float32) — keys + - w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys + - u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values + - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate + - initial_state: torch.Tensor of shape [B, H, K, V] (float32) — initial hidden state (zeros or random) + + Output: tuple(h, v_new) where: + - h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk hidden states + - v_new: torch.Tensor of shape [B, T, H, V] (float32) — corrected values + + Constraint: T must be a multiple of 64. NT = T // 64. + + See also: Helion examples/gdn_fwd_h.py for a related implementation + (simpler variant that returns only h, without v_new or initial_state support). + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "use_initial_state": false, "seed": 4242} + - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "use_initial_state": true, "seed": 5236} + - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "use_initial_state": false, "seed": 1001} + - {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "use_initial_state": true, "seed": 5531} + - {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "use_initial_state": true, "seed": 9173} + +benchmarks: + - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "use_initial_state": false, "seed": 31232} + - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "use_initial_state": true, "seed": 4052} + - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "use_initial_state": false, "seed": 2146} + - {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "use_initial_state": true, "seed": 3129} + - {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "use_initial_state": false, "seed": 54352} + - {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "use_initial_state": true, "seed": 71234} + - {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "use_initial_state": true, "seed": 82345} + +test_timeout: 180 +benchmark_timeout: 180 +ranked_timeout: 420 +ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py new file mode 100644 index 00000000..0078d9e5 --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py @@ -0,0 +1,59 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +CHUNK_SIZE = 64 + + +def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + NT = T // CHUNK_SIZE + q = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous() + k = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous() + v_new = torch.randn(B, T, H, V, dtype=torch.float32, device="cuda", generator=gen).contiguous() + h = torch.randn(B, NT, H, K, V, dtype=torch.float32, device="cuda", generator=gen).contiguous() + # Use negative values for g to keep exp(g) bounded in (0, 1] + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda", generator=gen)).contiguous() + return q, k, v_new, h, g + + +def ref_kernel(data: input_t) -> output_t: + q, k, v_new, h, g = data + B, T, H, K = q.shape + V = v_new.shape[-1] + BT = CHUNK_SIZE + scale = K ** -0.5 + + o = torch.empty_like(v_new) + causal = torch.tril(torch.ones(BT, BT, device=q.device, dtype=torch.bool)) + + for cs in range(0, T, BT): + ce = cs + BT + c_idx = cs // BT + + # Reshape to [B, H, BT, ...] for batched matmul + b_q = q[:, cs:ce, :, :].permute(0, 2, 1, 3).float() # [B, H, BT, K] + b_k = k[:, cs:ce, :, :].permute(0, 2, 1, 3).float() # [B, H, BT, K] + b_v = v_new[:, cs:ce, :, :].permute(0, 2, 1, 3).float() # [B, H, BT, V] + b_h = h[:, c_idx, :, :, :].float() # [B, H, K, V] + b_g = g[:, cs:ce, :].permute(0, 2, 1).float() # [B, H, BT] + + # Inter-chunk: q @ h * exp(g) + inter = torch.matmul(b_q, b_h) # [B, H, BT, V] + inter = inter * torch.exp(b_g).unsqueeze(-1) + + # Intra-chunk: causal(q @ k^T * exp(g_diff)) @ v_new + attn = torch.matmul(b_q, b_k.transpose(-1, -2)) # [B, H, BT, BT] + g_diff = b_g.unsqueeze(-1) - b_g.unsqueeze(-2) # [B, H, BT, BT] + attn = attn * torch.exp(g_diff) + attn = attn.masked_fill(~causal, 0.0) + intra = torch.matmul(attn, b_v) # [B, H, BT, V] + + b_o = (inter + intra) * scale + o[:, cs:ce, :, :] = b_o.permute(0, 2, 1, 3) + + return o + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py new file mode 100644 index 00000000..0b5f02cd --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py @@ -0,0 +1,38 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + import torch + + q, k, v_new, h, g = data + B, T, H, K = q.shape + V = v_new.shape[-1] + BT = 64 + scale = K ** -0.5 + + o = torch.empty_like(v_new) + causal = torch.tril(torch.ones(BT, BT, device=q.device, dtype=torch.bool)) + + for cs in range(0, T, BT): + ce = cs + BT + c_idx = cs // BT + + b_q = q[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + b_k = k[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + b_v = v_new[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + b_h = h[:, c_idx, :, :, :].float() + b_g = g[:, cs:ce, :].permute(0, 2, 1).float() + + inter = torch.matmul(b_q, b_h) + inter = inter * torch.exp(b_g).unsqueeze(-1) + + attn = torch.matmul(b_q, b_k.transpose(-1, -2)) + g_diff = b_g.unsqueeze(-1) - b_g.unsqueeze(-2) + attn = attn * torch.exp(g_diff) + attn = attn.masked_fill(~causal, 0.0) + intra = torch.matmul(attn, b_v) + + b_o = (inter + intra) * scale + o[:, cs:ce, :, :] = b_o.permute(0, 2, 1, 3) + + return o diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py new file mode 100644 index 00000000..08d4b4f6 --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py @@ -0,0 +1,13 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) + +class TestSpec(TypedDict): + B: int + T: int + H: int + K: int + V: int + seed: int diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml new file mode 100644 index 00000000..73b9321c --- /dev/null +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml @@ -0,0 +1,61 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement the chunk_fwd_o (output computation) kernel for Gated DeltaNet. + + This kernel computes the final output by combining inter-chunk (state-based) + and intra-chunk (attention-based) contributions for the chunkwise parallel + forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). + + The sequence is divided into chunks of BT=64 timesteps. For each chunk + independently: + inter = q @ h * exp(g) + intra = causal_mask(q @ k^T * exp(g[:, None] - g[None, :])) @ v_new + output = (inter + intra) * scale + + where scale = K^(-0.5), and causal_mask zeros out entries where row < col. + + Input: tuple(q, k, v_new, h, g) where: + - q: torch.Tensor of shape [B, T, H, K] (float32) — queries + - k: torch.Tensor of shape [B, T, H, K] (float32) — keys + - v_new: torch.Tensor of shape [B, T, H, V] (float32) — corrected values + - h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk states + - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate + + Output: torch.Tensor of shape [B, T, H, V] (float32) + + Constraint: T must be a multiple of 64. NT = T // 64. scale = K^(-0.5). + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242} + - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236} + - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001} + - {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "seed": 5531} + - {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "seed": 9173} + +benchmarks: + - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232} + - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052} + - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146} + - {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "seed": 3129} + - {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "seed": 54352} + - {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "seed": 71234} + - {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "seed": 82345} + +test_timeout: 180 +benchmark_timeout: 180 +ranked_timeout: 420 +ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py new file mode 100644 index 00000000..99750dda --- /dev/null +++ b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py @@ -0,0 +1,61 @@ +import torch +from task import input_t, output_t +from utils import verbose_allclose + +CHUNK_SIZE = 64 + + +def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + k = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous() + v = torch.randn(B, T, H, V, dtype=torch.float32, device="cuda", generator=gen).contiguous() + beta = torch.randn(B, T, H, dtype=torch.float32, device="cuda", generator=gen).contiguous() + A = torch.randn(B, T, H, CHUNK_SIZE, dtype=torch.float32, device="cuda", generator=gen).contiguous() + # Use negative values for g to keep exp(g) bounded in (0, 1] + g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda", generator=gen)).contiguous() + return k, v, beta, A, g + + +def ref_kernel(data: input_t) -> output_t: + k, v, beta, A, g = data + B, T, H, K = k.shape + V = v.shape[-1] + BT = CHUNK_SIZE + + w = torch.empty_like(k) + u = torch.empty_like(v) + + for cs in range(0, T, BT): + ce = cs + BT + # Reshape to [B, H, BT, BT] for batched matmul + A_bh = A[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + + # u = A @ (v * beta[..., None]) + vb = (v[:, cs:ce, :, :] * beta[:, cs:ce, :, None]).permute(0, 2, 1, 3).float() + u[:, cs:ce, :, :] = torch.matmul(A_bh, vb).permute(0, 2, 1, 3) + + # w = A @ (k * beta[..., None] * exp(g)[..., None]) + kb = (k[:, cs:ce, :, :] * beta[:, cs:ce, :, None] * torch.exp(g[:, cs:ce, :, None])).permute(0, 2, 1, 3).float() + w[:, cs:ce, :, :] = torch.matmul(A_bh, kb).permute(0, 2, 1, 3) + + return w, u + + +def check_implementation(data, output): + expected = ref_kernel(data) + exp_w, exp_u = expected + got_w, got_u = output + + reasons_w = verbose_allclose(got_w, exp_w, rtol=1e-3, atol=1e-3) + reasons_u = verbose_allclose(got_u, exp_u, rtol=1e-3, atol=1e-3) + + reasons = [] + if reasons_w: + reasons.append("w mismatch: " + " ".join(reasons_w)) + if reasons_u: + reasons.append("u mismatch: " + " ".join(reasons_u)) + + if reasons: + return False, " | ".join(reasons) + return True, "" diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py new file mode 100644 index 00000000..ec50c3cf --- /dev/null +++ b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + import torch + + k, v, beta, A, g = data + B, T, H, K = k.shape + V = v.shape[-1] + BT = A.shape[-1] + + w = torch.empty_like(k) + u = torch.empty_like(v) + + for cs in range(0, T, BT): + ce = cs + BT + A_bh = A[:, cs:ce, :, :].permute(0, 2, 1, 3).float() + + vb = (v[:, cs:ce, :, :] * beta[:, cs:ce, :, None]).permute(0, 2, 1, 3).float() + u[:, cs:ce, :, :] = torch.matmul(A_bh, vb).permute(0, 2, 1, 3) + + kb = (k[:, cs:ce, :, :] * beta[:, cs:ce, :, None] * torch.exp(g[:, cs:ce, :, None])).permute(0, 2, 1, 3).float() + w[:, cs:ce, :, :] = torch.matmul(A_bh, kb).permute(0, 2, 1, 3) + + return w, u diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/task.py b/problems/helion/gated_deltanet_recompute_w_u_py/task.py new file mode 100644 index 00000000..2887eb89 --- /dev/null +++ b/problems/helion/gated_deltanet_recompute_w_u_py/task.py @@ -0,0 +1,13 @@ +from typing import TypedDict, TypeVar +import torch + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) + +class TestSpec(TypedDict): + B: int + T: int + H: int + K: int + V: int + seed: int diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/task.yml b/problems/helion/gated_deltanet_recompute_w_u_py/task.yml new file mode 100644 index 00000000..f3d83002 --- /dev/null +++ b/problems/helion/gated_deltanet_recompute_w_u_py/task.yml @@ -0,0 +1,66 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement the recompute_w_u forward kernel for Gated DeltaNet. + + This kernel computes WY-transformed keys (w) and values (u) for the chunkwise + parallel forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). It is + one of three per-chunk kernels in the forward pipeline. + + The sequence is divided into non-overlapping chunks of BT=64 timesteps. + For each chunk independently: + u = A @ diag(beta) @ v (WY-transformed values) + w = A @ diag(beta * exp(g)) @ k (WY-transformed keys) + + Equivalently: + u = A @ (v * beta[:, None]) + w = A @ (k * beta[:, None] * exp(g)[:, None]) + + where A is a [BT, BT] WY representation matrix per chunk. + + Input: tuple(k, v, beta, A, g) where: + - k: torch.Tensor of shape [B, T, H, K] (float32) — keys + - v: torch.Tensor of shape [B, T, H, V] (float32) — values + - beta: torch.Tensor of shape [B, T, H] (float32) — gating coefficients + - A: torch.Tensor of shape [B, T, H, BT] (float32) — WY matrix (BT=64) + - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate + + Output: tuple(w, u) where: + - w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys + - u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values + + Constraint: T must be a multiple of 64. + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242} + - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236} + - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001} + - {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "seed": 5531} + - {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "seed": 9173} + +benchmarks: + - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232} + - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052} + - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146} + - {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "seed": 3129} + - {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "seed": 54352} + - {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "seed": 71234} + - {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "seed": 82345} + +test_timeout: 180 +benchmark_timeout: 180 +ranked_timeout: 420 +ranking_by: "geom" diff --git a/problems/helion/template.py b/problems/helion/template.py new file mode 100644 index 00000000..4aec6a6c --- /dev/null +++ b/problems/helion/template.py @@ -0,0 +1,5 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + pass diff --git a/problems/helion/utils.py b/problems/helion/utils.py new file mode 100644 index 00000000..e8a9082f --- /dev/null +++ b/problems/helion/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file