diff --git a/src/net/mod.rs b/src/net/mod.rs index ad7b79e..1600edc 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,7 +1,29 @@ //! Async network abstractions. +use std::io::{self, ErrorKind}; +use wasip2::sockets::network::ErrorCode; + mod tcp_listener; mod tcp_stream; pub use tcp_listener::*; pub use tcp_stream::*; + +fn to_io_err(err: ErrorCode) -> io::Error { + match err { + ErrorCode::Unknown => ErrorKind::Other.into(), + ErrorCode::AccessDenied => ErrorKind::PermissionDenied.into(), + ErrorCode::NotSupported => ErrorKind::Unsupported.into(), + ErrorCode::InvalidArgument => ErrorKind::InvalidInput.into(), + ErrorCode::OutOfMemory => ErrorKind::OutOfMemory.into(), + ErrorCode::Timeout => ErrorKind::TimedOut.into(), + ErrorCode::WouldBlock => ErrorKind::WouldBlock.into(), + ErrorCode::InvalidState => ErrorKind::InvalidData.into(), + ErrorCode::AddressInUse => ErrorKind::AddrInUse.into(), + ErrorCode::ConnectionRefused => ErrorKind::ConnectionRefused.into(), + ErrorCode::ConnectionReset => ErrorKind::ConnectionReset.into(), + ErrorCode::ConnectionAborted => ErrorKind::ConnectionAborted.into(), + ErrorCode::ConcurrencyConflict => ErrorKind::AlreadyExists.into(), + _ => ErrorKind::Other.into(), + } +} diff --git a/src/net/tcp_listener.rs b/src/net/tcp_listener.rs index 830d673..9f6d67f 100644 --- a/src/net/tcp_listener.rs +++ b/src/net/tcp_listener.rs @@ -1,12 +1,11 @@ use wasip2::sockets::network::Ipv4SocketAddress; -use wasip2::sockets::tcp::{ErrorCode, IpAddressFamily, IpSocketAddress, TcpSocket}; +use wasip2::sockets::tcp::{IpAddressFamily, IpSocketAddress, TcpSocket}; use crate::io; use crate::iter::AsyncIterator; -use std::io::ErrorKind; use std::net::SocketAddr; -use super::TcpStream; +use super::{to_io_err, TcpStream}; use crate::runtime::AsyncPollable; /// A TCP socket server, listening for connections. @@ -81,29 +80,6 @@ impl<'a> AsyncIterator for Incoming<'a> { } } -pub(super) fn to_io_err(err: ErrorCode) -> io::Error { - match err { - wasip2::sockets::network::ErrorCode::Unknown => ErrorKind::Other.into(), - wasip2::sockets::network::ErrorCode::AccessDenied => ErrorKind::PermissionDenied.into(), - wasip2::sockets::network::ErrorCode::NotSupported => ErrorKind::Unsupported.into(), - wasip2::sockets::network::ErrorCode::InvalidArgument => ErrorKind::InvalidInput.into(), - wasip2::sockets::network::ErrorCode::OutOfMemory => ErrorKind::OutOfMemory.into(), - wasip2::sockets::network::ErrorCode::Timeout => ErrorKind::TimedOut.into(), - wasip2::sockets::network::ErrorCode::WouldBlock => ErrorKind::WouldBlock.into(), - wasip2::sockets::network::ErrorCode::InvalidState => ErrorKind::InvalidData.into(), - wasip2::sockets::network::ErrorCode::AddressInUse => ErrorKind::AddrInUse.into(), - wasip2::sockets::network::ErrorCode::ConnectionRefused => { - ErrorKind::ConnectionRefused.into() - } - wasip2::sockets::network::ErrorCode::ConnectionReset => ErrorKind::ConnectionReset.into(), - wasip2::sockets::network::ErrorCode::ConnectionAborted => { - ErrorKind::ConnectionAborted.into() - } - wasip2::sockets::network::ErrorCode::ConcurrencyConflict => ErrorKind::AlreadyExists.into(), - _ => ErrorKind::Other.into(), - } -} - fn sockaddr_from_wasi(addr: IpSocketAddress) -> std::net::SocketAddr { use wasip2::sockets::network::Ipv6SocketAddress; match addr { diff --git a/src/net/tcp_stream.rs b/src/net/tcp_stream.rs index fc6ef99..af3674a 100644 --- a/src/net/tcp_stream.rs +++ b/src/net/tcp_stream.rs @@ -1,9 +1,17 @@ +use std::io::ErrorKind; +use std::net::{SocketAddr, ToSocketAddrs}; +use wasip2::sockets::instance_network::instance_network; +use wasip2::sockets::network::Ipv4SocketAddress; +use wasip2::sockets::tcp::{IpAddressFamily, IpSocketAddress}; +use wasip2::sockets::tcp_create_socket::create_tcp_socket; use wasip2::{ io::streams::{InputStream, OutputStream}, sockets::tcp::TcpSocket, }; +use super::to_io_err; use crate::io::{self, AsyncInputStream, AsyncOutputStream}; +use crate::runtime::AsyncPollable; /// A TCP stream between a local and a remote socket. pub struct TcpStream { @@ -20,12 +28,61 @@ impl TcpStream { socket, } } + + /// Opens a TCP connection to a remote host. + /// + /// `addr` is an address of the remote host. Anything which implements the + /// [`ToSocketAddrs`] trait can be supplied as the address. If `addr` + /// yields multiple addresses, connect will be attempted with each of the + /// addresses until a connection is successful. If none of the addresses + /// result in a successful connection, the error returned from the last + /// connection attempt (the last address) is returned. + pub async fn connect(addr: impl ToSocketAddrs) -> io::Result { + let addrs = addr.to_socket_addrs()?; + let mut last_err = None; + for addr in addrs { + match TcpStream::connect_addr(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new(ErrorKind::InvalidInput, "could not resolve to any address") + })) + } + + /// Establishes a connection to the specified `addr`. + pub async fn connect_addr(addr: SocketAddr) -> io::Result { + let family = match addr { + SocketAddr::V4(_) => IpAddressFamily::Ipv4, + SocketAddr::V6(_) => IpAddressFamily::Ipv6, + }; + let socket = create_tcp_socket(family).map_err(to_io_err)?; + let network = instance_network(); + + let remote_address = match addr { + SocketAddr::V4(addr) => { + let ip = addr.ip().octets(); + let address = (ip[0], ip[1], ip[2], ip[3]); + let port = addr.port(); + IpSocketAddress::Ipv4(Ipv4SocketAddress { port, address }) + } + SocketAddr::V6(_) => todo!("IPv6 not yet supported in `wstd::net::TcpStream`"), + }; + socket + .start_connect(&network, remote_address) + .map_err(to_io_err)?; + let pollable = AsyncPollable::new(socket.subscribe()); + pollable.wait_for().await; + let (input, output) = socket.finish_connect().map_err(to_io_err)?; + + Ok(TcpStream::new(input, output, socket)) + } + /// Returns the socket address of the remote peer of this TCP connection. pub fn peer_addr(&self) -> io::Result { - let addr = self - .socket - .remote_address() - .map_err(super::tcp_listener::to_io_err)?; + let addr = self.socket.remote_address().map_err(to_io_err)?; Ok(format!("{addr:?}")) }