Skip to content

Commit be1119f

Browse files
authored
Support thinking models from anthropic. (#1555)
<!-- ELLIPSIS_HIDDEN --> > [!IMPORTANT] > Add support for Anthropic thinking models by introducing `TestThinking` function and `CustomStory` model, updating response handling and integration tests. > > - **Behavior**: > - Add `TestThinking` function to handle thinking events, returning `CustomStory`. > - Update `parse_anthropic_response` in `response_handler.rs` to handle multiple content types. > - **Models**: > - Add `CustomStory` model to `types.ts`, `types.py`, and `types.rb`. > - Update `AnthropicMessageContent` enum in `types.rs` to include new content types. > - **Integration Tests**: > - Add `TestThinking` to `clients.baml` and `anthropic.baml`. > - Add tests for `TestThinking` in `test_functions.py`. > - **Misc**: > - Remove `Drop` implementation from `FunctionResultStream` in `stream.rs`. > - Update `openapi.yaml` to include `TestThinking` endpoint. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for ae0e94f. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent 80ba612 commit be1119f

File tree

35 files changed

+693
-99
lines changed

35 files changed

+693
-99
lines changed

engine/baml-runtime/src/internal/llm_client/primitive/anthropic/response_handler.rs

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use anyhow::Result;
22
use baml_types::BamlMap;
33

4-
use super::types::{AnthropicMessageResponse, MessageChunk, StopReason};
5-
use crate::internal::llm_client::{primitive::request::RequestBuilder, traits::WithClient, ErrorCode, LLMCompleteResponse, LLMCompleteResponseMetadata, LLMErrorResponse, LLMResponse};
4+
use super::types::{AnthropicMessageContent, AnthropicMessageResponse, MessageChunk, StopReason};
5+
use crate::internal::llm_client::{
6+
primitive::request::RequestBuilder, traits::WithClient, ErrorCode, LLMCompleteResponse,
7+
LLMCompleteResponseMetadata, LLMErrorResponse, LLMResponse,
8+
};
69
use anyhow::Context;
710
use serde::Deserialize;
811
use serde_json::Value;
@@ -24,11 +27,13 @@ pub fn parse_anthropic_response<C: WithClient + RequestBuilder>(
2427
instant_now: web_time::Instant,
2528
model_name: Option<String>,
2629
) -> LLMResponse {
27-
let response = match AnthropicMessageResponse::deserialize(&response_body).context(format!(
28-
"Failed to parse into a response accepted by {}: {}",
29-
std::any::type_name::<AnthropicMessageResponse>(),
30-
response_body
31-
)).map_err(|e| LLMErrorResponse {
30+
let response = match AnthropicMessageResponse::deserialize(&response_body)
31+
.context(format!(
32+
"Failed to parse into a response accepted by {}: {}",
33+
std::any::type_name::<AnthropicMessageResponse>(),
34+
response_body
35+
))
36+
.map_err(|e| LLMErrorResponse {
3237
client: client.context().name.to_string(),
3338
model: model_name.clone(),
3439
prompt: to_prompt(prompt),
@@ -37,33 +42,41 @@ pub fn parse_anthropic_response<C: WithClient + RequestBuilder>(
3742
latency: instant_now.elapsed(),
3843
message: format!("{:?}", e),
3944
code: ErrorCode::Other(2),
40-
})
41-
{
45+
}) {
4246
Ok(response) => response,
4347
Err(e) => return LLMResponse::LLMFailure(e),
4448
};
4549

50+
println!("response: {:?}", response.content);
51+
52+
let content = response
53+
.content
54+
.iter()
55+
.filter_map(|v| match v {
56+
AnthropicMessageContent::Text { text } => Some(text),
57+
_ => None,
58+
})
59+
.next();
4660

47-
if response.content.len() != 1 {
61+
let content = if let Some(content) = content {
62+
content
63+
} else {
4864
return LLMResponse::LLMFailure(LLMErrorResponse {
4965
client: client.context().name.to_string(),
5066
model: model_name.clone(),
5167
prompt: to_prompt(prompt),
5268
start_time: system_now,
5369
request_options: client.request_options().clone(),
5470
latency: instant_now.elapsed(),
55-
message: format!(
56-
"Expected exactly one content block, got {}",
57-
response.content.len()
58-
),
59-
code: ErrorCode::Other(200),
71+
message: "Anthropic response contains no text".to_string(),
72+
code: ErrorCode::Other(2),
6073
});
61-
}
74+
};
6275

6376
LLMResponse::Success(LLMCompleteResponse {
6477
client: client.context().name.to_string(),
6578
prompt: to_prompt(prompt),
66-
content: response.content[0].text.clone(),
79+
content: content.to_string(),
6780
start_time: system_now,
6881
latency: instant_now.elapsed(),
6982
request_options: client.request_options().clone(),
@@ -84,7 +97,6 @@ pub fn parse_anthropic_response<C: WithClient + RequestBuilder>(
8497
})
8598
}
8699

87-
88100
pub fn scan_anthropic_response_stream(
89101
client_name: &str,
90102
request_options: &BamlMap<String, serde_json::Value>,
@@ -98,14 +110,16 @@ pub fn scan_anthropic_response_stream(
98110
let inner = match accumulated {
99111
Ok(accumulated) => accumulated,
100112
// We'll just keep the first error and return it
101-
Err(e) => return Ok(())
113+
Err(e) => return Ok(()),
102114
};
103115

104-
let event = match MessageChunk::deserialize(&event_body).context(format!(
105-
"Failed to parse into a response accepted by {}: {}",
106-
std::any::type_name::<MessageChunk>(),
107-
event_body
108-
)).map_err(|e| LLMErrorResponse {
116+
let event = match MessageChunk::deserialize(&event_body)
117+
.context(format!(
118+
"Failed to parse into a response accepted by {}: {}",
119+
std::any::type_name::<MessageChunk>(),
120+
event_body
121+
))
122+
.map_err(|e| LLMErrorResponse {
109123
client: client_name.to_string(),
110124
model: model_name.clone(),
111125
prompt: prompt.clone(),
@@ -114,8 +128,7 @@ pub fn scan_anthropic_response_stream(
114128
latency: instant_now.elapsed(),
115129
message: format!("{:?}", e),
116130
code: ErrorCode::Other(2),
117-
})
118-
{
131+
}) {
119132
Ok(response) => response,
120133
Err(e) => return Err(e),
121134
};
@@ -128,15 +141,18 @@ pub fn scan_anthropic_response_stream(
128141
body.stop_reason,
129142
Some(StopReason::StopSequence) | Some(StopReason::EndTurn)
130143
);
131-
inner.finish_reason =
132-
body.stop_reason.as_ref().map(ToString::to_string);
144+
inner.finish_reason = body.stop_reason.as_ref().map(ToString::to_string);
133145
inner.prompt_tokens = Some(body.usage.input_tokens);
134146
inner.output_tokens = Some(body.usage.output_tokens);
135-
inner.total_tokens =
136-
Some(body.usage.input_tokens + body.usage.output_tokens);
147+
inner.total_tokens = Some(body.usage.input_tokens + body.usage.output_tokens);
137148
}
138149
MessageChunk::ContentBlockDelta(event) => {
139-
inner.content += &event.delta.text;
150+
match event.delta {
151+
super::types::ContentBlockDelta::TextDelta { text } => {
152+
inner.content += &text;
153+
}
154+
_ => (),
155+
}
140156
}
141157
MessageChunk::ContentBlockStart(_) => (),
142158
MessageChunk::ContentBlockStop(_) => (),
@@ -154,25 +170,22 @@ pub fn scan_anthropic_response_stream(
154170
.as_ref()
155171
.map(|r| serde_json::to_string(r).unwrap_or("".into()));
156172
inner.output_tokens = Some(body.usage.output_tokens);
157-
inner.total_tokens = Some(
158-
inner.prompt_tokens.unwrap_or(0) + body.usage.output_tokens,
159-
);
173+
inner.total_tokens = Some(inner.prompt_tokens.unwrap_or(0) + body.usage.output_tokens);
160174
}
161175
MessageChunk::MessageStop => (),
162176
MessageChunk::Error(err) => {
163-
return Err(
164-
LLMErrorResponse {
165-
client: client_name.to_string(),
166-
model: model_name.clone(),
167-
prompt: prompt.clone(),
168-
request_options: request_options.clone(),
169-
start_time: system_now.clone(),
170-
latency: instant_now.elapsed(),
171-
message: err.message,
172-
code: ErrorCode::Other(2),
173-
}
174-
);
177+
return Err(LLMErrorResponse {
178+
client: client_name.to_string(),
179+
model: model_name.clone(),
180+
prompt: prompt.clone(),
181+
request_options: request_options.clone(),
182+
start_time: system_now.clone(),
183+
latency: instant_now.elapsed(),
184+
message: err.message,
185+
code: ErrorCode::Other(2),
186+
});
175187
}
188+
MessageChunk::Other => (),
176189
};
177190

178191
inner.latency = instant_now.elapsed();

engine/baml-runtime/src/internal/llm_client/primitive/anthropic/types.rs

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,30 @@ use serde::{Deserialize, Serialize};
22

33
// https://docs.anthropic.com/claude/reference/messages_post
44
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
5-
pub struct AnthropicMessageContent {
6-
pub r#type: String,
7-
pub text: String,
5+
#[serde(tag = "type", rename_all = "snake_case")]
6+
pub enum AnthropicMessageContent {
7+
// type: text
8+
Text {
9+
text: String,
10+
},
11+
// // type: tool_use
12+
ToolUse {
13+
id: Option<String>,
14+
input: serde_json::Value,
15+
name: String,
16+
},
17+
// // type: thinking
18+
// Thinking {
19+
// signature: Option<String>,
20+
// thinking: String,
21+
// },
22+
// type: redacted_thinking
23+
RedactedThinking {
24+
data: String,
25+
},
26+
// fallback for unknown types
27+
#[serde(other)]
28+
Other,
829
}
930

1031
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
@@ -98,6 +119,9 @@ pub enum MessageChunk {
98119
/// Message stop chunk.
99120
MessageStop,
100121
Error(AnthropicErrorInner),
122+
/// Fallback for unknown types
123+
#[serde(other)]
124+
Other,
101125
}
102126

103127
/// The message start chunk.
@@ -122,7 +146,16 @@ pub struct ContentBlockDeltaChunk {
122146
/// The index.
123147
pub index: u32,
124148
/// The text delta content block.
125-
pub delta: TextDeltaContentBlock,
149+
pub delta: ContentBlockDelta,
150+
}
151+
152+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
153+
#[serde(tag = "type", rename_all = "snake_case")]
154+
pub enum ContentBlockDelta {
155+
TextDelta { text: String },
156+
SignatureDelta { signature: String },
157+
ThinkingDelta { thinking: String },
158+
Other,
126159
}
127160

128161
/// The content block stop chunk.
@@ -178,8 +211,7 @@ mod tests {
178211

179212
let chunk = MessageChunk::ContentBlockDelta(ContentBlockDeltaChunk {
180213
index: 0,
181-
delta: TextDeltaContentBlock {
182-
_type: ContentType::TextDelta,
214+
delta: ContentBlockDelta::TextDelta {
183215
text: "Hello".to_string(),
184216
},
185217
});

engine/baml-runtime/src/types/stream.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ pub struct FunctionResultStream {
3737
pub(crate) collectors: Vec<Arc<Collector>>,
3838
}
3939

40-
impl Drop for FunctionResultStream {
41-
fn drop(&mut self) {
42-
log::info!("Dropping FunctionResultStream: {}", self.function_name);
43-
}
44-
}
45-
4640
#[cfg(target_arch = "wasm32")]
4741
// JsFuture is !Send, so when building for WASM, we have to drop that requirement from StreamCallback
4842
static_assertions::assert_impl_all!(FunctionResultStream: Send);

integ-tests/baml_src/clients.baml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,20 @@ client<llm> Sonnet {
265265
}
266266
}
267267

268+
269+
client<llm> SonnetThinking {
270+
provider anthropic
271+
options {
272+
model "claude-3-7-sonnet-20250219"
273+
api_key env.ANTHROPIC_API_KEY
274+
max_tokens 2048
275+
thinking {
276+
type "enabled"
277+
budget_tokens 1024
278+
}
279+
}
280+
}
281+
268282
client<llm> Claude {
269283
provider anthropic
270284
options {

integ-tests/baml_src/test-files/providers/anthropic.baml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,22 @@ function TestCaching(input: string, not_cached: string) -> string {
2525
{{ _.role('user') }}
2626
{{ not_cached }}
2727
"#
28+
}
29+
30+
class CustomStory {
31+
title string
32+
characters string[]
33+
content string
34+
}
35+
36+
function TestThinking(input: string) -> CustomStory {
37+
client SonnetThinking
38+
prompt #"
39+
{{ _.role('system') }}
40+
Generate the following story
41+
{{ ctx.output_format }}
42+
43+
{{ _.role('user') }}
44+
{{ input }}
45+
"#
2846
}

integ-tests/openapi/baml_client/openapi.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,6 +2146,19 @@ paths:
21462146
title: TestSingleFallbackClientResponse
21472147
type: string
21482148
operationId: TestSingleFallbackClient
2149+
/call/TestThinking:
2150+
post:
2151+
requestBody:
2152+
$ref: '#/components/requestBodies/TestThinking'
2153+
responses:
2154+
'200':
2155+
description: Successful operation
2156+
content:
2157+
application/json:
2158+
schema:
2159+
title: TestThinkingResponse
2160+
$ref: '#/components/schemas/CustomStory'
2161+
operationId: TestThinking
21492162
/call/TestUniverseQuestion:
21502163
post:
21512164
requestBody:
@@ -4797,6 +4810,22 @@ components:
47974810
$ref: '#/components/schemas/BamlOptions'
47984811
required: []
47994812
additionalProperties: false
4813+
TestThinking:
4814+
required: true
4815+
content:
4816+
application/json:
4817+
schema:
4818+
title: TestThinkingRequest
4819+
type: object
4820+
properties:
4821+
input:
4822+
type: string
4823+
__baml_options__:
4824+
nullable: true
4825+
$ref: '#/components/schemas/BamlOptions'
4826+
required:
4827+
- input
4828+
additionalProperties: false
48004829
TestUniverseQuestion:
48014830
required: true
48024831
content:
@@ -5239,6 +5268,22 @@ components:
52395268
required:
52405269
- primary
52415270
additionalProperties: false
5271+
CustomStory:
5272+
type: object
5273+
properties:
5274+
characters:
5275+
type: array
5276+
items:
5277+
type: string
5278+
content:
5279+
type: string
5280+
title:
5281+
type: string
5282+
required:
5283+
- characters
5284+
- content
5285+
- title
5286+
additionalProperties: false
52425287
CustomTaskResult:
52435288
type: object
52445289
properties:

0 commit comments

Comments
 (0)