Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
runs-on: ubuntu-24.04
strategy:
matrix:
# Note: because `async-std` is deprecated, we only check it in a single job to save CI time.
runtime: [ async-std, async-global-executor, smol, tokio ]
tls: [ native-tls, rustls, none ]
timeout-minutes: 30
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions sqlx-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ features = [
[features]
default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"]

# TLS options
rustls = ["sqlx/tls-rustls"]
native-tls = ["sqlx/tls-native-tls"]

# databases
mysql = ["sqlx/mysql"]
postgres = ["sqlx/postgres"]
sqlite = ["sqlx/sqlite", "_sqlite"]
Expand Down
8 changes: 6 additions & 2 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ any = []
json = ["serde", "serde_json"]

# for conditional compilation
_rt-async-global-executor = ["async-global-executor", "_rt-async-io"]
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
_rt-async-std = ["async-std", "_rt-async-io"]
_rt-smol = ["smol", "_rt-async-io"]
_rt-async-task = ["async-task"]
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
_rt-tokio = ["tokio", "tokio-stream"]

_tls-native-tls = ["native-tls"]
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
Expand Down Expand Up @@ -68,6 +70,8 @@ mac_address = { workspace = true, optional = true }
uuid = { workspace = true, optional = true }

async-io = { version = "2.4.1", optional = true }
async-task = { version = "4.7.1", optional = true }

# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281
async-fs = { version = "2.1", optional = true }
base64 = { version = "0.22.0", default-features = false, features = ["std"] }
Expand Down
140 changes: 51 additions & 89 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
use std::future::Future;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{
future::Future,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
};

pub use buffered::{BufferedSocket, WriteBuffer};
use bytes::BufMut;
use cfg_if::cfg_if;

pub use buffered::{BufferedSocket, WriteBuffer};

use crate::{io::ReadBuf, rt::spawn_blocking};
use crate::io::ReadBuf;

mod buffered;

Expand Down Expand Up @@ -146,10 +142,7 @@ where
pub trait WithSocket {
type Output;

fn with_socket<S: Socket>(
self,
socket: S,
) -> impl std::future::Future<Output = Self::Output> + Send;
fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
}

pub struct SocketIntoBox;
Expand Down Expand Up @@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(
port: u16,
with_socket: Ws,
) -> crate::Result<Ws::Output> {
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);

let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() {
let addr = SocketAddrV4::new(addr, port);
vec![SocketAddr::V4(addr)].into_iter()
} else if let Ok(addr) = host.parse::<Ipv6Addr>() {
let addr = SocketAddrV6::new(addr, port, 0, 0);
vec![SocketAddr::V6(addr)].into_iter()
} else {
let host = host.to_string();
spawn_blocking(move || {
let addr = (host.as_str(), port);
ToSocketAddrs::to_socket_addrs(&addr)
})
.await?
};

let mut last_err = None;

// Loop through all the Socket Addresses that the hostname resolves to
for socket_addr in addresses {
match connect_tcp_address(socket_addr).await {
Ok(stream) => return Ok(with_socket.with_socket(stream).await),
Err(e) => last_err = Some(e),
}
#[cfg(feature = "_rt-tokio")]
if crate::rt::rt_tokio::available() {
return Ok(with_socket
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
.await);
}

// If we reach this point, it means we failed to connect to any of the addresses.
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
Err(match last_err {
Some(err) => err,
None => io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Hostname did not resolve to any addresses",
)
.into(),
})
}

async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> {
cfg_if! {
if #[cfg(feature = "_rt-tokio")] {
if crate::rt::rt_tokio::available() {
use tokio::net::TcpStream;

let stream = TcpStream::connect(socket_addr).await?;
stream.set_nodelay(true)?;

Ok(stream)
} else {
crate::rt::missing_rt(socket_addr)
}
} else if #[cfg(feature = "_rt-async-io")] {
use async_io::Async;
use std::net::TcpStream;

let stream = Async::<TcpStream>::connect(socket_addr).await?;
stream.get_ref().set_nodelay(true)?;

Ok(stream)
if #[cfg(feature = "_rt-async-io")] {
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
} else {
crate::rt::missing_rt(socket_addr);
#[allow(unreachable_code)]
Ok(())
crate::rt::missing_rt((host, port, with_socket))
}
}
}

// Work around `impl Socket`` and 'unability to specify test build cargo feature'.
// `connect_tcp_address` compilation would fail without this impl with
// 'cannot infer return type' error.
impl Socket for () {
fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> {
unreachable!()
}
/// Open a TCP socket to `host` and `port`.
///
/// If `host` is a hostname, attempt to connect to each address it resolves to.
///
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
#[cfg(feature = "_rt-async-io")]
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
use async_io::Async;
use std::net::{IpAddr, TcpStream, ToSocketAddrs};

fn try_write(&mut self, _: &[u8]) -> io::Result<usize> {
unreachable!()
}
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);

fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
if let Ok(addr) = host.parse::<IpAddr>() {
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
}

fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
let host = host.to_string();

let addresses = crate::rt::spawn_blocking(move || {
let addr = (host.as_str(), port);
ToSocketAddrs::to_socket_addrs(&addr)
})
.await?;

fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
let mut last_err = None;

// Loop through all the Socket Addresses that the hostname resolves to
for socket_addr in addresses {
match Async::<TcpStream>::connect(socket_addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}

// If we reach this point, it means we failed to connect to any of the addresses.
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
Err(last_err
.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Hostname did not resolve to any addresses",
)
})
.into())
}

/// Connect a Unix Domain Socket at the given path.
Expand Down
Loading
Loading