Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Use Socket2 SockAddr for Windows #433

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ name = "trip"
thiserror = "1.0.39"
derive_more = "0.99.17"
arrayvec = "0.7.2"
socket2 = { version = "0.5.1", features = [ "all" ] }
socket2 = { git="https://github.com/fujiapple852/socket2", features = [ "all" ] }

# TUI dependencies
anyhow = "1.0.68"
Expand Down
163 changes: 36 additions & 127 deletions src/tracing/net/platform/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@ use crate::tracing::net::channel::MAX_PACKET_SIZE;
use crate::tracing::net::platform::windows::adapter::Adapters;
use crate::tracing::net::socket::TracerSocket;
use socket2::{Domain, Protocol, SockAddr, Type};
use std::ffi::c_void;
use std::io::{Error, ErrorKind, Result};
use std::mem::{size_of, zeroed};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::os::windows::prelude::AsRawSocket;
use std::ptr::{addr_of, addr_of_mut, null_mut};
use std::time::Duration;
use windows_sys::Win32::Foundation::{WAIT_FAILED, WAIT_TIMEOUT};
use windows_sys::Win32::Networking::WinSock::{
AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IN6_ADDR, IN6_ADDR_0, IN_ADDR,
IN_ADDR_0, IPPROTO_RAW, IPPROTO_TCP, SIO_ROUTING_INTERFACE_QUERY, SOCKADDR_IN, SOCKADDR_IN6,
SOCKADDR_IN6_0, SOCKADDR_STORAGE, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY,
AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IPPROTO_RAW, IPPROTO_TCP,
SIO_ROUTING_INTERFACE_QUERY, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY,
SO_REUSE_UNICASTPORT, TCP_FAIL_CONNECT_ON_ICMP_ERROR, TCP_ICMP_ERROR_INFO, WSABUF, WSADATA,
WSAEADDRNOTAVAIL, WSAECONNREFUSED, WSAEHOSTUNREACH, WSAEINPROGRESS, WSA_IO_INCOMPLETE,
WSA_IO_PENDING,
Expand Down Expand Up @@ -102,7 +100,7 @@ pub struct Socket {
inner: socket2::Socket,
ol: Box<OVERLAPPED>,
buf: Vec<u8>,
from: Box<SOCKADDR_STORAGE>,
from: Box<SockAddr>,
}

#[allow(clippy::cast_possible_wrap)]
Expand All @@ -117,7 +115,7 @@ impl Socket {

fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> Result<Self> {
let inner = socket2::Socket::new(domain, ty, protocol)?;
let from = Box::new(Self::new_sockaddr_storage());
let from = Box::new(Self::new_sockaddr());
let ol = Box::new(Self::new_overlapped());
let buf = vec![0u8; MAX_PACKET_SIZE];
Ok(Self {
Expand Down Expand Up @@ -196,7 +194,7 @@ impl Socket {
fn is_err(res: i32) -> bool {
res == SOCKET_ERROR && Error::last_os_error().raw_os_error() != Some(WSA_IO_PENDING)
}
let mut fromlen = std::mem::size_of::<SOCKADDR_STORAGE>() as i32;
let mut fromlen = self.from.len();
let wbuf = WSABUF {
len: MAX_PACKET_SIZE as u32,
buf: self.buf.as_mut_ptr(),
Expand All @@ -208,7 +206,7 @@ impl Socket {
1,
null_mut(),
&mut 0,
addr_of_mut!(*self.from).cast(),
self.from.as_mut_ptr(),
addr_of_mut!(fromlen),
addr_of_mut!(*self.ol),
None,
Expand Down Expand Up @@ -241,10 +239,8 @@ impl Socket {
unsafe { zeroed::<WSADATA>() }
}

#[allow(unsafe_code)]
fn new_sockaddr_storage() -> SOCKADDR_STORAGE {
// Safety: an all-zero value is valid for SOCKADDR_STORAGE.
unsafe { zeroed::<SOCKADDR_STORAGE>() }
fn new_sockaddr() -> SockAddr {
SockAddr::from(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
}

#[allow(unsafe_code)]
Expand Down Expand Up @@ -272,7 +268,6 @@ impl Drop for Socket {
#[allow(clippy::cast_possible_wrap)]
impl TracerSocket for Socket {
fn new_icmp_send_socket_ipv4() -> Result<Self> {
// let sock = Self::new(AF_INET, SOCK_RAW, IPPROTO_RAW)?;
let sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
sock.set_non_blocking(true)?;
sock.set_header_included(true)?;
Expand Down Expand Up @@ -425,9 +420,8 @@ impl TracerSocket for Socket {
}

fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, Option<SocketAddr>)> {
let addr = sockaddrptr_to_ipaddr(addr_of_mut!(*self.from))?;
let len = self.read(buf)?;
Ok((len, Some(SocketAddr::new(addr, 0))))
Ok((len, self.from.as_socket()))
}

// TODO
Expand Down Expand Up @@ -483,132 +477,39 @@ impl TracerSocket for Socket {
}
}

/// NOTE under Windows, we cannot use a bind connect/getsockname as "If the socket
/// is using a connectionless protocol, the address may not be available until I/O
/// occurs on the socket." We use `SIO_ROUTING_INTERFACE_QUERY` instead.
/// Determine the src `IpAddr` used for routing to a given target `IpAddr`.
///
/// under Windows, we cannot use a bind connect/getsockname as "If the socket is using a connectionless protocol, the
/// address may not be available until I/O occurs on the socket.". Therefore we use `SIO_ROUTING_INTERFACE_QUERY`
/// instead.
///
/// Note that the `WSAIoctl` call potentially returns multiple results (see
/// <https://www.winsocketdotnetworkprogramming.com/winsock2programming/winsock2advancedsocketoptionioctl7h.html>),
/// and we currently choose the first one arbitrarily.
#[allow(clippy::cast_sign_loss)]
fn routing_interface_query(target: IpAddr) -> TraceResult<IpAddr> {
let src: *mut c_void = [0; 1024].as_mut_ptr().cast();
let mut src = Socket::new_sockaddr();
let dest = SockAddr::from(SocketAddr::new(target, 0));
let mut bytes = 0;
let socket = match target {
IpAddr::V4(_) => Socket::new_udp_dgram_socket_ipv4(),
IpAddr::V6(_) => Socket::new_udp_dgram_socket_ipv6(),
}?;
let (dest, destlen) = socketaddr_to_sockaddr(SocketAddr::new(target, 0));
syscall!(
WSAIoctl(
socket.inner.as_raw_socket() as _,
SIO_ROUTING_INTERFACE_QUERY,
addr_of!(dest).cast(),
destlen as u32,
src,
1024,
dest.as_ptr().cast(),
dest.len() as u32,
src.as_mut_ptr().cast(),
src.len() as u32,
addr_of_mut!(bytes),
null_mut(),
None,
),
|res| res == SOCKET_ERROR
)?;
// Note that the WSAIoctl call potentially returns multiple results (see
// <https://www.winsocketdotnetworkprogramming.com/winsock2programming/winsock2advancedsocketoptionioctl7h.html>),
// TBD We choose the first one arbitrarily.
let sockaddr = src.cast::<SOCKADDR_STORAGE>();
sockaddrptr_to_ipaddr(sockaddr).map_err(TracerError::IoError)
}

#[allow(unsafe_code)]
fn sockaddrptr_to_ipaddr(sockaddr: *mut SOCKADDR_STORAGE) -> Result<IpAddr> {
// Safety: TODO
match sockaddr_to_socketaddr(unsafe { sockaddr.as_ref().unwrap() }) {
Err(e) => Err(e),
Ok(socketaddr) => match socketaddr {
SocketAddr::V4(socketaddrv4) => Ok(IpAddr::V4(*socketaddrv4.ip())),
SocketAddr::V6(socketaddrv6) => Ok(IpAddr::V6(*socketaddrv6.ip())),
},
}
}

#[allow(unsafe_code)]
fn sockaddr_to_socketaddr(sockaddr: &SOCKADDR_STORAGE) -> Result<SocketAddr> {
let ptr = sockaddr as *const SOCKADDR_STORAGE;
let af = sockaddr.ss_family;
if af == AF_INET {
let sockaddr_in_ptr = ptr.cast::<SOCKADDR_IN>();
// Safety: TODO
let sockaddr_in = unsafe { *sockaddr_in_ptr };
let ipv4addr = u32::from_be(unsafe { sockaddr_in.sin_addr.S_un.S_addr });
let port = sockaddr_in.sin_port;
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from(ipv4addr),
port,
)))
} else if af == AF_INET6 {
#[allow(clippy::cast_ptr_alignment)]
let sockaddr_in6_ptr = ptr.cast::<SOCKADDR_IN6>();
// Safety: TODO
let sockaddr_in6 = unsafe { *sockaddr_in6_ptr };
// TODO: check endianness
// Safety: TODO
let ipv6addr = unsafe { sockaddr_in6.sin6_addr.u.Byte };
let port = sockaddr_in6.sin6_port;
// Safety: TODO
let scope_id = unsafe { sockaddr_in6.Anonymous.sin6_scope_id };
Ok(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(ipv6addr),
port,
sockaddr_in6.sin6_flowinfo,
scope_id,
)))
} else {
Err(Error::new(
ErrorKind::Unsupported,
format!("Unsupported address family: {af:?}"),
))
}
}

#[allow(unsafe_code)]
#[allow(clippy::cast_possible_wrap)]
#[must_use]
fn socketaddr_to_sockaddr(socketaddr: SocketAddr) -> (SOCKADDR_STORAGE, i32) {
#[repr(C)]
union SockAddr {
storage: SOCKADDR_STORAGE,
in4: SOCKADDR_IN,
in6: SOCKADDR_IN6,
}

let sockaddr = match socketaddr {
SocketAddr::V4(socketaddrv4) => SockAddr {
in4: SOCKADDR_IN {
sin_family: AF_INET,
sin_port: socketaddrv4.port().to_be(),
sin_addr: IN_ADDR {
S_un: IN_ADDR_0 {
S_addr: u32::from(*socketaddrv4.ip()).to_be(),
},
},
sin_zero: [0; 8],
},
},
SocketAddr::V6(socketaddrv6) => SockAddr {
in6: SOCKADDR_IN6 {
sin6_family: AF_INET6,
sin6_port: socketaddrv6.port().to_be(),
sin6_flowinfo: socketaddrv6.flowinfo(),
sin6_addr: IN6_ADDR {
u: IN6_ADDR_0 {
Byte: socketaddrv6.ip().octets(),
},
},
Anonymous: SOCKADDR_IN6_0 {
sin6_scope_id: socketaddrv6.scope_id(),
},
},
},
};

(unsafe { sockaddr.storage }, size_of::<SockAddr>() as i32)
Ok(src.as_socket().unwrap().ip())
}

fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult<IpAddr> {
Expand All @@ -626,7 +527,7 @@ fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult<IpAddr>

mod adapter {
use crate::tracing::error::{TraceResult, TracerError};
use crate::tracing::net::platform::windows::sockaddrptr_to_ipaddr;
use socket2::SockAddr;
use std::io::Error;
use std::marker::PhantomData;
use std::net::IpAddr;
Expand Down Expand Up @@ -745,7 +646,15 @@ mod adapter {
let first_unicast = (*self.next).FirstUnicastAddress;
let socket_address = (*first_unicast).Address;
let sockaddr = socket_address.lpSockaddr;
sockaddrptr_to_ipaddr(sockaddr.cast()).ok()?

// Safety: TODO
let (_, addr) = SockAddr::try_init(|s, _length| {
// TODO or memcpy?
*s = *sockaddr.cast();
Ok(())
})
.unwrap();
addr.as_socket().unwrap().ip()
};
self.next = (*self.next).Next;
Some(AdapterAddress {
Expand Down