Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Unroll loops iteratively #4779

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub(crate) fn optimize_into_acir(
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization")
.try_run_pass(Ssa::evaluate_assert_constant, "After Assert Constant:")?
.try_run_pass(Ssa::unroll_loops, "After Unrolling:")?
.try_run_pass(unroll_all_acir_loops, "After Unrolling:")?
.run_pass(Ssa::simplify_cfg, "After Simplifying:")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
Expand All @@ -75,6 +75,31 @@ pub(crate) fn optimize_into_acir(
time("SSA to ACIR", print_timings, || ssa.into_acir(&brillig, abi_distinctness))
}

/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
fn unroll_all_acir_loops(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
sirasistant marked this conversation as resolved.
Show resolved Hide resolved
// Try to unroll loops first:
let mut unroll_errors;
(ssa, unroll_errors) = ssa.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying
ssa = ssa.simplify_cfg();
sirasistant marked this conversation as resolved.
Show resolved Hide resolved
ssa = ssa.mem2reg();

// Unroll again
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() == prev_unroll_err_count {
sirasistant marked this conversation as resolved.
Show resolved Hide resolved
return Err(unroll_errors[0].clone());
sirasistant marked this conversation as resolved.
Show resolved Hide resolved
}
}
Ok(ssa)
}

// Helper to time SSA passes
fn time<T>(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T {
let start_time = chrono::Utc::now().time();
Expand Down
36 changes: 16 additions & 20 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ use crate::{
use fxhash::FxHashMap as HashMap;

impl Ssa {
/// Unroll all loops in each SSA function.
/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
/// Returns the ssa along with all unrolling errors encountered
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn unroll_loops(mut self) -> Result<Ssa, RuntimeError> {
pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in self.functions.values_mut() {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
Expand All @@ -46,12 +48,9 @@ impl Ssa {
continue;
}

// This check is always true with the addition of the above guard, but I'm
// keeping it in case the guard on brillig functions is ever removed.
let abort_on_error = matches!(function.runtime(), RuntimeType::Acir(_));
find_all_loops(function).unroll_each_loop(function, abort_on_error)?;
errors.extend(find_all_loops(function).unroll_each_loop(function));
}
Ok(self)
(self, errors)
}
}

Expand Down Expand Up @@ -115,34 +114,29 @@ fn find_all_loops(function: &Function) -> Loops {
impl Loops {
/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
fn unroll_each_loop(
mut self,
function: &mut Function,
abort_on_error: bool,
) -> Result<(), RuntimeError> {
fn unroll_each_loop(mut self, function: &mut Function) -> Vec<RuntimeError> {
let mut unroll_errors = vec![];
while let Some(next_loop) = self.yet_to_unroll.pop() {
// If we've previously modified a block in this loop we need to refresh the context.
// This happens any time we have nested loops.
if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) {
let mut new_context = find_all_loops(function);
new_context.failed_to_unroll = self.failed_to_unroll;
return new_context.unroll_each_loop(function, abort_on_error);
return new_context.unroll_each_loop(function);
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match unroll_loop(function, &self.cfg, &next_loop) {
Ok(_) => self.modified_blocks.extend(next_loop.blocks),
Err(call_stack) if abort_on_error => {
return Err(RuntimeError::UnknownLoopBound { call_stack });
}
Err(_) => {
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
}
Ok(())
unroll_errors
}
}

Expand Down Expand Up @@ -585,7 +579,8 @@ mod tests {
// }
// The final block count is not 1 because unrolling creates some unnecessary jmps.
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
let ssa = ssa.unroll_loops().expect("All loops should be unrolled");
let (ssa, errors) = ssa.try_to_unroll_loops();
assert_eq!(errors.len(), 0, "All loops should be unrolled");
assert_eq!(ssa.main().reachable_blocks().len(), 5);
}

Expand Down Expand Up @@ -634,6 +629,7 @@ mod tests {
assert_eq!(ssa.main().reachable_blocks().len(), 4);

// Expected that we failed to unroll the loop
assert!(ssa.unroll_loops().is_err());
let (_, errors) = ssa.try_to_unroll_loops();
assert_eq!(errors.len(), 1, "Expected to fail to unroll loop");
}
}
6 changes: 6 additions & 0 deletions test_programs/execution_success/slice_loop/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "slice_loop"
type = "bin"
authors = [""]

[dependencies]
11 changes: 11 additions & 0 deletions test_programs/execution_success/slice_loop/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[[points]]
x = "1"
y = "2"

[[points]]
x = "3"
y = "4"

[[points]]
x = "5"
y = "6"
26 changes: 26 additions & 0 deletions test_programs/execution_success/slice_loop/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
struct Point {
x: Field,
y: Field,
}

impl Point {
fn serialize(self) -> [Field; 2] {
[self.x, self.y]
}
}

fn sum(values: [Field]) -> Field {
let mut sum = 0;
for value in values {
sum = sum + value;
}
sum
}

fn main(points: [Point; 3]) {
let mut serialized_points = &[];
for point in points {
serialized_points = serialized_points.append(point.serialize().as_slice());
}
assert_eq(sum(serialized_points), 21);
}
Loading