Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,22 @@ impl Render<Message> for HarmonyEncoding {

// finally content type
if let Some(content_type) = &message.content_type {
self.render_text_into(format!(" {content_type}"), into)?;
// <|constrain|> is a unique case which needs to be tokenized as a special token
if let Some(constrain_marker) = self.mapped_format_token(FormattingToken::ConstrainedFormat) {
if content_type.starts_with(constrain_marker) {
// Render the space, then the constrain marker as a special token, then the rest as text (if any)
self.render_text_into(" ", into)?;
self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?;
let rest = &content_type[constrain_marker.len()..];
if !rest.is_empty() {
self.render_text_into(rest, into)?;
}
} else {
self.render_text_into(format!(" {content_type}"), into)?;
}
} else {
self.render_text_into(format!(" {content_type}"), into)?;
}
}

self.render_formatting_token_into(FormattingToken::Message, into)?;
Expand Down
34 changes: 33 additions & 1 deletion tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,36 @@ def test_simple_tool_call(encoding_name):
assert parsed == expected


@pytest.mark.parametrize(
"encoding_name",
[
HarmonyEncodingName.HARMONY_GPT_OSS,
],
)
def test_tool_call_with_constrain_tokenized_correctly(encoding_name):
"""
Despite passing <|constrain|> as a string in "content_type" it has to be kept as a special token.
"""
encoding = load_harmony_encoding(encoding_name)
text = (
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
' <|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
)
tokens = encoding.encode(text, allowed_special="all")
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
expected = [
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_weather")
.with_content_type("<|constrain|>json"),
]
assert parsed == expected

rendered = encoding.render_conversation(Conversation.from_messages(expected))
assert text == encoding.decode_utf8(tokens)
assert rendered == tokens


@pytest.mark.parametrize(
"encoding_name",
[
Expand All @@ -248,7 +278,7 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name):
encoding = load_harmony_encoding(encoding_name)
text = (
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
'<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
)
tokens = encoding.encode(text, allowed_special="all")
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
Expand Down Expand Up @@ -702,6 +732,8 @@ def test_does_not_drop_if_ongoing_analysis():
)

assert encoding.decode_utf8(tokens) == expected_output
# ensure that <|constrain|>json part is tokenized correctly as special tokens
assert encoding.encode(expected_output, allowed_special="all") == tokens


def test_preserve_cot():
Expand Down