diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index dfa5f20d89..b2f81b75ad 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -22,6 +22,7 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: + # Note: because `async-std` is deprecated, we only check it in a single job to save CI time. runtime: [ async-std, async-global-executor, smol, tokio ] tls: [ native-tls, rustls, none ] timeout-minutes: 30 diff --git a/Cargo.lock b/Cargo.lock index 41fdb5cdb5..61d2e7d7b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3543,6 +3543,7 @@ dependencies = [ "async-global-executor 3.1.0", "async-io", "async-std", + "async-task", "base64 0.22.1", "bigdecimal", "bit-vec", diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 9891e80ee0..d69048e698 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -55,9 +55,11 @@ features = [ [features] default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"] +# TLS options rustls = ["sqlx/tls-rustls"] native-tls = ["sqlx/tls-native-tls"] +# databases mysql = ["sqlx/mysql"] postgres = ["sqlx/postgres"] sqlite = ["sqlx/sqlite", "_sqlite"] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 5d547f5f37..58c5b67e05 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,11 +20,13 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-global-executor = ["async-global-executor", "_rt-async-io"] +_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] _rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration _rt-async-std = ["async-std", "_rt-async-io"] -_rt-smol = ["smol", "_rt-async-io"] +_rt-async-task = ["async-task"] +_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] _rt-tokio = ["tokio", "tokio-stream"] + _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] _tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"] @@ -68,6 +70,8 @@ mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } async-io = { version = "2.4.1", optional = true } +async-task = { version = "4.7.1", optional = true } + # work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 async-fs = { version = "2.1", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 4760c359af..0f9aae61b4 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -1,18 +1,14 @@ +use std::future::Future; use std::io; use std::path::Path; use std::pin::Pin; use std::task::{ready, Context, Poll}; -use std::{ - future::Future, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, -}; +pub use buffered::{BufferedSocket, WriteBuffer}; use bytes::BufMut; use cfg_if::cfg_if; -pub use buffered::{BufferedSocket, WriteBuffer}; - -use crate::{io::ReadBuf, rt::spawn_blocking}; +use crate::io::ReadBuf; mod buffered; @@ -146,10 +142,7 @@ where pub trait WithSocket { type Output; - fn with_socket( - self, - socket: S, - ) -> impl std::future::Future + Send; + fn with_socket(self, socket: S) -> impl Future + Send; } pub struct SocketIntoBox; @@ -193,98 +186,67 @@ pub async fn connect_tcp( port: u16, with_socket: Ws, ) -> crate::Result { - // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. - let host = host.trim_matches(&['[', ']'][..]); - - let addresses = if let Ok(addr) = host.parse::() { - let addr = SocketAddrV4::new(addr, port); - vec![SocketAddr::V4(addr)].into_iter() - } else if let Ok(addr) = host.parse::() { - let addr = SocketAddrV6::new(addr, port, 0, 0); - vec![SocketAddr::V6(addr)].into_iter() - } else { - let host = host.to_string(); - spawn_blocking(move || { - let addr = (host.as_str(), port); - ToSocketAddrs::to_socket_addrs(&addr) - }) - .await? - }; - - let mut last_err = None; - - // Loop through all the Socket Addresses that the hostname resolves to - for socket_addr in addresses { - match connect_tcp_address(socket_addr).await { - Ok(stream) => return Ok(with_socket.with_socket(stream).await), - Err(e) => last_err = Some(e), - } + #[cfg(feature = "_rt-tokio")] + if crate::rt::rt_tokio::available() { + return Ok(with_socket + .with_socket(tokio::net::TcpStream::connect((host, port)).await?) + .await); } - // If we reach this point, it means we failed to connect to any of the addresses. - // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. - Err(match last_err { - Some(err) => err, - None => io::Error::new( - io::ErrorKind::AddrNotAvailable, - "Hostname did not resolve to any addresses", - ) - .into(), - }) -} - -async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result { cfg_if! { - if #[cfg(feature = "_rt-tokio")] { - if crate::rt::rt_tokio::available() { - use tokio::net::TcpStream; - - let stream = TcpStream::connect(socket_addr).await?; - stream.set_nodelay(true)?; - - Ok(stream) - } else { - crate::rt::missing_rt(socket_addr) - } - } else if #[cfg(feature = "_rt-async-io")] { - use async_io::Async; - use std::net::TcpStream; - - let stream = Async::::connect(socket_addr).await?; - stream.get_ref().set_nodelay(true)?; - - Ok(stream) + if #[cfg(feature = "_rt-async-io")] { + Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await) } else { - crate::rt::missing_rt(socket_addr); - #[allow(unreachable_code)] - Ok(()) + crate::rt::missing_rt((host, port, with_socket)) } } } -// Work around `impl Socket`` and 'unability to specify test build cargo feature'. -// `connect_tcp_address` compilation would fail without this impl with -// 'cannot infer return type' error. -impl Socket for () { - fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result { - unreachable!() - } +/// Open a TCP socket to `host` and `port`. +/// +/// If `host` is a hostname, attempt to connect to each address it resolves to. +/// +/// This implements the same behavior as [`tokio::net::TcpStream::connect()`]. +#[cfg(feature = "_rt-async-io")] +async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result { + use async_io::Async; + use std::net::{IpAddr, TcpStream, ToSocketAddrs}; - fn try_write(&mut self, _: &[u8]) -> io::Result { - unreachable!() - } + // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. + let host = host.trim_matches(&['[', ']'][..]); - fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll> { - unreachable!() + if let Ok(addr) = host.parse::() { + return Ok(Async::::connect((addr, port)).await?); } - fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll> { - unreachable!() - } + let host = host.to_string(); + + let addresses = crate::rt::spawn_blocking(move || { + let addr = (host.as_str(), port); + ToSocketAddrs::to_socket_addrs(&addr) + }) + .await?; - fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll> { - unreachable!() + let mut last_err = None; + + // Loop through all the Socket Addresses that the hostname resolves to + for socket_addr in addresses { + match Async::::connect(socket_addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = Some(e), + } } + + // If we reach this point, it means we failed to connect to any of the addresses. + // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. + Err(last_err + .unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::AddrNotAvailable, + "Hostname did not resolve to any addresses", + ) + }) + .into()) } /// Connect a Unix Domain Socket at the given path. diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 2d7c8e27e9..273a1bfcd9 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -9,12 +9,6 @@ use cfg_if::cfg_if; #[cfg(feature = "_rt-async-io")] pub mod rt_async_io; -#[cfg(feature = "_rt-async-global-executor")] -pub mod rt_async_global_executor; - -#[cfg(feature = "_rt-smol")] -pub mod rt_smol; - #[cfg(feature = "_rt-tokio")] pub mod rt_tokio; @@ -23,14 +17,16 @@ pub mod rt_tokio; pub struct TimeoutError; pub enum JoinHandle { - #[cfg(feature = "_rt-async-global-executor")] - AsyncGlobalExecutor(rt_async_global_executor::JoinHandle), #[cfg(feature = "_rt-async-std")] AsyncStd(async_std::task::JoinHandle), - #[cfg(feature = "_rt-smol")] - Smol(rt_smol::JoinHandle), + #[cfg(feature = "_rt-tokio")] Tokio(tokio::task::JoinHandle), + + // Implementation shared by `smol` and `async-global-executor` + #[cfg(feature = "_rt-async-task")] + AsyncTask(Option>), + // `PhantomData` requires `T: Unpin` _Phantom(PhantomData T>), } @@ -41,7 +37,6 @@ pub async fn timeout(duration: Duration, f: F) -> Result(_unused: T) -> ! { panic!("this functionality requires a Tokio context") } - panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, `runtime-smol`, or `runtime-tokio` feature must be enabled") + panic!("one of the `runtime` features of SQLx must be enabled") } impl Future for JoinHandle { @@ -178,16 +175,20 @@ impl Future for JoinHandle { #[track_caller] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match &mut *self { - #[cfg(feature = "_rt-async-global-executor")] - Self::AsyncGlobalExecutor(handle) => Pin::new(handle).poll(cx), #[cfg(feature = "_rt-async-std")] Self::AsyncStd(handle) => Pin::new(handle).poll(cx), - #[cfg(feature = "_rt-smol")] - Self::Smol(handle) => Pin::new(handle).poll(cx), + + #[cfg(feature = "_rt-async-task")] + Self::AsyncTask(task) => Pin::new(task) + .as_pin_mut() + .expect("BUG: task taken") + .poll(cx), + #[cfg(feature = "_rt-tokio")] Self::Tokio(handle) => Pin::new(handle) .poll(cx) .map(|res| res.expect("spawned task panicked")), + Self::_Phantom(_) => { let _ = cx; unreachable!("runtime should have been checked on spawn") @@ -195,3 +196,19 @@ impl Future for JoinHandle { } } } + +impl Drop for JoinHandle { + fn drop(&mut self) { + match self { + // `async_task` cancels on-drop by default. + // We need to explicitly detach to match Tokio and `async-std`. + #[cfg(feature = "_rt-async-task")] + Self::AsyncTask(task) => { + if let Some(task) = task.take() { + task.detach(); + } + } + _ => (), + } + } +} diff --git a/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs b/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs deleted file mode 100644 index 580883e21f..0000000000 --- a/sqlx-core/src/rt/rt_async_global_executor/join_handle.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use async_global_executor::Task; - -pub struct JoinHandle { - pub task: Option>, -} - -impl Drop for JoinHandle { - fn drop(&mut self) { - if let Some(task) = self.task.take() { - task.detach(); - } - } -} - -impl Future for JoinHandle { - type Output = T; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.task.as_mut() { - Some(task) => Future::poll(Pin::new(task), cx), - None => unreachable!("JoinHandle polled after dropping"), - } - } -} diff --git a/sqlx-core/src/rt/rt_async_global_executor/mod.rs b/sqlx-core/src/rt/rt_async_global_executor/mod.rs deleted file mode 100644 index 65a56c8764..0000000000 --- a/sqlx-core/src/rt/rt_async_global_executor/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod join_handle; -pub use join_handle::*; - -pub mod yield_now; -pub use yield_now::*; diff --git a/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs b/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs deleted file mode 100644 index 1adb55e0f4..0000000000 --- a/sqlx-core/src/rt/rt_async_global_executor/yield_now.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -pub fn yield_now() -> impl Future { - YieldNow(false) -} - -struct YieldNow(bool); - -impl Future for YieldNow { - type Output = (); - - // The futures executor is implemented as a FIFO queue, so all this future - // does is re-schedule the future back to the end of the queue, giving room - // for other futures to progress. - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if !self.0 { - self.0 = true; - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(()) - } - } -} diff --git a/sqlx-core/src/rt/rt_smol/join_handle.rs b/sqlx-core/src/rt/rt_smol/join_handle.rs deleted file mode 100644 index 6702733c4a..0000000000 --- a/sqlx-core/src/rt/rt_smol/join_handle.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use smol::Task; - -pub struct JoinHandle { - pub task: Option>, -} - -impl Drop for JoinHandle { - fn drop(&mut self) { - if let Some(task) = self.task.take() { - task.detach(); - } - } -} - -impl Future for JoinHandle { - type Output = T; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.task.as_mut() { - Some(task) => Future::poll(Pin::new(task), cx), - None => unreachable!("JoinHandle polled after dropping"), - } - } -} diff --git a/sqlx-core/src/rt/rt_smol/mod.rs b/sqlx-core/src/rt/rt_smol/mod.rs deleted file mode 100644 index 0b620d5116..0000000000 --- a/sqlx-core/src/rt/rt_smol/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod join_handle; -pub use join_handle::*;