Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 68 additions & 5 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::{
vec,
};

const REPLACEMENT: &str = "\u{FFFD}";

// Parsed representation of a message header.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ParsedHeader {
Expand Down Expand Up @@ -1059,6 +1061,7 @@ pub struct StreamableParser {
stop_tokens: HashSet<Rank>,
last_content_delta: Option<String>,
undecoded_tokens: Vec<Rank>,
undecoded_bytes: Vec<u8>,
options: ParseOptions,
}

Expand Down Expand Up @@ -1105,6 +1108,7 @@ impl StreamableParser {
stop_tokens,
last_content_delta: None,
undecoded_tokens: Vec::new(),
undecoded_bytes: Vec::new(),
options,
})
}
Expand Down Expand Up @@ -1214,14 +1218,59 @@ impl StreamableParser {
match self
.encoding
.tokenizer()
.decode_utf8(&self.undecoded_tokens)
.decode_bytes(&self.undecoded_tokens)
{
Ok(decoded) => {
content_tokens.extend(self.undecoded_tokens.iter().copied());
self.last_content_delta = Some(decoded);
Ok(decoded_bytes) => {
self.undecoded_bytes.extend(decoded_bytes.iter().copied());
match String::from_utf8(self.undecoded_bytes.clone()) {
Ok(decoded_str) => {
self.encoding
.render_text_into(&decoded_str, content_tokens)?;
self.last_content_delta = Some(decoded_str);
self.undecoded_bytes.clear();
}
Err(e) => {
let utf8_error = e.utf8_error();
let decoded_bytes = e.into_bytes();

let valid_len = utf8_error.valid_up_to();

let mut content_delta = String::new();
if valid_len > 0 {
let valid_str = String::from_utf8(
decoded_bytes[..valid_len].to_vec(),
)
.unwrap();
self.encoding
.render_text_into(&valid_str, content_tokens)?;
content_delta.push_str(&valid_str);
self.undecoded_bytes.drain(..valid_len);
}

match utf8_error.error_len() {
Some(error_len) => {
self.encoding.render_text_into(
REPLACEMENT,
content_tokens,
)?;
content_delta.push_str(REPLACEMENT);
self.undecoded_bytes.drain(..error_len);
}
None => {
// waiting on next byte in our utf-8 sequence
self.last_content_delta = None;
}
}

if !content_delta.is_empty() {
self.last_content_delta = Some(content_delta);
}
}
}
self.undecoded_tokens.clear();
}
Err(_) => {
// Bytes not yet valid utf-8, wait on the next token
self.last_content_delta = None;
}
}
Expand All @@ -1233,7 +1282,20 @@ impl StreamableParser {
true
};
if is_eos {
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
// Our rendered content tokens are valid utf-8, so we can decode them directly
let content_text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
// Decode any remaining undecoded tokens, replacing any invalid tokens with the replacement character
let tokens_text = match self
.encoding
.tokenizer()
.decode_utf8(self.undecoded_tokens.clone())
{
Ok(text) => text,
Err(_) => REPLACEMENT.to_string(),
};
// Decode any remaining undecoded bytes, replacing any invalid bytes with the replacement character
let bytes_text = String::from_utf8_lossy(&self.undecoded_bytes);
let text = content_text + &tokens_text + &bytes_text;
let message = Message {
author: header.author.clone(),
recipient: header.recipient.clone(),
Expand All @@ -1245,6 +1307,7 @@ impl StreamableParser {
self.state = StreamState::ExpectStart;
self.last_content_delta = None;
self.undecoded_tokens.clear();
self.undecoded_bytes.clear();
}
}
}
Expand Down
147 changes: 147 additions & 0 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,150 @@ fn test_tool_call_with_channel_before_recipient_and_constrain_adjacent() {
.with_content_type("<|constrain|>json")];
assert_eq!(parsed, expected);
}

#[test]
fn test_streamable_parser_does_not_leak_bytes_between_messages() {
// This test ensures that any partially decoded bytes from the first message
// do not leak into the content of the next message.
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();

// 9552 is known (in this tokenizer) to expand to bytes that form an incomplete
// UTF-8 sequence when used alone, which exercises the streaming invalid-UTF-8 path.
// Construct two assistant messages back-to-back where the first includes invalid
// UTF-8 bytes and the second is simple ASCII text.
let first_prefix = "<|start|>assistant<|message|>";
let first_suffix = "<|end|>";
let second_prefix = "<|start|>assistant<|message|>";
let second_suffix = "<|end|>";

let mut tokens = Vec::new();
tokens.extend(
encoding
.tokenizer()
.encode_with_special_tokens(first_prefix),
);
// Two invalid tokens to ensure we end the first message with incomplete UTF-8 bytes.
tokens.push(9552);
tokens.push(9552);
tokens.extend(
encoding
.tokenizer()
.encode_with_special_tokens(first_suffix),
);

// Second message should be clean and unaffected.
tokens.extend(
encoding
.tokenizer()
.encode_with_special_tokens(second_prefix),
);
tokens.extend(encoding.tokenizer().encode_with_special_tokens("Hi"));
tokens.extend(
encoding
.tokenizer()
.encode_with_special_tokens(second_suffix),
);

let mut parser = StreamableParser::new(encoding, None).unwrap();
for t in tokens {
parser.process(t).unwrap();
}

let messages = parser.messages();
assert_eq!(messages.len(), 2, "expected two parsed messages");

// Verify the second message content is exactly "Hi" (no leaked replacement chars or bytes).
let second = &messages[1];
let expected_second = Message::from_role_and_content(Role::Assistant, "Hi");
assert_eq!(
second, &expected_second,
"second message must be clean and isolated"
);
}

#[test]
fn test_streamable_parser_flushes_partial_bytes_on_eos() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();

let mut tokens = encoding
.tokenizer()
.encode_with_special_tokens("<|start|>assistant<|message|>");
tokens.push(9552);
tokens.extend(encoding.tokenizer().encode_with_special_tokens("Hi"));

let mut parser = StreamableParser::new(encoding.clone(), None).unwrap();
for token in tokens {
parser.process(token).unwrap();
}
parser.process_eos().unwrap();

let messages = parser.messages();
assert_eq!(messages.len(), 1, "expected a single message after EOS");
let expected = Message::from_role_and_content(Role::Assistant, " \u{FFFD}Hi");
assert_eq!(messages[0], expected);
assert_eq!(parser.last_content_delta().unwrap(), None);
assert_eq!(parser.current_content().unwrap(), "");
}

