Skip to content

Commit

Permalink
tcp server timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
bluejekyll committed Nov 21, 2016
1 parent 16cb3c3 commit 8859b5c
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 9 deletions.
2 changes: 2 additions & 0 deletions server/src/server/mod.rs
Expand Up @@ -19,10 +19,12 @@
mod request_stream;
mod server;
mod server_future;
mod timeout_stream;

pub use self::request_stream::Request;
pub use self::request_stream::RequestStream;
pub use self::request_stream::ResponseHandle;
#[allow(deprecated)]
pub use self::server::Server;
pub use self::server_future::ServerFuture;
pub use self::timeout_stream::TimeoutStream;
26 changes: 19 additions & 7 deletions server/src/server/server_future.rs
Expand Up @@ -7,6 +7,7 @@
use std;
use std::io;
use std::sync::Arc;
use std::time::Duration;

use futures::{Async, Future, Poll};
use futures::stream::Stream;
Expand All @@ -17,7 +18,7 @@ use trust_dns::op::RequestHandler;
use trust_dns::udp::UdpStream;
use trust_dns::tcp::TcpStream;

use ::server::{Request, RequestStream, ResponseHandle};
use ::server::{Request, RequestStream, ResponseHandle, TimeoutStream};
use ::authority::Catalog;

// TODO, would be nice to have a Slab for buffers here...
Expand All @@ -34,8 +35,8 @@ impl ServerFuture {
})
}

