From 9b1c682ee90ad98a0363e28b15122d3efa7d1670 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 13 Mar 2026 02:27:04 -0700 Subject: [PATCH] Switch helion eval to CUDA graph capture/replay for benchmarking - Add _do_bench_cudagraph() for stable kernel timing using captured CUDA graphs with L2 cache clearing and overhead subtraction - Add _copy_data_inplace() to feed new inputs into graph buffers without recapturing, used during recheck correctness passes - Capture kernels in CUDA graphs during testing (_run_single_test) to validate that submissions are graph-capturable - Add run_local() for local eval without Popcorn infrastructure (usage: python eval.py ) - Defer imports of problem-directory modules (reference, submission, utils, task) to runtime instead of module level --- problems/helion/eval.py | 347 +++++++++++++++++++++++++++++++--------- 1 file changed, 273 insertions(+), 74 deletions(-) diff --git a/problems/helion/eval.py b/problems/helion/eval.py index f99aba46..92db9afd 100644 --- a/problems/helion/eval.py +++ b/problems/helion/eval.py @@ -11,29 +11,21 @@ import torch.cuda -from utils import set_seed, clear_l2_cache -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - class PopcornOutput: def __init__(self, fd: int): self.file = os.fdopen(fd, 'w') os.set_inheritable(fd, False) - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() - + def print(self, *args, **kwargs): print(*args, **kwargs, file=self.file, flush=True) - + def log(self, key, value): self.print(f"{key}: {value}") @@ -141,16 +133,144 @@ def _clone_data(data): return data +def _copy_data_inplace(dst, src): + """ + Recursively copy tensor data from src into dst (same structure, same shapes). + Used to feed new inputs into CUDA graph buffers without recapturing. + """ + if isinstance(dst, torch.Tensor): + dst.copy_(src) + elif isinstance(dst, (tuple, list)): + for d, s in zip(dst, src): + _copy_data_inplace(d, s) + elif isinstance(dst, dict): + for k in dst: + _copy_data_inplace(dst[k], src[k]) + + +def _do_bench_cudagraph(fn, rep_ms=100, return_mode="mean", clear_l2=True): + """ + Benchmark fn using CUDA graphs with optional L2 cache clearing. + Based on triton.testing.do_bench_cudagraph + triton-lang/triton#8384. + + :param fn: Callable to benchmark (no args). + :param rep_ms: Target repetition time per measurement in milliseconds. + :param return_mode: "min", "max", "mean", "median", or "all" (list of ms). + :param clear_l2: If True, flush L2 cache before each invocation and subtract + the flushing overhead from reported times. + :return: Time(s) in milliseconds. + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + # 256 MB cache tensor — larger than any current GPU L2 + cache = torch.empty(32 * 1024 * 1024, dtype=torch.int64, device="cuda") if clear_l2 else None + + def maybe_clear_cache(): + if cache is not None: + cache.zero_() + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + maybe_clear_cache() + fn() + + # step 1 — estimate per-call time + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + maybe_clear_cache() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + n_repeat = max(1, int(rep_ms / estimate_ms)) if estimate_ms > 0 else 1000 + + # step 2 — capture graph with n_repeat unrolled calls + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + maybe_clear_cache() + fn() + torch.cuda.synchronize() + + # step 3 — if L2 clearing enabled, capture a separate graph to measure + # the clearing overhead so we can subtract it + cache_clear_graph = None + if clear_l2: + cache_clear_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cache_clear_graph): + for _ in range(n_repeat): + maybe_clear_cache() + torch.cuda.synchronize() + + # step 4 — measure + n_retries = 10 + cache_clear_times = [] + total_times = [] + for _ in range(n_retries): + if cache_clear_graph is not None: + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + cache_clear_graph.replay() + e.record() + torch.cuda.synchronize() + cache_clear_times.append(s.elapsed_time(e) / n_repeat) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + total_times.append(start_event.elapsed_time(end_event) / n_repeat) + + if clear_l2: + ret = [max(0, t - c) for t, c in zip(total_times, cache_clear_times)] + else: + ret = total_times + + if return_mode == "all": + return ret + elif return_mode == "min": + return min(ret) + elif return_mode == "max": + return max(ret) + elif return_mode == "mean": + return sum(ret) / len(ret) + elif return_mode == "median": + return sorted(ret)[len(ret) // 2] + + def _run_single_test(test: TestCase): """ - Runs a single test case. Do not call directly + Runs a single test case via CUDA graph capture + replay. + This validates that the kernel is capturable and produces correct output. """ from submission import custom_kernel + from reference import check_implementation, generate_input + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Warmup call to trigger JIT compilation (outside graph capture) + _ = custom_kernel(_clone_data(data)) torch.cuda.synchronize() - submission_output = custom_kernel(_clone_data(data)) + + # Capture and replay through CUDA graph + input_data = _clone_data(data) + try: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = custom_kernel(input_data) + except Exception as e: + return False, f"Failed to capture kernel in CUDA graph: {e}" + g.replay() torch.cuda.synchronize() - return check_implementation(data, submission_output) + + return check_implementation(check_copy, output) def run_single_test(pool: multiprocessing.Pool, test: TestCase): @@ -190,81 +310,75 @@ def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[T return 112 -def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: +def _run_single_benchmark(test: TestCase, recheck: bool, rep_ms: int) -> Stats | Any: """ Runs one benchmark. Do not call directly. + + Correctness is verified via CUDA graph capture + replay first. + Timing only runs if all correctness checks pass. + + :param test: Test case with input arguments. + :param recheck: If True, run additional correctness checks with varying seeds. + :param rep_ms: Target repetition time per measurement in milliseconds. """ from submission import custom_kernel + from reference import check_implementation, generate_input - durations = [] - # generate input data once data = generate_input(**test.args) check_copy = _clone_data(data) - # first, one obligatory correctness check - output = custom_kernel(data) + + # Warmup (JIT compilation) + _ = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + + # Capture in CUDA graph and run initial correctness check + input_data = _clone_data(data) + try: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = custom_kernel(input_data) + except Exception as e: + return f"Failed to capture kernel in CUDA graph: {e}" + g.replay() + torch.cuda.synchronize() good, message = check_implementation(check_copy, output) if not good: return message - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark + if recheck: + # Reuse the captured graph with new input data for each seed + for i in range(10): if "seed" in test.args: test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - clear_l2_cache() - - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - - if recheck: + new_data = generate_input(**test.args) + check_copy = _clone_data(new_data) + _copy_data_inplace(input_data, new_data) + g.replay() + torch.cuda.synchronize() good, message = check_implementation(check_copy, output) if not good: return message - del output - durations.append(duration) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: - break - + # Timing (only reached if all correctness checks passed) + data = generate_input(**test.args) + fn = lambda: custom_kernel(data) + times_ms = _do_bench_cudagraph(fn, rep_ms=rep_ms, return_mode="all", clear_l2=True) + time.sleep(10) # GPU cooldown to avoid thermal throttling + durations = [t * 1e6 for t in times_ms] # convert ms to ns return calculate_stats(durations) -def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, - max_time_ns: float): +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, rep_ms: int): """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. + Run a benchmark in a subprocess. + + :param pool: Process pool. + :param test: TestCase object. + :param recheck: Flag for whether to explicitly check functional correctness. + :param rep_ms: Target repetition time per measurement in milliseconds. + :return: A Stats object or an error string. """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + return pool.apply(_run_single_benchmark, (test, recheck, rep_ms)) def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): @@ -277,13 +391,13 @@ def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: l @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. """ # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) + run_single_benchmark(pool, tests[0], False, 20) passed = True logger.log("benchmark-count", len(tests)) for idx, test in enumerate(tests): logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 100, 10e9) + result = run_single_benchmark(pool, test, False, 100) if isinstance(result, Stats): for field in dataclasses.fields(Stats): logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) @@ -305,6 +419,7 @@ def run_single_profile(test: TestCase) -> str: Runs a single test case. Do not call directly """ from submission import custom_kernel + from reference import generate_input from torch.profiler import profile, record_function, ProfilerActivity data = generate_input(**test.args) torch.cuda.synchronize() @@ -325,14 +440,98 @@ def run_profiling(logger: PopcornOutput, tests: list[TestCase]): return 0 +def run_local(): + """ + Local eval mode: reads task.yml from a problem directory, runs correctness tests + and benchmarks, prints results to stdout. No Popcorn infrastructure needed. + + Usage: python eval.py + mode: test, benchmark, or both + problem_dir: path to the problem directory containing task.yml + """ + import yaml + + if len(sys.argv) < 3: + print("Usage: python eval.py ", file=sys.stderr) + print(" mode: test, benchmark, or both", file=sys.stderr) + print(" problem_dir: path to problem directory containing task.yml", file=sys.stderr) + return 1 + + mode = sys.argv[1] + problem_dir = Path(sys.argv[2]) + + if mode not in ("test", "benchmark", "both"): + print(f"Unknown mode '{mode}'. Use 'test', 'benchmark', or 'both'.", file=sys.stderr) + return 1 + + problem_dir = problem_dir.resolve() + task_path = problem_dir / "task.yml" + if not task_path.exists(): + print(f"Error: task.yml not found in {problem_dir}", file=sys.stderr) + return 1 + + task = yaml.safe_load(task_path.read_text()) + + # chdir into the problem directory so that `from submission import ...` works + os.chdir(problem_dir) + sys.path.insert(0, str(problem_dir)) + + from utils import set_seed + + set_seed(42) + exit_code = 0 + + # --- Correctness tests --- + if mode in ("test", "both"): + tests = [TestCase(args=dict(t), spec=str(t)) for t in task.get("tests", [])] + print(f"Running {len(tests)} correctness tests...") + all_passed = True + for idx, test in enumerate(tests): + good, message = _run_single_test(test) + status = "PASS" if good else "FAIL" + print(f" Test {idx}: {status} {test.spec}") + if not good: + print(f" {message}") + all_passed = False + if all_passed: + print("All tests passed.") + else: + print("Some tests FAILED.") + exit_code = 1 + + # --- Benchmarks --- + if mode in ("benchmark", "both"): + benchmarks = [TestCase(args=dict(t), spec=str(t)) for t in task.get("benchmarks", [])] + print(f"\nRunning {len(benchmarks)} benchmarks...") + + # Warmup + _run_single_benchmark(benchmarks[0], False, 20) + + for idx, bench in enumerate(benchmarks): + result = _run_single_benchmark(bench, False, 100) + if isinstance(result, Stats): + mean_ms = result.mean / 1e6 # Stats stores ns + min_ms = result.best / 1e6 + max_ms = result.worst / 1e6 + print(f" Benchmark {idx}: {mean_ms:.4f} ms (min={min_ms:.4f}, max={max_ms:.4f}) {bench.spec}") + else: + print(f" Benchmark {idx}: FAIL (correctness) {bench.spec}") + print(f" {result}") + exit_code = 1 + + return exit_code + + def main(): fd = os.getenv("POPCORN_FD") if not fd: - return 111 + return run_local() if len(sys.argv) < 3: return 2 + from utils import set_seed + mode = sys.argv[1] seed = os.getenv("POPCORN_SEED") os.unsetenv("POPCORN_SEED") @@ -351,11 +550,11 @@ def main(): if mode == "leaderboard": # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) + run_single_benchmark(pool, tests[0], False, 20) logger.log("benchmark-count", len(tests)) passed = True for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + result = run_single_benchmark(pool, tests[i], True, 200) logger.log(f"benchmark.{i}.spec", tests[i].spec) if isinstance(result, Stats): for field in dataclasses.fields(Stats): @@ -363,7 +562,7 @@ def main(): else: passed = False logger.log(f"benchmark.{i}.status", "fail") - logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + logger.log(f"benchmark.{i}.error", str(result)) break logger.log("check", "pass" if passed else "fail")