diff --git a/.config/nats.dic b/.config/nats.dic index e2e2b0174..219428adb 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -141,3 +141,5 @@ filter_subjects rollup IoT ObjectMeta +128k +ObjectMetadata diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..eb5a316cb --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +target diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cb5d75e00..095b68931 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,7 +66,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: '1.19' + go-version: '1.20' - name: Install nats-server run: go install github.com/nats-io/nats-server/v2@${{ matrix.nats_server }} @@ -206,7 +206,7 @@ jobs: - name: Set up go uses: actions/setup-go@v2 with: - go-version: '1.19' + go-version: '1.20' - name: Set up rust run: | diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index 32679f76b..284932599 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -16,8 +16,8 @@ categories = ["network-programming", "api-bindings"] [dependencies] memchr = "2.4" bytes = { version = "1.4.0", features = ["serde"] } -futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] } -nkeys = "0.3.0" +futures = { version = "0.3.28", default-features = false, features = ["std"] } +nkeys = "0.3.1" once_cell = "1.18.0" regex = "1.9.1" serde = { version = "1.0.184", features = ["derive"] } @@ -25,7 +25,6 @@ serde_json = "1.0.104" serde_repr = "0.1.16" http = "0.2.9" tokio = { version = "1.29.0", features = ["macros", "rt", "fs", "net", "sync", "time", "io-util"] } -itoa = "1" url = { version = "2"} tokio-rustls = "0.24" rustls-pemfile = "1.0.2" @@ -41,13 +40,18 @@ ring = "0.16" rand = "0.8" webpki = { package = "rustls-webpki", version = "0.101.2", features = ["alloc", "std"] } +# for -Z minimal-versions +rustls = "0.21.6" # used by tokio-rustls 0.24.0 + [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"]} nats-server = { path = "../nats-server" } rand = "0.8" tokio = { version = "1.25.0", features = ["rt-multi-thread"] } +futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] } tracing-subscriber = "0.3" async-nats = {path = ".", features = ["experimental"]} +reqwest = "0.11.18" [features] @@ -55,9 +59,14 @@ service = [] experimental = ["service"] "server_2_10" = [] slow_tests = [] +compatibility_tests = [] [[bench]] name = "main" harness = false lto = true + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 94a15edf9..57f1449e0 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -346,9 +346,6 @@ impl Client { } None => self.publish_with_reply(subject, inbox, payload).await?, } - self.flush() - .await - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; let request = match timeout { Some(timeout) => { tokio::time::timeout(timeout, sub.next()) @@ -517,12 +514,11 @@ impl Client { pub async fn flush(&self) -> Result<(), FlushError> { let (tx, rx) = tokio::sync::oneshot::channel(); self.sender - .send(Command::Flush { result: tx }) + .send(Command::Flush { observer: tx }) .await .map_err(|err| FlushError::with_source(FlushErrorKind::SendError, err))?; - // first question mark is an error from rx itself, second for error from flush. + rx.await - .map_err(|err| FlushError::with_source(FlushErrorKind::FlushError, err))? .map_err(|err| FlushError::with_source(FlushErrorKind::FlushError, err))?; Ok(()) } diff --git a/async-nats/src/connection.rs b/async-nats/src/connection.rs index 2790de2a1..de6a092fb 100644 --- a/async-nats/src/connection.rs +++ b/async-nats/src/connection.rs @@ -13,19 +13,30 @@ //! This module provides a connection implementation for communicating with a NATS server. -use std::fmt::Display; +use std::collections::VecDeque; +use std::fmt::{self, Display, Write as _}; +use std::future::{self, Future}; +use std::io::IoSlice; +use std::pin::Pin; use std::str::{self, FromStr}; +use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWriteExt}; -use tokio::io::{AsyncReadExt, AsyncWrite}; - -use bytes::{Buf, BytesMut}; -use tokio::io; +use bytes::{Buf, Bytes, BytesMut}; +use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite}; use crate::header::{HeaderMap, HeaderName, IntoHeaderValue}; use crate::status::StatusCode; use crate::{ClientOp, ServerError, ServerOp}; +/// Soft limit for the amount of bytes in [`Connection::write_buf`] +/// and [`Connection::flattened_writes`]. +const SOFT_WRITE_BUF_LIMIT: usize = 65535; +/// How big a single buffer must be before it's written separately +/// instead of being flattened. +const WRITE_FLATTEN_THRESHOLD: usize = 4096; +/// How many buffers to write in a single vectored write call. +const WRITE_VECTORED_CHUNKS: usize = 64; + /// Supertrait enabling trait object for containing both TLS and non TLS `TcpStream` connection. pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {} @@ -53,57 +64,82 @@ impl Display for State { /// A framed connection pub(crate) struct Connection { pub(crate) stream: Box, - pub(crate) buffer: BytesMut, + read_buf: BytesMut, + write_buf: VecDeque, + write_buf_len: usize, + flattened_writes: BytesMut, + can_flush: bool, } /// Internal representation of the connection. /// Holds connection with NATS Server and communicates with `Client` via channels. impl Connection { + pub(crate) fn new(stream: Box, read_buffer_capacity: usize) -> Self { + Self { + stream, + read_buf: BytesMut::with_capacity(read_buffer_capacity), + write_buf: VecDeque::new(), + write_buf_len: 0, + flattened_writes: BytesMut::new(), + can_flush: false, + } + } + + /// Returns `true` if no more calls to [`Self::enqueue_write_op`] _should_ be made. + pub(crate) fn is_write_buf_full(&self) -> bool { + self.write_buf_len >= SOFT_WRITE_BUF_LIMIT + } + + /// Returns `true` if [`Self::poll_flush`] should be polled. + pub(crate) fn should_flush(&self) -> bool { + self.can_flush && self.write_buf.is_empty() && self.flattened_writes.is_empty() + } + /// Attempts to read a server operation from the read buffer. /// Returns `None` if there is not enough data to parse an entire operation. pub(crate) fn try_read_op(&mut self) -> Result, io::Error> { - let len = match memchr::memmem::find(&self.buffer, b"\r\n") { + let len = match memchr::memmem::find(&self.read_buf, b"\r\n") { Some(len) => len, None => return Ok(None), }; - if self.buffer.starts_with(b"+OK") { - self.buffer.advance(len + 2); + if self.read_buf.starts_with(b"+OK") { + self.read_buf.advance(len + 2); return Ok(Some(ServerOp::Ok)); } - if self.buffer.starts_with(b"PING") { - self.buffer.advance(len + 2); + if self.read_buf.starts_with(b"PING") { + self.read_buf.advance(len + 2); return Ok(Some(ServerOp::Ping)); } - if self.buffer.starts_with(b"PONG") { - self.buffer.advance(len + 2); + if self.read_buf.starts_with(b"PONG") { + self.read_buf.advance(len + 2); return Ok(Some(ServerOp::Pong)); } - if self.buffer.starts_with(b"-ERR") { - let description = str::from_utf8(&self.buffer[5..len]) + if self.read_buf.starts_with(b"-ERR") { + let description = str::from_utf8(&self.read_buf[5..len]) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))? .trim_matches('\'') .to_owned(); - self.buffer.advance(len + 2); + self.read_buf.advance(len + 2); return Ok(Some(ServerOp::Error(ServerError::new(description)))); } - if self.buffer.starts_with(b"INFO ") { - let info = serde_json::from_slice(&self.buffer[4..len]) + if self.read_buf.starts_with(b"INFO ") { + let info = serde_json::from_slice(&self.read_buf[4..len]) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; - self.buffer.advance(len + 2); + self.read_buf.advance(len + 2); return Ok(Some(ServerOp::Info(Box::new(info)))); } - if self.buffer.starts_with(b"MSG ") { - let line = str::from_utf8(&self.buffer[4..len]).unwrap(); + if self.read_buf.starts_with(b"MSG ") { + let line = str::from_utf8(&self.read_buf[4..len]).unwrap(); let mut args = line.split(' ').filter(|s| !s.is_empty()); // Parse the operation syntax: MSG [reply-to] <#bytes> @@ -139,16 +175,16 @@ impl Connection { // Return early without advancing if there is not enough data read the entire // message - if len + payload_len + 4 > self.buffer.remaining() { + if len + payload_len + 4 > self.read_buf.remaining() { return Ok(None); } let subject = subject.to_owned(); let reply_to = reply_to.map(ToOwned::to_owned); - self.buffer.advance(len + 2); - let payload = self.buffer.split_to(payload_len).freeze(); - self.buffer.advance(2); + self.read_buf.advance(len + 2); + let payload = self.read_buf.split_to(payload_len).freeze(); + self.read_buf.advance(2); let length = payload_len + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0) @@ -165,9 +201,9 @@ impl Connection { })); } - if self.buffer.starts_with(b"HMSG ") { + if self.read_buf.starts_with(b"HMSG ") { // Extract whitespace-delimited arguments that come after "HMSG". - let line = std::str::from_utf8(&self.buffer[5..len]).unwrap(); + let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap(); let mut args = line.split_whitespace().filter(|s| !s.is_empty()); // [reply-to] <# header bytes><# total bytes> @@ -237,14 +273,14 @@ impl Connection { )); } - if len + total_len + 4 > self.buffer.remaining() { + if len + total_len + 4 > self.read_buf.remaining() { return Ok(None); } - self.buffer.advance(len + 2); - let header = self.buffer.split_to(header_len); - let payload = self.buffer.split_to(total_len - header_len).freeze(); - self.buffer.advance(2); + self.read_buf.advance(len + 2); + let header = self.read_buf.split_to(header_len); + let payload = self.read_buf.split_to(total_len - header_len).freeze(); + self.read_buf.advance(2); let mut lines = std::str::from_utf8(&header) .map_err(|_| { @@ -321,7 +357,7 @@ impl Connection { })); } - let buffer = self.buffer.split_to(len + 2); + let buffer = self.read_buf.split_to(len + 2); let line = str::from_utf8(&buffer).map_err(|_| { io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input") })?; @@ -332,35 +368,62 @@ impl Connection { )) } + pub(crate) fn read_op(&mut self) -> impl Future>> + '_ { + future::poll_fn(|cx| self.poll_read_op(cx)) + } + // TODO: do we want an custom error here? /// Read a server operation from read buffer. /// Blocks until an operation ca be parsed. - pub(crate) async fn read_op(&mut self) -> Result, io::Error> { + pub(crate) fn poll_read_op( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { loop { if let Some(op) = self.try_read_op()? { - return Ok(Some(op)); + return Poll::Ready(Ok(Some(op))); } - if 0 == self.stream.read_buf(&mut self.buffer).await? { - if self.buffer.is_empty() { - return Ok(None); - } else { - return Err(io::Error::new(io::ErrorKind::ConnectionReset, "")); - } - } + let read_buf = self.stream.read_buf(&mut self.read_buf); + tokio::pin!(read_buf); + return match read_buf.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)), + Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())), + Poll::Ready(Ok(_n)) => continue, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + }; + } + } + + pub(crate) async fn easy_write_and_flush<'a>( + &mut self, + items: impl Iterator, + ) -> io::Result<()> { + for item in items { + self.enqueue_write_op(item); } + + future::poll_fn(|cx| self.poll_write(cx)).await?; + future::poll_fn(|cx| self.poll_flush(cx)).await?; + Ok(()) } /// Writes a client operation to the write buffer. - pub(crate) async fn write_op<'a>(&mut self, item: &'a ClientOp) -> Result<(), io::Error> { + pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) { + macro_rules! small_write { + ($dst:expr) => { + write!(self.small_write(), $dst).expect("do small write to Connection"); + }; + } + match item { ClientOp::Connect(connect_info) => { - let op = format!( - "CONNECT {}\r\n", - serde_json::to_string(&connect_info) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? - ); - self.stream.write_all(op.as_bytes()).await?; + let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`"); + + self.write("CONNECT "); + self.write(json); + self.write("\r\n"); } ClientOp::Publish { subject, @@ -368,99 +431,230 @@ impl Connection { respond, headers, } => { - match headers.as_ref() { - Some(headers) if !headers.is_empty() => { - self.stream.write_all(b"HPUB ").await?; - } - _ => { - self.stream.write_all(b"PUB ").await?; - } - } + let verb = match headers.as_ref() { + Some(headers) if !headers.is_empty() => "HPUB", + _ => "PUB", + }; - self.stream.write_all(subject.as_bytes()).await?; - self.stream.write_all(b" ").await?; + small_write!("{verb} {subject} "); if let Some(respond) = respond { - self.stream.write_all(respond.as_bytes()).await?; - self.stream.write_all(b" ").await?; + small_write!("{respond} "); } match headers { Some(headers) if !headers.is_empty() => { let headers = headers.to_bytes(); - let mut header_len_buf = itoa::Buffer::new(); - self.stream - .write_all(header_len_buf.format(headers.len()).as_bytes()) - .await?; - - self.stream.write_all(b" ").await?; - - let mut total_len_buf = itoa::Buffer::new(); - self.stream - .write_all( - total_len_buf - .format(headers.len() + payload.len()) - .as_bytes(), - ) - .await?; - - self.stream.write_all(b"\r\n").await?; - self.stream.write_all(&headers).await?; + let headers_len = headers.len(); + let total_len = headers_len + payload.len(); + small_write!("{headers_len} {total_len}\r\n"); + self.write(headers); } _ => { - let mut len_buf = itoa::Buffer::new(); - self.stream - .write_all(len_buf.format(payload.len()).as_bytes()) - .await?; - self.stream.write_all(b"\r\n").await?; + let payload_len = payload.len(); + small_write!("{payload_len}\r\n"); } } - self.stream.write_all(payload).await?; - self.stream.write_all(b"\r\n").await?; + self.write(Bytes::clone(payload)); + self.write("\r\n"); } ClientOp::Subscribe { sid, subject, queue_group, - } => { - self.stream.write_all(b"SUB ").await?; - self.stream.write_all(subject.as_bytes()).await?; - if let Some(queue_group) = queue_group { - self.stream - .write_all(format!(" {queue_group}").as_bytes()) - .await?; + } => match queue_group { + Some(queue_group) => { + small_write!("SUB {subject} {queue_group} {sid}\r\n"); } - self.stream - .write_all(format!(" {sid}\r\n").as_bytes()) - .await?; - } + None => { + small_write!("SUB {subject} {sid}\r\n"); + } + }, - ClientOp::Unsubscribe { sid, max } => { - self.stream.write_all(b"UNSUB ").await?; - self.stream.write_all(format!("{sid}").as_bytes()).await?; - if let Some(max) = max { - self.stream.write_all(format!(" {max}").as_bytes()).await?; + ClientOp::Unsubscribe { sid, max } => match max { + Some(max) => { + small_write!("UNSUB {sid} {max}\r\n"); } - self.stream.write_all(b"\r\n").await?; - } + None => { + small_write!("UNSUB {sid}\r\n"); + } + }, ClientOp::Ping => { - self.stream.write_all(b"PING\r\n").await?; - self.stream.flush().await?; + self.write("PING\r\n"); } ClientOp::Pong => { - self.stream.write_all(b"PONG\r\n").await?; + self.write("PONG\r\n"); } } + } - Ok(()) + /// Write the internal buffers into the write stream + /// + /// Returns one of the following: + /// + /// * `Poll::Pending` means that we weren't able to fully empty + /// the internal buffers. Compared to [`AsyncWrite::poll_write`], + /// this implementation may do a partial write before yielding. + /// * `Poll::Ready(Ok())` means that the internal write buffers have + /// been emptied or were already empty. + /// * `Poll::Ready(Err(err))` means that writing to the stream failed. + /// Compared to [`AsyncWrite::poll_write`], this implementation + /// may do a partial write before failing. + pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.stream.is_write_vectored() { + self.poll_write_sequential(cx) + } else { + self.poll_write_vectored(cx) + } + } + + /// Write the internal buffers into the write stream using sequential write operations + /// + /// Writes one chunk at a time. Less efficient. + fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let buf = match self.write_buf.front() { + Some(buf) => &**buf, + None if !self.flattened_writes.is_empty() => &self.flattened_writes, + None => return Poll::Ready(Ok(())), + }; + + debug_assert!(!buf.is_empty()); + + match Pin::new(&mut self.stream).poll_write(cx, buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(n)) => { + self.write_buf_len -= n; + self.can_flush = true; + + match self.write_buf.front_mut() { + Some(buf) if n < buf.len() => { + buf.advance(n); + } + Some(_buf) => { + self.write_buf.pop_front(); + } + None => { + self.flattened_writes.advance(n); + } + } + continue; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + } + + /// Write the internal buffers into the write stream using vectored write operations + /// + /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_ + /// the underlying writer supports it. + fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll> { + 'outer: loop { + let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS]; + let mut writes_len = 0; + + self.write_buf + .iter() + .take(WRITE_VECTORED_CHUNKS) + .enumerate() + .for_each(|(i, buf)| { + writes[i] = IoSlice::new(buf); + writes_len += 1; + }); + + if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() { + writes[writes_len] = IoSlice::new(&self.flattened_writes); + writes_len += 1; + } + + if writes_len == 0 { + return Poll::Ready(Ok(())); + } + + match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(mut n)) => { + self.write_buf_len -= n; + self.can_flush = true; + + while let Some(buf) = self.write_buf.front_mut() { + if n < buf.len() { + buf.advance(n); + continue 'outer; + } + + n -= buf.len(); + self.write_buf.pop_front(); + } + + self.flattened_writes.advance(n); + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + } + + /// Write `buf` into the writes buffer + /// + /// If `buf` is smaller than [`WRITE_FLATTEN_THRESHOLD`] + /// flattens it, otherwise appends it to the chunks queue. + /// + /// Empty `buf`s are a no-op. + fn write(&mut self, buf: impl Into) { + let buf = buf.into(); + if buf.is_empty() { + return; + } + + self.write_buf_len += buf.len(); + if buf.len() < WRITE_FLATTEN_THRESHOLD { + self.flattened_writes.extend_from_slice(&buf); + } else { + if !self.flattened_writes.is_empty() { + let buf = self.flattened_writes.split().freeze(); + self.write_buf.push_back(buf); + } + + self.write_buf.push_back(buf); + } + } + + /// Obtain an [`fmt::Write`]r for the small writes buffer. + fn small_write(&mut self) -> impl fmt::Write + '_ { + struct Writer<'a> { + this: &'a mut Connection, + } + + impl<'a> fmt::Write for Writer<'a> { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.this.write_buf_len += s.len(); + self.this.flattened_writes.write_str(s) + } + } + + Writer { this: self } } /// Flush the write buffer, sending all pending data down the current write stream. - pub(crate) async fn flush(&mut self) -> Result<(), io::Error> { - self.stream.flush().await + /// + /// no-op if the write stream didn't need to be flushed. + pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.can_flush { + return Poll::Ready(Ok(())); + } + + match Pin::new(&mut self.stream).poll_flush(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + self.can_flush = false; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } } } @@ -468,16 +662,12 @@ impl Connection { mod read_op { use super::Connection; use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, StatusCode}; - use bytes::BytesMut; use tokio::io::{self, AsyncWriteExt}; #[tokio::test] async fn ok() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"+OK\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -487,10 +677,7 @@ mod read_op { #[tokio::test] async fn ping() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"PING\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -500,10 +687,7 @@ mod read_op { #[tokio::test] async fn pong() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"PONG\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -513,10 +697,7 @@ mod read_op { #[tokio::test] async fn info() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"INFO {}\r\n").await.unwrap(); server.flush().await.unwrap(); @@ -543,10 +724,7 @@ mod read_op { #[tokio::test] async fn error() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"INFO {}\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -568,10 +746,7 @@ mod read_op { #[tokio::test] async fn message() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n") @@ -718,10 +893,7 @@ mod read_op { #[tokio::test] async fn unknown() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); server.write_all(b"ONE\r\n").await.unwrap(); connection.read_op().await.unwrap_err(); @@ -773,27 +945,25 @@ mod read_op { mod write_op { use super::Connection; use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol}; - use bytes::BytesMut; use tokio::io::{self, AsyncBufReadExt, BufReader}; #[tokio::test] async fn publish() { let (stream, server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); connection - .write_op(&ClientOp::Publish { - subject: "FOO.BAR".into(), - payload: "Hello World".into(), - respond: None, - headers: None, - }) + .easy_write_and_flush( + [ClientOp::Publish { + subject: "FOO.BAR".into(), + payload: "Hello World".into(), + respond: None, + headers: None, + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); let mut buffer = String::new(); let mut reader = BufReader::new(server); @@ -802,15 +972,17 @@ mod write_op { assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n"); connection - .write_op(&ClientOp::Publish { - subject: "FOO.BAR".into(), - payload: "Hello World".into(), - respond: Some("INBOX.67".into()), - headers: None, - }) + .easy_write_and_flush( + [ClientOp::Publish { + subject: "FOO.BAR".into(), + payload: "Hello World".into(), + respond: Some("INBOX.67".into()), + headers: None, + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); buffer.clear(); reader.read_line(&mut buffer).await.unwrap(); @@ -818,18 +990,20 @@ mod write_op { assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n"); connection - .write_op(&ClientOp::Publish { - subject: "FOO.BAR".into(), - payload: "Hello World".into(), - respond: Some("INBOX.67".into()), - headers: Some(HeaderMap::from_iter([( - "Header".parse().unwrap(), - "X".parse().unwrap(), - )])), - }) + .easy_write_and_flush( + [ClientOp::Publish { + subject: "FOO.BAR".into(), + payload: "Hello World".into(), + respond: Some("INBOX.67".into()), + headers: Some(HeaderMap::from_iter([( + "Header".parse().unwrap(), + "X".parse().unwrap(), + )])), + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); buffer.clear(); reader.read_line(&mut buffer).await.unwrap(); @@ -845,20 +1019,19 @@ mod write_op { #[tokio::test] async fn subscribe() { let (stream, server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); connection - .write_op(&ClientOp::Subscribe { - sid: 11, - subject: "FOO.BAR".into(), - queue_group: None, - }) + .easy_write_and_flush( + [ClientOp::Subscribe { + sid: 11, + subject: "FOO.BAR".into(), + queue_group: None, + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); let mut buffer = String::new(); let mut reader = BufReader::new(server); @@ -866,14 +1039,16 @@ mod write_op { assert_eq!(buffer, "SUB FOO.BAR 11\r\n"); connection - .write_op(&ClientOp::Subscribe { - sid: 11, - subject: "FOO.BAR".into(), - queue_group: Some("QUEUE.GROUP".into()), - }) + .easy_write_and_flush( + [ClientOp::Subscribe { + sid: 11, + subject: "FOO.BAR".into(), + queue_group: Some("QUEUE.GROUP".into()), + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); buffer.clear(); reader.read_line(&mut buffer).await.unwrap(); @@ -883,16 +1058,12 @@ mod write_op { #[tokio::test] async fn unsubscribe() { let (stream, server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); connection - .write_op(&ClientOp::Unsubscribe { sid: 11, max: None }) + .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter()) .await .unwrap(); - connection.flush().await.unwrap(); let mut buffer = String::new(); let mut reader = BufReader::new(server); @@ -900,13 +1071,15 @@ mod write_op { assert_eq!(buffer, "UNSUB 11\r\n"); connection - .write_op(&ClientOp::Unsubscribe { - sid: 11, - max: Some(2), - }) + .easy_write_and_flush( + [ClientOp::Unsubscribe { + sid: 11, + max: Some(2), + }] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); buffer.clear(); reader.read_line(&mut buffer).await.unwrap(); @@ -916,16 +1089,15 @@ mod write_op { #[tokio::test] async fn ping() { let (stream, server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); let mut reader = BufReader::new(server); let mut buffer = String::new(); - connection.write_op(&ClientOp::Ping).await.unwrap(); - connection.flush().await.unwrap(); + connection + .easy_write_and_flush([ClientOp::Ping].iter()) + .await + .unwrap(); reader.read_line(&mut buffer).await.unwrap(); @@ -935,16 +1107,15 @@ mod write_op { #[tokio::test] async fn pong() { let (stream, server) = io::duplex(128); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); let mut reader = BufReader::new(server); let mut buffer = String::new(); - connection.write_op(&ClientOp::Pong).await.unwrap(); - connection.flush().await.unwrap(); + connection + .easy_write_and_flush([ClientOp::Pong].iter()) + .await + .unwrap(); reader.read_line(&mut buffer).await.unwrap(); @@ -954,36 +1125,35 @@ mod write_op { #[tokio::test] async fn connect() { let (stream, server) = io::duplex(1024); - let mut connection = Connection { - stream: Box::new(stream), - buffer: BytesMut::new(), - }; + let mut connection = Connection::new(Box::new(stream), 0); let mut reader = BufReader::new(server); let mut buffer = String::new(); connection - .write_op(&ClientOp::Connect(ConnectInfo { - verbose: false, - pedantic: false, - user_jwt: None, - nkey: None, - signature: None, - name: None, - echo: false, - lang: "Rust".into(), - version: "1.0.0".into(), - protocol: Protocol::Dynamic, - tls_required: false, - user: None, - pass: None, - auth_token: None, - headers: false, - no_responders: false, - })) + .easy_write_and_flush( + [ClientOp::Connect(ConnectInfo { + verbose: false, + pedantic: false, + user_jwt: None, + nkey: None, + signature: None, + name: None, + echo: false, + lang: "Rust".into(), + version: "1.0.0".into(), + protocol: Protocol::Dynamic, + tls_required: false, + user: None, + pass: None, + auth_token: None, + headers: false, + no_responders: false, + })] + .iter(), + ) .await .unwrap(); - connection.flush().await.unwrap(); reader.read_line(&mut buffer).await.unwrap(); assert_eq!( diff --git a/async-nats/src/connector.rs b/async-nats/src/connector.rs index 6a285d979..d2a16dc4f 100644 --- a/async-nats/src/connector.rs +++ b/async-nats/src/connector.rs @@ -33,7 +33,6 @@ use crate::LANG; use crate::VERSION; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::engine::Engine; -use bytes::BytesMut; use rand::seq::SliceRandom; use rand::thread_rng; use std::cmp; @@ -41,7 +40,6 @@ use std::io; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use tokio::io::BufWriter; use tokio::io::ErrorKind; use tokio::net::TcpStream; use tokio::time::sleep; @@ -102,10 +100,10 @@ impl Connector { }) } - pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), io::Error> { + pub(crate) async fn connect(&mut self) -> (ServerInfo, Connection) { loop { match self.try_connect().await { - Ok(inner) => return Ok(inner), + Ok(inner) => return inner, Err(error) => { self.events_tx .send(Event::ClientError(ClientError::Other(error.to_string()))) @@ -238,10 +236,10 @@ impl Connector { } connection - .write_op(&ClientOp::Connect(connect_info)) + .easy_write_and_flush( + [ClientOp::Connect(connect_info), ClientOp::Ping].iter(), + ) .await?; - connection.write_op(&ClientOp::Ping).await?; - connection.flush().await?; match connection.read_op().await? { Some(ServerOp::Error(err)) => match err { @@ -296,10 +294,10 @@ impl Connector { tcp_stream.set_nodelay(true)?; - let mut connection = Connection { - stream: Box::new(BufWriter::new(tcp_stream)), - buffer: BytesMut::with_capacity(self.options.read_buffer_capacity.into()), - }; + let mut connection = Connection::new( + Box::new(tcp_stream), + self.options.read_buffer_capacity.into(), + ); let op = connection.read_op().await?; let info = match op { @@ -336,10 +334,10 @@ impl Connector { let domain = rustls::ServerName::try_from(tls_host) .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?; - connection = Connection { - stream: Box::new(tls_connector.connect(domain, connection.stream).await?), - buffer: BytesMut::new(), - }; + connection = Connection::new( + Box::new(tls_connector.connect(domain, connection.stream).await?), + 0, + ); }; Ok((*info, connection)) diff --git a/async-nats/src/jetstream/consumer/pull.rs b/async-nats/src/jetstream/consumer/pull.rs index 9b347b6f5..18c38eb78 100644 --- a/async-nats/src/jetstream/consumer/pull.rs +++ b/async-nats/src/jetstream/consumer/pull.rs @@ -153,11 +153,6 @@ impl Consumer { .publish_with_reply(subject, inbox, payload.into()) .await .map_err(|err| BatchRequestError::with_source(BatchRequestErrorKind::Publish, err))?; - self.context - .client - .flush() - .await - .map_err(|err| BatchRequestError::with_source(BatchRequestErrorKind::Flush, err))?; debug!("batch request sent"); Ok(()) } @@ -918,9 +913,6 @@ impl Stream { .publish_with_reply(subject.clone(), inbox.clone(), request.clone()) .await .map(|_| pending_reset); - if let Err(err) = consumer.context.client.flush().await { - debug!("flush failed: {err:?}"); - } // TODO: add tracing instead of ignoring this. request_result_tx .send(result.map(|_| pending_reset).map_err(|err| { diff --git a/async-nats/src/jetstream/consumer/push.rs b/async-nats/src/jetstream/consumer/push.rs index 7040e080e..1a8bd13e4 100644 --- a/async-nats/src/jetstream/consumer/push.rs +++ b/async-nats/src/jetstream/consumer/push.rs @@ -709,9 +709,110 @@ impl<'a> futures::Stream for Ordered<'a> { } } } - None => return Poll::Ready(None), }, Poll::Pending => return Poll::Pending, + Some(subscriber) => match subscriber.as_mut().poll(cx) { + Poll::Ready(subscriber) => { + self.subscriber_future = None; + self.consumer_sequence.store(0, Ordering::Relaxed); + self.subscriber = Some(subscriber.map_err(|err| { + OrderedError::with_source(OrderedErrorKind::Recreate, err) + })?); + } + Poll::Pending => { + return Poll::Pending; + } + }, + } + } + if let Some(subscriber) = self.subscriber.as_mut() { + match subscriber.receiver.poll_recv(cx) { + Poll::Ready(maybe_message) => match maybe_message { + Some(message) => { + self.heartbeat_sleep = None; + match message.status { + Some(StatusCode::IDLE_HEARTBEAT) => { + debug!("received idle heartbeats"); + if let Some(headers) = message.headers.as_ref() { + if let Some(sequence) = + headers.get(crate::header::NATS_LAST_CONSUMER) + { + let sequence: u64 = + sequence.as_str().parse().map_err(|err| { + OrderedError::with_source( + OrderedErrorKind::Other, + err, + ) + })?; + + let last_sequence = + self.consumer_sequence.load(Ordering::Relaxed); + + if sequence != last_sequence { + debug!("hearbeats sequence mismatch. got {}, expected {}, resetting consumer", sequence, last_sequence); + self.subscriber = None; + } + } + } + // flow control. + if let Some(subject) = message.reply.clone() { + trace!("received flow control message"); + let client = self.context.client.clone(); + tokio::task::spawn(async move { + client + .publish(subject, Bytes::from_static(b"")) + .await + .ok(); + }); + } + continue; + } + Some(status) => { + debug!("received status message: {}", status); + continue; + } + None => { + trace!("received a message"); + let jetstream_message = jetstream::message::Message { + message, + context: self.context.clone(), + }; + + let info = jetstream_message.info().map_err(|err| { + OrderedError::with_source(OrderedErrorKind::Other, err) + })?; + trace!("consumer sequence: {:?}, stream sequence {:?}, consumer sequence in message: {:?} stream sequence in message: {:?}", + self.consumer_sequence, + self.stream_sequence, + info.consumer_sequence, + info.stream_sequence); + if info.consumer_sequence + != self.consumer_sequence.load(Ordering::Relaxed) + 1 + { + debug!( + "ordered consumer mismatch. current {}, info: {}", + self.consumer_sequence.load(Ordering::Relaxed), + info.consumer_sequence + ); + self.subscriber = None; + self.consumer_sequence.store(0, Ordering::Relaxed); + continue; + } + self.stream_sequence + .store(info.stream_sequence, Ordering::Relaxed); + self.consumer_sequence + .store(info.consumer_sequence, Ordering::Relaxed); + return Poll::Ready(Some(Ok(jetstream_message))); + } + } + } + None => { + return Poll::Ready(None); + } + }, + Poll::Pending => return Poll::Pending, + } +>>>>>>> main } } } diff --git a/async-nats/src/jetstream/context.rs b/async-nats/src/jetstream/context.rs index 1fd009848..b1d95a68d 100644 --- a/async-nats/src/jetstream/context.rs +++ b/async-nats/src/jetstream/context.rs @@ -18,7 +18,7 @@ use crate::header::{IntoHeaderName, IntoHeaderValue}; use crate::jetstream::account::Account; use crate::jetstream::publish::PublishAck; use crate::jetstream::response::Response; -use crate::{header, Client, Command, HeaderMap, HeaderValue, StatusCode}; +use crate::{header, Client, HeaderMap, HeaderValue, StatusCode}; use bytes::Bytes; use futures::future::BoxFuture; use futures::{Future, StreamExt, TryFutureExt}; @@ -987,7 +987,6 @@ pub struct PublishAckFuture { impl PublishAckFuture { async fn next_with_timeout(mut self) -> Result { - self.subscription.sender.send(Command::TryFlush).await.ok(); let next = tokio::time::timeout(self.timeout, self.subscription.next()) .await .map_err(|_| PublishError::new(PublishErrorKind::TimedOut))?; diff --git a/async-nats/src/jetstream/object_store/mod.rs b/async-nats/src/jetstream/object_store/mod.rs index 219bfce0b..0fafac3c3 100644 --- a/async-nats/src/jetstream/object_store/mod.rs +++ b/async-nats/src/jetstream/object_store/mod.rs @@ -69,7 +69,7 @@ pub struct Config { /// A short description of the purpose of this storage bucket. pub description: Option, /// Maximum age of any value in the bucket, expressed in nanoseconds - #[serde(with = "serde_nanos")] + #[serde(default, with = "serde_nanos")] pub max_age: Duration, /// The type of storage backend, `File` (default) and `Memory` pub storage: StorageType, @@ -118,26 +118,29 @@ impl ObjectStore { { Box::pin(async move { let object_info = self.info(object_name).await?; - if let Some(link) = object_info.link.as_ref() { - if let Some(link_name) = link.name.as_ref() { - let link_name = link_name.clone(); - debug!("getting object via link"); - if link.bucket == self.name { - return self.get(link_name).await; + if let Some(ref options) = object_info.options { + if let Some(link) = options.link.as_ref() { + if let Some(link_name) = link.name.as_ref() { + let link_name = link_name.clone(); + debug!("getting object via link"); + if link.bucket == self.name { + return self.get(link_name).await; + } else { + let bucket = self + .stream + .context + .get_object_store(&link_name) + .await + .map_err(|err| GetError::with_source(GetErrorKind::Other, err))?; + let object = bucket.get(&link_name).await?; + return Ok(object); + } } else { - let bucket = self - .stream - .context - .get_object_store(&link_name) - .await - .map_err(|err| GetError::with_source(GetErrorKind::Other, err))?; - let object = bucket.get(&link_name).await?; - return Ok(object); + return Err(GetError::new(GetErrorKind::BucketLink)); } - } else { - return Err(GetError::new(GetErrorKind::BucketLink)); } } + debug!("not a link. Getting the object"); Ok(Object::new(object_info, self.stream.clone())) }) @@ -273,9 +276,9 @@ impl ObjectStore { data: &mut (impl tokio::io::AsyncRead + std::marker::Unpin), ) -> Result where - ObjectMeta: From, + ObjectMetadata: From, { - let object_meta: ObjectMeta = meta.into(); + let object_meta: ObjectMetadata = meta.into(); let encoded_object_name = encode_object_name(&object_meta.name); if !is_valid_object_name(&encoded_object_name) { @@ -293,7 +296,8 @@ impl ObjectStore { let mut object_chunks = 0; let mut object_size = 0; - let mut buffer = BytesMut::with_capacity(DEFAULT_CHUNK_SIZE); + let chunk_size = object_meta.chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let mut buffer = BytesMut::with_capacity(chunk_size); let mut context = ring::digest::Context::new(&SHA256); loop { @@ -335,7 +339,10 @@ impl ObjectStore { let object_info = ObjectInfo { name: object_meta.name, description: object_meta.description, - link: None, + options: Some(ObjectOptions { + max_chunk_size: Some(chunk_size), + link: None, + }), bucket: self.name.clone(), nuid: object_nuid.to_string(), chunks: object_chunks, @@ -515,7 +522,7 @@ impl ObjectStore { Ok(()) } - /// Updates [Object] [ObjectMeta]. + /// Updates [Object] [ObjectMetadata]. /// /// # Examples /// @@ -530,7 +537,7 @@ impl ObjectStore { /// bucket /// .update_metadata( /// "object", - /// object_store::ObjectMeta { + /// object_store::UpdateMetadata { /// name: "new_name".to_string(), /// description: Some("a new description".to_string()), /// }, @@ -542,7 +549,7 @@ impl ObjectStore { pub async fn update_metadata>( &self, object: A, - metadata: ObjectMeta, + metadata: UpdateMetadata, ) -> Result { let mut info = self.info(object.as_ref()).await?; @@ -669,13 +676,18 @@ impl ObjectStore { if object.deleted { return Err(AddLinkError::new(AddLinkErrorKind::Deleted)); } - if object.link.is_some() { - return Err(AddLinkError::new(AddLinkErrorKind::LinkToLink)); + if let Some(ref options) = object.options { + if options.link.is_some() { + return Err(AddLinkError::new(AddLinkErrorKind::LinkToLink)); + } } - match self.info(&name).await { Ok(info) => { - if info.link.is_none() { + if let Some(options) = info.options { + if options.link.is_none() { + return Err(AddLinkError::new(AddLinkErrorKind::AlreadyExists)); + } + } else { return Err(AddLinkError::new(AddLinkErrorKind::AlreadyExists)); } } @@ -688,9 +700,12 @@ impl ObjectStore { let info = ObjectInfo { name, description: None, - link: Some(ObjectLink { - name: Some(object.name.clone()), - bucket: object.bucket.clone(), + options: Some(ObjectOptions { + link: Some(ObjectLink { + name: Some(object.name.clone()), + bucket: object.bucket.clone(), + }), + max_chunk_size: None, }), bucket: self.name.clone(), nuid: nuid::next().to_string(), @@ -736,8 +751,10 @@ impl ObjectStore { match self.info(&name).await { Ok(info) => { - if info.link.is_none() { - return Err(AddLinkError::new(AddLinkErrorKind::AlreadyExists)); + if let Some(options) = info.options { + if options.link.is_none() { + return Err(AddLinkError::new(AddLinkErrorKind::AlreadyExists)); + } } } Err(err) if err.kind() != InfoErrorKind::NotFound => { @@ -749,7 +766,10 @@ impl ObjectStore { let info = ObjectInfo { name: name.clone(), description: None, - link: Some(ObjectLink { name: None, bucket }), + options: Some(ObjectOptions { + link: Some(ObjectLink { name: None, bucket }), + max_chunk_size: None, + }), bucket: self.name.clone(), nuid: nuid::next().to_string(), size: 0, @@ -1030,6 +1050,12 @@ impl tokio::io::AsyncRead for Object<'_> { } } +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct ObjectOptions { + pub link: Option, + pub max_chunk_size: Option, +} + /// Meta and instance information about an object. #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] pub struct ObjectInfo { @@ -1038,7 +1064,7 @@ pub struct ObjectInfo { /// A short human readable description of the object. pub description: Option, /// Link this object points to, if any. - pub link: Option, + pub options: Option, /// Name of the bucket the object is stored in. pub bucket: String, /// Unique identifier used to uniquely identify this version of the object. @@ -1071,18 +1097,28 @@ pub struct ObjectLink { pub bucket: String, } +#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct UpdateMetadata { + /// Name of the object + pub name: String, + /// A short human readable description of the object. + pub description: Option, +} + /// Meta information about an object. #[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] -pub struct ObjectMeta { +pub struct ObjectMetadata { /// Name of the object pub name: String, /// A short human readable description of the object. pub description: Option, + /// Max chunk size. Default is 128k. + pub chunk_size: Option, } -impl From<&str> for ObjectMeta { - fn from(s: &str) -> ObjectMeta { - ObjectMeta { +impl From<&str> for ObjectMetadata { + fn from(s: &str) -> ObjectMetadata { + ObjectMetadata { name: s.to_string(), ..Default::default() } @@ -1104,11 +1140,12 @@ impl AsObjectInfo for &ObjectInfo { } } -impl From for ObjectMeta { +impl From for ObjectMetadata { fn from(info: ObjectInfo) -> Self { - ObjectMeta { + ObjectMetadata { name: info.name, description: info.description, + chunk_size: None, } } } diff --git a/async-nats/src/jetstream/stream.rs b/async-nats/src/jetstream/stream.rs index 0283753b0..914603629 100644 --- a/async-nats/src/jetstream/stream.rs +++ b/async-nats/src/jetstream/stream.rs @@ -32,7 +32,7 @@ use base64::engine::general_purpose::STANDARD; use base64::engine::Engine; use bytes::Bytes; use futures::{future::BoxFuture, TryFutureExt}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use serde_json::json; use time::{serde::rfc3339, OffsetDateTime}; @@ -1153,16 +1153,21 @@ pub enum StorageType { /// Shows config and current state for this stream. #[derive(Debug, Deserialize, Clone)] pub struct Info { - /// The configuration associated with this stream + /// The configuration associated with this stream. pub config: Config, - /// The time that this stream was created + /// The time that this stream was created. #[serde(with = "rfc3339")] pub created: time::OffsetDateTime, - /// Various metrics associated with this stream + /// Various metrics associated with this stream. pub state: State, - - ///information about leader and replicas + /// Information about leader and replicas. pub cluster: Option, + /// Information about mirror config if present. + #[serde(default)] + pub mirror: Option, + /// Information about sources configs if present. + #[serde(default)] + pub sources: Vec, } #[derive(Deserialize)] @@ -1374,6 +1379,40 @@ pub struct PeerInfo { pub lag: Option, } +#[derive(Debug, Clone, Deserialize)] +pub struct SourceInfo { + /// Source name. + pub name: String, + /// Number of messages this source is lagging behind. + pub lag: u64, + /// Last time the source was seen active. + #[serde(deserialize_with = "negative_duration_as_none")] + pub active: Option, + /// Filtering for the source. + #[serde(default)] + pub filter_subject: Option, + /// Source destination subject. + #[serde(default)] + pub subject_transform_dest: Option, + /// List of transforms. + #[serde(default)] + pub subject_transforms: Vec, +} + +fn negative_duration_as_none<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let n = i64::deserialize(deserializer)?; + if n.is_negative() { + Ok(None) + } else { + Ok(Some(std::time::Duration::from_nanos(n as u64))) + } +} + /// The response generated by trying to purge a stream. #[derive(Debug, Deserialize, Clone, Copy)] pub struct PurgeResponse { @@ -1422,14 +1461,10 @@ pub struct Source { /// Optional config to set a domain, if source is residing in different one. #[serde(default, skip_serializing_if = "is_default")] pub domain: Option, - /// Optional config to set the subject transform destination + /// Subject transforms for Stream. #[cfg(feature = "server_2_10")] - #[serde( - default, - rename = "subject_transform_dest", - skip_serializing_if = "is_default" - )] - pub subject_transform_destination: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub subject_transforms: Vec, } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Default)] diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 4e3d14e36..ccda935d0 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -118,18 +118,20 @@ #![deny(rustdoc::private_intra_doc_links)] #![deny(rustdoc::invalid_codeblock_attributes)] #![deny(rustdoc::invalid_rust_codeblocks)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] use thiserror::Error; -use futures::future::FutureExt; -use futures::select; use futures::stream::Stream; +use tokio::sync::oneshot; use tracing::{debug, error}; use core::fmt; use std::collections::HashMap; use std::fmt::Display; +use std::future::Future; use std::iter; +use std::mem; use std::net::{SocketAddr, ToSocketAddrs}; use std::option; use std::pin::Pin; @@ -144,7 +146,7 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use tokio::io; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use tokio::task; pub type Error = Box; @@ -286,9 +288,8 @@ pub(crate) enum Command { max: Option, }, Flush { - result: oneshot::Sender>, + observer: oneshot::Sender<()>, }, - TryFlush, } /// `ClientOp` represents all actions of `Client`. @@ -339,7 +340,8 @@ pub(crate) struct ConnectionHandler { pending_pings: usize, info_sender: tokio::sync::watch::Sender, ping_interval: Interval, - flush_interval: Interval, + is_flushing: bool, + flush_observers: Vec>, } impl ConnectionHandler { @@ -348,14 +350,10 @@ impl ConnectionHandler { connector: Connector, info_sender: tokio::sync::watch::Sender, ping_period: Duration, - flush_period: Duration, ) -> ConnectionHandler { let mut ping_interval = interval(ping_period); ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - let mut flush_interval = interval(flush_period); - flush_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); - ConnectionHandler { connection, connector, @@ -364,82 +362,175 @@ impl ConnectionHandler { pending_pings: 0, info_sender, ping_interval, - flush_interval, + is_flushing: false, + flush_observers: Vec::new(), } } - pub(crate) async fn process( - &mut self, - mut receiver: mpsc::Receiver, - ) -> Result<(), io::Error> { - loop { - select! { - _ = self.ping_interval.tick().fuse() => { - self.pending_pings += 1; - - if self.pending_pings > MAX_PENDING_PINGS { - debug!( - "pending pings {}, max pings {}. disconnecting", - self.pending_pings, MAX_PENDING_PINGS - ); - self.handle_disconnect().await?; + pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver) { + struct ProcessFut<'a> { + handler: &'a mut ConnectionHandler, + receiver: &'a mut mpsc::Receiver, + } + + enum ExitReason { + Disconnected(Option), + Closed, + } + + impl<'a> ProcessFut<'a> { + #[cold] + fn ping(&mut self) -> Poll { + self.handler.pending_pings += 1; + + if self.handler.pending_pings > MAX_PENDING_PINGS { + debug!( + "pending pings {}, max pings {}. disconnecting", + self.handler.pending_pings, MAX_PENDING_PINGS + ); + + Poll::Ready(ExitReason::Disconnected(None)) + } else { + self.handler.connection.enqueue_write_op(&ClientOp::Ping); + self.handler.is_flushing = true; + + Poll::Pending + } + } + } + + impl<'a> Future for ProcessFut<'a> { + type Output = ExitReason; + + /// Drives the connection forward. + /// + /// Returns one of the following: + /// + /// * `Poll::Pending` means that the connection + /// is blocked on all fronts or there are + /// no commands to send or receive + /// * `Poll::Ready(ExitReason::Disconnected(_))` means + /// that an I/O operation failed and the connection + /// is considered dead. + /// * `Poll::Ready(ExitReason::Closed)` means that + /// [`Self::receiver`] was closed, so there's nothing + /// more for us to do than to exit the client. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.handler.ping_interval.poll_tick(cx).is_ready() { + if let Poll::Ready(exit) = self.ping() { + return Poll::Ready(exit); } + } - if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { - self.handle_disconnect().await?; + loop { + match self.handler.connection.poll_read_op(cx) { + Poll::Pending => break, + Poll::Ready(Ok(Some(server_op))) => { + self.handler.handle_server_op(server_op); + } + Poll::Ready(Ok(None)) => { + return Poll::Ready(ExitReason::Disconnected(None)) + } + Poll::Ready(Err(err)) => { + return Poll::Ready(ExitReason::Disconnected(Some(err))) + } } + } - self.handle_flush().await?; + // WARNING: after the following loop `handle_command`, + // or other functions which call `enqueue_write_op`, + // cannot be called anymore. Runtime wakeups won't + // trigger a call to `poll_write` + + let mut made_progress = true; + loop { + while !self.handler.connection.is_write_buf_full() { + match self.receiver.poll_recv(cx) { + Poll::Pending => break, + Poll::Ready(Some(cmd)) => { + made_progress = true; + self.handler.handle_command(cmd); + } + Poll::Ready(None) => return Poll::Ready(ExitReason::Closed), + } + } - }, - _ = self.flush_interval.tick().fuse() => { - if let Err(_err) = self.handle_flush().await { - self.handle_disconnect().await?; + // The first round will poll both from + // the `receiver` and the writer, giving + // them both a chance to make progress + // and register `Waker`s. + // + // If writing is `Poll::Pending` we exit. + // + // If writing is completed we can repeat the entire + // cycle as long as the `receiver` doesn't end-up + // `Poll::Pending` immediately. + if !mem::take(&mut made_progress) { + break; } - }, - maybe_command = receiver.recv().fuse() => { - match maybe_command { - Some(command) => if let Err(err) = self.handle_command(command).await { - error!("error handling command {}", err); - } - None => { + + match self.handler.connection.poll_write(cx) { + Poll::Pending => { + // Write buffer couldn't be fully emptied break; } + Poll::Ready(Ok(())) => { + // Write buffer is empty + continue; + } + Poll::Ready(Err(err)) => { + return Poll::Ready(ExitReason::Disconnected(Some(err))) + } } } - maybe_op_result = self.connection.read_op().fuse() => { - match maybe_op_result { - Ok(Some(server_op)) => if let Err(err) = self.handle_server_op(server_op).await { - error!("error handling operation {}", err); - } - Ok(None) => { - if let Err(err) = self.handle_disconnect().await { - error!("error handling operation {}", err); + if !self.handler.is_flushing && self.handler.connection.should_flush() { + self.handler.is_flushing = true; + } + + if self.handler.is_flushing { + match self.handler.connection.poll_flush(cx) { + Poll::Pending => {} + Poll::Ready(Ok(())) => { + self.handler.is_flushing = false; + + for observer in self.handler.flush_observers.drain(..) { + let _ = observer.send(()); } } - Err(op_err) => { - if let Err(err) = self.handle_disconnect().await { - error!("error reconnecting {}. original error={}", err, op_err); - } - }, + Poll::Ready(Err(err)) => { + return Poll::Ready(ExitReason::Disconnected(Some(err))) + } } } + + Poll::Pending } } - self.handle_flush().await?; - - Ok(()) + loop { + let process = ProcessFut { + handler: self, + receiver, + }; + match process.await { + ExitReason::Disconnected(err) => { + debug!(?err, "disconnected"); + + self.handle_disconnect().await; + debug!("reconnected"); + } + ExitReason::Closed => break, + } + } } - async fn handle_server_op(&mut self, server_op: ServerOp) -> Result<(), io::Error> { + fn handle_server_op(&mut self, server_op: ServerOp) { self.ping_interval.reset(); match server_op { ServerOp::Ping => { - self.connection.write_op(&ClientOp::Pong).await?; - self.handle_flush().await?; + self.connection.enqueue_write_op(&ClientOp::Pong); } ServerOp::Pong => { debug!("received PONG"); @@ -489,16 +580,13 @@ impl ConnectionHandler { Err(mpsc::error::TrySendError::Full(_)) => { self.connector .events_tx - .send(Event::SlowConsumer(sid)) - .await + .try_send(Event::SlowConsumer(sid)) .ok(); } Err(mpsc::error::TrySendError::Closed(_)) => { self.subscriptions.remove(&sid); self.connection - .write_op(&ClientOp::Unsubscribe { sid, max: None }) - .await?; - self.handle_flush().await?; + .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None }); } } } else if sid == MULTIPLEXER_SID { @@ -517,9 +605,7 @@ impl ConnectionHandler { length, }; - sender.send(message).map_err(|_| { - io::Error::new(io::ErrorKind::Other, "request receiver closed") - })?; + let _ = sender.send(message); } } } @@ -528,11 +614,7 @@ impl ConnectionHandler { // TODO: we should probably update advertised server list here too. ServerOp::Info(info) => { if info.lame_duck_mode { - self.connector - .events_tx - .send(Event::LameDuckMode) - .await - .ok(); + self.connector.events_tx.try_send(Event::LameDuckMode).ok(); } } @@ -540,18 +622,9 @@ impl ConnectionHandler { // TODO: don't ignore. } } - - Ok(()) - } - - async fn handle_flush(&mut self) -> Result<(), io::Error> { - self.connection.flush().await?; - self.flush_interval.reset(); - - Ok(()) } - async fn handle_command(&mut self, command: Command) -> Result<(), io::Error> { + fn handle_command(&mut self, command: Command) { self.ping_interval.reset(); match command { @@ -569,38 +642,13 @@ impl ConnectionHandler { } } - if let Err(err) = self - .connection - .write_op(&ClientOp::Unsubscribe { sid, max }) - .await - { - error!("Send failed with {:?}", err); - } + self.connection + .enqueue_write_op(&ClientOp::Unsubscribe { sid, max }); } } - Command::Flush { result } => { - if let Err(_err) = self.handle_flush().await { - if let Err(err) = self.handle_disconnect().await { - result.send(Err(err)).map_err(|_| { - io::Error::new(io::ErrorKind::Other, "one shot failed to be received") - })?; - } else if let Err(err) = self.handle_flush().await { - result.send(Err(err)).map_err(|_| { - io::Error::new(io::ErrorKind::Other, "one shot failed to be received") - })?; - } else { - result.send(Ok(())).map_err(|_| { - io::Error::new(io::ErrorKind::Other, "one shot failed to be received") - })?; - } - } else { - result.send(Ok(())).map_err(|_| { - io::Error::new(io::ErrorKind::Other, "one shot failed to be received") - })?; - } - } - Command::TryFlush => { - self.handle_flush().await?; + Command::Flush { observer } => { + self.is_flushing = true; + self.flush_observers.push(observer); } Command::Subscribe { sid, @@ -618,17 +666,11 @@ impl ConnectionHandler { self.subscriptions.insert(sid, subscription); - if let Err(err) = self - .connection - .write_op(&ClientOp::Subscribe { - sid, - subject, - queue_group, - }) - .await - { - error!("Sending Subscribe failed with {:?}", err); - } + self.connection.enqueue_write_op(&ClientOp::Subscribe { + sid, + subject, + queue_group, + }); } Command::Request { subject, @@ -637,26 +679,18 @@ impl ConnectionHandler { headers, sender, } => { - let (prefix, token) = respond.rsplit_once('.').ok_or_else(|| { - io::Error::new(io::ErrorKind::Other, "malformed request subject") - })?; + let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject"); let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() { multiplexer } else { let subject = format!("{}.*", prefix); - if let Err(err) = self - .connection - .write_op(&ClientOp::Subscribe { - sid: MULTIPLEXER_SID, - subject: subject.clone(), - queue_group: None, - }) - .await - { - error!("Sending Subscribe failed with {:?}", err); - } + self.connection.enqueue_write_op(&ClientOp::Subscribe { + sid: MULTIPLEXER_SID, + subject: subject.clone(), + queue_group: None, + }); self.multiplexer.insert(Multiplexer { subject, @@ -674,12 +708,7 @@ impl ConnectionHandler { headers, }; - while let Err(err) = self.connection.write_op(&pub_op).await { - self.handle_disconnect().await?; - error!("Sending Publish failed with {:?}", err); - } - - self.connection.flush().await?; + self.connection.enqueue_write_op(&pub_op); } Command::Publish { @@ -688,69 +717,49 @@ impl ConnectionHandler { respond, headers, } => { - let pub_op = ClientOp::Publish { + self.connection.enqueue_write_op(&ClientOp::Publish { subject, payload, respond, headers, - }; - while let Err(err) = self.connection.write_op(&pub_op).await { - self.handle_disconnect().await?; - error!("Sending Publish failed with {:?}", err); - } + }); } } - - Ok(()) } - async fn handle_disconnect(&mut self) -> io::Result<()> { + async fn handle_disconnect(&mut self) { self.pending_pings = 0; self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); - self.handle_reconnect().await?; - Ok(()) + self.handle_reconnect().await; } - async fn handle_reconnect(&mut self) -> Result<(), io::Error> { - let (info, connection) = self.connector.connect().await?; + async fn handle_reconnect(&mut self) { + let (info, connection) = self.connector.connect().await; self.connection = connection; - self.info_sender.send(info).map_err(|err| { - std::io::Error::new( - ErrorKind::Other, - format!("failed to send info update: {err}"), - ) - })?; + let _ = self.info_sender.send(info); self.subscriptions .retain(|_, subscription| !subscription.sender.is_closed()); for (sid, subscription) in &self.subscriptions { - self.connection - .write_op(&ClientOp::Subscribe { - sid: *sid, - subject: subscription.subject.to_owned(), - queue_group: subscription.queue_group.to_owned(), - }) - .await - .unwrap(); + self.connection.enqueue_write_op(&ClientOp::Subscribe { + sid: *sid, + subject: subscription.subject.to_owned(), + queue_group: subscription.queue_group.to_owned(), + }); } if let Some(multiplexer) = &self.multiplexer { - self.connection - .write_op(&ClientOp::Subscribe { - sid: MULTIPLEXER_SID, - subject: multiplexer.subject.to_owned(), - queue_group: None, - }) - .await - .unwrap(); + self.connection.enqueue_write_op(&ClientOp::Subscribe { + sid: MULTIPLEXER_SID, + subject: multiplexer.subject.to_owned(), + queue_group: None, + }); } self.connector.events_tx.try_send(Event::Connected).ok(); - - Ok(()) } } @@ -774,7 +783,6 @@ pub async fn connect_with_options( options: ConnectOptions, ) -> Result { let ping_period = options.ping_interval; - let flush_period = options.flush_interval; let (events_tx, mut events_rx) = mpsc::channel(128); let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending); @@ -812,7 +820,7 @@ pub async fn connect_with_options( } let (info_sender, info_watcher) = tokio::sync::watch::channel(info); - let (sender, receiver) = mpsc::channel(options.sender_capacity); + let (sender, mut receiver) = mpsc::channel(options.sender_capacity); let client = Client::new( info_watcher, @@ -831,19 +839,14 @@ pub async fn connect_with_options( task::spawn(async move { if connection.is_none() && options.retry_on_initial_connect { - let (info, connection_ok) = connector.connect().await.unwrap(); + let (info, connection_ok) = connector.connect().await; info_sender.send(info).ok(); connection = Some(connection_ok); } let connection = connection.unwrap(); - let mut connection_handler = ConnectionHandler::new( - connection, - connector, - info_sender, - ping_period, - flush_period, - ); - connection_handler.process(receiver).await + let mut connection_handler = + ConnectionHandler::new(connection, connector, info_sender, ping_period); + connection_handler.process(&mut receiver).await }); Ok(client) @@ -1055,7 +1058,6 @@ impl Subscriber { /// /// let mut subscriber = client.subscribe("test".into()).await?; /// subscriber.unsubscribe_after(3).await?; - /// client.flush().await?; /// /// for _ in 0..3 { /// client.publish("test".into(), "data".into()).await?; diff --git a/async-nats/src/options.rs b/async-nats/src/options.rs index 93d7629ef..89132f96b 100644 --- a/async-nats/src/options.rs +++ b/async-nats/src/options.rs @@ -55,7 +55,6 @@ pub struct ConnectOptions { pub(crate) client_cert: Option, pub(crate) client_key: Option, pub(crate) tls_client_config: Option, - pub(crate) flush_interval: Duration, pub(crate) ping_interval: Duration, pub(crate) subscription_capacity: usize, pub(crate) sender_capacity: usize, @@ -84,7 +83,6 @@ impl fmt::Debug for ConnectOptions { .entry(&"client_cert", &self.client_cert) .entry(&"client_key", &self.client_key) .entry(&"tls_client_config", &"XXXXXXXX") - .entry(&"flush_interval", &self.flush_interval) .entry(&"ping_interval", &self.ping_interval) .entry(&"sender_capacity", &self.sender_capacity) .entry(&"inbox_prefix", &self.inbox_prefix) @@ -108,7 +106,6 @@ impl Default for ConnectOptions { client_cert: None, client_key: None, tls_client_config: None, - flush_interval: Duration::from_millis(1), ping_interval: Duration::from_secs(60), sender_capacity: 128, subscription_capacity: 4096, @@ -568,27 +565,6 @@ impl ConnectOptions { self } - /// Sets the interval for flushing. NATS connection will send buffered data to the NATS Server - /// whenever buffer limit is reached, but it is also necessary to flush once in a while if - /// client is sending rarely and small messages. Flush interval allows to modify that interval. - /// - /// # Examples - /// ```no_run - /// # use tokio::time::Duration; - /// # #[tokio::main] - /// # async fn main() -> Result<(), async_nats::ConnectError> { - /// async_nats::ConnectOptions::new() - /// .flush_interval(Duration::from_millis(100)) - /// .connect("demo.nats.io") - /// .await?; - /// # Ok(()) - /// # } - /// ``` - pub fn flush_interval(mut self, flush_interval: Duration) -> ConnectOptions { - self.flush_interval = flush_interval; - self - } - /// Sets how often Client sends PING message to the server. /// /// # Examples diff --git a/async-nats/tests/compatibility.rs b/async-nats/tests/compatibility.rs new file mode 100644 index 000000000..df1b04d5a --- /dev/null +++ b/async-nats/tests/compatibility.rs @@ -0,0 +1,425 @@ +// Copyright 2020-2022 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "compatibility_tests")] +mod compatibility { + use futures::{pin_mut, stream::Peekable, StreamExt}; + + use core::panic; + use std::{collections::HashMap, pin::Pin}; + + use async_nats::jetstream::{ + self, + object_store::{self, ObjectMetadata, UpdateMetadata}, + }; + use ring::digest::{self, SHA256}; + use serde::{Deserialize, Serialize}; + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn kv() { + panic!("kv suite not implemented yet") + } + + #[tokio::test] + async fn object_store() { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .init(); + let url = std::env::var("NATS_URL").unwrap_or_else(|_| "demo.nats.io".to_string()); + tracing::info!("staring client for object store tests at {}", url); + let client = async_nats::connect(url).await.unwrap(); + + let tests = client + .subscribe("tests.object-store.>".into()) + .await + .unwrap() + .peekable(); + pin_mut!(tests); + + let mut done = client.subscribe("tests.done".into()).await.unwrap(); + + loop { + tokio::select! { + _ = done.next() => { + tracing::info!("object store tests done"); + return; + } + message = tests.as_mut().peek() => { + let test: Test = Test::try_from(message.unwrap()).unwrap(); + match test.suite.as_str() { + "object-store" => { + let object_store = ObjectStore { + client: client.clone(), + }; + match test.test.as_str() { + "default-bucket" => object_store.default_bucket(tests.as_mut()).await, + "custom-bucket" => object_store.custom_bucket(tests.as_mut()).await, + "get-object" => object_store.get_object(tests.as_mut()).await, + "put-object" => object_store.put_object(tests.as_mut()).await, + "update-metadata" => object_store.update_metadata(tests.as_mut()).await, + "watch-updates" => object_store.watch_updates(tests.as_mut()).await, + "watch" => object_store.watch(tests.as_mut()).await, + "get-link" => object_store.get_link(tests.as_mut()).await, + "put-link" => object_store.put_link(tests.as_mut()).await, + unknown => panic!("unkown test: {}", unknown), + } + } + unknown => panic!("not an object store suite: {}", unknown), + } + } + } + } + } + struct Test { + suite: String, + test: String, + } + + impl TryFrom<&async_nats::Message> for Test { + type Error = String; + + fn try_from(message: &async_nats::Message) -> Result { + let mut elements = message.subject.split('.').skip(1); + + let suite = elements + .next() + .ok_or("missing suite token".to_string())? + .to_string(); + let test = elements + .next() + .ok_or("missing test token".to_string())? + .to_string(); + + Ok(Test { suite, test }) + } + } + + struct ObjectStore { + client: async_nats::Client, + } + + type PinnedSubscriber<'a> = Pin<&'a mut Peekable>; + + impl ObjectStore { + async fn default_bucket(&self, mut test_commands: PinnedSubscriber<'_>) { + let create = test_commands.as_mut().next().await.unwrap(); + println!("received first request: {}", create.subject); + + let given: TestRequest> = + serde_json::from_slice(&create.payload).unwrap(); + let jetstream = async_nats::jetstream::new(self.client.clone()); + jetstream + .create_object_store(object_store::Config { + bucket: given.config.get("bucket").unwrap().to_string(), + ..Default::default() + }) + .await + .unwrap(); + + self.client + .publish(create.reply.unwrap(), "".into()) + .await + .unwrap(); + self.client.flush().await.unwrap(); + + let done = test_commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test default-bucket PASS"); + } + } + + async fn custom_bucket(&self, mut commands: PinnedSubscriber<'_>) { + let create = commands.as_mut().next().await.unwrap(); + println!("received custom request: {}", create.subject); + + let custom_config: TestRequest = + serde_json::from_slice(&create.payload).unwrap(); + + async_nats::jetstream::new(self.client.clone()) + .create_object_store(custom_config.config) + .await + .unwrap(); + + self.client + .publish(create.reply.unwrap(), "".into()) + .await + .unwrap(); + self.client.flush().await.unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test custom-bucket PASS"); + } + } + + async fn put_object(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Debug, Deserialize)] + struct ObjectRequest { + url: String, + bucket: String, + #[serde(flatten)] + test_request: TestRequest, + } + + let object_request = commands.as_mut().next().await.unwrap(); + println!("received third request: {}", object_request.subject); + let reply = object_request.reply.unwrap().clone(); + let object_request: ObjectRequest = + serde_json::from_slice(&object_request.payload).unwrap(); + + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(object_request.bucket.clone()) + .await + .unwrap(); + + let file = reqwest::get(object_request.url).await.unwrap(); + let contents = file.bytes().await.unwrap(); + + bucket + .put(object_request.test_request.config, &mut contents.as_ref()) + .await + .unwrap(); + + self.client.publish(reply, "".into()).await.unwrap(); + self.client.flush().await.unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test put-object PASS"); + } + } + + async fn get_object(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize)] + struct Command { + object: String, + bucket: String, + } + let get_request = commands.as_mut().next().await.unwrap(); + + let request: Command = serde_json::from_slice(&get_request.payload).unwrap(); + + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(request.bucket) + .await + .unwrap(); + let mut object = bucket.get(request.object).await.unwrap(); + let mut contents = vec![]; + + object.read_to_end(&mut contents).await.unwrap(); + + let digest = digest::digest(&SHA256, &contents); + + self.client + .publish( + get_request.reply.unwrap(), + digest.as_ref().to_owned().into(), + ) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test get-object PASS"); + } + } + + async fn put_link(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize, Debug)] + struct Command { + object: String, + bucket: String, + link_name: String, + } + let get_request = commands.as_mut().next().await.unwrap(); + + let request: Command = serde_json::from_slice(&get_request.payload).unwrap(); + + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(request.bucket) + .await + .unwrap(); + let object = bucket.get(request.object).await.unwrap(); + + bucket.add_link(request.link_name, &object).await.unwrap(); + + self.client + .publish(get_request.reply.unwrap(), "".into()) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test put-link PASS"); + } + } + + async fn get_link(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize, Debug)] + struct Command { + object: String, + bucket: String, + } + let get_request = commands.as_mut().next().await.unwrap(); + + let request: Command = serde_json::from_slice(&get_request.payload).unwrap(); + + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(request.bucket) + .await + .unwrap(); + let mut object = bucket.get(request.object).await.unwrap(); + let mut contents = vec![]; + + object.read_to_end(&mut contents).await.unwrap(); + + let digest = digest::digest(&SHA256, &contents); + + self.client + .publish( + get_request.reply.unwrap(), + digest.as_ref().to_owned().into(), + ) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test get-object PASS"); + } + } + + async fn update_metadata(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize)] + struct Metadata { + object: String, + bucket: String, + config: UpdateMetadata, + } + + let update_command = commands.as_mut().next().await.unwrap(); + + let given: Metadata = serde_json::from_slice(&update_command.payload).unwrap(); + + let object_store = jetstream::new(self.client.clone()) + .get_object_store(given.bucket) + .await + .unwrap(); + + object_store + .update_metadata(given.object, given.config) + .await + .unwrap(); + + self.client + .publish(update_command.reply.unwrap(), "".into()) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test update-metadata PASS"); + } + } + + async fn watch_updates(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize)] + struct Command { + #[allow(dead_code)] + object: String, + bucket: String, + } + let get_request = commands.as_mut().next().await.unwrap(); + + let request: Command = serde_json::from_slice(&get_request.payload).unwrap(); + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(request.bucket) + .await + .unwrap(); + + let mut watch = bucket.watch().await.unwrap(); + + let info = watch.next().await.unwrap().unwrap(); + + self.client + .publish(get_request.reply.unwrap(), info.digest.unwrap().into()) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test update-metadata PASS"); + } + } + + async fn watch(&self, mut commands: PinnedSubscriber<'_>) { + #[derive(Deserialize)] + struct Command { + #[allow(dead_code)] + object: String, + bucket: String, + } + let get_request = commands.as_mut().next().await.unwrap(); + + let request: Command = serde_json::from_slice(&get_request.payload).unwrap(); + let bucket = async_nats::jetstream::new(self.client.clone()) + .get_object_store(request.bucket) + .await + .unwrap(); + + let mut watch = bucket.watch_with_history().await.unwrap(); + + let info = watch.next().await.unwrap().unwrap(); + let second_info = watch.next().await.unwrap().unwrap(); + + let response = [info.digest.unwrap(), second_info.digest.unwrap()].join(","); + + self.client + .publish(get_request.reply.unwrap(), response.into()) + .await + .unwrap(); + + let done = commands.next().await.unwrap(); + if done.headers.is_some() { + panic!("test failed: {:?}", done.headers); + } else { + println!("test update-metadata PASS"); + } + } + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct TestRequest { + suite: String, + test: String, + command: String, + config: T, + } +} diff --git a/async-nats/tests/configs/docker/Dockerfile b/async-nats/tests/configs/docker/Dockerfile new file mode 100644 index 000000000..3203f9cbd --- /dev/null +++ b/async-nats/tests/configs/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM rust:1.71.1 +WORKDIR /usr/src/nats.rs/async-nats +ARG PROFILE=test +COPY . /usr/src/nats.rs +RUN cargo test --features compatibility_tests --no-run +ENV NATS_URL=localhost:4222 +CMD cargo test --features compatibility_tests compatibility -- --nocapture \ No newline at end of file diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 5797f7d7b..49188831f 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -2733,6 +2733,64 @@ mod jetstream { ); } + #[tokio::test] + #[cfg(feature = "server_2_10")] + async fn stream_subject_transforms() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + let context = async_nats::jetstream::new(client); + + let subject_transform = stream::SubjectTransform { + source: "foo".to_string(), + destination: "bar".to_string(), + }; + + let source = stream::Source { + name: "source".to_string(), + filter_subject: Some("stream1.foo".to_string()), + ..Default::default() + }; + + let sources = vec![ + source.clone(), + stream::Source { + name: "multi_source".to_string(), + subject_transforms: vec![stream::SubjectTransform { + source: "stream2.foo.>".to_string(), + destination: "foo.>".to_string(), + }], + ..Default::default() + }, + ]; + + let mut stream = context + .create_stream(stream::Config { + name: "filtered".to_string(), + subject_transform: Some(subject_transform.clone()), + sources: Some(sources.clone()), + ..Default::default() + }) + .await + .unwrap(); + + let info = stream.info().await.unwrap(); + assert_eq!(info.config.sources, Some(sources.clone())); + assert_eq!(info.config.subject_transform, Some(subject_transform)); + + let mut stream = context + .create_stream(stream::Config { + name: "mirror".to_string(), + mirror: Some(source.clone()), + ..Default::default() + }) + .await + .unwrap(); + + let info = stream.info().await.unwrap(); + + assert_eq!(info.config.mirror, Some(source)); + } + #[tokio::test] async fn pull_by_bytes() { let server = nats_server::run_server("tests/configs/jetstream.conf"); @@ -3166,6 +3224,8 @@ mod jetstream { #[cfg(feature = "server_2_10")] #[tokio::test] async fn subject_transform() { + use async_nats::jetstream::stream::SubjectTransform; + let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = async_nats::connect(server.client_url()).await.unwrap(); let context = async_nats::jetstream::new(client.clone()); @@ -3195,7 +3255,10 @@ mod jetstream { name: "sourcing".to_string(), sources: Some(vec![async_nats::jetstream::stream::Source { name: "origin".to_string(), - subject_transform_destination: Some("fromtest.>".to_string()), + subject_transforms: vec![SubjectTransform { + source: ">".to_string(), + destination: "fromtest.>".to_string(), + }], ..Default::default() }]), ..Default::default() diff --git a/async-nats/tests/kv_tests.rs b/async-nats/tests/kv_tests.rs index 86903f722..ab8503578 100644 --- a/async-nats/tests/kv_tests.rs +++ b/async-nats/tests/kv_tests.rs @@ -335,7 +335,7 @@ mod kv { .await .unwrap(); - let context = async_nats::jetstream::new(client); + let context = async_nats::jetstream::new(client.clone()); let kv = context .create_key_value(async_nats::jetstream::kv::Config { @@ -352,6 +352,7 @@ mod kv { // check if we get only updated values. This should not pop up in watcher. kv.put("foo", 22.to_string().into()).await.unwrap(); let mut watch = kv.watch("foo").await.unwrap().enumerate(); + client.flush().await.unwrap(); tokio::task::spawn({ let kv = kv.clone(); diff --git a/async-nats/tests/object_store.rs b/async-nats/tests/object_store.rs index ad76211e5..b7366cabe 100644 --- a/async-nats/tests/object_store.rs +++ b/async-nats/tests/object_store.rs @@ -16,7 +16,7 @@ mod object_store { use std::{io, time::Duration}; use async_nats::jetstream::{ - object_store::{AddLinkErrorKind, ObjectMeta}, + object_store::{AddLinkErrorKind, ObjectMetadata, UpdateMetadata}, stream::DirectGetErrorKind, }; use base64::Engine; @@ -81,6 +81,24 @@ mod object_store { tracing::info!("reading content"); object_link.read_to_end(&mut contents).await.unwrap(); assert_eq!(contents, result); + + bucket + .put( + ObjectMetadata { + name: "BAR".to_string(), + description: Some("custom object".to_string()), + chunk_size: Some(64 * 1024), + }, + &mut bytes.as_slice(), + ) + .await + .unwrap(); + + let meta = bucket.get("BAR").await.unwrap(); + assert_eq!( + 64 * 1024, + meta.info.options.unwrap().max_chunk_size.unwrap() + ); } #[tokio::test] @@ -353,9 +371,10 @@ mod object_store { .unwrap(); bucket .put( - ObjectMeta { + ObjectMetadata { name: "Foo".to_string(), description: Some("foo desc".to_string()), + chunk_size: None, }, &mut "dadada".as_bytes(), ) @@ -436,7 +455,7 @@ mod object_store { .await .unwrap(); - let given_metadata = ObjectMeta { + let given_metadata = UpdateMetadata { name: "new_object".to_owned(), description: Some("description".to_string()), }; @@ -502,6 +521,9 @@ mod object_store { assert_eq!( link_info + .options + .as_ref() + .unwrap() .link .as_ref() .unwrap() @@ -511,7 +533,18 @@ mod object_store { .as_str(), "object" ); - assert_eq!(link_info.link.as_ref().unwrap().bucket.as_str(), "bucket"); + assert_eq!( + link_info + .options + .as_ref() + .unwrap() + .link + .as_ref() + .unwrap() + .bucket + .as_str(), + "bucket" + ); let result = bucket .add_link("object", &another_object) @@ -551,7 +584,26 @@ mod object_store { bucket.add_bucket_link("link", "another").await.unwrap(); let link_info = bucket.info("link").await.unwrap(); - assert!(link_info.link.as_ref().unwrap().name.is_none()); - assert_eq!(link_info.link.as_ref().unwrap().bucket.as_str(), "another"); + assert!(link_info + .options + .as_ref() + .unwrap() + .link + .as_ref() + .unwrap() + .name + .is_none()); + assert_eq!( + link_info + .options + .as_ref() + .unwrap() + .link + .as_ref() + .unwrap() + .bucket + .as_str(), + "another" + ); } } diff --git a/nats-server/Cargo.toml b/nats-server/Cargo.toml index b665ff67c..36234f764 100644 --- a/nats-server/Cargo.toml +++ b/nats-server/Cargo.toml @@ -10,7 +10,7 @@ license = "Apache-2.0" lazy_static = "1.4.0" regex = { version = "1.7.1", default-features = false, features = ["std", "unicode-perl"] } url = "2" -json = "0.12" +serde_json = "1.0.104" nuid = "0.5" rand = "0.8" tokio-retry = "0.3.0" diff --git a/nats-server/src/lib.rs b/nats-server/src/lib.rs index 57ba39f2e..fd5ba7a11 100644 --- a/nats-server/src/lib.rs +++ b/nats-server/src/lib.rs @@ -22,6 +22,7 @@ use std::{thread, time::Duration}; use lazy_static::lazy_static; use rand::Rng; use regex::Regex; +use serde_json::{self, Value}; pub struct Server { inner: Inner, @@ -77,8 +78,8 @@ impl Server { let mut r = BufReader::with_capacity(1024, TcpStream::connect(addr).unwrap()); let mut line = String::new(); r.read_line(&mut line).expect("did not receive INFO"); - let si = json::parse(&line["INFO".len()..]).unwrap(); - let port = si["port"].as_u16().expect("could not parse port"); + let si: Value = serde_json::from_str(&line["INFO".len()..]).expect("could not parse INFO"); + let port = si["port"].as_u64().expect("could not parse port") as u16; let mut scheme = "nats://"; if si["tls_required"].as_bool().unwrap_or(false) { scheme = "tls://"; @@ -91,8 +92,8 @@ impl Server { let mut r = BufReader::with_capacity(1024, TcpStream::connect(addr).unwrap()); let mut line = String::new(); r.read_line(&mut line).expect("did not receive INFO"); - let si = json::parse(&line["INFO".len()..]).unwrap(); - si["port"].as_u16().expect("could not parse port") + let si: Value = serde_json::from_str(&line["INFO".len()..]).expect("could not parse INFO"); + si["port"].as_u64().expect("could not parse port") as u16 } // Allow user/pass override.