diff --git a/async-nats/src/error.rs b/async-nats/src/error.rs index 13cd633af..567143723 100644 --- a/async-nats/src/error.rs +++ b/async-nats/src/error.rs @@ -73,7 +73,7 @@ where /// Enables wrapping source errors to the crate-specific error type /// by additionally specifying the kind of the target error. -trait WithKind +pub(crate) trait WithKind where Kind: Clone + Debug + Display + PartialEq, Self: Into, diff --git a/async-nats/src/header.rs b/async-nats/src/header.rs index fcc7d3ba7..30fd41dc1 100644 --- a/async-nats/src/header.rs +++ b/async-nats/src/header.rs @@ -22,6 +22,7 @@ use std::{collections::HashMap, fmt, slice::Iter, str::FromStr}; +use crate::error::Error; use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -312,7 +313,7 @@ impl FromStr for HeaderValue { fn from_str(s: &str) -> Result { if s.contains(['\r', '\n']) { - return Err(ParseHeaderValueError); + return Err(s.into()); } Ok(HeaderValue { @@ -339,19 +340,26 @@ impl HeaderValue { } } -#[derive(Debug, Clone)] -pub struct ParseHeaderValueError; +#[derive(Clone, Debug, PartialEq)] +pub struct ParseHeaderValueErrorKind(String); -impl fmt::Display for ParseHeaderValueError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseHeaderValueErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - r#"invalid character found in header value (value cannot contain '\r' or '\n')"# + r#"invalid character found in header value (value cannot contain '\r' or '\n'): {:?}"#, + self.0 ) } } -impl std::error::Error for ParseHeaderValueError {} +pub type ParseHeaderValueError = Error; + +impl From<&str> for ParseHeaderValueError { + fn from(value: &str) -> Self { + Self::new(ParseHeaderValueErrorKind(value.into())) + } +} pub trait IntoHeaderName { fn into_header_name(self) -> HeaderName; @@ -568,7 +576,9 @@ impl FromStr for HeaderName { fn from_str(s: &str) -> Result { if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) { - return Err(ParseHeaderNameError); + return Err(ParseHeaderNameError::new(ParseHeaderNameErrorKind( + s.into(), + ))); } match StandardHeader::from_bytes(s.as_ref()) { @@ -600,16 +610,22 @@ impl AsRef for HeaderName { } } -#[derive(Debug, Clone)] -pub struct ParseHeaderNameError; +#[derive(Clone, Debug, PartialEq)] +pub struct ParseHeaderNameErrorKind(String); -impl std::fmt::Display for ParseHeaderNameError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')") +impl fmt::Display for ParseHeaderNameErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-'): {:?}", self.0) } } -impl std::error::Error for ParseHeaderNameError {} +pub type ParseHeaderNameError = Error; + +impl From<&str> for ParseHeaderNameError { + fn from(value: &str) -> Self { + Self::new(ParseHeaderNameErrorKind(value.into())) + } +} #[cfg(test)] mod tests { diff --git a/async-nats/src/jetstream/stream.rs b/async-nats/src/jetstream/stream.rs index db9973e41..8290c4221 100644 --- a/async-nats/src/jetstream/stream.rs +++ b/async-nats/src/jetstream/stream.rs @@ -15,10 +15,11 @@ #[cfg(feature = "server_2_10")] use std::collections::HashMap; +use std::fmt::Formatter; use std::{ fmt::{self, Debug, Display}, future::IntoFuture, - io::{self, ErrorKind}, + io::ErrorKind, pin::Pin, str::FromStr, task::Poll, @@ -1311,31 +1312,64 @@ fn is_continuation(c: char) -> bool { } const HEADER_LINE: &str = "NATS/1.0"; +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ParseHeadersErrorKind { + HeaderInfoMissing, + InvalidHeader, + InvalidVersionLine, + InvalidStatusCode, + MalformedHeader, +} + +impl Display for ParseHeadersErrorKind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidHeader => write!(f, "invalid header"), + Self::InvalidStatusCode => write!(f, "invalid status code"), + Self::InvalidVersionLine => write!(f, "version line does not start with NATS/1.0"), + Self::HeaderInfoMissing => write!(f, "expected header information not found"), + Self::MalformedHeader => write!(f, "malformed header line"), + } + } +} + +pub type ParseHeadersError = Error; + +impl From for ParseHeadersError { + fn from(e: ParseHeaderNameError) -> Self { + e.with_kind(ParseHeadersErrorKind::MalformedHeader) + } +} + +impl From for ParseHeadersError { + fn from(e: ParseHeaderValueError) -> Self { + e.with_kind(ParseHeadersErrorKind::MalformedHeader) + } +} + +impl From for ParseHeadersError { + fn from(e: InvalidStatusCode) -> Self { + e.with_kind(ParseHeadersErrorKind::InvalidStatusCode) + } +} + #[allow(clippy::type_complexity)] fn parse_headers( buf: &[u8], -) -> Result<(Option, Option, Option), crate::Error> { +) -> Result<(Option, Option, Option), ParseHeadersError> { let mut headers = HeaderMap::new(); let mut maybe_status: Option = None; let mut maybe_description: Option = None; let mut lines = if let Ok(line) = std::str::from_utf8(buf) { line.lines().peekable() } else { - return Err(Box::new(std::io::Error::new( - ErrorKind::Other, - "invalid header", - ))); + return Err(ParseHeadersErrorKind::InvalidHeader.into()); }; if let Some(line) = lines.next() { let line = line .strip_prefix(HEADER_LINE) - .ok_or_else(|| { - Box::new(std::io::Error::new( - ErrorKind::Other, - "version line does not start with NATS/1.0", - )) - })? + .ok_or_else(|| ParseHeadersError::new(ParseHeadersErrorKind::InvalidHeader))? .trim(); match line.split_once(' ') { @@ -1355,10 +1389,7 @@ fn parse_headers( } } } else { - return Err(Box::new(std::io::Error::new( - ErrorKind::Other, - "expected header information not found", - ))); + return Err(ParseHeadersErrorKind::HeaderInfoMissing.into()); }; while let Some(line) = lines.next() { @@ -1373,16 +1404,9 @@ fn parse_headers( s.push_str(v.trim()); } - headers.insert( - HeaderName::from_str(k)?, - HeaderValue::from_str(&s) - .map_err(|err| Box::new(io::Error::new(ErrorKind::Other, err)))?, - ); + headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(&s)?); } else { - return Err(Box::new(std::io::Error::new( - ErrorKind::Other, - "malformed header line", - ))); + return Err(ParseHeadersErrorKind::MalformedHeader.into()); } } @@ -1528,6 +1552,9 @@ pub struct External { pub delivery_prefix: Option, } +use crate::error::WithKind; +use crate::header::{ParseHeaderNameError, ParseHeaderValueError}; +use crate::status::InvalidStatusCode; use std::marker::PhantomData; #[derive(Debug, Default)]