Skip to content

Commit

Permalink
Fix handling payload timer after payload got consumed (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed May 29, 2024
1 parent 3b49828 commit 9c29de1
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 96 deletions.
9 changes: 3 additions & 6 deletions ntex-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
unreachable_pub,
missing_debug_implementations
)]

use std::{future::Future, rc::Rc};
use std::rc::Rc;

mod and_then;
mod apply;
Expand Down Expand Up @@ -183,11 +182,9 @@ pub trait ServiceFactory<Req, Cfg = ()> {
type InitError;

/// Create and return a new service value asynchronously.
fn create(
&self,
cfg: Cfg,
) -> impl Future<Output = Result<Self::Service, Self::InitError>>;
async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError>;

#[inline]
/// Create and return a new service value asynchronously and wrap into a container
async fn pipeline(&self, cfg: Cfg) -> Result<Pipeline<Self::Service>, Self::InitError>
where
Expand Down
4 changes: 4 additions & 0 deletions ntex/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [2.0.1] - 2024-05-29

* http: Fix handling payload timer after payload got consumed

## [2.0.0] - 2024-05-28

* Use "async fn" for Service::ready() and Service::shutdown()
Expand Down
6 changes: 3 additions & 3 deletions ntex/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "2.0.0"
version = "2.0.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
Expand Down Expand Up @@ -63,10 +63,10 @@ ntex-router = "0.5.3"
ntex-service = "3.0"
ntex-macros = "0.1.3"
ntex-util = "2.0"
ntex-bytes = "0.1.25"
ntex-bytes = "0.1.27"
ntex-server = "2.0"
ntex-h2 = "1.0"
ntex-rt = "0.4.12"
ntex-rt = "0.4.13"
ntex-io = "2.0"
ntex-net = "2.0"
ntex-tls = "2.0"
Expand Down
15 changes: 8 additions & 7 deletions ntex/src/http/h1/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::http::message::CurrentIo;
use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError};
use crate::io::{Filter, Io, IoBoxed};

#[derive(Debug)]
pub enum Control<F, Err> {
/// New request is loaded
NewRequest(NewRequest),
Expand Down Expand Up @@ -40,19 +41,19 @@ bitflags::bitflags! {

#[derive(Debug)]
pub(super) enum ControlResult {
// handle request expect
/// handle request expect
Expect(Request),
// handle request upgrade
/// handle request upgrade
Upgrade(Request),
// forward request to publish service
/// forward request to publish service
Publish(Request),
// forward request to publish service
/// forward request to publish service
PublishUpgrade(Request),
// send response
/// send response
Response(Response<()>, Body),
// send response
/// send response
ResponseWithIo(Response<()>, Body, IoBoxed),
// drop connection
/// drop connection
Stop,
}

Expand Down
2 changes: 2 additions & 0 deletions ntex/src/http/h1/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ where
type Service = DefaultControlService;
type InitError = io::Error;

#[inline]
async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
Ok(DefaultControlService)
}
Expand All @@ -33,6 +34,7 @@ where
type Response = ControlAck;
type Error = io::Error;

#[inline]
async fn call(
&self,
req: Control<F, Err>,
Expand Down
157 changes: 85 additions & 72 deletions ntex/src/http/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,20 @@ where
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<State<F, C, S, B>> {
if self.payload.is_some() {
if let Some(st) = ready!(self.poll_request_payload(cx)) {
return Poll::Ready(st);
Poll::Ready(st)
} else {
Poll::Pending
}
} else {
// check for io changes, could close while waiting for service call
match ready!(self.io.poll_status_update(cx)) {
IoStatusUpdate::KeepAlive => Poll::Pending,
IoStatusUpdate::Stop | IoStatusUpdate::PeerGone(_) => {
Poll::Ready(self.stop())
}
IoStatusUpdate::WriteBackpressure => Poll::Pending,
}
} else if self.poll_io_closed(cx) {
// check if io is closed
return Poll::Ready(self.stop());
}
Poll::Pending
}

fn set_payload_error(&mut self, err: PayloadError) {
Expand Down Expand Up @@ -580,6 +587,7 @@ where
self.payload.as_mut().unwrap().1.feed_data(chunk);
}
Ok(PayloadItem::Eof) => {
self.flags.remove(Flags::READ_PL_TIMEOUT);
self.payload.as_mut().unwrap().1.feed_eof();
self.payload = None;
break;
Expand Down Expand Up @@ -651,76 +659,66 @@ where
}
}

/// check for io changes, could close while waiting for service call
fn poll_io_closed(&self, cx: &mut Context<'_>) -> bool {
match self.io.poll_status_update(cx) {
Poll::Pending => false,
Poll::Ready(
IoStatusUpdate::KeepAlive
| IoStatusUpdate::Stop
| IoStatusUpdate::PeerGone(_),
) => true,
Poll::Ready(IoStatusUpdate::WriteBackpressure) => false,
}
}

fn handle_timeout(&mut self) -> Result<(), ProtocolError> {
// check read rate
if self
.flags
.intersects(Flags::READ_PL_TIMEOUT | Flags::READ_HDRS_TIMEOUT)
{
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
&self.config.headers_read_rate
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
&self.config.headers_read_rate
} else if self.flags.contains(Flags::READ_PL_TIMEOUT) {
&self.config.payload_read_rate
} else {
return Ok(());
};

if let Some(ref cfg) = cfg {
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
let total = (self.read_remains - self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_remains = 0;
total
} else {
&self.config.payload_read_rate
let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_consumed = 0;
total
};

if let Some(ref cfg) = cfg {
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
let total = (self.read_remains - self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_remains = 0;
total
} else {
let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_consumed = 0;
total
};

if total > cfg.rate {
// update max timeout
if !cfg.max_timeout.is_zero() {
self.read_max_timeout =
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
}
if total > cfg.rate {
// update max timeout
if !cfg.max_timeout.is_zero() {
self.read_max_timeout =
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
}

// start timer for next period
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
log::trace!(
"{}: Bytes read rate {:?}, extend timer",
self.io.tag(),
total
);
self.io.start_timer(cfg.timeout);
return Ok(());
}
// start timer for next period
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
log::trace!(
"{}: Bytes read rate {:?}, extend timer",
self.io.tag(),
total
);
self.io.start_timer(cfg.timeout);
return Ok(());
}
}
}

log::trace!("{}: Timeout during reading", self.io.tag());
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
self.set_payload_error(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)));
Err(ProtocolError::SlowPayloadTimeout)
log::trace!(
"{}: Timeout during reading, {:?}",
self.io.tag(),
self.flags
);
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
self.set_payload_error(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)));
Err(ProtocolError::SlowPayloadTimeout)
} else {
Err(ProtocolError::SlowRequestTimeout)
}
} else {
Err(ProtocolError::SlowRequestTimeout)
Ok(())
}
}

