diff --git a/src/encoding.rs b/src/encoding.rs index 1999372..60257e7 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -9,6 +9,8 @@ use std::{ vec, }; +const REPLACEMENT: &str = "\u{FFFD}"; + // Parsed representation of a message header. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ParsedHeader { @@ -1059,6 +1061,7 @@ pub struct StreamableParser { stop_tokens: HashSet, last_content_delta: Option, undecoded_tokens: Vec, + undecoded_bytes: Vec, options: ParseOptions, } @@ -1105,6 +1108,7 @@ impl StreamableParser { stop_tokens, last_content_delta: None, undecoded_tokens: Vec::new(), + undecoded_bytes: Vec::new(), options, }) } @@ -1214,14 +1218,59 @@ impl StreamableParser { match self .encoding .tokenizer() - .decode_utf8(&self.undecoded_tokens) + .decode_bytes(&self.undecoded_tokens) { - Ok(decoded) => { - content_tokens.extend(self.undecoded_tokens.iter().copied()); - self.last_content_delta = Some(decoded); + Ok(decoded_bytes) => { + self.undecoded_bytes.extend(decoded_bytes.iter().copied()); + match String::from_utf8(self.undecoded_bytes.clone()) { + Ok(decoded_str) => { + self.encoding + .render_text_into(&decoded_str, content_tokens)?; + self.last_content_delta = Some(decoded_str); + self.undecoded_bytes.clear(); + } + Err(e) => { + let utf8_error = e.utf8_error(); + let decoded_bytes = e.into_bytes(); + + let valid_len = utf8_error.valid_up_to(); + + let mut content_delta = String::new(); + if valid_len > 0 { + let valid_str = String::from_utf8( + decoded_bytes[..valid_len].to_vec(), + ) + .unwrap(); + self.encoding + .render_text_into(&valid_str, content_tokens)?; + content_delta.push_str(&valid_str); + self.undecoded_bytes.drain(..valid_len); + } + + match utf8_error.error_len() { + Some(error_len) => { + self.encoding.render_text_into( + REPLACEMENT, + content_tokens, + )?; + content_delta.push_str(REPLACEMENT); + self.undecoded_bytes.drain(..error_len); + } + None => { + // waiting on next byte in our utf-8 sequence + self.last_content_delta = None; + } + } + + if !content_delta.is_empty() { + self.last_content_delta = Some(content_delta); + } + } + } self.undecoded_tokens.clear(); } Err(_) => { + // Bytes not yet valid utf-8, wait on the next token self.last_content_delta = None; } } @@ -1233,7 +1282,20 @@ impl StreamableParser { true }; if is_eos { - let text = self.encoding.tokenizer().decode_utf8(content_tokens)?; + // Our rendered content tokens are valid utf-8, so we can decode them directly + let content_text = self.encoding.tokenizer().decode_utf8(content_tokens)?; + // Decode any remaining undecoded tokens, replacing any invalid tokens with the replacement character + let tokens_text = match self + .encoding + .tokenizer() + .decode_utf8(self.undecoded_tokens.clone()) + { + Ok(text) => text, + Err(_) => REPLACEMENT.to_string(), + }; + // Decode any remaining undecoded bytes, replacing any invalid bytes with the replacement character + let bytes_text = String::from_utf8_lossy(&self.undecoded_bytes); + let text = content_text + &tokens_text + &bytes_text; let message = Message { author: header.author.clone(), recipient: header.recipient.clone(), @@ -1245,6 +1307,7 @@ impl StreamableParser { self.state = StreamState::ExpectStart; self.last_content_delta = None; self.undecoded_tokens.clear(); + self.undecoded_bytes.clear(); } } } diff --git a/src/tests.rs b/src/tests.rs index 922be79..7aba934 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -739,3 +739,150 @@ fn test_tool_call_with_channel_before_recipient_and_constrain_adjacent() { .with_content_type("<|constrain|>json")]; assert_eq!(parsed, expected); } + +#[test] +fn test_streamable_parser_does_not_leak_bytes_between_messages() { + // This test ensures that any partially decoded bytes from the first message + // do not leak into the content of the next message. + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + + // 9552 is known (in this tokenizer) to expand to bytes that form an incomplete + // UTF-8 sequence when used alone, which exercises the streaming invalid-UTF-8 path. + // Construct two assistant messages back-to-back where the first includes invalid + // UTF-8 bytes and the second is simple ASCII text. + let first_prefix = "<|start|>assistant<|message|>"; + let first_suffix = "<|end|>"; + let second_prefix = "<|start|>assistant<|message|>"; + let second_suffix = "<|end|>"; + + let mut tokens = Vec::new(); + tokens.extend( + encoding + .tokenizer() + .encode_with_special_tokens(first_prefix), + ); + // Two invalid tokens to ensure we end the first message with incomplete UTF-8 bytes. + tokens.push(9552); + tokens.push(9552); + tokens.extend( + encoding + .tokenizer() + .encode_with_special_tokens(first_suffix), + ); + + // Second message should be clean and unaffected. + tokens.extend( + encoding + .tokenizer() + .encode_with_special_tokens(second_prefix), + ); + tokens.extend(encoding.tokenizer().encode_with_special_tokens("Hi")); + tokens.extend( + encoding + .tokenizer() + .encode_with_special_tokens(second_suffix), + ); + + let mut parser = StreamableParser::new(encoding, None).unwrap(); + for t in tokens { + parser.process(t).unwrap(); + } + + let messages = parser.messages(); + assert_eq!(messages.len(), 2, "expected two parsed messages"); + + // Verify the second message content is exactly "Hi" (no leaked replacement chars or bytes). + let second = &messages[1]; + let expected_second = Message::from_role_and_content(Role::Assistant, "Hi"); + assert_eq!( + second, &expected_second, + "second message must be clean and isolated" + ); +} + +#[test] +fn test_streamable_parser_flushes_partial_bytes_on_eos() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + + let mut tokens = encoding + .tokenizer() + .encode_with_special_tokens("<|start|>assistant<|message|>"); + tokens.push(9552); + tokens.extend(encoding.tokenizer().encode_with_special_tokens("Hi")); + + let mut parser = StreamableParser::new(encoding.clone(), None).unwrap(); + for token in tokens { + parser.process(token).unwrap(); + } + parser.process_eos().unwrap(); + + let messages = parser.messages(); + assert_eq!(messages.len(), 1, "expected a single message after EOS"); + let expected = Message::from_role_and_content(Role::Assistant, " \u{FFFD}Hi"); + assert_eq!(messages[0], expected); + assert_eq!(parser.last_content_delta().unwrap(), None); + assert_eq!(parser.current_content().unwrap(), ""); +} + +#[test] +fn test_streamable_parser_waits_for_multi_token_utf8_sequence() { + use std::collections::HashSet; + + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let mut parser = StreamableParser::new(encoding.clone(), None).unwrap(); + + let start_tokens = encoding + .tokenizer() + .encode_with_special_tokens("<|start|>assistant<|message|>"); + for token in &start_tokens { + parser.process(*token).unwrap(); + } + + let emoji_tokens = encoding.tokenizer().encode("💖", &HashSet::new()).0; + assert!( + emoji_tokens.len() >= 2, + "expected multi-token emoji encoding" + ); + + parser.process(emoji_tokens[0]).unwrap(); + assert_eq!(parser.last_content_delta().unwrap(), None); + assert_eq!(parser.current_content().unwrap(), ""); + + parser.process(emoji_tokens[1]).unwrap(); + assert_eq!(parser.last_content_delta().unwrap(), Some("💖".to_string())); + assert_eq!(parser.current_content().unwrap(), "💖"); + + let end_tokens = encoding.tokenizer().encode_with_special_tokens("<|end|>"); + for token in end_tokens { + parser.process(token).unwrap(); + } + + let messages = parser.messages(); + assert_eq!(messages.len(), 1, "expected a single completed message"); + let expected = Message::from_role_and_content(Role::Assistant, "💖"); + assert_eq!(messages[0], expected); +} + +#[test] +fn test_parse_completion_with_invalid_content_token_errors_on_eos() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let mut parser = StreamableParser::new(encoding.clone(), None).unwrap(); + + let start_tokens = encoding.tokenizer().encode_with_special_tokens( + "<|start|>assistant<|channel|>analysis<|message|>Practice invalid token handling.", + ); + for token in &start_tokens { + parser.process(*token).unwrap(); + } + + parser.process(u32::MAX).unwrap(); + parser.process_eos().unwrap(); + + let messages = parser.messages(); + assert_eq!(messages.len(), 1); + let parsed_message = &messages[0]; + let expected_message = + Message::from_role_and_content(Role::Assistant, "Practice invalid token handling.\u{FFFD}") + .with_channel("analysis"); + assert_eq!(parsed_message, &expected_message); +} diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 761bcef..dbb9925 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -1088,3 +1088,159 @@ def test_streamable_parser_missing_message_token_tool_call( .with_content_type("json"), ] assert parser.messages == expected + + +def test_streamable_parser_invalid_utf8_decoding(): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + # Confirm our token sequence is invalid utf-8 + # token 9552 corresponds to the bytes [32, 240, 159] + # 32 is a space, 240,159 is an invalid utf-8 sequence + invalid_token_sequence = [9552, 9552] + with pytest.raises(HarmonyError): + encoding.decode_utf8(invalid_token_sequence) + + prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + suffix_tokens = encoding.encode("worked<|end|>", allowed_special="all") + tokens = prefix_tokens + invalid_token_sequence + suffix_tokens + parser = StreamableParser(encoding, None) + for token in tokens: + parser.process(token) + + expected = [ + # Confirm we got the utf-8 replacement characters for the invalid sequences + # and the remaining valid utf-8 sequence + Message.from_role_and_content(Role.ASSISTANT, " \uFFFD \uFFFDworked"), + ] + assert parser.messages == expected + + +def test_streamable_parser_invalid_utf8_decoding_split_across_tokens(): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + valid_token_sequence = encoding.encode("XY") + encoding.decode_utf8(valid_token_sequence) + + # Confirm prepending specific token makes invalid utf-8 + # 9552 token is the start of a multi-byte utf-8 sequence, + # which means prepending it to our previously valid sequence + # makes it invalid utf-8 + invalid_token_sequence = [9552] + valid_token_sequence + with pytest.raises(HarmonyError): + encoding.decode_utf8(invalid_token_sequence) + + prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + suffix_tokens = encoding.encode("<|end|>", allowed_special="all") + tokens = prefix_tokens + invalid_token_sequence + suffix_tokens + parser = StreamableParser(encoding, None) + for token in tokens: + parser.process(token) + + expected = [ + # One utf-8 replacement character but otherwise kept our space + # (from token 9552) and "X" and "Y" tokens + Message.from_role_and_content(Role.ASSISTANT, " \uFFFDXY"), + ] + assert parser.messages == expected + + +def test_streamable_parser_invalid_utf8_decoding_multi_byte_token(): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + # Valid utf-8 sequence - 55=X, 56=Y in tokenizer + valid_token_sequence = encoding.encode(" interesting") + encoding.decode_utf8(valid_token_sequence) + + # Confirm prepending specific token makes invalid utf-8 + # 9552 token is the start of a multi-byte utf-8 sequence, + # which means prepending it to our previously valid sequence + # makes it invalid utf-8 + invalid_token_sequence = [9552] + valid_token_sequence + with pytest.raises(HarmonyError): + encoding.decode_utf8(invalid_token_sequence) + + prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + suffix_tokens = encoding.encode("<|end|>", allowed_special="all") + tokens = prefix_tokens + invalid_token_sequence + suffix_tokens + parser = StreamableParser(encoding, None) + for token in tokens: + parser.process(token) + + expected = [ + # One utf-8 replacement character and the contents of our second token, + # which maps to the text " interesting" + Message.from_role_and_content(Role.ASSISTANT, " \uFFFD interesting"), + ] + assert parser.messages == expected + + +def test_streamable_parser_invalid_utf8_decoding_multi_byte_token_no_eos_marker(): + """Ensure we don't leave partially decoded tokens with no EOS marker.""" + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + # Valid utf-8 sequence - 55=X, 56=Y in tokenizer + valid_token_sequence = encoding.encode(" interesting") + encoding.decode_utf8(valid_token_sequence) + + # Confirm prepending specific token makes invalid utf-8 + # 9552 token is the start of a multi-byte utf-8 sequence, + # which means prepending it to our previously valid sequence + # makes it invalid utf-8 + invalid_token_sequence = [9552] + valid_token_sequence + with pytest.raises(HarmonyError): + encoding.decode_utf8(invalid_token_sequence) + + prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + suffix_tokens = encoding.encode(" story") + tokens = prefix_tokens + invalid_token_sequence + suffix_tokens + parser = StreamableParser(encoding, None) + + content_deltas = [] + for token in tokens: + parser.process(token) + if parser.last_content_delta is not None: + content_deltas.append(parser.last_content_delta) + + # No EOS, so no full message, but make sure we have the current content + assert parser.current_content == " \uFFFD interesting story" + + # Ensure all the deltas combine to form our expected content + assert "".join(content_deltas) == " \uFFFD interesting story" + + # Confirm we can keep accumulating content delta and content + one_more_token = encoding.encode("Y")[0] + parser.process(one_more_token) + assert parser.last_content_delta == "Y" + assert parser.current_content == " \uFFFD interesting storyY" + + +def test_streamable_parser_tricky_utf8_decoding(): + """Try text with various types of utf-8 sequences that are more likely to fail.""" + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + tricky_utf8_text = ( + "Hello Müller, Γειά σου, Привет, שלום, مرحبا, नमस्ते, こんにちは, 안녕하세요," + " 你好. Normalized (naïve) vs. decomposed (naïve) characters. " + "Some emojis: 😊👋🏾👨‍👩‍👧‍👦🇺🇸." + ) + valid_token_sequence = encoding.encode(tricky_utf8_text) + + prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + suffix_tokens = encoding.encode("<|end|>", allowed_special="all") + tokens = prefix_tokens + valid_token_sequence + suffix_tokens + parser = StreamableParser(encoding, None) + + content_deltas = [] + for token in tokens: + parser.process(token) + if parser.last_content_delta is not None: + content_deltas.append(parser.last_content_delta) + + expected = [ + Message.from_role_and_content(Role.ASSISTANT, tricky_utf8_text), + ] + # Ensure we got the entirety of our tricky utf-8 text as message content + assert parser.messages == expected + + # Ensure if we're accumulating content deltas we still get the full utf-8 text + assert "".join(content_deltas) == tricky_utf8_text