diff --git a/CHANGELOG.md b/CHANGELOG.md index 601dd527..89705114 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- Show TCP read bytes instead of body size + # 0.5.5 (2022-09-19) - Add colors to the tui view #64 diff --git a/src/client.rs b/src/client.rs index faece494..1b4875b9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,11 @@ use futures::future::FutureExt; use futures::StreamExt; use rand::prelude::*; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use thiserror::Error; +use crate::tcp_stream::CustomTcpStream; use crate::ConnectToEntry; #[derive(Debug, Clone)] @@ -128,6 +130,7 @@ impl ClientBuilder { rng: rand::rngs::StdRng::from_entropy(), }, client: None, + read_bytes_counter: Arc::new(AtomicUsize::new(0)), timeout: self.timeout, http_version: self.http_version, redirect_limit: self.redirect_limit, @@ -195,6 +198,7 @@ pub struct Client { body: Option<&'static [u8]>, dns: DNS, client: Option>, + read_bytes_counter: Arc, timeout: Option, redirect_limit: usize, disable_keepalive: bool, @@ -211,6 +215,7 @@ impl Client { } else { let stream = tokio::net::TcpStream::connect(addr).await?; stream.set_nodelay(true)?; + let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone()); // stream.set_keepalive(std::time::Duration::from_secs(1).into())?; let (send, conn) = hyper::client::conn::handshake(stream).await?; tokio::spawn(conn); @@ -225,6 +230,7 @@ impl Client { ) -> Result, ClientError> { let stream = tokio::net::TcpStream::connect(addr).await?; stream.set_nodelay(true)?; + let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone()); let connector = if self.insecure { native_tls::TlsConnector::builder() @@ -251,6 +257,7 @@ impl Client { ) -> Result, ClientError> { let stream = tokio::net::TcpStream::connect(addr).await?; stream.set_nodelay(true)?; + let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone()); let mut root_cert_store = rustls::RootCertStore::empty(); for cert in rustls_native_certs::load_native_certs()? { @@ -336,14 +343,11 @@ impl Client { let (parts, mut stream) = res.into_parts(); let mut status = parts.status; - let mut len_sum = 0; - while let Some(chunk) = stream.next().await { - len_sum += chunk?.len(); - } + while stream.next().await.is_some() {} if self.redirect_limit != 0 { if let Some(location) = parts.headers.get("Location") { - let (send_request_redirect, new_status, len) = self + let (send_request_redirect, new_status) = self .redirect( send_request, &self.url.clone(), @@ -354,7 +358,6 @@ impl Client { send_request = send_request_redirect; status = new_status; - len_sum = len; } } @@ -364,7 +367,7 @@ impl Client { start, end, status, - len_bytes: len_sum, + len_bytes: self.read_bytes_counter.swap(0, Ordering::Relaxed), connection_time, }; @@ -404,7 +407,6 @@ impl Client { ( hyper::client::conn::SendRequest, http::StatusCode, - usize, ), ClientError, >, @@ -451,28 +453,25 @@ impl Client { )?, ); } + self.read_bytes_counter.store(0, Ordering::Relaxed); let res = send_request.send_request(request).await?; let (parts, mut stream) = res.into_parts(); let mut status = parts.status; - let mut len_sum = 0; - while let Some(chunk) = stream.next().await { - len_sum += chunk?.len(); - } + while stream.next().await.is_some() {} if let Some(location) = parts.headers.get("Location") { - let (send_request_redirect, new_status, len) = self + let (send_request_redirect, new_status) = self .redirect(send_request, &url, location, limit - 1) .await?; send_request = send_request_redirect; status = new_status; - len_sum = len; } if let Some(send_request_base) = send_request_base { - Ok((send_request_base, status, len_sum)) + Ok((send_request_base, status)) } else { - Ok((send_request, status, len_sum)) + Ok((send_request, status)) } } .boxed() @@ -546,7 +545,6 @@ pub async fn work( n_tasks: usize, n_workers: usize, ) { - use std::sync::atomic::{AtomicUsize, Ordering}; let counter = Arc::new(AtomicUsize::new(0)); let futures = (0..n_workers) diff --git a/src/main.rs b/src/main.rs index 0a9e1457..9a301d82 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ mod client; mod histogram; mod monitor; mod printer; +mod tcp_stream; mod timescale; use client::{ClientError, RequestResult}; diff --git a/src/tcp_stream.rs b/src/tcp_stream.rs new file mode 100644 index 00000000..3e3e0bcc --- /dev/null +++ b/src/tcp_stream.rs @@ -0,0 +1,69 @@ +// Ported from https://github.com/lnx-search/rewrk/pull/6 +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::ReadBuf; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + +use std::io::{IoSlice, Result}; + +pub struct CustomTcpStream { + inner: TcpStream, + counter: Arc, +} + +impl CustomTcpStream { + pub fn new(stream: TcpStream, counter: Arc) -> Self { + Self { + inner: stream, + counter, + } + } +} + +impl AsyncRead for CustomTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let result = Pin::new(&mut self.inner).poll_read(cx, buf); + + self.counter + .fetch_add(buf.filled().len(), Ordering::Relaxed); + + result + } +} + +impl AsyncWrite for CustomTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +}