From 06b15c2cc86411d561ccd77bbbb6fa10ed412221 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 13 May 2024 22:00:06 +0000 Subject: [PATCH] simplify ParseResults --- controllers/guidance_ctrl/run_g.py | 6 +- controllers/guidance_ctrl/src/earley/bench.rs | 8 +- controllers/guidance_ctrl/src/earley/mod.rs | 2 +- .../guidance_ctrl/src/earley/parser.rs | 85 +++++++++---------- 4 files changed, 47 insertions(+), 54 deletions(-) diff --git a/controllers/guidance_ctrl/run_g.py b/controllers/guidance_ctrl/run_g.py index ba68c8e8..46bf32cc 100644 --- a/controllers/guidance_ctrl/run_g.py +++ b/controllers/guidance_ctrl/run_g.py @@ -92,6 +92,9 @@ def main(): lm += gen("words", regex=r"[A-Z ]+", stop="\n") grm = lm + grm = select(["1", "12", "123"], name="the number") + prompt = "<|user|>\nPick a number:\n<|computer|>\n" + # read current script file # with open(__file__) as f: @@ -101,9 +104,10 @@ def main(): print(len(b64)) mod_id = pyaici.cli.build_rust(".") if "127.0.0.1" in pyaici.rest.base_url: - pyaici.rest.tag_module(mod_id, ["guidance_ctrl-latest"]) + pyaici.rest.tag_module(mod_id, ["guidance_ctrl-latest", "guidance"]) pyaici.rest.log_level = 2 res = pyaici.rest.run_controller( + prompt=prompt, controller=mod_id, controller_arg=json.dumps({"guidance_b64": b64}), temperature=0.0, diff --git a/controllers/guidance_ctrl/src/earley/bench.rs b/controllers/guidance_ctrl/src/earley/bench.rs index 21c31e35..4f3b7eab 100644 --- a/controllers/guidance_ctrl/src/earley/bench.rs +++ b/controllers/guidance_ctrl/src/earley/bench.rs @@ -1,7 +1,7 @@ use aici_abi::toktree; use super::Parser; -use crate::earley::{from_guidance::earley_grm_from_guidance, parser::ParseResult}; +use crate::earley::from_guidance::earley_grm_from_guidance; pub fn earley_test(trie: toktree::TokTrie) { let g_bytes = include_bytes!("../../../aici_abi/grammars/json0.guidance"); @@ -18,15 +18,13 @@ pub fn earley_test(trie: toktree::TokTrie) { let grm = cfg.compile(); let mut parser = Parser::new(grm.clone()); - let mut last_res = ParseResult::Reject; for b in input { - last_res = parser.scan(*b); - if last_res == ParseResult::Reject { + if !parser.scan(*b) { println!("reject"); break; } } - if last_res != ParseResult::Accept { + if !parser.is_accepting() { println!("final non-accept"); } diff --git a/controllers/guidance_ctrl/src/earley/mod.rs b/controllers/guidance_ctrl/src/earley/mod.rs index 6f81e0db..b8a9832d 100644 --- a/controllers/guidance_ctrl/src/earley/mod.rs +++ b/controllers/guidance_ctrl/src/earley/mod.rs @@ -7,7 +7,7 @@ pub use byteset::ByteSet; pub use from_guidance::earley_grm_from_guidance; #[allow(unused_imports)] pub use grammar::{Grammar, ModelVariable}; -pub use parser::{Parser, ParseResult}; +pub use parser::Parser; #[cfg(not(target_arch = "wasm32"))] pub mod bench; diff --git a/controllers/guidance_ctrl/src/earley/parser.rs b/controllers/guidance_ctrl/src/earley/parser.rs index 8397ba84..94c10242 100644 --- a/controllers/guidance_ctrl/src/earley/parser.rs +++ b/controllers/guidance_ctrl/src/earley/parser.rs @@ -76,13 +76,6 @@ pub struct Stats { pub all_items: usize, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ParseResult { - Accept, - Reject, - Continue, -} - struct Row { first_item: usize, last_item: usize, @@ -141,7 +134,6 @@ pub struct Parser { rows: Vec, row_infos: Vec, stats: Stats, - is_accepting: bool, last_collapse: usize, token_idx: usize, } @@ -260,7 +252,6 @@ impl Parser { captures: vec![], scratch, stats: Stats::default(), - is_accepting: false, last_collapse: 0, token_idx: 0, }; @@ -273,7 +264,18 @@ impl Parser { } pub fn is_accepting(&self) -> bool { - self.is_accepting + for idx in self.curr_row().item_indices() { + let item = self.scratch.items[idx]; + let rule = item.rule_idx(); + let after_dot = self.grammar.sym_idx_at(rule); + if after_dot == CSymIdx::NULL { + let lhs = self.grammar.sym_idx_of(item.rule_idx()); + if lhs == self.grammar.start() { + return true; + } + } + } + false } fn item_to_string(&self, idx: usize) -> String { @@ -358,6 +360,7 @@ impl Parser { self.assert_definitive(); let mut byte_idx = 1; // row_infos[0] has just the 0 byte let mut tok_idx = 0; + debug!("apply_tokens: {:?}", tokens); for t in tokens { for b in trie.token(*t).iter() { if num_skip > 0 { @@ -366,7 +369,7 @@ impl Parser { } if byte_idx >= self.row_infos.len() { - if self.scan(*b) == ParseResult::Reject { + if !self.scan(*b) { return "parse reject"; } if byte_idx >= self.row_infos.len() { @@ -434,10 +437,10 @@ impl Parser { pub fn force_bytes(&mut self) -> Vec { self.assert_definitive(); + debug!("force_bytes"); let mut bytes = vec![]; while let Some(b) = self.forced_byte() { - let res = self.scan(b); - if res == ParseResult::Reject { + if !self.scan(b) { // shouldn't happen? break; } @@ -466,7 +469,7 @@ impl Parser { } fn forced_byte(&self) -> Option { - if self.is_accepting { + if self.is_accepting() { // we're not forced when in accepting state return None; } @@ -497,7 +500,7 @@ impl Parser { } } - pub fn hide_item(&mut self, sym: CSymIdx, row_idx: usize) -> ParseResult { + pub fn hide_item(&mut self, sym: CSymIdx, row_idx: usize) -> bool { info!("hide_item: {} {}", self.grammar.sym_data(sym).name, row_idx); let row_range = self.rows[row_idx].item_indices(); @@ -531,7 +534,7 @@ impl Parser { } #[inline(always)] - pub fn scan(&mut self, b: u8) -> ParseResult { + pub fn scan(&mut self, b: u8) -> bool { let row_idx = self.rows.len() - 1; let last = self.rows[row_idx].last_item; let mut i = self.rows[row_idx].first_item; @@ -563,12 +566,11 @@ impl Parser { } #[inline(always)] - fn push_row(&mut self, mut agenda_ptr: usize, byte: u8) -> ParseResult { + fn push_row(&mut self, mut agenda_ptr: usize, byte: u8) -> bool { let curr_idx = self.rows.len(); let mut commit_item = Item::NULL; self.stats.rows += 1; - self.is_accepting = false; while agenda_ptr < self.scratch.row_end { let mut item_idx = agenda_ptr; @@ -583,9 +585,7 @@ impl Parser { if after_dot == CSymIdx::NULL { let flags = self.grammar.sym_flags_of(rule); - let lhs = self.grammar.sym_idx_of(item.rule_idx()); - // complete - self.is_accepting = self.is_accepting || lhs == self.grammar.start(); + let lhs = self.grammar.sym_idx_of(rule); if self.scratch.definitive && flags.capture() { let var_name = self @@ -671,31 +671,27 @@ impl Parser { } let row_len = self.scratch.row_len(); - self.stats.all_items += row_len; if row_len == 0 { - assert!(!self.is_accepting); - return ParseResult::Reject; - } - - self.rows.push(Row { - first_item: self.scratch.row_start, - last_item: self.scratch.row_end, - }); + false + } else { + self.stats.all_items += row_len; - if self.scratch.definitive { - self.row_infos.drain((self.rows.len() - 1)..); - self.row_infos.push(RowInfo { - byte, - commit_item, - token_idx: self.token_idx, + self.rows.push(Row { + first_item: self.scratch.row_start, + last_item: self.scratch.row_end, }); - } - if self.is_accepting { - ParseResult::Accept - } else { - ParseResult::Continue + if self.scratch.definitive { + self.row_infos.drain((self.rows.len() - 1)..); + self.row_infos.push(RowInfo { + byte, + commit_item, + token_idx: self.token_idx, + }); + } + + true } } } @@ -755,12 +751,7 @@ impl Recognizer for Parser { } fn try_push_byte(&mut self, byte: u8) -> bool { - let res = self.scan(byte); - if res == ParseResult::Reject { - false - } else { - true - } + self.scan(byte) } }