#[test]
fn test_streamable_parser_waits_for_multi_token_utf8_sequence() {
use std::collections::HashSet;

let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let mut parser = StreamableParser::new(encoding.clone(), None).unwrap();

let start_tokens = encoding
.tokenizer()
.encode_with_special_tokens("<|start|>assistant<|message|>");
for token in &start_tokens {
parser.process(*token).unwrap();
}

let emoji_tokens = encoding.tokenizer().encode("💖", &HashSet::new()).0;
assert!(
emoji_tokens.len() >= 2,
"expected multi-token emoji encoding"
);

parser.process(emoji_tokens[0]).unwrap();
assert_eq!(parser.last_content_delta().unwrap(), None);
assert_eq!(parser.current_content().unwrap(), "");

parser.process(emoji_tokens[1]).unwrap();
assert_eq!(parser.last_content_delta().unwrap(), Some("💖".to_string()));
assert_eq!(parser.current_content().unwrap(), "💖");

let end_tokens = encoding.tokenizer().encode_with_special_tokens("<|end|>");
for token in end_tokens {
parser.process(token).unwrap();
}

let messages = parser.messages();
assert_eq!(messages.len(), 1, "expected a single completed message");
let expected = Message::from_role_and_content(Role::Assistant, "💖");
assert_eq!(messages[0], expected);
}

#[test]
fn test_parse_completion_with_invalid_content_token_errors_on_eos() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let mut parser = StreamableParser::new(encoding.clone(), None).unwrap();

let start_tokens = encoding.tokenizer().encode_with_special_tokens(
"<|start|>assistant<|channel|>analysis<|message|>Practice invalid token handling.",
);
for token in &start_tokens {
parser.process(*token).unwrap();
}

parser.process(u32::MAX).unwrap();
parser.process_eos().unwrap();

let messages = parser.messages();
assert_eq!(messages.len(), 1);
let parsed_message = &messages[0];
let expected_message =
Message::from_role_and_content(Role::Assistant, "Practice invalid token handling.\u{FFFD}")
.with_channel("analysis");
assert_eq!(parsed_message, &expected_message);
}
Loading
Loading