Skip to content

Commit

Permalink
simplify ParseResults
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 13, 2024
1 parent 88520db commit 06b15c2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 54 deletions.
6 changes: 5 additions & 1 deletion controllers/guidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def main():
lm += gen("words", regex=r"[A-Z ]+", stop="\n")
grm = lm

grm = select(["1", "12", "123"], name="the number")
prompt = "<|user|>\nPick a number:\n<|computer|>\n"


# read current script file
# with open(__file__) as f:
Expand All @@ -101,9 +104,10 @@ def main():
print(len(b64))
mod_id = pyaici.cli.build_rust(".")
if "127.0.0.1" in pyaici.rest.base_url:
pyaici.rest.tag_module(mod_id, ["guidance_ctrl-latest"])
pyaici.rest.tag_module(mod_id, ["guidance_ctrl-latest", "guidance"])
pyaici.rest.log_level = 2
res = pyaici.rest.run_controller(
prompt=prompt,
controller=mod_id,
controller_arg=json.dumps({"guidance_b64": b64}),
temperature=0.0,
Expand Down
8 changes: 3 additions & 5 deletions controllers/guidance_ctrl/src/earley/bench.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use aici_abi::toktree;

use super::Parser;
use crate::earley::{from_guidance::earley_grm_from_guidance, parser::ParseResult};
use crate::earley::from_guidance::earley_grm_from_guidance;

pub fn earley_test(trie: toktree::TokTrie) {
let g_bytes = include_bytes!("../../../aici_abi/grammars/json0.guidance");
Expand All @@ -18,15 +18,13 @@ pub fn earley_test(trie: toktree::TokTrie) {
let grm = cfg.compile();

let mut parser = Parser::new(grm.clone());
let mut last_res = ParseResult::Reject;
for b in input {
last_res = parser.scan(*b);
if last_res == ParseResult::Reject {
if !parser.scan(*b) {
println!("reject");
break;
}
}
if last_res != ParseResult::Accept {
if !parser.is_accepting() {
println!("final non-accept");
}

Expand Down
2 changes: 1 addition & 1 deletion controllers/guidance_ctrl/src/earley/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub use byteset::ByteSet;
pub use from_guidance::earley_grm_from_guidance;
#[allow(unused_imports)]
pub use grammar::{Grammar, ModelVariable};
pub use parser::{Parser, ParseResult};
pub use parser::Parser;

#[cfg(not(target_arch = "wasm32"))]
pub mod bench;
85 changes: 38 additions & 47 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ pub struct Stats {
pub all_items: usize,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseResult {
Accept,
Reject,
Continue,
}

struct Row {
first_item: usize,
last_item: usize,
Expand Down Expand Up @@ -141,7 +134,6 @@ pub struct Parser {
rows: Vec<Row>,
row_infos: Vec<RowInfo>,
stats: Stats,
is_accepting: bool,
last_collapse: usize,
token_idx: usize,
}
Expand Down Expand Up @@ -260,7 +252,6 @@ impl Parser {
captures: vec![],
scratch,
stats: Stats::default(),
is_accepting: false,
last_collapse: 0,
token_idx: 0,
};
Expand All @@ -273,7 +264,18 @@ impl Parser {
}

pub fn is_accepting(&self) -> bool {
self.is_accepting
for idx in self.curr_row().item_indices() {
let item = self.scratch.items[idx];
let rule = item.rule_idx();
let after_dot = self.grammar.sym_idx_at(rule);
if after_dot == CSymIdx::NULL {
let lhs = self.grammar.sym_idx_of(item.rule_idx());
if lhs == self.grammar.start() {
return true;
}
}
}
false
}

fn item_to_string(&self, idx: usize) -> String {
Expand Down Expand Up @@ -358,6 +360,7 @@ impl Parser {
self.assert_definitive();
let mut byte_idx = 1; // row_infos[0] has just the 0 byte
let mut tok_idx = 0;
debug!("apply_tokens: {:?}", tokens);
for t in tokens {
for b in trie.token(*t).iter() {
if num_skip > 0 {
Expand All @@ -366,7 +369,7 @@ impl Parser {
}

if byte_idx >= self.row_infos.len() {
if self.scan(*b) == ParseResult::Reject {
if !self.scan(*b) {
return "parse reject";
}
if byte_idx >= self.row_infos.len() {
Expand Down Expand Up @@ -434,10 +437,10 @@ impl Parser {

pub fn force_bytes(&mut self) -> Vec<u8> {
self.assert_definitive();
debug!("force_bytes");
let mut bytes = vec![];
while let Some(b) = self.forced_byte() {
let res = self.scan(b);
if res == ParseResult::Reject {
if !self.scan(b) {
// shouldn't happen?
break;
}
Expand Down Expand Up @@ -466,7 +469,7 @@ impl Parser {
}

fn forced_byte(&self) -> Option<u8> {
if self.is_accepting {
if self.is_accepting() {
// we're not forced when in accepting state
return None;
}
Expand Down Expand Up @@ -497,7 +500,7 @@ impl Parser {
}
}

pub fn hide_item(&mut self, sym: CSymIdx, row_idx: usize) -> ParseResult {
pub fn hide_item(&mut self, sym: CSymIdx, row_idx: usize) -> bool {
info!("hide_item: {} {}", self.grammar.sym_data(sym).name, row_idx);

let row_range = self.rows[row_idx].item_indices();
Expand Down Expand Up @@ -531,7 +534,7 @@ impl Parser {
}

#[inline(always)]
pub fn scan(&mut self, b: u8) -> ParseResult {
pub fn scan(&mut self, b: u8) -> bool {
let row_idx = self.rows.len() - 1;
let last = self.rows[row_idx].last_item;
let mut i = self.rows[row_idx].first_item;
Expand Down Expand Up @@ -563,12 +566,11 @@ impl Parser {
}

#[inline(always)]
fn push_row(&mut self, mut agenda_ptr: usize, byte: u8) -> ParseResult {
fn push_row(&mut self, mut agenda_ptr: usize, byte: u8) -> bool {
let curr_idx = self.rows.len();
let mut commit_item = Item::NULL;

self.stats.rows += 1;
self.is_accepting = false;

while agenda_ptr < self.scratch.row_end {
let mut item_idx = agenda_ptr;
Expand All @@ -583,9 +585,7 @@ impl Parser {

if after_dot == CSymIdx::NULL {
let flags = self.grammar.sym_flags_of(rule);
let lhs = self.grammar.sym_idx_of(item.rule_idx());
// complete
self.is_accepting = self.is_accepting || lhs == self.grammar.start();
let lhs = self.grammar.sym_idx_of(rule);

if self.scratch.definitive && flags.capture() {
let var_name = self
Expand Down Expand Up @@ -671,31 +671,27 @@ impl Parser {
}

let row_len = self.scratch.row_len();
self.stats.all_items += row_len;

if row_len == 0 {
assert!(!self.is_accepting);
return ParseResult::Reject;
}

self.rows.push(Row {
first_item: self.scratch.row_start,
last_item: self.scratch.row_end,
});
false
} else {
self.stats.all_items += row_len;

if self.scratch.definitive {
self.row_infos.drain((self.rows.len() - 1)..);
self.row_infos.push(RowInfo {
byte,
commit_item,
token_idx: self.token_idx,
self.rows.push(Row {
first_item: self.scratch.row_start,
last_item: self.scratch.row_end,
});
}

if self.is_accepting {
ParseResult::Accept
} else {
ParseResult::Continue
if self.scratch.definitive {
self.row_infos.drain((self.rows.len() - 1)..);
self.row_infos.push(RowInfo {
byte,
commit_item,
token_idx: self.token_idx,
});
}

true
}
}
}
Expand Down Expand Up @@ -755,12 +751,7 @@ impl Recognizer for Parser {
}

fn try_push_byte(&mut self, byte: u8) -> bool {
let res = self.scan(byte);
if res == ParseResult::Reject {
false
} else {
true
}
self.scan(byte)
}
}

Expand Down

0 comments on commit 06b15c2

Please sign in to comment.