Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions packages/compiler/src/bin/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ enum Commands {
Decomposed {
#[arg(short, long)]
decomposed_regex_path: String,
#[arg(short, long)]
#[arg(long)]
halo2_dir_path: Option<String>,
#[arg(short, long)]
circom_file_path: Option<String>,
#[arg(short, long)]
template_name: Option<String>,
#[arg(long)]
noir_file_path: Option<String>,
#[arg(short, long)]
gen_substrs: Option<bool>,
},
Expand All @@ -74,12 +76,14 @@ enum Commands {
raw_regex: String,
#[arg(short, long)]
substrs_json_path: Option<String>,
#[arg(short, long)]
#[arg(long)]
halo2_dir_path: Option<String>,
#[arg(short, long)]
circom_file_path: Option<String>,
#[arg(short, long)]
template_name: Option<String>,
#[arg(long)]
noir_file_path: Option<String>,
#[arg(short, long)]
gen_substrs: Option<bool>,
},
Expand All @@ -99,6 +103,7 @@ fn process_decomposed(cli: Cli) {
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
gen_substrs,
} = cli.command
{
Expand All @@ -107,6 +112,7 @@ fn process_decomposed(cli: Cli) {
halo2_dir_path.as_deref(),
circom_file_path.as_deref(),
template_name.as_deref(),
noir_file_path.as_deref(),
gen_substrs,
) {
eprintln!("Error: {}", e);
Expand All @@ -122,6 +128,7 @@ fn process_raw(cli: Cli) {
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
gen_substrs,
} = cli.command
{
Expand All @@ -131,6 +138,7 @@ fn process_raw(cli: Cli) {
halo2_dir_path.as_deref(),
circom_file_path.as_deref(),
template_name.as_deref(),
noir_file_path.as_deref(),
gen_substrs,
) {
eprintln!("Error: {}", e);
Expand Down
11 changes: 11 additions & 0 deletions packages/compiler/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod circom;
mod errors;
mod halo2;
mod noir;
mod regex;
mod structs;
mod wasm;
Expand All @@ -9,6 +10,7 @@ use circom::gen_circom_template;
use errors::CompilerError;
use halo2::gen_halo2_tables;
use itertools::Itertools;
use noir::gen_noir_fn;
use regex::{create_regex_and_dfa_from_str_and_defs, get_regex_and_dfa};
use std::{fs::File, path::PathBuf};
use structs::{DecomposedRegexConfig, RegexAndDFA, SubstringDefinitionsJson};
Expand Down Expand Up @@ -55,6 +57,7 @@ fn generate_outputs(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
circom_template_name: Option<&str>,
noir_file_path: Option<&str>,
num_public_parts: usize,
gen_substrs: bool,
) -> Result<(), CompilerError> {
Expand Down Expand Up @@ -86,6 +89,10 @@ fn generate_outputs(
)?;
}

if let Some(noir_file_path) = noir_file_path {
gen_noir_fn(regex_and_dfa, &PathBuf::from(noir_file_path))?;
}

Ok(())
}

Expand All @@ -107,6 +114,7 @@ pub fn gen_from_decomposed(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
circom_template_name: Option<&str>,
noir_file_path: Option<&str>,
gen_substrs: Option<bool>,
) -> Result<(), CompilerError> {
let mut decomposed_regex_config: DecomposedRegexConfig =
Expand All @@ -126,6 +134,7 @@ pub fn gen_from_decomposed(
halo2_dir_path,
circom_file_path,
circom_template_name,
noir_file_path,
num_public_parts,
gen_substrs,
)?;
Expand Down Expand Up @@ -153,6 +162,7 @@ pub fn gen_from_raw(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
template_name: Option<&str>,
noir_file_path: Option<&str>,
gen_substrs: Option<bool>,
) -> Result<(), CompilerError> {
let substrs_defs_json = load_substring_definitions_json(substrs_json_path)?;
Expand All @@ -167,6 +177,7 @@ pub fn gen_from_raw(
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
num_public_parts,
gen_substrs,
)?;
Expand Down
120 changes: 120 additions & 0 deletions packages/compiler/src/noir.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::{collections::HashSet, fs::File, io::Write, iter::FromIterator, path::Path};

use itertools::Itertools;

use crate::structs::RegexAndDFA;

const ACCEPT_STATE_ID: &str = "accept";

pub fn gen_noir_fn(regex_and_dfa: &RegexAndDFA, path: &Path) -> Result<(), std::io::Error> {
let noir_fn = to_noir_fn(regex_and_dfa);
let mut file = File::create(path)?;
file.write_all(noir_fn.as_bytes())?;
file.flush()?;
Ok(())
}

fn to_noir_fn(regex_and_dfa: &RegexAndDFA) -> String {
let accept_state_ids = {
let accept_states = regex_and_dfa
.dfa
.states
.iter()
.filter(|s| s.state_type == ACCEPT_STATE_ID)
.map(|s| s.state_id)
.collect_vec();
assert!(accept_states.len() > 0, "no accept states");
accept_states
};

const BYTE_SIZE: u32 = 256; // u8 size
let mut lookup_table_body = String::new();

// curr_state + char_code -> next_state
let mut rows: Vec<(usize, u8, usize)> = vec![];

for state in regex_and_dfa.dfa.states.iter() {
for (&tran_next_state_id, tran) in &state.transitions {
for &char_code in tran {
rows.push((state.state_id, char_code, tran_next_state_id));
}
}
if state.state_type == ACCEPT_STATE_ID {
let existing_char_codes = &state
.transitions
.iter()
.flat_map(|(_, tran)| tran.iter().copied().collect_vec())
.collect::<HashSet<_>>();
let all_char_codes = HashSet::from_iter(0..=255);
let mut char_codes = all_char_codes.difference(existing_char_codes).collect_vec();
char_codes.sort(); // to be deterministic
for &char_code in char_codes {
let next_state_id = if regex_and_dfa.has_end_anchor {
0 // reset if we encounter another char after we reach the end anchor
} else {
state.state_id // no end anchor? Just stay in the same state
};
rows.push((state.state_id, char_code, next_state_id));
}
}
}

for (curr_state_id, char_code, next_state_id) in rows {
lookup_table_body +=
&format!("table[{curr_state_id} * {BYTE_SIZE} + {char_code}] = {next_state_id};\n",);
}

lookup_table_body = indent(&lookup_table_body);
let table_size = BYTE_SIZE as usize * regex_and_dfa.dfa.states.len();
let lookup_table = format!(
r#"
comptime fn make_lookup_table() -> [Field; {table_size}] {{
let mut table = [0; {table_size}];
{lookup_table_body}

table
}}
"#
);

let final_states_condition_body = accept_state_ids
.iter()
.map(|id| format!("(s == {id})"))
.collect_vec()
.join(" | ");
let fn_body = format!(
r#"
global table = comptime {{ make_lookup_table() }};
pub fn regex_match<let N: u32>(input: [u8; N]) {{
// regex: {regex_pattern}
let mut s = 0;
for i in 0..input.len() {{
s = table[s * {BYTE_SIZE} + input[i] as Field];
}}
assert({final_states_condition_body}, f"no match: {{s}}");
}}
"#,
regex_pattern = regex_and_dfa.regex_pattern,
);
format!(
r#"
{fn_body}
{lookup_table}
"#
)
.trim()
.to_owned()
}

fn indent(s: &str) -> String {
s.split("\n")
.map(|s| {
if s.trim().is_empty() {
s.to_owned()
} else {
format!("{}{}", " ", s)
}
})
.collect::<Vec<_>>()
.join("\n")
}