Skip to content

Commit

Permalink
add is_generated to Text
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 14, 2024
1 parent 9abd972 commit 1654501
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
6 changes: 3 additions & 3 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def character_maker2(lm, id, description, valid_weapons):
}}"""
return lm

grm = character_maker2(1, "A nimble fighter", ["axe", "sword", "bow"])
prompt = ""

grm = "Write a number: " + gen("text", max_tokens=3)
grm = "Q: 1000 + 3\nA: " + gen("text", regex="[0-9]+", max_tokens=20)
grm = "Q: 1000 + 3\nA: " + gen("text", regex="[0-9]+", max_tokens=2)
Expand Down Expand Up @@ -230,6 +227,9 @@ def character_maker2(lm, id, description, valid_weapons):
+ gen("score", regex=r"[0-9]")
)

grm = character_maker2(1, "A nimble fighter", ["axe", "sword", "bow"])
prompt = ""


# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n"

Expand Down
3 changes: 1 addition & 2 deletions controllers/llguidance_ctrl/src/gctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ impl AiciCtrl for Runner {
}
fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult {
let r = self.tok_parser.mid_process(arg);
let is_final = r.is_stop();
for v in self.reporter.get_progress(&mut self.tok_parser, is_final) {
for v in self.reporter.get_progress(&mut self.tok_parser, &r) {
json_out(&v);
}
r
Expand Down
12 changes: 9 additions & 3 deletions controllers/llguidance_ctrl/src/output.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use aici_abi::bytes::to_hex_string;
use aici_abi::{bytes::to_hex_string, MidProcessResult};
use serde::{Deserialize, Serialize};

use crate::{earley, TokenParser};
Expand Down Expand Up @@ -27,6 +27,7 @@ pub enum ParserOutput {
bytes: BytesOutput,
log_prob: f64,
num_tokens: usize,
is_generated: bool,
stats: ParserStats,
},
}
Expand Down Expand Up @@ -58,6 +59,7 @@ pub struct Reporter {
text_ptr: usize,
token_ptr: usize,
prev_stats: earley::ParserStats,
is_generated: bool,
}

impl Reporter {
Expand All @@ -67,13 +69,14 @@ impl Reporter {
text_ptr: 0,
token_ptr: tok_parser.num_tokens(),
prev_stats: tok_parser.parser_stats().clone(),
is_generated: false,
}
}

pub fn get_progress(
&mut self,
tok_parser: &mut TokenParser,
is_final: bool,
mid_res: &MidProcessResult,
) -> Vec<ParserOutput> {
let mut res = vec![];

Expand Down Expand Up @@ -112,12 +115,15 @@ impl Reporter {
bytes: new_text.into(),
log_prob: 0.0, // TODO
num_tokens: num_tokens - self.token_ptr,
is_generated: self.is_generated,
stats,
});
self.text_ptr += new_text.len();
self.token_ptr = num_tokens;

if is_final {
self.is_generated = mid_res.branches.len() >= 1 && mid_res.branches[0].splices.len() == 0;

if mid_res.is_stop() {
res.push(ParserOutput::FinalText {
bytes: tok_parser.final_bytes().into(),
});
Expand Down
2 changes: 1 addition & 1 deletion py/llguidance/rust/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl LLInterpreter {
});
let is_final = r.is_stop();
let mut res = PyMidProcessResult {
progress: self.reporter.get_progress(&mut self.inner, is_final),
progress: self.reporter.get_progress(&mut self.inner, &r),
stop: is_final,
backtrack: 0,
temperature: self.temperature,
Expand Down

0 comments on commit 1654501

Please sign in to comment.