From 8859b5c09b4665f19609b09d019949ce2ead8a09 Mon Sep 17 00:00:00 2001 From: Benjamin Fry Date: Mon, 21 Nov 2016 01:19:44 -0800 Subject: [PATCH] tcp server timeouts --- server/src/server/mod.rs | 2 ++ server/src/server/server_future.rs | 26 ++++++++++---- server/src/server/timeout_stream.rs | 50 ++++++++++++++++++++++++++ server/tests/server_future_tests.rs | 5 +-- server/tests/timeout_stream_tests.rs | 53 ++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 9 deletions(-) create mode 100644 server/src/server/timeout_stream.rs create mode 100644 server/tests/timeout_stream_tests.rs diff --git a/server/src/server/mod.rs b/server/src/server/mod.rs index 86097d445f..1bc672028e 100644 --- a/server/src/server/mod.rs +++ b/server/src/server/mod.rs @@ -19,6 +19,7 @@ mod request_stream; mod server; mod server_future; +mod timeout_stream; pub use self::request_stream::Request; pub use self::request_stream::RequestStream; @@ -26,3 +27,4 @@ pub use self::request_stream::ResponseHandle; #[allow(deprecated)] pub use self::server::Server; pub use self::server_future::ServerFuture; +pub use self::timeout_stream::TimeoutStream; diff --git a/server/src/server/server_future.rs b/server/src/server/server_future.rs index f113e4c1d1..39923796d7 100644 --- a/server/src/server/server_future.rs +++ b/server/src/server/server_future.rs @@ -7,6 +7,7 @@ use std; use std::io; use std::sync::Arc; +use std::time::Duration; use futures::{Async, Future, Poll}; use futures::stream::Stream; @@ -17,7 +18,7 @@ use trust_dns::op::RequestHandler; use trust_dns::udp::UdpStream; use trust_dns::tcp::TcpStream; -use ::server::{Request, RequestStream, ResponseHandle}; +use ::server::{Request, RequestStream, ResponseHandle, TimeoutStream}; use ::authority::Catalog; // TODO, would be nice to have a Slab for buffers here... @@ -34,8 +35,8 @@ impl ServerFuture { }) } - /// register a UDP socket. Should be bound before calling this function. - pub fn register_socket(&mut self, socket: std::net::UdpSocket) { + /// Register a UDP socket. Should be bound before calling this function. + pub fn register_socket(&self, socket: std::net::UdpSocket) { // create the new UdpStream let (buf_stream, stream_handle) = UdpStream::with_bound(socket, self.io_loop.handle()); let request_stream = RequestStream::new(buf_stream, stream_handle); @@ -51,9 +52,19 @@ impl ServerFuture { ); } - /// register a TcpListener to the Server. This should already be bound to either an IPv6 or an + /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an /// IPv4 address. - pub fn register_listener(&mut self, listener: std::net::TcpListener) { + /// + /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken + /// to not make this too low depending on use cases. + /// + /// # Arguments + /// * `listener` - a bound and listenting TCP socket + /// * `timeout` - timeout duration of incoming requests, any connection that does not send + /// requests within this time period will be closed. In the future it should be + /// possible to create long-lived queries, but these should be from trusted sources + /// only, this would require some type of whitelisting. + pub fn register_listener(&self, listener: std::net::TcpListener, timeout: Duration) { let handle = self.io_loop.handle(); let catalog = self.catalog.clone(); // TODO: this is an awkward interface with socketaddr... @@ -67,7 +78,8 @@ impl ServerFuture { debug!("accepted request from: {}", src_addr); // take the created stream... let (buf_stream, stream_handle) = TcpStream::with_tcp_stream(tcp_stream, handle.clone()); - let request_stream = RequestStream::new(buf_stream, stream_handle); + let timeout_stream = try!(TimeoutStream::new(buf_stream, timeout, handle.clone())); + let request_stream = RequestStream::new(timeout_stream, stream_handle); let catalog = catalog.clone(); // and spawn to the io_loop @@ -75,7 +87,7 @@ impl ServerFuture { request_stream.for_each(move |(request, response_handle)| { Self::handle_request(request, response_handle, catalog.clone()) }) - .map_err(|e| debug!("error in TCP request_stream handler: {}", e)) + .map_err(move |e| debug!("error in TCP request_stream src: {:?} error: {}", src_addr, e)) ); Ok(()) diff --git a/server/src/server/timeout_stream.rs b/server/src/server/timeout_stream.rs new file mode 100644 index 0000000000..c64c33fcc1 --- /dev/null +++ b/server/src/server/timeout_stream.rs @@ -0,0 +1,50 @@ +use std::io; +use std::mem; +use std::time::Duration; + +use futures::{Async, Future, Poll}; +use futures::stream::Stream; +use tokio_core::reactor::{Handle, Timeout}; + +/// This wraps the underlying Stream in a timeout. +/// +/// Any `Ok(Async::Ready(_))` from the underlying Stream will reset the timeout. +pub struct TimeoutStream { + stream: S, + reactor_handle: Handle, + timeout_duration: Duration, + timeout: Timeout, +} + +impl TimeoutStream { + pub fn new(stream: S, timeout_duration: Duration, reactor_handle: Handle) -> io::Result { + // store a Timeout for this message before sending + let timeout = try!(Timeout::new(timeout_duration, &reactor_handle)); + Ok(TimeoutStream{ stream: stream, reactor_handle: reactor_handle, timeout_duration: timeout_duration, timeout: timeout }) + } +} + +impl Stream for TimeoutStream +where S: Stream { + type Item = I; + type Error = io::Error; + + // somehow insert a timeout here... + fn poll(&mut self) -> Poll, Self::Error> { + match self.stream.poll() { + r @ Ok(Async::Ready(_)) | r @ Err(_) => { + // reset the timeout to wait for the next request... + let timeout = try!(Timeout::new(self.timeout_duration, &self.reactor_handle)); + drop(mem::replace(&mut self.timeout, timeout)); + + return r + }, + Ok(Async::NotReady) => { + // otherwise poll the timeout + match try_ready!(self.timeout.poll()) { + () => return Err(io::Error::new(io::ErrorKind::TimedOut, format!("nothing ready in {:?}", self.timeout_duration))), + } + } + } + } +} diff --git a/server/tests/server_future_tests.rs b/server/tests/server_future_tests.rs index 4c41536919..1a93932f3e 100644 --- a/server/tests/server_future_tests.rs +++ b/server/tests/server_future_tests.rs @@ -2,8 +2,9 @@ extern crate mio; extern crate trust_dns; extern crate trust_dns_server; -use std::thread; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket, TcpListener}; +use std::thread; +use std::time::Duration; use trust_dns::client::*; use trust_dns::op::*; @@ -107,7 +108,7 @@ fn server_thread_udp(udp_socket: UdpSocket) { fn server_thread_tcp(tcp_listener: TcpListener) { let catalog = new_catalog(); let mut server = ServerFuture::new(catalog).expect("new tcp server failed"); - server.register_listener(tcp_listener); + server.register_listener(tcp_listener, Duration::from_secs(30)); server.listen().unwrap(); } diff --git a/server/tests/timeout_stream_tests.rs b/server/tests/timeout_stream_tests.rs new file mode 100644 index 0000000000..2e7cc832c1 --- /dev/null +++ b/server/tests/timeout_stream_tests.rs @@ -0,0 +1,53 @@ +extern crate futures; +extern crate tokio_core; +extern crate trust_dns_server; + +use std::io; +use std::time::Duration; +use futures::{Async, Poll}; +use futures::stream::{iter, Stream}; +use tokio_core::reactor::Core; + +use trust_dns_server::server::TimeoutStream; + +#[test] +fn test_no_timeout() { + let sequence = iter(vec![Ok(1), Err("error"), Ok(2)]).map_err(|e| io::Error::new(io::ErrorKind::Other, e)); + let mut core = Core::new().expect("could not get core"); + + let timeout_stream = TimeoutStream::new(sequence, Duration::from_secs(360), core.handle()).expect("could not create timeout_stream"); + + let (val, timeout_stream) = core.run(timeout_stream.into_future()).ok().expect("first run failed"); + assert_eq!(val, Some(1)); + + let error = core.run(timeout_stream.into_future()); + assert!(error.is_err()); + + let (_, timeout_stream) = error.err().unwrap(); + + let (val, timeout_stream) = core.run(timeout_stream.into_future()).ok().expect("third run failed"); + assert_eq!(val, Some(2)); + + let (val, _) = core.run(timeout_stream.into_future()).ok().expect("fourth run failed"); + assert!(val.is_none()) +} + +struct NeverStream {} + +impl Stream for NeverStream { + type Item = (); + type Error = io::Error; + + // somehow insert a timeout here... + fn poll(&mut self) -> Poll, Self::Error> { + Ok(Async::NotReady) + } +} + +#[test] +fn test_timeout() { + let mut core = Core::new().expect("could not get core"); + let timeout_stream = TimeoutStream::new(NeverStream{}, Duration::from_millis(1), core.handle()).expect("could not create timeout_stream"); + + assert!(core.run(timeout_stream.into_future()).is_err()); +}