Skip to content

Commit

Permalink
hydra auto-gen cairo1 unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jun 28, 2024
1 parent 27efa87 commit 05132a0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
4 changes: 2 additions & 2 deletions hydra/extension_field_modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def summarize(self):
def compile_circuit(self, function_name: str = None):
self.values_segment = self.values_segment.non_interactive_transform()
if self.compilation_mode == 0:
return self.compile_circuit_cairo_zero(function_name)
return self.compile_circuit_cairo_zero(function_name), None
elif self.compilation_mode == 1:
return self.compile_circuit_cairo_1(function_name)

Expand Down Expand Up @@ -756,7 +756,7 @@ def write_stack(
code += f"let res=array![{','.join(['o'+str(i) for i, _ in enumerate(outputs_refs)])}];\n"
code += "return res;\n"
code += "}\n"
return code
return code, function_name


if __name__ == "__main__":
Expand Down
79 changes: 68 additions & 11 deletions hydra/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
N_LIMBS,
)
from hydra.hints import neg_3
from hydra.hints.io import int_array_to_u384_array
from random import seed, randint
from enum import Enum
from tools.gnark_cli import GnarkCLI
Expand Down Expand Up @@ -86,7 +87,8 @@ def __init__(
self.generic_over_curve = False
self.compilation_mode = compilation_mode
if auto_run:
self.circuit: ModuloCircuit = self._run_circuit_inner(self.build_input())
self.input = self.build_input()
self.circuit: ModuloCircuit = self._run_circuit_inner(self.input)

@abstractmethod
def build_input(self) -> list[PyFelt]:
Expand Down Expand Up @@ -777,14 +779,28 @@ def compilation_mode_to_file_header(mode: int) -> str:
use core::circuit::{
RangeCheck96, AddMod, MulMod, u96, CircuitElement, CircuitInput, circuit_add, circuit_sub,
circuit_mul, circuit_inverse, EvalCircuitResult, EvalCircuitTrait, u384, CircuitOutputsTrait,
CircuitModulus, FillInputResultTrait, CircuitInputs, FillInputResult, CircuitDefinition,
CircuitModulus, AddInputResultTrait, CircuitInputs, CircuitDefinition,
CircuitData, CircuitInputAccumulator
};
use garaga::definitions::{get_a, get_b, get_p, get_g, get_min_one, G1Point};
use core::option::Option;
"""


def cairo1_tests_header() -> str:
return """
#[cfg(test)]
mod tests {
use core::traits::TryInto;
use core::circuit::{
RangeCheck96, AddMod, MulMod, u96, CircuitElement, CircuitInput, circuit_add, circuit_sub,
circuit_mul, circuit_inverse, EvalCircuitResult, EvalCircuitTrait, u384,
CircuitOutputsTrait, CircuitModulus, AddInputResultTrait, CircuitInputs,
};
"""


def main(
PRECOMPILED_CIRCUITS_DIR: str = "src/fustat/precompiled_circuits/",
CIRCUITS_TO_COMPILE: dict = ALL_FUSTAT_CIRCUITS,
Expand All @@ -802,6 +818,8 @@ def to_snake_case(s: str) -> str:
filenames_used = set([v["filename"] for v in CIRCUITS_TO_COMPILE.values()])
codes = {filename: set() for filename in filenames_used}
selector_functions = {filename: set() for filename in filenames_used}
cairo1_tests_functions = {filename: set() for filename in filenames_used}
cairo1_full_function_names = {filename: set() for filename in filenames_used}

files = {
f: open(f"{PRECOMPILED_CIRCUITS_DIR}{f}.cairo", "w") for f in filenames_used
Expand All @@ -826,6 +844,7 @@ def compile_circuit(

circuits: list[BaseModuloCircuit] = []
compiled_circuits: list[str] = []
full_function_names: list[str] = []

if params is None:
circuit_instance = circuit_class(
Expand All @@ -848,13 +867,38 @@ def compile_circuit(
)
function_name += f"_{params[i][param_name]}" if params else ""

compiled_circuit = circuit_instance.circuit.compile_circuit(
function_name=function_name
compiled_circuit, full_function_name = (
circuit_instance.circuit.compile_circuit(function_name=function_name)
)

compiled_circuits.append(compiled_circuit)

return compiled_circuits, selector_function
if compilation_mode == 1:
circuit_input = circuit_instance.input
circuit_output = circuit_instance.circuit.output
full_function_names.append(full_function_name)
cairo1_tests_functions[filename_key].add(
write_cairo1_test(
full_function_name,
circuit_input,
circuit_output,
curve_id.value,
)
)

return compiled_circuits, selector_function, full_function_names

def write_cairo1_test(function_name: str, input: list, output: list, curve_id: int):
code = f"""
#[test]
fn test_{function_name}_{CurveID(curve_id).name}() {{
let input = {int_array_to_u384_array(input)};
let output = {int_array_to_u384_array(output)};
let result = {function_name}(input, {curve_id});
assert_eq!(result, output);
}}
"""
return code

def build_selector_function(
circuit_id: CircuitID, circuit_instance: BaseModuloCircuit, params: list[dict]
Expand Down Expand Up @@ -941,14 +985,16 @@ def build_selector_function(
for curve_id in [CurveID.BN254, CurveID.BLS12_381]:
for circuit_id, circuit_info in CIRCUITS_TO_COMPILE.items():
filename_key = circuit_info["filename"]
compiled_circuits, selectors = compile_circuit(
compiled_circuits, selectors, full_function_names = compile_circuit(
curve_id,
circuit_info["class"],
circuit_id,
circuit_info["params"],
)
codes[filename_key].update(compiled_circuits)
selector_functions[filename_key].update(selectors)
if compilation_mode == 1:
cairo1_full_function_names[filename_key].update(full_function_names)

# Write selector functions and compiled circuit codes to their respective files
print(f"Writing circuits and selectors to .cairo files...")
Expand All @@ -960,6 +1006,16 @@ def build_selector_function(
# Write the compiled circuit codes
for compiled_circuit in codes[filename]:
files[filename].write(compiled_circuit + "\n")

if compilation_mode == 1:
files[filename].write(cairo1_tests_header() + "\n")
files[filename].write(
f"use super::{{{','.join(cairo1_full_function_names[filename])}}};\n"
)
for cairo1_test in cairo1_tests_functions[filename]:
files[filename].write(cairo1_test + "\n")
files[filename].write("}\n")

else:
print(f"Warning: No file associated with filename '{filename}'")

Expand All @@ -969,7 +1025,7 @@ def build_selector_function(

def format_cairo_files_in_parallel(filenames, compilation_mode):
if compilation_mode == 0:
print(f"Formatting .cairo files in parallel...")
print(f"Formatting .cairo zero files in parallel...")
cairo_files = [f"{PRECOMPILED_CIRCUITS_DIR}{f}.cairo" for f in filenames]
with ProcessPoolExecutor() as executor:
futures = [
Expand Down Expand Up @@ -1021,10 +1077,11 @@ def format_cairo_files_in_parallel(filenames, compilation_mode):
"params": None,
"filename": "ec",
},
# CircuitID.EVAL_FUNCTION_CHALLENGE_DUPL: {
# "class": EvalFunctionChallengeDuplCircuit,
# "params": [{"n_points": k} for k in [1, 2, 3]],
# },
CircuitID.EVAL_FUNCTION_CHALLENGE_DUPL: {
"class": EvalFunctionChallengeDuplCircuit,
"params": [{"n_points": k} for k in [1, 2, 3]],
"filename": "ec",
},
CircuitID.FP12_MUL: {
"class": FP12MulCircuit,
"params": None,
Expand Down

0 comments on commit 05132a0

Please sign in to comment.