/// register a UDP socket. Should be bound before calling this function.
pub fn register_socket(&mut self, socket: std::net::UdpSocket) {
/// Register a UDP socket. Should be bound before calling this function.
pub fn register_socket(&self, socket: std::net::UdpSocket) {
// create the new UdpStream
let (buf_stream, stream_handle) = UdpStream::with_bound(socket, self.io_loop.handle());
let request_stream = RequestStream::new(buf_stream, stream_handle);
Expand All @@ -51,9 +52,19 @@ impl ServerFuture {
);
}

/// register a TcpListener to the Server. This should already be bound to either an IPv6 or an
/// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
/// IPv4 address.
pub fn register_listener(&mut self, listener: std::net::TcpListener) {
///
/// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
/// to not make this too low depending on use cases.
///
/// # Arguments
/// * `listener` - a bound and listenting TCP socket
/// * `timeout` - timeout duration of incoming requests, any connection that does not send
/// requests within this time period will be closed. In the future it should be
/// possible to create long-lived queries, but these should be from trusted sources
/// only, this would require some type of whitelisting.
pub fn register_listener(&self, listener: std::net::TcpListener, timeout: Duration) {
let handle = self.io_loop.handle();
let catalog = self.catalog.clone();
// TODO: this is an awkward interface with socketaddr...
Expand All @@ -67,15 +78,16 @@ impl ServerFuture {
debug!("accepted request from: {}", src_addr);
// take the created stream...
let (buf_stream, stream_handle) = TcpStream::with_tcp_stream(tcp_stream, handle.clone());
let request_stream = RequestStream::new(buf_stream, stream_handle);
let timeout_stream = try!(TimeoutStream::new(buf_stream, timeout, handle.clone()));
let request_stream = RequestStream::new(timeout_stream, stream_handle);
let catalog = catalog.clone();

// and spawn to the io_loop
handle.spawn(
request_stream.for_each(move |(request, response_handle)| {
Self::handle_request(request, response_handle, catalog.clone())
})
.map_err(|e| debug!("error in TCP request_stream handler: {}", e))
.map_err(move |e| debug!("error in TCP request_stream src: {:?} error: {}", src_addr, e))
);

Ok(())
Expand Down
50 changes: 50 additions & 0 deletions server/src/server/timeout_stream.rs
@@ -0,0 +1,50 @@
use std::io;
use std::mem;
use std::time::Duration;

use futures::{Async, Future, Poll};
use futures::stream::Stream;
use tokio_core::reactor::{Handle, Timeout};

/// This wraps the underlying Stream in a timeout.
///
/// Any `Ok(Async::Ready(_))` from the underlying Stream will reset the timeout.
pub struct TimeoutStream<S> {
stream: S,
reactor_handle: Handle,
timeout_duration: Duration,
timeout: Timeout,
}

impl<S> TimeoutStream<S> {
pub fn new(stream: S, timeout_duration: Duration, reactor_handle: Handle) -> io::Result<Self> {
// store a Timeout for this message before sending
let timeout = try!(Timeout::new(timeout_duration, &reactor_handle));
Ok(TimeoutStream{ stream: stream, reactor_handle: reactor_handle, timeout_duration: timeout_duration, timeout: timeout })
}
}

impl<S, I> Stream for TimeoutStream<S>
where S: Stream<Item=I, Error=io::Error> {
type Item = I;
type Error = io::Error;

// somehow insert a timeout here...
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.stream.poll() {
r @ Ok(Async::Ready(_)) | r @ Err(_) => {
// reset the timeout to wait for the next request...
let timeout = try!(Timeout::new(self.timeout_duration, &self.reactor_handle));
drop(mem::replace(&mut self.timeout, timeout));

return r
},
Ok(Async::NotReady) => {
// otherwise poll the timeout
match try_ready!(self.timeout.poll()) {
() => return Err(io::Error::new(io::ErrorKind::TimedOut, format!("nothing ready in {:?}", self.timeout_duration))),
}
}
}
}
}
5 changes: 3 additions & 2 deletions server/tests/server_future_tests.rs
Expand Up @@ -2,8 +2,9 @@ extern crate mio;
extern crate trust_dns;
extern crate trust_dns_server;

use std::thread;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket, TcpListener};
use std::thread;
use std::time::Duration;

use trust_dns::client::*;
use trust_dns::op::*;
Expand Down Expand Up @@ -107,7 +108,7 @@ fn server_thread_udp(udp_socket: UdpSocket) {
fn server_thread_tcp(tcp_listener: TcpListener) {
let catalog = new_catalog();
let mut server = ServerFuture::new(catalog).expect("new tcp server failed");
server.register_listener(tcp_listener);
server.register_listener(tcp_listener, Duration::from_secs(30));

server.listen().unwrap();
}
53 changes: 53 additions & 0 deletions server/tests/timeout_stream_tests.rs
@@ -0,0 +1,53 @@
extern crate futures;
extern crate tokio_core;
extern crate trust_dns_server;

use std::io;
use std::time::Duration;
use futures::{Async, Poll};
use futures::stream::{iter, Stream};
use tokio_core::reactor::Core;

use trust_dns_server::server::TimeoutStream;

#[test]
fn test_no_timeout() {
let sequence = iter(vec![Ok(1), Err("error"), Ok(2)]).map_err(|e| io::Error::new(io::ErrorKind::Other, e));
let mut core = Core::new().expect("could not get core");

let timeout_stream = TimeoutStream::new(sequence, Duration::from_secs(360), core.handle()).expect("could not create timeout_stream");

let (val, timeout_stream) = core.run(timeout_stream.into_future()).ok().expect("first run failed");
assert_eq!(val, Some(1));

let error = core.run(timeout_stream.into_future());
assert!(error.is_err());

let (_, timeout_stream) = error.err().unwrap();

let (val, timeout_stream) = core.run(timeout_stream.into_future()).ok().expect("third run failed");
assert_eq!(val, Some(2));

let (val, _) = core.run(timeout_stream.into_future()).ok().expect("fourth run failed");
assert!(val.is_none())
}

struct NeverStream {}

impl Stream for NeverStream {
type Item = ();
type Error = io::Error;

// somehow insert a timeout here...
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
Ok(Async::NotReady)
}
}

#[test]
fn test_timeout() {
let mut core = Core::new().expect("could not get core");
let timeout_stream = TimeoutStream::new(NeverStream{}, Duration::from_millis(1), core.handle()).expect("could not create timeout_stream");

assert!(core.run(timeout_stream.into_future()).is_err());
}

0 comments on commit 8859b5c

Please sign in to comment.