Skip to content

Commit

Permalink
chore(acir)!: Move is_recursive flag to be part of the circuit defi…
Browse files Browse the repository at this point in the history
…nition (#4221)

Resolves #4222

Currently in order to specify whether we want to use a prover that
produces SNARK recursion friendly proofs, we must pass a flag from the
tooling infrastructure. This PR moves it be part of the circuit
definition itself.

The flag now lives on the Builder and is set when we call
`create_circuit` in the acir format. The proof produced when this flag
is true should be friendly for recursive verification inside of another
SNARK. For example, a recursive friendly proof may use Blake3Pedersen
for hashing in its transcript, while we still want a prove that uses
Keccak for its transcript in order to be able to verify SNARKs on
Ethereum.

However, a verifier does not need a full circuit description and should
be able to verify a proof with just the verification key and the proof.
An `is_recursive_circuit` field was thus added to the verification key
as well so that we can specify the accurate verifier to use for a given
proof without the full circuit description.

---------

Signed-off-by: kevaundray <kevtheappdev@gmail.com>
Co-authored-by: ledwards2225 <98505400+ledwards2225@users.noreply.github.com>
Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
3 people authored and AztecBot committed Feb 1, 2024
1 parent 9944bb1 commit 9a70040
Show file tree
Hide file tree
Showing 27 changed files with 134 additions and 146 deletions.
4 changes: 4 additions & 0 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ namespace Circuit {
Circuit::PublicInputs public_parameters;
Circuit::PublicInputs return_values;
std::vector<std::tuple<Circuit::OpcodeLocation, std::string>> assert_messages;
bool recursive;

friend bool operator==(const Circuit&, const Circuit&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4889,6 +4890,7 @@ namespace Circuit {
if (!(lhs.public_parameters == rhs.public_parameters)) { return false; }
if (!(lhs.return_values == rhs.return_values)) { return false; }
if (!(lhs.assert_messages == rhs.assert_messages)) { return false; }
if (!(lhs.recursive == rhs.recursive)) { return false; }
return true;
}

Expand Down Expand Up @@ -4919,6 +4921,7 @@ void serde::Serializable<Circuit::Circuit>::serialize(const Circuit::Circuit &ob
serde::Serializable<decltype(obj.public_parameters)>::serialize(obj.public_parameters, serializer);
serde::Serializable<decltype(obj.return_values)>::serialize(obj.return_values, serializer);
serde::Serializable<decltype(obj.assert_messages)>::serialize(obj.assert_messages, serializer);
serde::Serializable<decltype(obj.recursive)>::serialize(obj.recursive, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -4933,6 +4936,7 @@ Circuit::Circuit serde::Deserializable<Circuit::Circuit>::deserialize(Deserializ
obj.public_parameters = serde::Deserializable<decltype(obj.public_parameters)>::deserialize(deserializer);
obj.return_values = serde::Deserializable<decltype(obj.return_values)>::deserialize(deserializer);
obj.assert_messages = serde::Deserializable<decltype(obj.assert_messages)>::deserialize(deserializer);
obj.recursive = serde::Deserializable<decltype(obj.recursive)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}
Expand Down
7 changes: 7 additions & 0 deletions acvm-repo/acir/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ pub struct Circuit {
// c++ code at the moment when it is, due to OpcodeLocation needing a comparison
// implementation which is never generated.
pub assert_messages: Vec<(OpcodeLocation, String)>,

/// States whether the backend should use a SNARK recursion friendly prover.
/// If implemented by a backend, this means that proofs generated with this circuit
/// will be friendly for recursively verifying inside of another SNARK.
pub recursive: bool,
}

impl Circuit {
Expand Down Expand Up @@ -318,6 +323,7 @@ mod tests {
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])),
assert_messages: Default::default(),
recursive: false,
};

fn read_write(circuit: Circuit) -> (Circuit, Circuit) {
Expand Down Expand Up @@ -348,6 +354,7 @@ mod tests {
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
assert_messages: Default::default(),
recursive: false,
};

let json = serde_json::to_string_pretty(&circuit).unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ test_cases.forEach((testInfo) => {

// JS Proving

const proofWithPublicInputs = await program.generateFinalProof(inputs);
const proofWithPublicInputs = await program.generateProof(inputs);

// JS verification

const verified = await program.verifyFinalProof(proofWithPublicInputs);
const verified = await program.verifyProof(proofWithPublicInputs);
expect(verified, 'Proof fails verification in JS').to.be.true;
});

Expand Down
14 changes: 7 additions & 7 deletions compiler/integration-tests/test/browser/recursion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ await newABICoder();
await initACVM();

const base_relative_path = '../../../../..';
const circuit_main = 'test_programs/execution_success/assert_statement';
const circuit_main = 'test_programs/execution_success/assert_statement_recursive';
const circuit_recursion = 'compiler/integration-tests/circuits/recursion';

async function getCircuit(projectPath: string) {
Expand Down Expand Up @@ -48,15 +48,15 @@ describe('It compiles noir program code, receiving circuit bytes and abi object.

const { witness: main_witnessUint8Array } = await new Noir(main_program).execute(main_inputs);

const main_proof = await main_backend.generateIntermediateProof(main_witnessUint8Array);
const main_verification = await main_backend.verifyIntermediateProof(main_proof);
const main_proof = await main_backend.generateProof(main_witnessUint8Array);
const main_verification = await main_backend.verifyProof(main_proof);

logger.debug('main_verification', main_verification);

expect(main_verification).to.be.true;

const numPublicInputs = 1;
const { proofAsFields, vkAsFields, vkHash } = await main_backend.generateIntermediateProofArtifacts(
const { proofAsFields, vkAsFields, vkHash } = await main_backend.generateRecursiveProofArtifacts(
main_proof,
numPublicInputs,
);
Expand All @@ -76,20 +76,20 @@ describe('It compiles noir program code, receiving circuit bytes and abi object.

const { witness: recursion_witnessUint8Array } = await new Noir(recursion_program).execute(recursion_inputs);

const recursion_proof = await recursion_backend.generateFinalProof(recursion_witnessUint8Array);
const recursion_proof = await recursion_backend.generateProof(recursion_witnessUint8Array);

// Causes an "unreachable" error.
// Due to the fact that it's a non-recursive proof?
//
// const recursion_numPublicInputs = 1;
// const { proofAsFields: recursion_proofAsFields } = await recursion_backend.generateIntermediateProofArtifacts(
// const { proofAsFields: recursion_proofAsFields } = await recursion_backend.generateRecursiveProofArtifacts(
// recursion_proof,
// recursion_numPublicInputs,
// );
//
// logger.debug('recursion_proofAsFields', recursion_proofAsFields);

const recursion_verification = await recursion_backend.verifyFinalProof(recursion_proof);
const recursion_verification = await recursion_backend.verifyProof(recursion_proof);

logger.debug('recursion_verification', recursion_verification);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ it(`smart contract can verify a recursive proof`, async () => {
const fm = createFileManager(basePath);
const innerCompilationResult = await compile(
fm,
join(basePath, './test_programs/execution_success/assert_statement'),
join(basePath, './test_programs/execution_success/assert_statement_recursive'),
);
if (!('program' in innerCompilationResult)) {
throw new Error('Compilation failed');
Expand All @@ -38,17 +38,17 @@ it(`smart contract can verify a recursive proof`, async () => {
const inner = new Noir(innerProgram);

const inner_prover_toml = readFileSync(
join(basePath, `./test_programs/execution_success/assert_statement/Prover.toml`),
join(basePath, `./test_programs/execution_success/assert_statement_recursive/Prover.toml`),
).toString();

const inner_inputs = toml.parse(inner_prover_toml);

const { witness: main_witness } = await inner.execute(inner_inputs);
const intermediate_proof = await inner_backend.generateIntermediateProof(main_witness);
const intermediate_proof = await inner_backend.generateProof(main_witness);

expect(await inner_backend.verifyIntermediateProof(intermediate_proof)).to.be.true;
expect(await inner_backend.verifyProof(intermediate_proof)).to.be.true;

const { proofAsFields, vkAsFields, vkHash } = await inner_backend.generateIntermediateProofArtifacts(
const { proofAsFields, vkAsFields, vkHash } = await inner_backend.generateRecursiveProofArtifacts(
intermediate_proof,
1, // 1 public input
);
Expand All @@ -65,8 +65,8 @@ it(`smart contract can verify a recursive proof`, async () => {
key_hash: vkHash,
};

const recursion_proof = await recursion.generateFinalProof(recursion_inputs);
expect(await recursion.verifyFinalProof(recursion_proof)).to.be.true;
const recursion_proof = await recursion.generateProof(recursion_inputs);
expect(await recursion.verifyProof(recursion_proof)).to.be.true;

// Smart contract verification

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ test_cases.forEach((testInfo) => {
const prover_toml = readFileSync(resolve(`${base_relative_path}/${test_case}/Prover.toml`)).toString();
const inputs = toml.parse(prover_toml);

const proofData = await program.generateFinalProof(inputs);
const proofData = await program.generateProof(inputs);

// JS verification

const verified = await program.verifyFinalProof(proofData);
const verified = await program.verifyProof(proofData);
expect(verified, 'Proof fails verification in JS').to.be.true;

// Smart contract verification
Expand Down
2 changes: 2 additions & 0 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub fn create_circuit(
enable_brillig_logging: bool,
) -> Result<(Circuit, DebugInfo, Vec<Witness>, Vec<Witness>, Vec<SsaReport>), RuntimeError> {
let func_sig = program.main_function_signature.clone();
let recursive = program.recursive;
let mut generated_acir =
optimize_into_acir(program, enable_ssa_logging, enable_brillig_logging)?;
let opcodes = generated_acir.take_opcodes();
Expand All @@ -111,6 +112,7 @@ pub fn create_circuit(
public_parameters,
return_values,
assert_messages: assert_messages.into_iter().collect(),
recursive,
};

// This converts each im::Vector in the BTreeMap to a Vec
Expand Down
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/ast/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum FunctionKind {
Builtin,
Normal,
Oracle,
Recursive,
}

impl NoirFunction {
Expand Down Expand Up @@ -106,6 +107,7 @@ impl From<FunctionDefinition> for NoirFunction {
Some(FunctionAttribute::Foreign(_)) => FunctionKind::LowLevel,
Some(FunctionAttribute::Test { .. }) => FunctionKind::Normal,
Some(FunctionAttribute::Oracle(_)) => FunctionKind::Oracle,
Some(FunctionAttribute::Recursive) => FunctionKind::Recursive,
None => FunctionKind::Normal,
};

Expand Down
14 changes: 14 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ pub enum ResolverError {
InvalidTypeForEntryPoint { span: Span },
#[error("Nested slices are not supported")]
NestedSlices { span: Span },
#[error("#[recursive] attribute is only allowed on entry points to a program")]
MisplacedRecursiveAttribute { ident: Ident },
#[error("Usage of the `#[foreign]` or `#[builtin]` function attributes are not allowed outside of the Noir standard library")]
LowLevelFunctionOutsideOfStdlib { ident: Ident },
}
Expand Down Expand Up @@ -313,6 +315,18 @@ impl From<ResolverError> for Diagnostic {
"Try to use a constant sized array instead".into(),
span,
),
ResolverError::MisplacedRecursiveAttribute { ident } => {
let name = &ident.0.contents;

let mut diag = Diagnostic::simple_error(
format!("misplaced #[recursive] attribute on function {name} rather than the main function"),
"misplaced #[recursive] attribute".to_string(),
ident.0.span(),
);

diag.add_note("The `#[recursive]` attribute specifies to the backend whether it should use a prover which generates proofs that are friendly for recursive verification in another circuit".to_owned());
diag
}
ResolverError::LowLevelFunctionOutsideOfStdlib { ident } => Diagnostic::simple_error(
"Definition of low-level function outside of standard library".into(),
"Usage of the `#[foreign]` or `#[builtin]` function attributes are not allowed outside of the Noir standard library".into(),
Expand Down
8 changes: 7 additions & 1 deletion compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ impl<'a> Resolver<'a> {
FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => {
HirFunction::empty()
}
FunctionKind::Normal => {
FunctionKind::Normal | FunctionKind::Recursive => {
let expr_id = self.intern_block(func.def.body);
self.interner.push_expr_location(expr_id, func.def.span, self.file);
HirFunction::unchecked_from_expr(expr_id)
Expand Down Expand Up @@ -923,6 +923,12 @@ impl<'a> Resolver<'a> {
{
self.push_err(ResolverError::NecessaryPub { ident: func.name_ident().clone() });
}
// '#[recursive]' attribute is only allowed for entry point functions
if !self.is_entry_point_function(func) && func.kind == FunctionKind::Recursive {
self.push_err(ResolverError::MisplacedRecursiveAttribute {
ident: func.name_ident().clone(),
});
}

if !self.distinct_allowed(func)
&& func.def.return_distinctness != Distinctness::DuplicationAllowed
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl FuncMeta {
pub fn can_ignore_return_type(&self) -> bool {
match self.kind {
FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true,
FunctionKind::Normal => false,
FunctionKind::Normal | FunctionKind::Recursive => false,
}
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ impl Attribute {
Attribute::Function(FunctionAttribute::Oracle(name.to_string()))
}
["test"] => Attribute::Function(FunctionAttribute::Test(TestScope::None)),
["recursive"] => Attribute::Function(FunctionAttribute::Recursive),
["test", name] => {
validate(name)?;
let malformed_scope =
Expand Down Expand Up @@ -541,6 +542,7 @@ pub enum FunctionAttribute {
Builtin(String),
Oracle(String),
Test(TestScope),
Recursive,
}

impl FunctionAttribute {
Expand Down Expand Up @@ -578,6 +580,7 @@ impl fmt::Display for FunctionAttribute {
FunctionAttribute::Foreign(ref k) => write!(f, "#[foreign({k})]"),
FunctionAttribute::Builtin(ref k) => write!(f, "#[builtin({k})]"),
FunctionAttribute::Oracle(ref k) => write!(f, "#[oracle({k})]"),
FunctionAttribute::Recursive => write!(f, "#[recursive]"),
}
}
}
Expand Down Expand Up @@ -621,6 +624,7 @@ impl AsRef<str> for FunctionAttribute {
FunctionAttribute::Builtin(string) => string,
FunctionAttribute::Oracle(string) => string,
FunctionAttribute::Test { .. } => "",
FunctionAttribute::Recursive => "",
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/monomorphization/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ pub struct Program {
pub return_distinctness: Distinctness,
pub return_location: Option<Location>,
pub return_visibility: Visibility,
/// Indicates to a backend whether a SNARK-friendly prover should be used.
pub recursive: bool,
}

impl Program {
Expand All @@ -255,13 +257,15 @@ impl Program {
return_distinctness: Distinctness,
return_location: Option<Location>,
return_visibility: Visibility,
recursive: bool,
) -> Program {
Program {
functions,
main_function_signature,
return_distinctness,
return_location,
return_visibility,
recursive,
}
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ pub fn monomorphize(main: node_interner::FuncId, interner: &NodeInterner) -> Pro
meta.return_distinctness,
monomorphizer.return_location,
meta.return_visibility,
meta.kind == FunctionKind::Recursive,
)
}

Expand Down Expand Up @@ -195,6 +196,9 @@ impl<'interner> Monomorphizer<'interner> {
_ => unreachable!("Oracle function must have an oracle attribute"),
}
}
FunctionKind::Recursive => {
unreachable!("Only main can be specified as recursive, which should already be checked");
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "assert_statement_recursive"
type = "bin"
authors = [""]
compiler_version = ">=0.23.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "3"
y = "3"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Tests a very simple program.
//
// The features being tested is assertion
// This is the same as the `assert_statement` test except we specify
// that the backend should use a prover which will construct proofs
// friendly to recursive verification in another SNARK.
#[recursive]
fn main(x: Field, y: pub Field) {
assert(x == y, "x and y are not equal");
assert_eq(x, y, "x and y are not equal");
}
Loading

0 comments on commit 9a70040

Please sign in to comment.