Skip to content

Commit

Permalink
nested parsing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 11, 2024
1 parent 6c9f9a2 commit be3077c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 20 deletions.
4 changes: 2 additions & 2 deletions controllers/ag2_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def character_maker2(lm, id, description, valid_weapons):
prompt = "How much is 2 + 2? "
grm = gen(name="test", max_tokens=30, regex=r"\d+")

prompt = "About J. Random Hacker:\n"
grm = gen_json_object("hacker", max_tokens=50) + "\nScore (0-9): " + gen("score", regex=r"[0-9]")
prompt = "Three things about J. Random Hacker:\n"
grm = gen_json_object("hacker", max_tokens=150) + "\nScore (0-9): " + gen("score", regex=r"[0-9]")

# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n"

Expand Down
48 changes: 32 additions & 16 deletions controllers/ag2_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::{
lexerspec::{Lexeme, LexemeIdx, LexerSpec},
};

const TRACE: bool = true;
const TRACE: bool = false;
const DEBUG: bool = true;
const INFO: bool = true;

Expand Down Expand Up @@ -365,13 +365,34 @@ impl Parser {
&self.grammar
}

fn after_dots(&self) -> impl Iterator<Item = RuleIdx> + '_ {
self.curr_row()
.item_indices()
.map(|i| self.scratch.items[i].rule_idx())
}

fn after_dots_symdata(&self) -> impl Iterator<Item = &CSymbol> + '_ {
self.after_dots().map(|pos| self.grammar.sym_data_at(pos))
}

pub fn can_advance(&self) -> bool {
let skip = self.grammar.lexeme_to_sym_idx(LexemeIdx::SKIP);
for data in self.after_dots_symdata() {
if data.idx == skip || data.idx == CSymIdx::NULL {
continue;
}
if data.is_terminal || data.gen_grammar.is_some() {
return true;
}
}
false
}

pub fn is_accepting(&self) -> bool {
for idx in self.curr_row().item_indices() {
let item = self.scratch.items[idx];
let rule = item.rule_idx();
let after_dot = self.grammar.sym_idx_at(rule);
for pos in self.after_dots() {
let after_dot = self.grammar.sym_idx_at(pos);
if after_dot == CSymIdx::NULL {
let lhs = self.grammar.sym_idx_of(item.rule_idx());
let lhs = self.grammar.sym_idx_of(pos);
if lhs == self.grammar.start() {
return true;
}
Expand Down Expand Up @@ -498,9 +519,7 @@ impl Parser {

pub fn temperature(&self) -> f32 {
let mut temp = 0.0f32;
for i in self.curr_row().item_indices() {
let item = self.scratch.items[i];
let data = self.grammar.sym_data_at(item.rule_idx());
for data in self.after_dots_symdata() {
if data.is_terminal {
temp = temp.max(data.props.temperature);
}
Expand Down Expand Up @@ -727,9 +746,7 @@ impl Parser {

pub fn model_variables(&self) -> Vec<ModelVariable> {
let mut vars = vec![];
for i in self.curr_row().item_indices() {
let item = self.scratch.items[i];
let sym_data = self.grammar.sym_data_at(item.rule_idx());
for sym_data in self.after_dots_symdata() {
if let Some(ref mv) = sym_data.props.model_variable {
if !vars.contains(mv) {
vars.push(mv.clone());
Expand Down Expand Up @@ -777,10 +794,9 @@ impl Parser {
let mut res: Option<GenGrammarOptions> = None;
let mut res_idx = None;
let mut gen_grm = vec![];
for i in self.curr_row().item_indices() {
let item = self.scratch.items[i];
let idx = self.grammar.sym_idx_at(item.rule_idx());
let sym_data = self.grammar.sym_data_at(item.rule_idx());
for pos in self.after_dots() {
let idx = self.grammar.sym_idx_at(pos);
let sym_data = self.grammar.sym_data_at(pos);
if let Some(ref gg) = sym_data.gen_grammar {
// break ties by preferring the one with the lowest grammar number
if res.is_none() || res.as_ref().unwrap().grammar.0 > gg.grammar.0 {
Expand Down
13 changes: 12 additions & 1 deletion controllers/ag2_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ impl TokenParser {
}
}

let inner_done = {
let is_accepting = self.parser.is_accepting();
let can_advance = self.parser.can_advance();
let inner_done = is_accepting && !can_advance;
infoln!("inner_done: {inner_done}; can_advance: {can_advance}; accept: {is_accepting}");
inner_done
};

let trie = self.token_env.tok_trie();
let mut set = trie.alloc_token_set();
// self.parser.print_row(self.parser.num_rows() - 1);
Expand All @@ -303,7 +311,10 @@ impl TokenParser {
set.disallow_token(self.first_token_of_eos_marker);
}

if self.max_tokens_parser == 0 || (set.num_set() == 1 && set.is_allowed(trie.eos_token())) {
if inner_done
|| self.max_tokens_parser == 0
|| (set.num_set() == 1 && set.is_allowed(trie.eos_token()))
{
if self.parser_stack.is_empty() {
infoln!("only eos token allowed, stopping");
return MidProcessResult::stop();
Expand Down
2 changes: 1 addition & 1 deletion py/guidance

0 comments on commit be3077c

Please sign in to comment.