|
| 1 | +// See https://github.com/awslabs/aws-sdk-rust/issues/169 |
| 2 | +use std::time::Duration; |
| 3 | + |
| 4 | +use aws_smithy_runtime_api::client::http::{ |
| 5 | + HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector, |
| 6 | +}; |
| 7 | +use aws_smithy_runtime_api::client::result::ConnectorError; |
| 8 | +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; |
| 9 | +use aws_smithy_runtime_api::http::Request; |
| 10 | +use aws_smithy_types::body::SdkBody; |
| 11 | + |
| 12 | +use crate::request::create_client; |
| 13 | + |
| 14 | +// --- WASM specific imports --- |
| 15 | +#[cfg(target_arch = "wasm32")] |
| 16 | +use {futures::channel::oneshot, wasm_bindgen_futures::spawn_local}; |
| 17 | + |
| 18 | +/// Returns a wrapper around the global reqwest client. |
| 19 | +/// [HttpClient]. |
| 20 | +#[cfg(not(target_arch = "wasm32"))] // Keep function non-WASM for now |
| 21 | +pub fn client() -> anyhow::Result<Client> { |
| 22 | + let client = crate::request::create_client() |
| 23 | + .map_err(|e| anyhow::anyhow!("failed to create base http client: {}", e))?; |
| 24 | + Ok(Client::new(client.clone())) |
| 25 | +} |
| 26 | + |
| 27 | +#[cfg(target_arch = "wasm32")] // Define WASM client function |
| 28 | +pub fn client() -> anyhow::Result<Client> { |
| 29 | + let client = crate::request::create_client() |
| 30 | + .map_err(|e| anyhow::anyhow!("failed to create base http client for WASM: {}", e))?; |
| 31 | + Ok(Client::new(client.clone())) |
| 32 | +} |
| 33 | + |
| 34 | +/// A wrapper around [reqwest::Client] that implements [HttpClient]. |
| 35 | +/// |
| 36 | +/// This is required to support using proxy servers with the AWS SDK. |
| 37 | +#[derive(Debug, Clone)] |
| 38 | +pub struct Client { |
| 39 | + inner: reqwest::Client, |
| 40 | +} |
| 41 | + |
| 42 | +impl Client { |
| 43 | + pub fn new(client: reqwest::Client) -> Self { |
| 44 | + Self { inner: client } |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +#[derive(Debug)] |
| 49 | +struct CallError { |
| 50 | + kind: CallErrorKind, |
| 51 | + message: &'static str, |
| 52 | + source: Option<Box<dyn std::error::Error + Send + Sync>>, |
| 53 | +} |
| 54 | + |
| 55 | +impl CallError { |
| 56 | + fn user(message: &'static str) -> Self { |
| 57 | + Self { |
| 58 | + kind: CallErrorKind::User, |
| 59 | + message, |
| 60 | + source: None, |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + fn user_with_source<E>(message: &'static str, source: E) -> Self |
| 65 | + where |
| 66 | + E: std::error::Error + Send + Sync + 'static, |
| 67 | + { |
| 68 | + Self { |
| 69 | + kind: CallErrorKind::User, |
| 70 | + message, |
| 71 | + source: Some(Box::new(source)), |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + fn timeout<E>(source: E) -> Self |
| 76 | + where |
| 77 | + E: std::error::Error + Send + Sync + 'static, |
| 78 | + { |
| 79 | + Self { |
| 80 | + kind: CallErrorKind::Timeout, |
| 81 | + message: "request timed out", |
| 82 | + source: Some(Box::new(source)), |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + fn io<E>(source: E) -> Self |
| 87 | + where |
| 88 | + E: std::error::Error + Send + Sync + 'static, |
| 89 | + { |
| 90 | + Self { |
| 91 | + kind: CallErrorKind::Io, |
| 92 | + message: "an i/o error occurred", |
| 93 | + source: Some(Box::new(source)), |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + fn other<E>(message: &'static str, source: E) -> Self |
| 98 | + where |
| 99 | + E: std::error::Error + Send + Sync + 'static, |
| 100 | + { |
| 101 | + Self { |
| 102 | + kind: CallErrorKind::Other, |
| 103 | + message, |
| 104 | + source: Some(Box::new(source)), |
| 105 | + } |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +impl std::error::Error for CallError {} |
| 110 | + |
| 111 | +impl std::fmt::Display for CallError { |
| 112 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 113 | + write!(f, "{}", self.message)?; |
| 114 | + if let Some(err) = self.source.as_ref() { |
| 115 | + write!(f, ": {}", err)?; |
| 116 | + } |
| 117 | + Ok(()) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +impl From<CallError> for ConnectorError { |
| 122 | + fn from(value: CallError) -> Self { |
| 123 | + match &value.kind { |
| 124 | + CallErrorKind::User => Self::user(Box::new(value)), |
| 125 | + CallErrorKind::Timeout => Self::timeout(Box::new(value)), |
| 126 | + CallErrorKind::Io => Self::io(Box::new(value)), |
| 127 | + CallErrorKind::Other => Self::other(Box::new(value), None), |
| 128 | + } |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +impl From<reqwest::Error> for CallError { |
| 133 | + fn from(err: reqwest::Error) -> Self { |
| 134 | + if err.is_timeout() { |
| 135 | + return CallError::timeout(err); |
| 136 | + } |
| 137 | + |
| 138 | + // Conditionally check for connect error only on non-WASM targets. |
| 139 | + #[cfg(not(target_arch = "wasm32"))] |
| 140 | + { |
| 141 | + if err.is_connect() { |
| 142 | + return CallError::io(err); |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + // If it's not a timeout or (on non-WASM) a connect error, treat as other. |
| 147 | + CallError::other("an unknown error occurred", err) |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +#[derive(Debug, Clone)] |
| 152 | +enum CallErrorKind { |
| 153 | + User, |
| 154 | + Timeout, |
| 155 | + Io, |
| 156 | + Other, |
| 157 | +} |
| 158 | + |
| 159 | +#[derive(Debug)] |
| 160 | +struct ReqwestConnector { |
| 161 | + client: reqwest::Client, |
| 162 | + timeout: Option<Duration>, |
| 163 | +} |
| 164 | + |
| 165 | +// See https://github.com/aws/amazon-q-developer-cli/pull/1199 |
| 166 | +impl HttpConnector for ReqwestConnector { |
| 167 | + fn call(&self, request: Request) -> HttpConnectorFuture { |
| 168 | + let client = self.client.clone(); |
| 169 | + let timeout = self.timeout; |
| 170 | + |
| 171 | + #[cfg(not(target_arch = "wasm32"))] |
| 172 | + let future = async move { |
| 173 | + // Non-WASM logic (direct send) |
| 174 | + let mut req_builder = client.request( |
| 175 | + reqwest::Method::from_bytes(request.method().as_bytes()).map_err(|err| { |
| 176 | + CallError::user_with_source("failed to create method name", err) |
| 177 | + })?, |
| 178 | + request.uri().to_owned(), |
| 179 | + ); |
| 180 | + let parts = request.into_parts(); |
| 181 | + for (name, value) in parts.headers.iter() { |
| 182 | + req_builder = req_builder.header(name, value.as_bytes()); |
| 183 | + } |
| 184 | + let body_bytes = parts |
| 185 | + .body |
| 186 | + .bytes() |
| 187 | + .ok_or(CallError::user("streaming request body is not supported"))? |
| 188 | + .to_owned(); |
| 189 | + req_builder = req_builder.body(body_bytes); |
| 190 | + |
| 191 | + if let Some(timeout) = timeout { |
| 192 | + req_builder = req_builder.timeout(timeout); |
| 193 | + } |
| 194 | + |
| 195 | + let reqwest_response = req_builder.send().await.map_err(CallError::from)?; |
| 196 | + |
| 197 | + let http_response = { |
| 198 | + let (parts, body) = http::Response::from(reqwest_response).into_parts(); |
| 199 | + http::Response::from_parts(parts, SdkBody::from_body_1_x(body)) |
| 200 | + }; |
| 201 | + |
| 202 | + Ok( |
| 203 | + aws_smithy_runtime_api::http::Response::try_from(http_response).map_err(|err| { |
| 204 | + CallError::other("failed to convert to a proper response", err) |
| 205 | + })?, |
| 206 | + ) |
| 207 | + }; |
| 208 | + |
| 209 | + #[cfg(target_arch = "wasm32")] |
| 210 | + let future = async move { |
| 211 | + // WASM logic (spawn_local) |
| 212 | + let (tx, rx) = oneshot::channel(); |
| 213 | + |
| 214 | + spawn_local(async move { |
| 215 | + // Use a closure to handle errors |
| 216 | + let result = (async { |
| 217 | + let mut req_builder = client.request( |
| 218 | + reqwest::Method::from_bytes(request.method().as_bytes()).map_err( |
| 219 | + |err| CallError::user_with_source("failed to create method name", err), |
| 220 | + )?, |
| 221 | + request.uri().to_owned(), |
| 222 | + ); |
| 223 | + let parts = request.into_parts(); |
| 224 | + for (name, value) in parts.headers.iter() { |
| 225 | + req_builder = req_builder.header(name, value.as_bytes()); |
| 226 | + } |
| 227 | + let body_bytes = parts |
| 228 | + .body |
| 229 | + .bytes() |
| 230 | + .ok_or(CallError::user("streaming request body is not supported"))? |
| 231 | + .to_owned(); |
| 232 | + req_builder = req_builder.body(body_bytes); |
| 233 | + |
| 234 | + let reqwest_response = req_builder.send().await.map_err(CallError::from)?; |
| 235 | + |
| 236 | + // Use manual construction for WASM response conversion |
| 237 | + let http_response = { |
| 238 | + let status = reqwest_response.status(); |
| 239 | + let headers = reqwest_response.headers().clone(); |
| 240 | + let body_bytes = reqwest_response |
| 241 | + .bytes() |
| 242 | + .await |
| 243 | + .map_err(|e| CallError::other("failed to read response body", e))?; |
| 244 | + |
| 245 | + let mut response_builder = http::Response::builder().status(status); |
| 246 | + |
| 247 | + for (name, value) in headers.iter() { |
| 248 | + response_builder = response_builder.header(name, value); |
| 249 | + } |
| 250 | + |
| 251 | + response_builder |
| 252 | + .body(SdkBody::from(body_bytes)) |
| 253 | + .map_err(|e| CallError::other("failed to build http::Response", e))? |
| 254 | + }; |
| 255 | + |
| 256 | + aws_smithy_runtime_api::http::Response::try_from(http_response).map_err(|err| { |
| 257 | + CallError::other("failed to convert to a proper response", err) |
| 258 | + }) |
| 259 | + }) |
| 260 | + .await; |
| 261 | + |
| 262 | + // Convert the inner Result<_, CallError> to Result<_, ConnectorError> |
| 263 | + let final_result = result.map_err(ConnectorError::from); |
| 264 | + |
| 265 | + let _ = tx.send(final_result); |
| 266 | + }); |
| 267 | + |
| 268 | + rx.await.map_err(|_| { |
| 269 | + ConnectorError::other( |
| 270 | + Box::new(CallError::user("WASM future channel cancelled")), |
| 271 | + None, |
| 272 | + ) |
| 273 | + })? |
| 274 | + }; |
| 275 | + |
| 276 | + HttpConnectorFuture::new(future) |
| 277 | + } |
| 278 | +} |
| 279 | + |
| 280 | +impl HttpClient for Client { |
| 281 | + fn http_connector( |
| 282 | + &self, |
| 283 | + settings: &HttpConnectorSettings, |
| 284 | + _components: &RuntimeComponents, |
| 285 | + ) -> SharedHttpConnector { |
| 286 | + let timeout = if cfg!(target_arch = "wasm32") { |
| 287 | + None // Timeout not directly supported via reqwest on wasm |
| 288 | + } else { |
| 289 | + settings.read_timeout() |
| 290 | + }; |
| 291 | + let connector = ReqwestConnector { |
| 292 | + client: self.inner.clone(), |
| 293 | + timeout, |
| 294 | + }; |
| 295 | + SharedHttpConnector::new(connector) |
| 296 | + } |
| 297 | +} |
| 298 | + |
| 299 | +// --- Non-WASM Implementation using Reqwest --- |
| 300 | +#[cfg(not(target_arch = "wasm32"))] |
| 301 | +mod reqwest_impl { |
| 302 | + use std::time::Duration; |
| 303 | +} |
0 commit comments