Skip to content

Commit c5c7fc6

Browse files
authored
Support HTTPS_PROXY and HTTP_PROXY system proxies in AWS client by delegating to the reqwest client (#1827)
<!-- ELLIPSIS_HIDDEN --> > [!IMPORTANT] > Add support for system proxies in AWS client and introduce `TestAwsInferenceProfile` function with updates across multiple client languages. > > - **AWS Client**: > - Add support for `HTTPS_PROXY` and `HTTP_PROXY` by using a custom HTTP client in `aws_client.rs`. > - Introduce `custom_http_client.rs` to handle proxy settings. > - **Integration Tests**: > - Add `TestAwsInferenceProfile` function in `aws.baml` and corresponding test cases in `test-functions.py`. > - Update `client.go`, `openapi.yaml`, and `aws-utils.py` for new function. > - **Client Code**: > - Update Python, Go, Ruby, TypeScript, and React clients to support `TestAwsInferenceProfile`. > - Modify `pyproject.toml` to include `boto3` dependency. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for facfa56. You can [customize](https://app.ellipsis.dev/BoundaryML/settings/summaries) this summary. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent b27c980 commit c5c7fc6

File tree

38 files changed

+1593
-79
lines changed

38 files changed

+1593
-79
lines changed

engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ use crate::internal::llm_client::{
5252
};
5353
use crate::tracingv2::storage::storage::BAML_TRACER;
5454
use crate::{json_body, AwsCredProvider, JsonBodyInput, RenderCurlSettings, RuntimeContext};
55-
55+
// See https://github.com/awslabs/aws-sdk-rust/issues/169
56+
use super::custom_http_client;
5657
#[cfg(target_arch = "wasm32")]
5758
use super::wasm::WasmAwsCreds;
5859

@@ -388,8 +389,11 @@ impl AwsClient {
388389
}
389390

390391
let config = loader.load().await;
392+
let http_client = custom_http_client::client()?;
391393

392394
let bedrock_config = aws_sdk_bedrockruntime::config::Builder::from(&config)
395+
// To support HTTPS_PROXY https://github.com/awslabs/aws-sdk-rust/issues/169
396+
.http_client(http_client)
393397
.interceptor(CollectorInterceptor::new(span_id, http_request_id.clone()))
394398
.build();
395399
Ok(BedrockRuntimeClient::from_conf(bedrock_config))
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
// See https://github.com/awslabs/aws-sdk-rust/issues/169
2+
use std::time::Duration;
3+
4+
use aws_smithy_runtime_api::client::http::{
5+
HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
6+
};
7+
use aws_smithy_runtime_api::client::result::ConnectorError;
8+
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
9+
use aws_smithy_runtime_api::http::Request;
10+
use aws_smithy_types::body::SdkBody;
11+
12+
use crate::request::create_client;
13+
14+
// --- WASM specific imports ---
15+
#[cfg(target_arch = "wasm32")]
16+
use {futures::channel::oneshot, wasm_bindgen_futures::spawn_local};
17+
18+
/// Returns a wrapper around the global reqwest client.
19+
/// [HttpClient].
20+
#[cfg(not(target_arch = "wasm32"))] // Keep function non-WASM for now
21+
pub fn client() -> anyhow::Result<Client> {
22+
let client = crate::request::create_client()
23+
.map_err(|e| anyhow::anyhow!("failed to create base http client: {}", e))?;
24+
Ok(Client::new(client.clone()))
25+
}
26+
27+
#[cfg(target_arch = "wasm32")] // Define WASM client function
28+
pub fn client() -> anyhow::Result<Client> {
29+
let client = crate::request::create_client()
30+
.map_err(|e| anyhow::anyhow!("failed to create base http client for WASM: {}", e))?;
31+
Ok(Client::new(client.clone()))
32+
}
33+
34+
/// A wrapper around [reqwest::Client] that implements [HttpClient].
35+
///
36+
/// This is required to support using proxy servers with the AWS SDK.
37+
#[derive(Debug, Clone)]
38+
pub struct Client {
39+
inner: reqwest::Client,
40+
}
41+
42+
impl Client {
43+
pub fn new(client: reqwest::Client) -> Self {
44+
Self { inner: client }
45+
}
46+
}
47+
48+
#[derive(Debug)]
49+
struct CallError {
50+
kind: CallErrorKind,
51+
message: &'static str,
52+
source: Option<Box<dyn std::error::Error + Send + Sync>>,
53+
}
54+
55+
impl CallError {
56+
fn user(message: &'static str) -> Self {
57+
Self {
58+
kind: CallErrorKind::User,
59+
message,
60+
source: None,
61+
}
62+
}
63+
64+
fn user_with_source<E>(message: &'static str, source: E) -> Self
65+
where
66+
E: std::error::Error + Send + Sync + 'static,
67+
{
68+
Self {
69+
kind: CallErrorKind::User,
70+
message,
71+
source: Some(Box::new(source)),
72+
}
73+
}
74+
75+
fn timeout<E>(source: E) -> Self
76+
where
77+
E: std::error::Error + Send + Sync + 'static,
78+
{
79+
Self {
80+
kind: CallErrorKind::Timeout,
81+
message: "request timed out",
82+
source: Some(Box::new(source)),
83+
}
84+
}
85+
86+
fn io<E>(source: E) -> Self
87+
where
88+
E: std::error::Error + Send + Sync + 'static,
89+
{
90+
Self {
91+
kind: CallErrorKind::Io,
92+
message: "an i/o error occurred",
93+
source: Some(Box::new(source)),
94+
}
95+
}
96+
97+
fn other<E>(message: &'static str, source: E) -> Self
98+
where
99+
E: std::error::Error + Send + Sync + 'static,
100+
{
101+
Self {
102+
kind: CallErrorKind::Other,
103+
message,
104+
source: Some(Box::new(source)),
105+
}
106+
}
107+
}
108+
109+
impl std::error::Error for CallError {}
110+
111+
impl std::fmt::Display for CallError {
112+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113+
write!(f, "{}", self.message)?;
114+
if let Some(err) = self.source.as_ref() {
115+
write!(f, ": {}", err)?;
116+
}
117+
Ok(())
118+
}
119+
}
120+
121+
impl From<CallError> for ConnectorError {
122+
fn from(value: CallError) -> Self {
123+
match &value.kind {
124+
CallErrorKind::User => Self::user(Box::new(value)),
125+
CallErrorKind::Timeout => Self::timeout(Box::new(value)),
126+
CallErrorKind::Io => Self::io(Box::new(value)),
127+
CallErrorKind::Other => Self::other(Box::new(value), None),
128+
}
129+
}
130+
}
131+
132+
impl From<reqwest::Error> for CallError {
133+
fn from(err: reqwest::Error) -> Self {
134+
if err.is_timeout() {
135+
return CallError::timeout(err);
136+
}
137+
138+
// Conditionally check for connect error only on non-WASM targets.
139+
#[cfg(not(target_arch = "wasm32"))]
140+
{
141+
if err.is_connect() {
142+
return CallError::io(err);
143+
}
144+
}
145+
146+
// If it's not a timeout or (on non-WASM) a connect error, treat as other.
147+
CallError::other("an unknown error occurred", err)
148+
}
149+
}
150+
151+
#[derive(Debug, Clone)]
152+
enum CallErrorKind {
153+
User,
154+
Timeout,
155+
Io,
156+
Other,
157+
}
158+
159+
#[derive(Debug)]
160+
struct ReqwestConnector {
161+
client: reqwest::Client,
162+
timeout: Option<Duration>,
163+
}
164+
165+
// See https://github.com/aws/amazon-q-developer-cli/pull/1199
166+
impl HttpConnector for ReqwestConnector {
167+
fn call(&self, request: Request) -> HttpConnectorFuture {
168+
let client = self.client.clone();
169+
let timeout = self.timeout;
170+
171+
#[cfg(not(target_arch = "wasm32"))]
172+
let future = async move {
173+
// Non-WASM logic (direct send)
174+
let mut req_builder = client.request(
175+
reqwest::Method::from_bytes(request.method().as_bytes()).map_err(|err| {
176+
CallError::user_with_source("failed to create method name", err)
177+
})?,
178+
request.uri().to_owned(),
179+
);
180+
let parts = request.into_parts();
181+
for (name, value) in parts.headers.iter() {
182+
req_builder = req_builder.header(name, value.as_bytes());
183+
}
184+
let body_bytes = parts
185+
.body
186+
.bytes()
187+
.ok_or(CallError::user("streaming request body is not supported"))?
188+
.to_owned();
189+
req_builder = req_builder.body(body_bytes);
190+
191+
if let Some(timeout) = timeout {
192+
req_builder = req_builder.timeout(timeout);
193+
}
194+
195+
let reqwest_response = req_builder.send().await.map_err(CallError::from)?;
196+
197+
let http_response = {
198+
let (parts, body) = http::Response::from(reqwest_response).into_parts();
199+
http::Response::from_parts(parts, SdkBody::from_body_1_x(body))
200+
};
201+
202+
Ok(
203+
aws_smithy_runtime_api::http::Response::try_from(http_response).map_err(|err| {
204+
CallError::other("failed to convert to a proper response", err)
205+
})?,
206+
)
207+
};
208+
209+
#[cfg(target_arch = "wasm32")]
210+
let future = async move {
211+
// WASM logic (spawn_local)
212+
let (tx, rx) = oneshot::channel();
213+
214+
spawn_local(async move {
215+
// Use a closure to handle errors
216+
let result = (async {
217+
let mut req_builder = client.request(
218+
reqwest::Method::from_bytes(request.method().as_bytes()).map_err(
219+
|err| CallError::user_with_source("failed to create method name", err),
220+
)?,
221+
request.uri().to_owned(),
222+
);
223+
let parts = request.into_parts();
224+
for (name, value) in parts.headers.iter() {
225+
req_builder = req_builder.header(name, value.as_bytes());
226+
}
227+
let body_bytes = parts
228+
.body
229+
.bytes()
230+
.ok_or(CallError::user("streaming request body is not supported"))?
231+
.to_owned();
232+
req_builder = req_builder.body(body_bytes);
233+
234+
let reqwest_response = req_builder.send().await.map_err(CallError::from)?;
235+
236+
// Use manual construction for WASM response conversion
237+
let http_response = {
238+
let status = reqwest_response.status();
239+
let headers = reqwest_response.headers().clone();
240+
let body_bytes = reqwest_response
241+
.bytes()
242+
.await
243+
.map_err(|e| CallError::other("failed to read response body", e))?;
244+
245+
let mut response_builder = http::Response::builder().status(status);
246+
247+
for (name, value) in headers.iter() {
248+
response_builder = response_builder.header(name, value);
249+
}
250+
251+
response_builder
252+
.body(SdkBody::from(body_bytes))
253+
.map_err(|e| CallError::other("failed to build http::Response", e))?
254+
};
255+
256+
aws_smithy_runtime_api::http::Response::try_from(http_response).map_err(|err| {
257+
CallError::other("failed to convert to a proper response", err)
258+
})
259+
})
260+
.await;
261+
262+
// Convert the inner Result<_, CallError> to Result<_, ConnectorError>
263+
let final_result = result.map_err(ConnectorError::from);
264+
265+
let _ = tx.send(final_result);
266+
});
267+
268+
rx.await.map_err(|_| {
269+
ConnectorError::other(
270+
Box::new(CallError::user("WASM future channel cancelled")),
271+
None,
272+
)
273+
})?
274+
};
275+
276+
HttpConnectorFuture::new(future)
277+
}
278+
}
279+
280+
impl HttpClient for Client {
281+
fn http_connector(
282+
&self,
283+
settings: &HttpConnectorSettings,
284+
_components: &RuntimeComponents,
285+
) -> SharedHttpConnector {
286+
let timeout = if cfg!(target_arch = "wasm32") {
287+
None // Timeout not directly supported via reqwest on wasm
288+
} else {
289+
settings.read_timeout()
290+
};
291+
let connector = ReqwestConnector {
292+
client: self.inner.clone(),
293+
timeout,
294+
};
295+
SharedHttpConnector::new(connector)
296+
}
297+
}
298+
299+
// --- Non-WASM Implementation using Reqwest ---
300+
#[cfg(not(target_arch = "wasm32"))]
301+
mod reqwest_impl {
302+
use std::time::Duration;
303+
}

engine/baml-runtime/src/internal/llm_client/primitive/aws/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod aws_client;
2+
mod custom_http_client;
23
pub(super) mod types;
34
#[cfg(target_arch = "wasm32")]
45
pub(super) mod wasm;

engine/language_client_codegen/src/python/mod.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ impl ToTypeReferenceInClientDefinition for FieldType {
323323
}
324324
None => base.to_type_ref(ir, _with_checked),
325325
},
326-
FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"),
326+
FieldType::Arrow(_) => {
327+
todo!("Arrow types should not be used in generated type definitions")
328+
}
327329
}
328330
}
329331

