Skip to content

Commit

Permalink
Refactor ParseError
Browse files Browse the repository at this point in the history
Instead of the `crate::Error`, the <private> method `stream::parse_headers`
will return the specialized error with kind enumerating the possible
causes of parsing errors. This change will help to specialize errors
error returned by public methods.

Along this change, I also refactored 2 errors, namely
the `ParseHeaderNameError` and the `ParseHeaderValueError`
to convert them to specialization of generic `crate::error::Error<_>`.

These changes are collected in the same commit because all of them
describe errors occurred during either building or parsing headers.
  • Loading branch information
nepalez committed Sep 22, 2023
1 parent 02fd086 commit ba6866d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 40 deletions.
2 changes: 1 addition & 1 deletion async-nats/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ where

/// Enables wrapping source errors to the crate-specific error type
/// by additionally specifying the kind of the target error.
trait WithKind<Kind>
pub(crate) trait WithKind<Kind>
where
Kind: Clone + Debug + Display + PartialEq,
Self: Into<crate::Error>,
Expand Down
44 changes: 30 additions & 14 deletions async-nats/src/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};

use crate::error::Error;
use bytes::Bytes;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -312,7 +313,7 @@ impl FromStr for HeaderValue {

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.contains(['\r', '\n']) {
return Err(ParseHeaderValueError);
return Err(s.into());
}

Ok(HeaderValue {
Expand All @@ -339,19 +340,26 @@ impl HeaderValue {
}
}

#[derive(Debug, Clone)]
pub struct ParseHeaderValueError;
#[derive(Clone, Debug, PartialEq)]
pub struct ParseHeaderValueErrorKind(String);

impl fmt::Display for ParseHeaderValueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl fmt::Display for ParseHeaderValueErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
r#"invalid character found in header value (value cannot contain '\r' or '\n'): {:?}"#,
self.0
)
}
}

impl std::error::Error for ParseHeaderValueError {}
pub type ParseHeaderValueError = Error<ParseHeaderValueErrorKind>;

impl From<&str> for ParseHeaderValueError {
fn from(value: &str) -> Self {
Self::new(ParseHeaderValueErrorKind(value.into()))
}
}

pub trait IntoHeaderName {
fn into_header_name(self) -> HeaderName;
Expand Down Expand Up @@ -568,7 +576,9 @@ impl FromStr for HeaderName {

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) {
return Err(ParseHeaderNameError);
return Err(ParseHeaderNameError::new(ParseHeaderNameErrorKind(
s.into(),
)));
}

match StandardHeader::from_bytes(s.as_ref()) {
Expand Down Expand Up @@ -600,16 +610,22 @@ impl AsRef<str> for HeaderName {
}
}

#[derive(Debug, Clone)]
pub struct ParseHeaderNameError;
#[derive(Clone, Debug, PartialEq)]
pub struct ParseHeaderNameErrorKind(String);

impl std::fmt::Display for ParseHeaderNameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
impl fmt::Display for ParseHeaderNameErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-'): {:?}", self.0)
}
}

impl std::error::Error for ParseHeaderNameError {}
pub type ParseHeaderNameError = Error<ParseHeaderNameErrorKind>;

impl From<&str> for ParseHeaderNameError {
fn from(value: &str) -> Self {
Self::new(ParseHeaderNameErrorKind(value.into()))
}
}

#[cfg(test)]
mod tests {
Expand Down
77 changes: 52 additions & 25 deletions async-nats/src/jetstream/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

#[cfg(feature = "server_2_10")]
use std::collections::HashMap;
use std::fmt::Formatter;
use std::{
fmt::{self, Debug, Display},
future::IntoFuture,
io::{self, ErrorKind},
io::ErrorKind,
pin::Pin,
str::FromStr,
task::Poll,
Expand Down Expand Up @@ -1311,31 +1312,64 @@ fn is_continuation(c: char) -> bool {
}
const HEADER_LINE: &str = "NATS/1.0";

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ParseHeadersErrorKind {
HeaderInfoMissing,
InvalidHeader,
InvalidVersionLine,
InvalidStatusCode,
MalformedHeader,
}

impl Display for ParseHeadersErrorKind {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidHeader => write!(f, "invalid header"),
Self::InvalidStatusCode => write!(f, "invalid status code"),
Self::InvalidVersionLine => write!(f, "version line does not start with NATS/1.0"),
Self::HeaderInfoMissing => write!(f, "expected header information not found"),
Self::MalformedHeader => write!(f, "malformed header line"),
}
}
}

pub type ParseHeadersError = Error<ParseHeadersErrorKind>;

impl From<ParseHeaderNameError> for ParseHeadersError {
fn from(e: ParseHeaderNameError) -> Self {
e.with_kind(ParseHeadersErrorKind::MalformedHeader)
}
}

impl From<ParseHeaderValueError> for ParseHeadersError {
fn from(e: ParseHeaderValueError) -> Self {
e.with_kind(ParseHeadersErrorKind::MalformedHeader)
}
}

impl From<InvalidStatusCode> for ParseHeadersError {
fn from(e: InvalidStatusCode) -> Self {
e.with_kind(ParseHeadersErrorKind::InvalidStatusCode)
}
}

#[allow(clippy::type_complexity)]
fn parse_headers(
buf: &[u8],
) -> Result<(Option<HeaderMap>, Option<StatusCode>, Option<String>), crate::Error> {
) -> Result<(Option<HeaderMap>, Option<StatusCode>, Option<String>), ParseHeadersError> {
let mut headers = HeaderMap::new();
let mut maybe_status: Option<StatusCode> = None;
let mut maybe_description: Option<String> = None;
let mut lines = if let Ok(line) = std::str::from_utf8(buf) {
line.lines().peekable()
} else {
return Err(Box::new(std::io::Error::new(
ErrorKind::Other,
"invalid header",
)));
return Err(ParseHeadersErrorKind::InvalidHeader.into());
};

if let Some(line) = lines.next() {
let line = line
.strip_prefix(HEADER_LINE)
.ok_or_else(|| {
Box::new(std::io::Error::new(
ErrorKind::Other,
"version line does not start with NATS/1.0",
))
})?
.ok_or_else(|| ParseHeadersError::new(ParseHeadersErrorKind::InvalidHeader))?
.trim();

match line.split_once(' ') {
Expand All @@ -1355,10 +1389,7 @@ fn parse_headers(
}
}
} else {
return Err(Box::new(std::io::Error::new(
ErrorKind::Other,
"expected header information not found",
)));
return Err(ParseHeadersErrorKind::HeaderInfoMissing.into());
};

while let Some(line) = lines.next() {
Expand All @@ -1373,16 +1404,9 @@ fn parse_headers(
s.push_str(v.trim());
}

headers.insert(
HeaderName::from_str(k)?,
HeaderValue::from_str(&s)
.map_err(|err| Box::new(io::Error::new(ErrorKind::Other, err)))?,
);
headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(&s)?);
} else {
return Err(Box::new(std::io::Error::new(
ErrorKind::Other,
"malformed header line",
)));
return Err(ParseHeadersErrorKind::MalformedHeader.into());
}
}

Expand Down Expand Up @@ -1528,6 +1552,9 @@ pub struct External {
pub delivery_prefix: Option<String>,
}

use crate::error::WithKind;
use crate::header::{ParseHeaderNameError, ParseHeaderValueError};
use crate::status::InvalidStatusCode;
use std::marker::PhantomData;

#[derive(Debug, Default)]
Expand Down

0 comments on commit ba6866d

Please sign in to comment.