Skip to content

Commit e7c45c2

Browse files
authored
bedrock collector (#1703)
- **Add bedrock http request and response** - **update tests** <!-- ELLIPSIS_HIDDEN --> ---- > [!IMPORTANT] > Introduces Bedrock HTTP request/response handling with `CollectorInterceptor` for tracing and updates AWS client tests for enhanced error handling and dynamic configuration. > > - **Bedrock HTTP Request and Response**: > - Added `CollectorInterceptor` in `aws_client.rs` to log HTTP requests and responses to `BAML_TRACER`. > - Implemented `smithy_json_headers()` to convert headers to JSON format. > - Updated `client_anyhow()` to use `CollectorInterceptor` for tracing. > - **Testing**: > - Added `TestOpenAIDummyClient` in `dummy-clients.baml` and corresponding test cases in `test_client_response.py`. > - Enhanced AWS tests in `aws.test.ts` to cover various scenarios including invalid credentials and dynamic client configuration. > - **Dependencies**: > - Added `futures` and `reqwest` to `Cargo.lock` and `Cargo.toml` for async operations and HTTP requests. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for e9b304f. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent f945204 commit e7c45c2

36 files changed

Lines changed: 1361 additions & 35 deletions

File tree

engine/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
use std::collections::HashMap;
2+
use std::sync::Arc;
23

34
use aws_config::Region;
45
use aws_config::{identity::IdentityCache, retry::RetryConfig, BehaviorVersion, ConfigLoader};
56
use aws_credential_types::Credentials;
7+
use aws_sdk_bedrockruntime::config::Intercept;
8+
use aws_sdk_bedrockruntime::Client as BedrockRuntimeClient;
69
use aws_sdk_bedrockruntime::{self as bedrock, operation::converse::ConverseOutput};
710

811
use anyhow::{Context, Result};
912
use aws_smithy_json::serialize::JsonObjectWriter;
1013
use aws_smithy_runtime_api::client::result::SdkError;
14+
use aws_smithy_runtime_api::http::Headers;
1115
use aws_smithy_types::Blob;
12-
use baml_types::tracing::events::HttpRequestId;
16+
use baml_types::tracing::events::{
17+
ContentId, FunctionId, HTTPBody, HTTPRequest, HTTPResponse, HttpRequestId, TraceData,
18+
TraceEvent, TraceLevel,
19+
};
1320
use baml_types::{BamlMap, BamlMediaContent};
1421
use baml_types::{BamlMedia, BamlMediaType};
1522
use futures::stream;
1623
use internal_baml_core::ir::ClientWalker;
1724
use internal_baml_jinja::{ChatMessagePart, RenderContext_Client, RenderedChatMessage};
18-
use internal_llm_client::aws_bedrock::ResolvedAwsBedrock;
25+
use internal_llm_client::aws_bedrock::{self, ResolvedAwsBedrock};
1926
use internal_llm_client::{
2027
AllowedRoleMetadata, ClientProvider, ResolvedClientProperty, UnresolvedClientProperty,
2128
};
2229
use secrecy::ExposeSecret;
2330
use serde::Deserialize;
2431
use serde_json::{json, Map};
32+
use uuid::Uuid;
2533
use web_time::Instant;
2634
use web_time::SystemTime;
2735

@@ -36,7 +44,8 @@ use crate::internal::llm_client::{
3644
ErrorCode, LLMCompleteResponse, LLMCompleteResponseMetadata, LLMErrorResponse, LLMResponse,
3745
ModelFeatures, ResolveMediaUrls,
3846
};
39-
use crate::{RenderCurlSettings, RuntimeContext};
47+
use crate::tracingv2::storage::storage::BAML_TRACER;
48+
use crate::{json_body, JsonBodyInput, RenderCurlSettings, RuntimeContext};
4049

4150
// represents client that interacts with the Bedrock API
4251
pub struct AwsClient {
@@ -73,6 +82,121 @@ fn resolve_properties(
7382
Ok(props)
7483
}
7584

85+
#[derive(Debug)]
86+
struct CollectorInterceptor {
87+
span_id: Option<Uuid>,
88+
http_request_id: HttpRequestId,
89+
}
90+
91+
impl CollectorInterceptor {
92+
fn new(span_id: Option<Uuid>, http_request_id: HttpRequestId) -> Self {
93+
Self {
94+
span_id,
95+
http_request_id,
96+
}
97+
}
98+
}
99+
100+
pub fn smithy_json_headers(headers: &Headers) -> serde_json::Value {
101+
let mut json_headers = serde_json::Map::new();
102+
for (key, value) in headers.iter() {
103+
json_headers.insert(key.to_string(), json!(value));
104+
}
105+
serde_json::Value::Object(json_headers)
106+
}
107+
108+
impl aws_smithy_runtime_api::client::interceptors::Intercept for CollectorInterceptor {
109+
fn name(&self) -> &'static str {
110+
"CollectorInterceptor"
111+
}
112+
113+
fn read_before_attempt(
114+
&self,
115+
context: &aws_sdk_bedrockruntime::config::interceptors::BeforeTransmitInterceptorContextRef<
116+
'_,
117+
>,
118+
_runtime_components: &aws_sdk_bedrockruntime::config::RuntimeComponents,
119+
_cfg: &mut aws_smithy_types::config_bag::ConfigBag,
120+
) -> std::result::Result<(), aws_sdk_bedrockruntime::error::BoxError> {
121+
if let Some(span_id) = self.span_id.clone() {
122+
let request = context.request();
123+
let headers = smithy_json_headers(request.headers());
124+
let body = if let Some(bytes) = request.body().bytes() {
125+
json_body(JsonBodyInput::Bytes(bytes)).unwrap_or_default()
126+
} else {
127+
serde_json::Value::Null
128+
};
129+
130+
BAML_TRACER.lock().unwrap().put(Arc::new(TraceEvent {
131+
span_id: FunctionId(span_id.to_string()),
132+
event_id: ContentId(uuid::Uuid::new_v4().to_string()),
133+
span_chain: vec![],
134+
timestamp: web_time::SystemTime::now(),
135+
callsite: "".to_string(),
136+
verbosity: TraceLevel::Info,
137+
content: TraceData::RawLLMRequest(Arc::new(HTTPRequest {
138+
id: self.http_request_id.clone(),
139+
url: request.uri().to_string(),
140+
method: request.method().to_string(),
141+
headers,
142+
body: HTTPBody::new(request.body().bytes().unwrap_or_default().to_vec().into()),
143+
})),
144+
tags: Default::default(),
145+
}));
146+
}
147+
148+
Ok(())
149+
}
150+
151+
fn read_after_attempt(
152+
&self,
153+
context: &aws_sdk_bedrockruntime::config::interceptors::FinalizerInterceptorContextRef<'_>,
154+
_runtime_components: &aws_sdk_bedrockruntime::config::RuntimeComponents,
155+
_cfg: &mut aws_smithy_types::config_bag::ConfigBag,
156+
) -> std::result::Result<(), aws_sdk_bedrockruntime::error::BoxError> {
157+
if let Some(span_id) = self.span_id.clone() {
158+
let trace_level = if let Some(response) = context.response() {
159+
if response.status().is_success() {
160+
TraceLevel::Info
161+
} else {
162+
TraceLevel::Error
163+
}
164+
} else {
165+
TraceLevel::Error
166+
};
167+
168+
if let Some(response) = context.response() {
169+
let headers = smithy_json_headers(response.headers());
170+
let body = if let Some(bytes) = response.body().bytes() {
171+
json_body(JsonBodyInput::Bytes(bytes)).unwrap_or_default()
172+
} else {
173+
serde_json::Value::Null
174+
};
175+
176+
BAML_TRACER.lock().unwrap().put(Arc::new(TraceEvent {
177+
span_id: FunctionId(span_id.to_string()),
178+
event_id: ContentId(uuid::Uuid::new_v4().to_string()),
179+
span_chain: vec![],
180+
timestamp: web_time::SystemTime::now(),
181+
callsite: "".to_string(),
182+
verbosity: trace_level,
183+
content: TraceData::RawLLMResponse(Arc::new(HTTPResponse {
184+
request_id: self.http_request_id.clone(),
185+
status: response.status().as_u16(),
186+
headers,
187+
body: HTTPBody::new(
188+
response.body().bytes().unwrap_or_default().to_vec().into(),
189+
),
190+
})),
191+
tags: Default::default(),
192+
}));
193+
}
194+
}
195+
196+
Ok(())
197+
}
198+
}
199+
76200
impl AwsClient {
77201
pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result<AwsClient> {
78202
let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?;
@@ -136,7 +260,11 @@ impl AwsClient {
136260
// Note: This function necessarily exposes secret keys when they are provided, so it should
137261
// only be called while generating real requests to the provider, not when rendering raw
138262
// cURL previews.
139-
async fn client_anyhow(&self) -> Result<bedrock::Client> {
263+
async fn client_anyhow(
264+
&self,
265+
span_id: Option<Uuid>,
266+
http_request_id: &HttpRequestId,
267+
) -> Result<bedrock::Client> {
140268
#[cfg(target_arch = "wasm32")]
141269
let mut loader = super::wasm::load_aws_config();
142270
#[cfg(not(target_arch = "wasm32"))]
@@ -221,7 +349,11 @@ impl AwsClient {
221349
}
222350

223351
let config = loader.load().await;
224-
Ok(bedrock::Client::new(&config))
352+
353+
let bedrock_config = aws_sdk_bedrockruntime::config::Builder::from(&config)
354+
.interceptor(CollectorInterceptor::new(span_id, http_request_id.clone()))
355+
.build();
356+
Ok(BedrockRuntimeClient::from_conf(bedrock_config))
225357
}
226358

227359
async fn chat_anyhow<'r>(&self, response: &'r ConverseOutput) -> Result<&'r String> {
@@ -388,7 +520,10 @@ impl WithStreamChat for AwsClient {
388520
let request_options = Default::default();
389521
let prompt = internal_baml_jinja::RenderedPrompt::Chat(chat_messages.to_vec());
390522

391-
let aws_client = match self.client_anyhow().await {
523+
let aws_client = match self
524+
.client_anyhow(ctx.span_id.clone(), &http_request_id)
525+
.await
526+
{
392527
Ok(c) => c,
393528
Err(e) => {
394529
return Err(LLMResponse::LLMFailure(LLMErrorResponse {
@@ -671,7 +806,7 @@ impl AwsClient {
671806
impl WithChat for AwsClient {
672807
async fn chat(
673808
&self,
674-
_ctx: &RuntimeContext,
809+
ctx: &RuntimeContext,
675810
chat_messages: &[RenderedChatMessage],
676811
http_request_id: HttpRequestId,
677812
) -> LLMResponse {
@@ -681,7 +816,10 @@ impl WithChat for AwsClient {
681816
let request_options = Default::default();
682817
let prompt = internal_baml_jinja::RenderedPrompt::Chat(chat_messages.to_vec());
683818

684-
let aws_client = match self.client_anyhow().await {
819+
let aws_client = match self
820+
.client_anyhow(ctx.span_id.clone(), &http_request_id)
821+
.await
822+
{
685823
Ok(c) => c,
686824
Err(e) => {
687825
return LLMResponse::LLMFailure(LLMErrorResponse {
@@ -697,7 +835,7 @@ impl WithChat for AwsClient {
697835
}
698836
};
699837

700-
let request = match self.build_request(_ctx, chat_messages) {
838+
let request = match self.build_request(ctx, chat_messages) {
701839
Ok(r) => r,
702840
Err(e) => {
703841
return LLMResponse::LLMFailure(LLMErrorResponse {

engine/baml-schema-wasm/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ crate-type = ["cdylib", "rlib"]
1414
elided_named_lifetimes = "deny"
1515

1616
[dependencies]
17+
reqwest.workspace = true
1718
anyhow.workspace = true
19+
futures.workspace = true
1820
baml-runtime = { path = "../baml-runtime", features = [
1921
"internal",
2022
], default-features = false }

engine/language_client_typescript/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@
6464
"scripts": {
6565
"artifacts": "napi artifacts",
6666
"build": "pnpm build:napi-release && pnpm build:ts_build",
67-
"build:debug": "pnpm build:napi-debug && pnpm build:ts_build && pnpm napi create-npm-dirs && pnpm artifacts",
67+
"build:debug": "pnpm build:napi-debug && pnpm build:ts_build",
6868
"build:napi-release": "pnpm build:napi-debug --release",
6969
"build:napi-debug": "napi build --js ./native.js --dts ./native.d.ts --platform",
7070
"build:ts_build": "tsc ./typescript_src/*.ts --outDir ./ --module nodenext --module nodenext --allowJs --declaration true --declarationMap true || true && pnpm build:ts_build_local",
71+
"build:local": "pnpm build:debug && pnpm napi create-npm-dirs && pnpm artifacts",
7172
"build:napi-debug-local": "napi build -o ./artifacts --js ./native.js --dts ./native.d.ts --platform",
7273
"build:ts_build_local": "tsc ./typescript_src/*.ts --outDir ./artifacts --module nodenext --module nodenext --allowJs --declaration true --declarationMap true || true",
7374
"format": "run-p format:biome format:rs format:toml",
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
client OpenAIDummyClient {
2+
provider openai-generic
3+
options {
4+
api_key env.OPENAI_API_KEY
5+
model "gpt-4o-mini"
6+
base_url "http://localhost:8000"
7+
}
8+
}
9+
10+
function TestOpenAIDummyClient(input: string) -> string {
11+
client OpenAIDummyClient
12+
prompt #"
13+
{{ _.role("user") }}
14+
{{ input }}
15+
"#
16+
}

integ-tests/openapi/baml_client/openapi.yaml

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)