diff --git a/src/lib.rs b/src/lib.rs index 455228a..6b9f49e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,9 +2,10 @@ use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::{timeout, Duration}; +use tokio::time::timeout; use tokio_io_timeout::TimeoutStream; use hyper::client::connect::{Connect, Connected, Connection}; @@ -58,33 +59,28 @@ where } fn call(&mut self, dst: Uri) -> Self::Future { + let connect_timeout = self.connect_timeout; let read_timeout = self.read_timeout; let write_timeout = self.write_timeout; let connecting = self.connector.call(dst); - if self.connect_timeout.is_none() { - let fut = async move { - let io = connecting.await.map_err(Into::into)?; - - let mut tm = TimeoutConnectorStream::new(TimeoutStream::new(io)); - tm.set_read_timeout(read_timeout); - tm.set_write_timeout(write_timeout); - Ok(Box::pin(tm)) - }; - - return Box::pin(fut); - } - - let connect_timeout = self.connect_timeout.expect("Connect timeout should be set"); - let timeout = timeout(connect_timeout, connecting); - let fut = async move { - let connecting = timeout - .await - .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?; - let io = connecting.map_err(Into::into)?; + let stream = match connect_timeout { + None => { + let io = connecting.await.map_err(Into::into)?; + TimeoutStream::new(io) + } + Some(connect_timeout) => { + let timeout = timeout(connect_timeout, connecting); + let connecting = timeout + .await + .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?; + let io = connecting.map_err(Into::into)?; + TimeoutStream::new(io) + } + }; - let mut tm = TimeoutConnectorStream::new(TimeoutStream::new(io)); + let mut tm = TimeoutConnectorStream::new(stream); tm.set_read_timeout(read_timeout); tm.set_write_timeout(write_timeout); Ok(Box::pin(tm))