From 374ab8ff9b3ed2e9b67d31e21c5520b1347ecc21 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 25 May 2026 01:14:51 +0400 Subject: [PATCH 1/2] remove `dynamic_unroll`(cons = 4% more cycles, pros = simpler compiler). --- crates/lean_compiler/snark_lib.py | 10 +- .../lean_compiler/src/a_simplify_lang/mod.rs | 252 +-------------- crates/lean_compiler/src/grammar.pest | 3 +- crates/lean_compiler/src/lang.rs | 47 +-- .../src/parser/parsers/statement.rs | 4 - crates/lean_compiler/tests/test_compiler.rs | 40 --- .../tests/test_data/program_177.py | 288 ------------------ crates/lean_compiler/zkDSL.md | 5 +- crates/lean_vm/src/isa/hint.rs | 15 +- .../rec_aggregation/zkdsl_implem/hashing.py | 33 +- crates/rec_aggregation/zkdsl_implem/main.py | 35 ++- crates/rec_aggregation/zkdsl_implem/utils.py | 1 + 12 files changed, 93 insertions(+), 640 deletions(-) delete mode 100644 crates/lean_compiler/tests/test_data/program_177.py diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index 06b117e4c..53f119e65 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -22,12 +22,6 @@ def parallel_range(a: int, b: int): return range(a, b) -# dynamic_unroll(start, end, n_bits) returns range(start, end) for Python execution -def dynamic_unroll(start: int, end: int, n_bits: int): - _ = n_bits - return range(start, end) - - # Array - simulates write-once memory with pointer arithmetic class Array: def __init__(self, size: int): @@ -184,6 +178,10 @@ def hint_log2_ceil(n): return log2_ceil(n) +def hint_div_floor(a, b, q_ptr, r_ptr): + _ = a, b, q_ptr, r_ptr + + def hint_witness(name, destination): """Write the next witness entry for `name` into `destination`.""" _ = (name, destination) diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 89137b53e..4e1bf3888 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -1,9 +1,4 @@ -use crate::{ - CompilationFlags, F, - a_simplify_lang::post_optimization::propagate_copies, - lang::*, - parser::{ConstArrayValue, parse_program}, -}; +use crate::{F, a_simplify_lang::post_optimization::propagate_copies, lang::*, parser::ConstArrayValue}; use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, @@ -755,43 +750,6 @@ fn compile_time_transform_in_lines( } } - Line::ForLoop { - loop_kind: LoopKind::DynamicUnroll { n_bits }, - iterator, - start, - end, - body, - location, - } => { - let Some(start_val) = start.compile_time_eval(const_arrays, &vector_len_tracker) else { - return Err(format!( - "line {location}: dynamic_unroll start must be a compile-time constant" - )); - }; - let start_val = start_val.to_usize(); - let Some(n_bits_val) = n_bits.compile_time_eval(const_arrays, &vector_len_tracker) else { - return Err(format!( - "line {location}: dynamic_unroll n_bits must be a compile-time constant" - )); - }; - let n_bits_val = n_bits_val.to_usize(); - if n_bits_val < 1 { - return Err(format!( - "line {location}: dynamic_unroll n_bits must be >= 1, got {n_bits_val}" - )); - } - let expanded = expand_dynamic_unroll( - &iterator.clone(), - &end.clone(), - n_bits_val, - start_val, - &body.clone(), - *location, - unroll_counter, - ); - lines.splice(i..=i, expanded); - continue; - } Line::ForLoop { iterator, start, @@ -828,7 +786,7 @@ fn compile_time_transform_in_lines( Line::IfCondition { .. } | Line::Match { .. } | Line::ForLoop { - loop_kind: LoopKind::Unroll | LoopKind::DynamicUnroll { .. }, + loop_kind: LoopKind::Unroll, .. } ) { @@ -1578,14 +1536,11 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { start, end, body, - loop_kind, + loop_kind: _, location: _, } => { check_expr_scoping(start, ctx); check_expr_scoping(end, ctx); - if let LoopKind::DynamicUnroll { n_bits } = loop_kind { - check_expr_scoping(n_bits, ctx); - } let mut new_scope_vars = BTreeSet::new(); new_scope_vars.insert(iterator.clone()); ctx.scopes.push(Scope { vars: new_scope_vars }); @@ -2778,7 +2733,7 @@ fn simplify_lines( } => { assert!( matches!(loop_kind, LoopKind::Range | LoopKind::ParallelRange), - "Unrolled/dynamic_unroll loops should have been handled already" + "Unrolled loops should have been handled already" ); let is_parallel = loop_kind.is_parallel(); @@ -3253,7 +3208,6 @@ pub fn find_variable_usage( start, end, body, - loop_kind, .. } => { let (body_internal, body_external) = find_variable_usage(body, const_arrays); @@ -3262,9 +3216,6 @@ pub fn find_variable_usage( external_vars.extend(body_external.difference(&internal_vars).cloned()); on_new_expr(start, &internal_vars, &mut external_vars); on_new_expr(end, &internal_vars, &mut external_vars); - if let LoopKind::DynamicUnroll { n_bits } = loop_kind { - on_new_expr(n_bits, &internal_vars, &mut external_vars); - } } Line::Panic { .. } | Line::LocationReport { .. } => {} Line::VecDeclaration { var, elements, .. } => { @@ -3703,201 +3654,6 @@ fn replace_vars_for_unroll( transform_vars_in_lines(lines, &transform); } -/// Chunk size threshold for splitting large unrolls into hybrid loops. -/// Bits k where 2^k > CHUNK_SIZE will use a runtime outer loop with CHUNK_SIZE inner unroll. -const DYNAMIC_UNROLL_CHUNK_SIZE: usize = 1 << 9; // 512 - -/// Expands `for idx in dynamic_unroll(start, a, n_bits): body` into: -/// 1. Bit decomposition of `a - start` (with constraints) -/// 2. Conditional execution of `body` for each index start..a -/// -/// Computes `n_iters = end - start_val`, decomposes into bits, and offsets -/// each iterator value by the compile-time `start_val`. -/// -/// The expansion template is written in zkDSL for readability, then parsed -/// and post-processed (variable renaming, body splicing, location fixup). -fn expand_dynamic_unroll( - iterator: &Var, - runtime_end: &Expression, - n_bits: usize, - start_val: usize, - body: &[Line], - location: SourceLocation, - unroll_counter: &mut Counter, -) -> Vec { - let id = unroll_counter.get_next(); - let pfx = format!("@du{id}"); - let ps_len = n_bits + 1; - - // The template is the zkDSL expansion of dynamic_unroll, with `end` as the - // runtime bound and `__iter` as a placeholder for the iterator assignment. - // - // Bits are stored in big-endian order: bits[0] is the most significant bit - // (weight 2^(n_bits-1)), bits[n_bits-1] is the least significant (weight 2^0). - // ps has n_bits+1 elements: ps[0]=0, ps[k+1] = ps[k] + bits[k]*2^(n_bits-1-k). - // So ps[k] is the offset (number of indices below bit k), and ps[n_bits] == n_iters. - // - // For large bits (block_size > CHUNK_SIZE), we split into chunks to reduce bytecode size: - // - outer runtime loop over n_chunks = block_size / CHUNK_SIZE - // - inner unroll over CHUNK_SIZE iterations - // For small bits, the range loop has minimal overhead. - - // Build the template with per-bit chunking logic. - // Pre-compute __base_k = start_val + ps[k] once per activated bit, - // so the inner loop stays at 1 ADD per iteration. - let mut loop_body = String::new(); - for k in 0..n_bits { - let block_size = 1usize << (n_bits - 1 - k); - if block_size <= DYNAMIC_UNROLL_CHUNK_SIZE { - // Small block: fully unroll - loop_body.push_str(&format!( - r#" - if bits[{k}] == 1: - __base_{k} = {start_val} + ps[{k}] - for j in unroll(0, {block_size}): - __iter = __base_{k} + j -"# - )); - } else { - // Large block: hybrid loop (runtime outer, unroll inner) - // Use an offset variable to avoid MUL per iteration - let n_chunks = block_size / DYNAMIC_UNROLL_CHUNK_SIZE; - loop_body.push_str(&format!( - r#" - if bits[{k}] == 1: - __offset_{k}: Mut = {start_val} + ps[{k}] - for chunk in range(0, {n_chunks}): - for j in unroll(0, {DYNAMIC_UNROLL_CHUNK_SIZE}): - __iter = __offset_{k} + j - __offset_{k} = __offset_{k} + {DYNAMIC_UNROLL_CHUNK_SIZE} -"# - )); - } - } - let template = format!( - r#" -def __dynamic_unroll_template(end): - n_iters = end - {start_val} - bits = Array({n_bits}) - hint_decompose_bits(n_iters, bits, {n_bits}) - ps = Array({ps_len}) - ps[0] = 0 - for k in unroll(0, {n_bits}): - b = bits[k] - assert b * (1 - b) == 0 - ps[k + 1] = ps[k] + b * 2**({n_bits} - 1 - k) - assert n_iters == ps[{n_bits}] -{loop_body} - return -"# - ); - - let program = parse_program(&crate::ProgramSource::Raw(template), CompilationFlags::default()).unwrap(); - assert_eq!(program.functions.len(), 1); - let func = program.functions.values().next().unwrap(); - let mut lines = func.body.clone(); - - // Strip trailing return + its LocationReport - while matches!( - lines.last(), - Some(Line::FunctionRet { .. } | Line::LocationReport { .. }) - ) { - lines.pop(); - } - // Strip LocationReport lines (they carry template line numbers, not the real ones) - strip_location_reports(&mut lines); - - // Rename all internal variables with @du{id}_ prefix. - // __iter is renamed directly to the user's iterator variable. - let internals: BTreeSet = ["bits", "ps", "k", "j", "b", "chunk", "n_iters"] - .iter() - .map(|s| s.to_string()) - .collect(); - transform_vars_in_lines(&mut lines, &|var: &Var| { - if var == "__iter" { - VarTransform::Rename(iterator.clone()) - } else if var == "end" || internals.contains(var) || var.starts_with("__offset_") || var.starts_with("__base_") - { - VarTransform::Rename(format!("{pfx}_{var}")) - } else { - VarTransform::Keep - } - }); - - // Prepend: @du{id}_end = runtime_end - lines.insert( - 0, - Line::Statement { - targets: vec![AssignmentTarget::Var { - var: format!("{pfx}_end"), - is_mutable: false, - }], - value: runtime_end.clone(), - location, - }, - ); - - // Insert body after every `{iterator} = ...` assignment (the renamed __iter lines) - insert_body_after_var(&mut lines, iterator, body); - - // Fix all source locations to point to the actual dynamic_unroll call site - set_locations_recursive(&mut lines, location); - - lines -} - -fn strip_location_reports(lines: &mut Vec) { - lines.retain(|l| !matches!(l, Line::LocationReport { .. })); - for line in lines.iter_mut() { - for block in line.nested_blocks_mut() { - strip_location_reports(block); - } - } -} - -/// In every nested block, insert `body` lines after each statement that assigns to `var`. -fn insert_body_after_var(lines: &mut [Line], var: &str, body: &[Line]) { - for line in lines.iter_mut() { - for block in line.nested_blocks_mut() { - let mut i = 0; - while i < block.len() { - if matches!(&block[i], Line::Statement { targets, .. } - if targets.iter().any(|t| matches!(t, AssignmentTarget::Var { var: v, .. } if v == var))) - { - let insert_pos = i + 1; - for (j, body_line) in body.iter().enumerate() { - block.insert(insert_pos + j, body_line.clone()); - } - i += 1 + body.len(); - } else { - i += 1; - } - } - insert_body_after_var(block, var, body); - } - } -} - -fn set_locations_recursive(lines: &mut [Line], location: SourceLocation) { - for line in lines { - match line { - Line::Statement { location: loc, .. } - | Line::Assert { location: loc, .. } - | Line::IfCondition { location: loc, .. } - | Line::ForLoop { location: loc, .. } - | Line::Match { location: loc, .. } - | Line::LocationReport { location: loc } - | Line::VecDeclaration { location: loc, .. } - | Line::Push { location: loc, .. } - | Line::Pop { location: loc, .. } => *loc = location, - Line::ForwardDeclaration { .. } | Line::FunctionRet { .. } | Line::Panic { .. } => {} - } - for block in line.nested_blocks_mut() { - set_locations_recursive(block, location); - } - } -} - fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { match expr { Expression::Value(value) => match &value { diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 7b862193c..0ee1c499b 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -88,11 +88,10 @@ elif_clause = { "elif" ~ condition ~ ":" ~ newline ~ statement* ~ end_block } else_clause = { "else" ~ ":" ~ newline ~ statement* ~ end_block } -for_statement = { "for" ~ identifier ~ "in" ~ (dynamic_unroll_range | unroll_range | parallel_range | range) ~ ":" ~ newline ~ statement* ~ end_block } +for_statement = { "for" ~ identifier ~ "in" ~ (unroll_range | parallel_range | range) ~ ":" ~ newline ~ statement* ~ end_block } range = { "range" ~ "(" ~ expression ~ "," ~ expression ~ ")" } parallel_range = { "parallel_range" ~ "(" ~ expression ~ "," ~ expression ~ ")" } unroll_range = { "unroll" ~ "(" ~ expression ~ "," ~ expression ~ ")" } -dynamic_unroll_range = { "dynamic_unroll" ~ "(" ~ expression ~ "," ~ expression ~ "," ~ expression ~ ")" } match_statement = { "match" ~ expression ~ ":" ~ newline ~ match_arm* ~ end_block } match_arm = { "case" ~ pattern ~ ":" ~ newline ~ statement* ~ end_block } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index ee46c1ebb..5accb007e 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -540,16 +540,11 @@ pub enum LoopKind { Range, ParallelRange, Unroll, - /// `for i in dynamic_unroll(0, a, n_bits): body` — unrolls over runtime-bounded range - /// using bit decomposition. `n_bits` must be compile-time known. - DynamicUnroll { - n_bits: Expression, - }, } impl LoopKind { pub fn is_unroll(&self) -> bool { - matches!(self, Self::Unroll | Self::DynamicUnroll { .. }) + matches!(self, Self::Unroll) } pub fn is_parallel(&self) -> bool { @@ -783,25 +778,17 @@ impl Line { .map(|line| line.to_string_with_indent(indent + 1)) .collect::>() .join("\n"); - match loop_kind { - LoopKind::DynamicUnroll { n_bits } => format!( - "for {} in dynamic_unroll({}, {}, {}) {{\n{}\n{}}}", - iterator, start, end, n_bits, body_str, spaces - ), - _ => { - let range_fn = if loop_kind.is_unroll() { - "unroll" - } else if loop_kind.is_parallel() { - "parallel_range" - } else { - "range" - }; - format!( - "for {} in {}({}, {}) {{\n{}\n{}}}", - iterator, range_fn, start, end, body_str, spaces - ) - } - } + let range_fn = if loop_kind.is_unroll() { + "unroll" + } else if loop_kind.is_parallel() { + "parallel_range" + } else { + "range" + }; + format!( + "for {} in {}({}, {}) {{\n{}\n{}}}", + iterator, range_fn, start, end, body_str, spaces + ) } Self::FunctionRet { return_data } => { let return_data_str = return_data @@ -904,14 +891,8 @@ impl Line { } Self::Assert { boolean, .. } => vec![&mut boolean.left, &mut boolean.right], Self::IfCondition { condition, .. } => vec![&mut condition.left, &mut condition.right], - Self::ForLoop { - start, end, loop_kind, .. - } => { - let mut exprs = vec![start, end]; - if let LoopKind::DynamicUnroll { n_bits } = loop_kind { - exprs.push(n_bits); - } - exprs + Self::ForLoop { start, end, .. } => { + vec![start, end] } Self::FunctionRet { return_data } => return_data.iter_mut().collect(), Self::Push { indices, element, .. } => { diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 30670c889..4f0016930 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -172,10 +172,6 @@ impl Parse for ForStatementParser { let end = ExpressionParser.parse(next_inner_pair(&mut range_inner, "loop end")?, ctx)?; let loop_kind = match rule { Rule::unroll_range => LoopKind::Unroll, - Rule::dynamic_unroll_range => { - let n_bits = ExpressionParser.parse(next_inner_pair(&mut range_inner, "n_bits")?, ctx)?; - LoopKind::DynamicUnroll { n_bits } - } Rule::parallel_range => LoopKind::ParallelRange, _ => LoopKind::Range, }; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index df77e420b..962d333cf 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -151,46 +151,6 @@ fn test_reserved_function_names() { } } -#[test] -fn test_dynamic_unroll_cycles() { - // Verify that dynamic_unroll costs ~2 cycles per iteration - for start in [0u32, 5, 50] { - let program = format!( - r#" -def main(): - a = 0 - end = a[0] - expected = a[1] - acc: Mut = 0 - for i in dynamic_unroll({start}, end, 13): - acc = acc + i - assert acc == expected - return -"# - ); - let bytecode = compile_program(&ProgramSource::Raw(program), DIGEST_LEN); - - let run = |end_val: u32| -> usize { - let expected_sum = (start..end_val).map(|i| i as u64).sum::() as u32; - let public_input = [F::new(end_val), F::new(expected_sum)]; - let result = try_execute_bytecode(&bytecode, &public_input, &ExecutionWitness::default(), false).unwrap(); - result.pcs.len() - }; - - let n_iters_a = 2000u32; - let n_iters_b = 4000u32; - let cycles_a = run(start + n_iters_a); - let cycles_b = run(start + n_iters_b); - let delta = cycles_b - cycles_a; - let extra_iters = n_iters_b - n_iters_a; - let expected_delta = 2 * extra_iters as usize; - // Allow 5% tolerance for fixed overhead per activated bit - let lo = expected_delta * 95 / 100; - let hi = expected_delta * 105 / 100; - assert!(delta >= lo && delta <= hi,); - } -} - #[test] fn debug_file_program() { let index = 167; diff --git a/crates/lean_compiler/tests/test_data/program_177.py b/crates/lean_compiler/tests/test_data/program_177.py deleted file mode 100644 index bc6fa5aeb..000000000 --- a/crates/lean_compiler/tests/test_data/program_177.py +++ /dev/null @@ -1,288 +0,0 @@ -from snark_lib import * - -# Comprehensive test for dynamic_unroll -# Tests: edge cases, mutable accumulators, array writes, nested dynamic_unroll, -# multiple dynamic_unrolls in one function, conditional body logic, -# interaction with regular unroll, and varying n_bits. -# -# Note: n_bits should be kept small (<=8) since the generated code size is O(2^n_bits). - - -def main(): - # --- Edge cases --- - # a = 0: no iterations - z = sum_up_to(0, 4) - assert z == 0 - - # a = 1: single iteration (only i=0) - z1 = sum_up_to(1, 4) - assert z1 == 0 - - # a = 2: two iterations - z2 = sum_up_to(2, 4) - assert z2 == 1 - - # a = 2^n_bits - 1: max value for n_bits=4 - z15 = sum_up_to(15, 4) - assert z15 == 105 - - # power of two - z8 = sum_up_to(8, 4) - assert z8 == 28 - - # --- Basic accumulation with n_bits=8 --- - z7 = sum_up_to(7, 8) - assert z7 == 21 - - z100 = sum_up_to(100, 8) - assert z100 == 4950 - - # --- Array writes via dynamic_unroll --- - buf = Array(16) - fill_squares(buf, 10, 4) - assert buf[0] == 0 - assert buf[1] == 1 - assert buf[4] == 16 - assert buf[9] == 81 - - # --- Nested dynamic_unroll (triangular sum) --- - # for i in 0..a: for j in 0..i: total += 1 - # i=0: 0, i=1: 1, i=2: 2, i=3: 3, i=4: 4 => 10 - tri = triangular(5, 4) - assert tri == 10 - tri0 = triangular(0, 4) - assert tri0 == 0 - tri1 = triangular(1, 4) - assert tri1 == 0 - tri2 = triangular(2, 4) - assert tri2 == 1 - - # --- Two sequential dynamic_unrolls in one function --- - r = double_loop(3, 5, 4) - # first loop: 0+1+2 = 3, second loop: 5*1 = 5, total = 8 - assert r == 8 - # edge: both zero - r0 = double_loop(0, 0, 4) - assert r0 == 0 - - # --- Conditional body: only accumulate even indices --- - e = sum_even_indices(8, 4) - # even indices in 0..8: 0,2,4,6 => sum = 12 - assert e == 12 - e0 = sum_even_indices(0, 4) - assert e0 == 0 - e1 = sum_even_indices(1, 4) - assert e1 == 0 - - # --- dynamic_unroll writing to array + reading back --- - check = Array(8) - write_and_verify(check, 6, 4) - - # --- dynamic_unroll with arithmetic in body --- - poly = eval_polynomial(5, 4) - # sum of i^2 for i in 0..5: 0+1+4+9+16 = 30 - assert poly == 30 - - # --- Nested: outer dynamic_unroll, inner regular unroll --- - m = mixed_loops(4, 4) - # for i in 0..4: for j in unroll(0,3): acc += i+j - # i=0: 0+1+2=3, i=1: 1+2+3=6, i=2: 2+3+4=9, i=3: 3+4+5=12 => 30 - assert m == 30 - - # --- Called with different n_bits for same function --- - s4 = sum_up_to(10, 4) - assert s4 == 45 - s8 = sum_up_to(10, 8) - assert s8 == 45 - - # --- Complex body: sum of squares with algebraic verification --- - sq = sum_squares(100, 8) - # sum_{i=0}^{99} i^2 = 100*99*199/6 = 328350 - # Verify: 6 * sum == a*(a-1)*(2a-1) - assert 6 * sq == 100 * 99 * 199 - - # --- Complex body: array write + accumulate + read back --- - work = Array(256) - wa = write_and_accumulate(work, 50, 8) - # Each entry: work[i] = i*i + 3*i + 7, sum of those for i in 0..50 - # sum = sum(i^2) + 3*sum(i) + 50*7 = 40425 + 3*1225 + 350 = 44450 - assert wa == 44450 - assert work[0] == 7 - assert work[1] == 11 - assert work[49] == 2555 - - # --- Copy array region using dynamic_unroll --- - src = Array(16) - dst = Array(16) - for i in unroll(0, 10): - src[i] = (i + 1) * 7 - copy_array(src, dst, 10, 4) - for i in unroll(0, 10): - assert src[i] == dst[i] - - # --- Large n_bits to test chunking (bits > 10 trigger chunking at 1024 threshold) --- - # n_bits=12: bit 11 has 2^11=2048 iterations, should chunk into 2×1024 - large_sum = sum_up_to(2500, 12) - # sum 0..2499 = 2500*2499/2 = 3123750 - assert large_sum == 3123750 - - # n_bits=14: bit 13 has 2^13=8192 iterations, should chunk into 8×1024 - huge_sum = sum_up_to(10000, 14) - # sum 0..9999 = 10000*9999/2 = 49995000 - assert huge_sum == 49995000 - - # --- Non-zero start: basic sum --- - # sum of 5..10 = 5+6+7+8+9 = 35 - ns1 = sum_from_to(5, 10, 4) - assert ns1 == 35 - - # sum of 3..3 = 0 iterations - ns2 = sum_from_to(3, 3, 4) - assert ns2 == 0 - - # sum of 7..8 = 7 (1 iteration) - ns3 = sum_from_to(7, 8, 4) - assert ns3 == 7 - - # sum of 1..16 = 1+2+...+15 = 120 - ns4 = sum_from_to(1, 16, 4) - assert ns4 == 120 - - # --- Non-zero start: array writes with offset --- - obuf = Array(16) - fill_squares_offset(obuf, 3, 8, 4) - # writes obuf[i-3] = i*i for i in 3..8 - assert obuf[0] == 9 - assert obuf[1] == 16 - assert obuf[2] == 25 - assert obuf[3] == 36 - assert obuf[4] == 49 - - # --- Non-zero start: large n_bits (chunking path) --- - # sum of 100..2600 = sum(100..2599) = sum(0..2599) - sum(0..99) - # = 2600*2599/2 - 100*99/2 = 3378700 - 4950 = 3373750 - large_offset_sum = sum_from_to_large(100, 2600, 12) - assert large_offset_sum == 3373750 - - return - - -def sum_up_to(a, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - acc = acc + i - return acc - - -def fill_squares(arr, n, n_bits: Const): - for i in dynamic_unroll(0, n, n_bits): - arr[i] = i * i - return - - -def triangular(a, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - for j in dynamic_unroll(0, i, n_bits): - acc = acc + 1 - return acc - - -def double_loop(a, b, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - acc = acc + i - for j in dynamic_unroll(0, b, n_bits): - acc = acc + 1 - return acc - - -def sum_even_indices(a, n_bits: Const): - # Big-endian: i_bits[0] = MSB, i_bits[n_bits - 1] = LSB. - # Parity check: i_bits[n_bits - 1] == 0 means even. - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - i_bits = Array(n_bits) - hint_decompose_bits(i, i_bits, n_bits) - i_ps = Array(n_bits) - i_ps[0] = i_bits[n_bits - 1] - assert i_ps[0] * (1 - i_ps[0]) == 0 - for k in unroll(1, n_bits): - ib = i_bits[n_bits - 1 - k] - assert ib * (1 - ib) == 0 - i_ps[k] = i_ps[k - 1] + ib * 2**k - assert i == i_ps[n_bits - 1] - if i_bits[n_bits - 1] == 0: - acc = acc + i - return acc - - -def write_and_verify(arr, n, n_bits: Const): - for i in dynamic_unroll(0, n, n_bits): - arr[i] = i * 3 + 1 - assert arr[0] == 1 - assert arr[1] == 4 - assert arr[2] == 7 - assert arr[3] == 10 - assert arr[4] == 13 - assert arr[5] == 16 - return - - -def eval_polynomial(a, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - acc = acc + i * i - return acc - - -def mixed_loops(a, n_bits: Const): - # Outer: dynamic_unroll over runtime bound - # Inner: regular unroll over compile-time bound - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - for j in unroll(0, 3): - acc = acc + i + j - return acc - - -def sum_squares(a, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, a, n_bits): - acc = acc + i * i - return acc - - -def write_and_accumulate(arr, n, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(0, n, n_bits): - val = i * i + 3 * i + 7 - arr[i] = val - acc = acc + val - return acc - - -def copy_array(src, dst, n, n_bits: Const): - for i in dynamic_unroll(0, n, n_bits): - dst[i] = src[i] - return - - -def sum_from_to(start: Const, end, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(start, end, n_bits): - acc = acc + i - return acc - - -def fill_squares_offset(arr, start: Const, end, n_bits: Const): - for i in dynamic_unroll(start, end, n_bits): - arr[i - start] = i * i - return - - -def sum_from_to_large(start: Const, end, n_bits: Const): - acc: Mut = 0 - for i in dynamic_unroll(start, end, n_bits): - acc = acc + i - return acc diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index f5b867435..218be263b 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -339,8 +339,6 @@ for i in parallel_range(0, n): # iterations executed in parallel (see b ... for i in unroll(0, 4): # unrolled at compile time ... -for i in dynamic_unroll(5, a, n_bits): # start=5 and n_bits compile-time; a runtime, with (a - start) < 2^n_bits - ... ``` Use `unroll` when bounds are const or compile-time expansion is needed. @@ -351,8 +349,6 @@ Use `unroll` when bounds are const or compile-time expansion is needed. - The memory footprint (i.e. total memory usage) must be the same across iterations - XMSS / Merkle hint consumption must be the same across iterations -**`dynamic_unroll`** enables iterating from `start` to a runtime value `a` (where `a - start` is known to be < 2^n_bits) in an unrolled fashion. The compiler automatically generates bit decomposition of `a - start`, verification constraints, and conditional execution for each index. Both `start` and `n_bits` must be compile-time known. - **Mutable variables in non-unrolled loops:** Mutable variables can be modified inside non-unrolled loops. The compiler automatically transforms these into buffer-based implementations: ``` @@ -465,6 +461,7 @@ hints = prover-supplied values at runtime (without adding snark constraints). Li | `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian | | `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` | | `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` | +| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a/b)` at `q_ptr` and `a mod b` at `r_ptr` (requires `b != 0`) | | `hint_decompose_bits_xmss` | `(decomposed_ptr, remaining_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | XMSS-specific decomposition (see `crates/lean_vm/src/isa/hint.rs`) | | `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition | diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 65297af53..724e6fa27 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -111,14 +111,16 @@ pub enum CustomHint { DecomposeBits, LessThan, Log2Ceil, + DivFloor, } -pub const CUSTOM_HINTS: [CustomHint; 5] = [ +pub const CUSTOM_HINTS: [CustomHint; 6] = [ CustomHint::DecomposeBitsXMSS, CustomHint::DecomposeBitsMerkleWhir, CustomHint::DecomposeBits, CustomHint::LessThan, CustomHint::Log2Ceil, + CustomHint::DivFloor, ]; impl CustomHint { @@ -129,6 +131,7 @@ impl CustomHint { Self::DecomposeBits => "hint_decompose_bits", Self::LessThan => "hint_less_than", Self::Log2Ceil => "hint_log2_ceil", + Self::DivFloor => "hint_div_floor", } } @@ -139,6 +142,7 @@ impl CustomHint { Self::DecomposeBits => 3, Self::LessThan => 3, Self::Log2Ceil => 2, + Self::DivFloor => 4, } } @@ -198,6 +202,15 @@ impl CustomHint { let res_ptr = args[1].memory_address(ctx.fp)?; ctx.memory.set(res_ptr, F::from_usize(log2_ceil_usize(n)))?; } + Self::DivFloor => { + let a = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); + let b = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); + let q_ptr = args[2].memory_address(ctx.fp)?; + let r_ptr = args[3].memory_address(ctx.fp)?; + assert!(b != 0, "hint_div_floor: division by zero"); + ctx.memory.set(q_ptr, F::from_usize(a / b))?; + ctx.memory.set(r_ptr, F::from_usize(a % b))?; + } } Ok(()) } diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index 7e740024b..30382e0d0 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -6,6 +6,7 @@ # memory layout: [public_input (PUBLIC_INPUT_LEN)] [preamble_memory (PREAMBLE_MEMORY_LEN)] [runtime ...] # `preamble_memory` is a region that is filled by the guest program, with usefull constants [0000...][1000...]... PUBLIC_INPUT_LEN = DIGEST_LEN +PARTIAL_UNROLL_BATCH = 64 ZERO_VEC_PTR = PUBLIC_INPUT_LEN ZERO_VEC_LEN = ZERO_VEC_LEN_PLACEHOLDER SAMPLING_DOMAIN_SEPARATOR_PTR = ZERO_VEC_PTR + ZERO_VEC_LEN @@ -104,9 +105,23 @@ def slice_hash(data, num_chunks, dest): return -def slice_hash_dynamic_unroll(data, num_chunks, num_chunks_bits: Const): +@inline +def euclidian_div_runtime(a, b): + # Returns (q, r) with q = floor(a / b) and r = a mod b. + # Requires: + # 1 <= b < 2^14 + # floor(a / b) < 2^16 (so that q*b + r stays well below p) + q: Imu + r: Imu + hint_div_floor(a, b, q, r) + assert r < b + assert q < 2 ** 16 + assert q * b + r == a + return q, r + + +def slice_hash_runtime(data, num_chunks): debug_assert(num_chunks != 0) - debug_assert(num_chunks < 2**num_chunks_bits) iv = build_iv(num_chunks * DIGEST_LEN) @@ -120,11 +135,21 @@ def slice_hash_dynamic_unroll(data, num_chunks, num_chunks_bits: Const): n_iters = num_chunks - 1 state_ptr: Mut = states data_ptr: Mut = data + DIGEST_LEN - for _ in dynamic_unroll(0, n_iters, num_chunks_bits): + + n_chunks_outer, remainder = euclidian_div_runtime(n_iters, PARTIAL_UNROLL_BATCH) + for _ in range(0, n_chunks_outer): + for _ in unroll(0, PARTIAL_UNROLL_BATCH): + new_state = state_ptr + DIGEST_LEN + poseidon16_compress(state_ptr, data_ptr, new_state) + state_ptr = new_state + data_ptr += DIGEST_LEN + + for _ in range(0, remainder): new_state = state_ptr + DIGEST_LEN poseidon16_compress(state_ptr, data_ptr, new_state) state_ptr = new_state - data_ptr = data_ptr + DIGEST_LEN + data_ptr += DIGEST_LEN + return state_ptr diff --git a/crates/rec_aggregation/zkdsl_implem/main.py b/crates/rec_aggregation/zkdsl_implem/main.py index 6b43d78ae..472ff5a9e 100644 --- a/crates/rec_aggregation/zkdsl_implem/main.py +++ b/crates/rec_aggregation/zkdsl_implem/main.py @@ -161,7 +161,7 @@ def main(): return # General path - computed_pubkeys_hash = slice_hash_dynamic_unroll(all_pubkeys, n_sigs, log2_ceil(MAX_N_SIGS)) + computed_pubkeys_hash = slice_hash_runtime(all_pubkeys, n_sigs) copy_8(computed_pubkeys_hash, pubkeys_hash_expected) # Buffer for partition verification @@ -187,16 +187,18 @@ def main(): sub_indices_arr = Array(n_sub) hint_witness("sub_indices", sub_indices_arr) + running_hash: Mut = build_iv(n_sub * PUB_KEY_SIZE) - for j in dynamic_unroll(0, n_sub, log2_ceil(MAX_N_SIGS)): - idx = sub_indices_arr[j] - assert idx < n_total - buffer[idx] = counter - counter += 1 - pk = all_pubkeys + idx * PUB_KEY_SIZE - new_hash = Array(DIGEST_LEN) - poseidon16_compress(running_hash, pk, new_hash) - running_hash = new_hash + n_chunks, remainder = euclidian_div_runtime(n_sub, PARTIAL_UNROLL_BATCH) + j: Mut = 0 + for _ in range(0, n_chunks): + for u in unroll(0, PARTIAL_UNROLL_BATCH): + counter, running_hash = absorb_recursive_pubkey(j + u, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) + j += PARTIAL_UNROLL_BATCH + # Tail iterations + for _ in range(0, remainder): + counter, running_hash = absorb_recursive_pubkey(j, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) + j += 1 type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) type1_data_buf[0] = TYPE_1_FLAG @@ -284,3 +286,16 @@ def ensure_well_formed_input_data(data_buf, bytecode_hash_domsep, flag): data_buf[k] = 0 copy_8(bytecode_hash_domsep, data_buf + BYTECODE_HASH_DOMSEP_OFFSET) return + + +@inline +def absorb_recursive_pubkey(j, sub_indices_arr, n_total, all_pubkeys, buffer, counter_in, running_hash_in): + idx = sub_indices_arr[j] + assert idx < n_total + buffer[idx] = counter_in + new_counter = counter_in + 1 + pk = all_pubkeys + idx * PUB_KEY_SIZE + new_hash = Array(DIGEST_LEN) + poseidon16_compress(running_hash_in, pk, new_hash) + return new_counter, new_hash + diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index fe3d6e7ce..b95eedfad 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -822,3 +822,4 @@ def log2_ceil_runtime(n): lambda i: _verify_log2_large(n, i), ) return log2 + From 5ba48e7305dda189ca830bc9041d81f504253147 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 25 May 2026 02:34:16 +0400 Subject: [PATCH 2/2] use match range --- .../rec_aggregation/zkdsl_implem/hashing.py | 24 +++++++++++++------ crates/rec_aggregation/zkdsl_implem/main.py | 18 +++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index 30382e0d0..5146ee82e 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -120,6 +120,17 @@ def euclidian_div_runtime(a, b): return q, r +def absorb_n_hashes_const(n: Const, sp_in, dp_in): + sp: Mut = sp_in + dp: Mut = dp_in + for _ in unroll(0, n): + new_state = sp + DIGEST_LEN + poseidon16_compress(sp, dp, new_state) + sp = new_state + dp += DIGEST_LEN + return sp + + def slice_hash_runtime(data, num_chunks): debug_assert(num_chunks != 0) @@ -144,13 +155,12 @@ def slice_hash_runtime(data, num_chunks): state_ptr = new_state data_ptr += DIGEST_LEN - for _ in range(0, remainder): - new_state = state_ptr + DIGEST_LEN - poseidon16_compress(state_ptr, data_ptr, new_state) - state_ptr = new_state - data_ptr += DIGEST_LEN - - return state_ptr + final_state_ptr = match_range( + remainder, + range(0, PARTIAL_UNROLL_BATCH), + lambda r: absorb_n_hashes_const(r, state_ptr, data_ptr), + ) + return final_state_ptr @inline diff --git a/crates/rec_aggregation/zkdsl_implem/main.py b/crates/rec_aggregation/zkdsl_implem/main.py index 472ff5a9e..f92cd2e9a 100644 --- a/crates/rec_aggregation/zkdsl_implem/main.py +++ b/crates/rec_aggregation/zkdsl_implem/main.py @@ -196,9 +196,13 @@ def main(): counter, running_hash = absorb_recursive_pubkey(j + u, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) j += PARTIAL_UNROLL_BATCH # Tail iterations - for _ in range(0, remainder): - counter, running_hash = absorb_recursive_pubkey(j, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) - j += 1 + tail_counter, tail_running_hash = match_range( + remainder, + range(0, PARTIAL_UNROLL_BATCH), + lambda r: absorb_n_pubkeys_const(r, j, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash), + ) + counter = tail_counter + running_hash = tail_running_hash type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) type1_data_buf[0] = TYPE_1_FLAG @@ -299,3 +303,11 @@ def absorb_recursive_pubkey(j, sub_indices_arr, n_total, all_pubkeys, buffer, co poseidon16_compress(running_hash_in, pk, new_hash) return new_counter, new_hash + +def absorb_n_pubkeys_const(n: Const, j_start, sub_indices_arr, n_total, all_pubkeys, buffer, counter_in, running_hash_in): + counter: Mut = counter_in + running_hash: Mut = running_hash_in + for u in unroll(0, n): + counter, running_hash = absorb_recursive_pubkey(j_start + u, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) + return counter, running_hash +