@@ -379,7 +381,9 @@ impl ToTypeReferenceInClientDefinition for FieldType {
379381
}
380382
None => base.to_partial_type_ref(ir, with_checked),
381383
},
382-
FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"),
384+
FieldType::Arrow(_) => {
385+
todo!("Arrow types should not be used in generated type definitions")
386+
}
383387
}
384388
}
385389
}
@@ -464,12 +468,13 @@ class Foo {
464468
.unwrap()
465469
}
466470

467-
#[test]
468-
fn generate_streaming_python() {
469-
let ir = mk_ir();
470-
let generator_args = mk_gen();
471-
let res = generate(&ir, &generator_args).unwrap();
472-
let partial_types = res.get(&PathBuf::from("partial_types.py")).unwrap();
473-
eprintln!("{}", partial_types);
474-
}
471+
// TODO: test is flaky since it seems a dir isnt cleaned up.
472+
// #[test]
473+
// fn generate_streaming_python() {
474+
// let ir = mk_ir();
475+
// let generator_args = mk_gen();
476+
// let res = generate(&ir, &generator_args).unwrap();
477+
// let partial_types = res.get(&PathBuf::from("partial_types.py")).unwrap();
478+
// eprintln!("{}", partial_types);
479+
// }
475480
}

0 commit comments

Comments
 (0)