Skip to content

Commit

Permalink
libstd/net: Add peek APIs to UdpSocket and TcpStream
Browse files Browse the repository at this point in the history
These methods enable socket reads without side-effects. That is,
repeated calls to peek() return identical data. This is accomplished
by providing the POSIX flag MSG_PEEK to the underlying socket read
operations.

This also moves the current implementation of recv_from out of the
platform-independent sys_common and into respective sys/windows and
sys/unix implementations. This allows for more platform-dependent
implementations.
  • Loading branch information
Tyler Julian committed Feb 4, 2017
1 parent 9749df5 commit a40be08
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/libstd/lib.rs
Expand Up @@ -275,6 +275,7 @@
#![feature(oom)]
#![feature(optin_builtin_traits)]
#![feature(panic_unwind)]
#![feature(peek)]
#![feature(placement_in_syntax)]
#![feature(prelude_import)]
#![feature(pub_restricted)]
Expand Down
54 changes: 54 additions & 0 deletions src/libstd/net/tcp.rs
Expand Up @@ -296,6 +296,29 @@ impl TcpStream {
self.0.write_timeout()
}

/// Receives data on the socket from the remote adress to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::TcpStream;
///
/// let stream = TcpStream::connect("127.0.0.1:8000")
/// .expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let len = stream.peek(&mut buf).expect("peek failed");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Sets the value of the `TCP_NODELAY` option on this socket.
///
/// If set, this option disables the Nagle algorithm. This means that
Expand Down Expand Up @@ -1405,4 +1428,35 @@ mod tests {
Err(e) => panic!("unexpected error {}", e),
}
}

#[test]
fn peek() {
each_ip(&mut |addr| {
let (txdone, rxdone) = channel();

let srv = t!(TcpListener::bind(&addr));
let _t = thread::spawn(move|| {
let mut cl = t!(srv.accept()).0;
cl.write(&[1,3,3,7]).unwrap();
t!(rxdone.recv());
});

let mut c = t!(TcpStream::connect(&addr));
let mut b = [0; 10];
for _ in 1..3 {
let len = c.peek(&mut b).unwrap();
assert_eq!(len, 4);
}
let len = c.read(&mut b).unwrap();
assert_eq!(len, 4);

t!(c.set_nonblocking(true));
match c.peek(&mut b) {
Ok(_) => panic!("expected error"),
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => panic!("unexpected error {}", e),
}
t!(txdone.send(()));
})
}
}
97 changes: 97 additions & 0 deletions src/libstd/net/udp.rs
Expand Up @@ -83,6 +83,30 @@ impl UdpSocket {
self.0.recv_from(buf)
}

/// Receives data from the socket, without removing it from the queue.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recvfrom` system call.
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf)
/// .expect("Didn't receive data");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.peek_from(buf)
}

/// Sends data on the socket to the given address. On success, returns the
/// number of bytes written.
///
Expand Down Expand Up @@ -579,6 +603,37 @@ impl UdpSocket {
self.0.recv(buf)
}

/// Receives data on the socket from the remote adress to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Errors
///
/// This method will fail if the socket is not connected. The `connect` method
/// will connect this socket to a remote address.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// socket.connect("127.0.0.1:8080").expect("connect function failed");
/// let mut buf = [0; 10];
/// match socket.peek(&mut buf) {
/// Ok(received) => println!("received {} bytes", received),
/// Err(e) => println!("peek function failed: {:?}", e),
/// }
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Moves this UDP socket into or out of nonblocking mode.
///
/// On Unix this corresponds to calling fcntl, and on Windows this
Expand Down Expand Up @@ -869,6 +924,48 @@ mod tests {
assert_eq!(b"hello world", &buf[..]);
}

#[test]
fn connect_send_peek_recv() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.connect(addr));

t!(socket.send(b"hello world"));

for _ in 1..3 {
let mut buf = [0; 11];
let size = t!(socket.peek(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}

let mut buf = [0; 11];
let size = t!(socket.recv(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}

#[test]
fn peek_from() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.send_to(b"hello world", &addr));

for _ in 1..3 {
let mut buf = [0; 11];
let (size, _) = t!(socket.peek_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}

let mut buf = [0; 11];
let (size, _) = t!(socket.recv_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}

#[test]
fn ttl() {
let ttl = 100;
Expand Down
45 changes: 42 additions & 3 deletions src/libstd/sys/unix/net.rs
Expand Up @@ -10,12 +10,13 @@

use ffi::CStr;
use io;
use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM};
use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
use mem;
use net::{SocketAddr, Shutdown};
use str;
use sys::fd::FileDesc;
use sys_common::{AsInner, FromInner, IntoInner};
use sys_common::net::{getsockopt, setsockopt};
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
use time::Duration;

pub use sys::{cvt, cvt_r};
Expand Down Expand Up @@ -155,8 +156,46 @@ impl Socket {
self.0.duplicate().map(Socket)
}

fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
let ret = cvt(unsafe {
libc::recv(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags)
})?;
Ok(ret as usize)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
self.recv_with_flags(buf, 0)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, MSG_PEEK)
}

fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;

let n = cvt(unsafe {
libc::recvfrom(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}

pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, MSG_PEEK)
}

pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
Expand Down
1 change: 1 addition & 0 deletions src/libstd/sys/windows/c.rs
Expand Up @@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
pub const IP_DROP_MEMBERSHIP: c_int = 13;
pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
pub const MSG_PEEK: c_int = 0x2;

#[repr(C)]
pub struct ip_mreq {
Expand Down
44 changes: 42 additions & 2 deletions src/libstd/sys/windows/net.rs
Expand Up @@ -147,19 +147,59 @@ impl Socket {
Ok(socket)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
// On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF.
let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
unsafe {
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) {
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
-1 => Err(last_error()),
n => Ok(n as usize)
}
}
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, 0)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, c::MSG_PEEK)
}

fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;

// On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF.
unsafe {
match c::recvfrom(self.0,
buf.as_mut_ptr() as *mut c_void,
len,
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
},
-1 => Err(last_error()),
n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
}
}
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}

pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, c::MSG_PEEK)
}

pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let mut me = self;
(&mut me).read_to_end(buf)
Expand Down
24 changes: 13 additions & 11 deletions src/libstd/sys_common/net.rs
Expand Up @@ -91,7 +91,7 @@ fn sockname<F>(f: F) -> io::Result<SocketAddr>
}
}

fn sockaddr_to_addr(storage: &c::sockaddr_storage,
pub fn sockaddr_to_addr(storage: &c::sockaddr_storage,
len: usize) -> io::Result<SocketAddr> {
match storage.ss_family as c_int {
c::AF_INET => {
Expand Down Expand Up @@ -222,6 +222,10 @@ impl TcpStream {
self.inner.timeout(c::SO_SNDTIMEO)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
Expand Down Expand Up @@ -441,17 +445,11 @@ impl UdpSocket {
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
self.inner.recv_from(buf)
}

let n = cvt(unsafe {
c::recvfrom(*self.inner.as_inner(),
buf.as_mut_ptr() as *mut c_void,
len, 0,
&mut storage as *mut _ as *mut _, &mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.inner.peek_from(buf)
}

pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
Expand Down Expand Up @@ -578,6 +576,10 @@ impl UdpSocket {
self.inner.read(buf)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}

pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
let ret = cvt(unsafe {
Expand Down

0 comments on commit a40be08

Please sign in to comment.