Skip to content

Commit

Permalink
fix eos handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 20, 2024
1 parent d3f8a1d commit 7504bbd
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
4 changes: 4 additions & 0 deletions controllers/guidance_ctrl/src/earley/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ pub enum ModelVariable {
}

impl ModelVariable {
pub fn eos_token() -> Self {
ModelVariable::SpecialToken(SpecialToken::EndOfSentence)
}

#[allow(dead_code)]
pub fn to_string(&self) -> String {
match self {
Expand Down
25 changes: 25 additions & 0 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,31 @@ impl Parser {
self.push_row(agenda_ptr, last_byte)
}

pub fn scan_model_variable(&mut self, mv: ModelVariable) -> bool {
if self.scratch.definitive {
debug!(" scan mv: {:?}", mv);
}

self.scratch.new_row(self.curr_row().last_item);

for idx in self.curr_row().item_indices() {
let item = self.scratch.items[idx];
let sym_data = self.grammar.sym_data_at(item.rule_idx());
if let Some(ref mv2) = sym_data.props.model_variable {
if mv == *mv2 {
self.scratch
.add_unique(item.advance_dot(), idx, "scan_model_variable");
}
}
}

if self.scratch.row_len() == 0 {
false
} else {
self.push_row(self.scratch.row_start, 0)
}
}

#[inline(always)]
pub fn scan(&mut self, b: u8) -> bool {
let row_idx = self.rows.len() - 1;
Expand Down
15 changes: 12 additions & 3 deletions controllers/guidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::earley::{earley_grm_from_guidance, Parser};
use crate::earley::{earley_grm_from_guidance, ModelVariable, Parser};
use aici_abi::{MidProcessArg, MidProcessResult, TokenId, TokenizerEnv};
use anyhow::Result;

Expand Down Expand Up @@ -95,7 +95,7 @@ impl TokenParser {
grm_bytes
}

pub fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult {
pub fn mid_process(&mut self, mut arg: MidProcessArg) -> MidProcessResult {
let start_time = std::time::Instant::now();

infoln!("\n");
Expand All @@ -106,7 +106,16 @@ impl TokenParser {
arg.backtrack,
trie.tokens_dbg(&arg.tokens)
);
arg.save_tokens(&mut self.llm_tokens);

if arg.tokens.contains(&trie.eos_token()) {
assert!(arg.tokens.len() == 1);
if self.parser.scan_model_variable(ModelVariable::eos_token()) {
// it got scanned correctly, so we remove it
arg.tokens.clear();
}
} else {
arg.save_tokens(&mut self.llm_tokens);
}

let new_bytes = trie.decode(&arg.tokens);
self.llm_bytes.extend_from_slice(&new_bytes);
Expand Down

0 comments on commit 7504bbd

Please sign in to comment.