Expand All @@ -731,7 +729,6 @@ where
// got parsed frame
if decoded.item.is_some() {
self.read_remains = 0;
self.io.stop_timer();
self.flags.remove(
Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT,
);
Expand All @@ -741,16 +738,16 @@ where
} else if self.read_remains == 0 && decoded.remains == 0 {
// no new data, start keep-alive timer
if self.codec.keepalive() {
if !self.flags.contains(Flags::READ_KA_TIMEOUT) {
if !self.flags.contains(Flags::READ_KA_TIMEOUT)
&& self.config.keep_alive_enabled()
{
log::debug!(
"{}: Start keep-alive timer {:?}",
self.io.tag(),
self.config.keep_alive
);
self.flags.insert(Flags::READ_KA_TIMEOUT);
if self.config.keep_alive_enabled() {
self.io.start_timer(self.config.keep_alive);
}
self.io.start_timer(self.config.keep_alive);
}
} else {
self.io.close();
Expand All @@ -765,7 +762,8 @@ where

// we got new data but not enough to parse single frame
// start read timer
self.flags.remove(Flags::READ_KA_TIMEOUT);
self.flags
.remove(Flags::READ_KA_TIMEOUT | Flags::READ_PL_TIMEOUT);
self.flags.insert(Flags::READ_HDRS_TIMEOUT);

self.read_consumed = 0;
Expand All @@ -781,6 +779,8 @@ where
self.read_remains = decoded.remains as u32;
self.read_consumed += decoded.consumed as u32;
} else if let Some(ref cfg) = self.config.payload_read_rate {
log::debug!("{}: Start payload timer {:?}", self.io.tag(), cfg.timeout);

// start payload timer
self.flags.insert(Flags::READ_PL_TIMEOUT);

Expand Down Expand Up @@ -1298,6 +1298,8 @@ mod tests {
async fn test_payload_timeout() {
let mark = Arc::new(AtomicUsize::new(0));
let mark2 = mark.clone();
let err_mark = Arc::new(AtomicUsize::new(0));
let err_mark2 = err_mark.clone();

let (client, server) = Io::create();
client.remote_buffer_cap(4096);
Expand Down Expand Up @@ -1332,7 +1334,17 @@ mod tests {
Rc::new(DispatcherConfig::new(
config,
svc.into_service(),
DefaultControlService,
fn_service(move |msg: Control<_, _>| {
if let Control::ProtocolError(ref err) = msg {
if matches!(err.err(), ProtocolError::SlowPayloadTimeout) {
err_mark2.store(
err_mark2.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
}
}
async move { Ok::<_, io::Error>(msg.ack()) }
}),
)),
);
crate::rt::spawn(disp);
Expand All @@ -1347,5 +1359,6 @@ mod tests {
sleep(Millis(750)).await;
}
assert!(mark.load(Ordering::Relaxed) == 1536);
assert!(err_mark.load(Ordering::Relaxed) == 1);
}
}
Loading

0 comments on commit 9c29de1

Please sign in to comment.