Skip to content

Commit

Permalink
Add some general types and untilities for more serde processing (#163)
Browse files Browse the repository at this point in the history
* Add test API for building bolt bytes

* Make from_bytes_ref usable

* Add derive feature to lib dependencies

* Add serializer to flatten types into a map

* Add some future error cases

* Implement new Summary enum

* Add generic stream response

* Format cargo.toml

* Also deserialze tuples as lists

* Implement detail message Record

* Add debug type to debug a packstream bytes sequence

* Make all fields in the Summary optional

* Add new response type that is Summary | Detail

* Rename Mode to Type

* Add into_error

* Add some allow_unused to make clippy happier

* Don't use let-else because MSRV

* Don't use saturating_sub_unsigned because MSRV

* Mark error enums as non_exhaustive
  • Loading branch information
knutwalker committed Jan 21, 2024
1 parent 1c416c5 commit 7294ef7
Show file tree
Hide file tree
Showing 10 changed files with 2,536 additions and 90 deletions.
16 changes: 12 additions & 4 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[package]
name = "neo4rs"
version = "0.8.0-alpha.1"
authors = ["Neo4j Labs <devrel@neo4j.com>", "John Pradeep Vincent <yehohanan7@gmail.com>"]
authors = [
"Neo4j Labs <devrel@neo4j.com>",
"John Pradeep Vincent <yehohanan7@gmail.com>",
]
edition = "2021"
description = "Neo4j driver in rust"
license = "MIT"
Expand All @@ -16,7 +19,10 @@ rust-version = "1.63"
[dependencies]
async-trait = "0.1.0"
bytes = { version = "1.5.0", features = ["serde"] }
chrono = { version = "0.4.23", features = [ "std", "serde"], default_features = false }
chrono = { version = "0.4.23", features = [
"std",
"serde",
], default_features = false }
chrono-tz = "0.8.3"
deadpool = "0.9.0"
delegate = "0.10.0"
Expand All @@ -25,15 +31,17 @@ log = "0.4"
neo4rs-macros = { version = "0.3.0", path = "../macros" }
paste = "1.0.0"
pin-project-lite = "0.2.9"
serde = "1.0.0"
serde = { version = "1.0.185", features = ["derive"] } # TODO: eliminate derive
thiserror = "1.0.7"
tokio = { version = "1.5.0", features = ["full"] }
tokio-rustls = "0.24.0"
url = "2.0.0"
webpki-roots = "0.23.0"

[dev-dependencies]
lenient_semver = { version = "0.4.2", default_features = false, features = ["version_lite"] }
lenient_semver = { version = "0.4.2", default_features = false, features = [
"version_lite",
] }
pretty_env_logger = "0.4.0"
serde = { version = "1.0.185", features = ["derive"] }
serde_bytes = "0.11.0"
Expand Down
119 changes: 79 additions & 40 deletions lib/src/bolt/de.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use std::{
fmt,
marker::PhantomData,
ops::{BitOr, BitOrAssign},
};
use std::{fmt, marker::PhantomData};

use bytes::{Buf, Bytes};
use serde::{
Expand All @@ -22,7 +18,7 @@ impl<'a: 'de, 'de> de::Deserializer<'de> for Deserializer<'a> {

forward_to_deserialize_any! {
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str
string newtype_struct tuple_struct ignored_any
string newtype_struct ignored_any
map unit_struct struct enum identifier
}

Expand All @@ -37,14 +33,26 @@ impl<'a: 'de, 'de> de::Deserializer<'de> for Deserializer<'a> {
where
V: Visitor<'de>,
{
self.parse_next_item(Visitation::MAP_AS_SEQ, visitor)
self.parse_next_item(Visitation::MapAsSeq, visitor)
}

fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_seq(ItemsParser::new(len, self.bytes))
self.parse_next_item(Visitation::SeqAsTuple(len), visitor)
}

fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}

fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -100,7 +108,7 @@ impl<'a: 'de, 'de> de::Deserializer<'de> for Deserializer<'a> {
where
V: Visitor<'de>,
{
self.parse_next_item(Visitation::BYTES_AS_BYTES, visitor)
self.parse_next_item(Visitation::BytesAsBytes, visitor)
}

fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -150,12 +158,26 @@ impl<'de> Deserializer<'de> {
return Err(Error::Empty);
}

if let Visitation::SeqAsTuple(2) = v {
return if self.bytes[0] == 0x92 {
self.bytes.advance(1);
Self::parse_list(v, 2, self.bytes, visitor)
} else {
visitor.visit_seq(ItemsParser::new(2, self.bytes))
};
}

Self::parse(v, self.bytes, visitor)
}

fn skip_next_item(self) -> Result<(), Error> {
self.parse_next_item(Visitation::BytesAsBytes, de::IgnoredAny)
.map(|_| ())
}

fn parse<V: Visitor<'de>>(
v: Visitation,
bytes: &mut Bytes,
bytes: &'de mut Bytes,
visitor: V,
) -> Result<V::Value, Error> {
let marker = bytes.get_u8();
Expand All @@ -164,7 +186,7 @@ impl<'de> Deserializer<'de> {

match hi {
0x8 => Self::parse_string(lo as _, bytes, visitor),
0x9 => Self::parse_list(lo as _, bytes, visitor),
0x9 => Self::parse_list(v, lo as _, bytes, visitor),
0xA => Self::parse_map(v, lo as _, bytes, visitor),
0xB => Self::parse_struct(lo as _, bytes, visitor),
0xC => match lo {
Expand All @@ -185,9 +207,9 @@ impl<'de> Deserializer<'de> {
0x0 => Self::parse_string(bytes.get_u8() as _, bytes, visitor),
0x1 => Self::parse_string(bytes.get_u16() as _, bytes, visitor),
0x2 => Self::parse_string(bytes.get_u32() as _, bytes, visitor),
0x4 => Self::parse_list(bytes.get_u8() as _, bytes, visitor),
0x5 => Self::parse_list(bytes.get_u16() as _, bytes, visitor),
0x6 => Self::parse_list(bytes.get_u32() as _, bytes, visitor),
0x4 => Self::parse_list(v, bytes.get_u8() as _, bytes, visitor),
0x5 => Self::parse_list(v, bytes.get_u16() as _, bytes, visitor),
0x6 => Self::parse_list(v, bytes.get_u32() as _, bytes, visitor),
0x8 => Self::parse_map(v, bytes.get_u8() as _, bytes, visitor),
0x9 => Self::parse_map(v, bytes.get_u16() as _, bytes, visitor),
0xA => Self::parse_map(v, bytes.get_u32() as _, bytes, visitor),
Expand All @@ -202,40 +224,53 @@ impl<'de> Deserializer<'de> {
fn parse_bytes<V: Visitor<'de>>(
v: Visitation,
len: usize,
bytes: &mut Bytes,
bytes: &'de mut Bytes,
visitor: V,
) -> Result<V::Value, Error> {
debug_assert!(bytes.len() >= len);

let bytes = bytes.split_to(len);
if v.visit_bytes_as_bytes() {
visitor.visit_bytes(&bytes)
let bytes: &'de [u8] = unsafe { std::mem::transmute(bytes.as_ref()) };
visitor.visit_borrowed_bytes(bytes)
} else {
visitor.visit_seq(SeqDeserializer::new(bytes.into_iter()))
}
}

fn parse_string<V: Visitor<'de>>(
len: usize,
bytes: &mut Bytes,
bytes: &'de mut Bytes,
visitor: V,
) -> Result<V::Value, Error> {
debug_assert!(bytes.len() >= len);

let bytes = bytes.split_to(len);
let bytes: &'de [u8] = unsafe { std::mem::transmute(bytes.as_ref()) };

match std::str::from_utf8(&bytes) {
Ok(s) => visitor.visit_str(s),
match std::str::from_utf8(bytes) {
Ok(s) => visitor.visit_borrowed_str(s),
Err(e) => Err(Error::InvalidUtf8(e)),
}
}

fn parse_list<V: Visitor<'de>>(
v: Visitation,
len: usize,
bytes: &mut Bytes,
visitor: V,
) -> Result<V::Value, Error> {
visitor.visit_seq(ItemsParser::new(len, bytes))
let items = ItemsParser::new(len, bytes);
match v {
Visitation::SeqAsTuple(tuple_len) => match len.checked_sub(tuple_len) {
None => Err(Error::InvalidLength {
expected: tuple_len,
actual: len,
}),
Some(excess) => visitor.visit_seq(items.with_excess(excess)),
},
_ => visitor.visit_seq(items),
}
}

fn parse_map<V: Visitor<'de>>(
Expand Down Expand Up @@ -266,16 +301,23 @@ impl<'de> Deserializer<'de> {

struct ItemsParser<'a> {
len: usize,
excess: usize,
bytes: SharedBytes<'a>,
}

impl<'a> ItemsParser<'a> {
fn new(len: usize, bytes: &'a mut Bytes) -> Self {
Self {
len,
excess: 0,
bytes: SharedBytes::new(bytes),
}
}

fn with_excess(mut self, excess: usize) -> Self {
self.excess = excess;
self
}
}

impl<'a, 'de> SeqAccess<'de> for ItemsParser<'a> {
Expand All @@ -286,6 +328,10 @@ impl<'a, 'de> SeqAccess<'de> for ItemsParser<'a> {
T: DeserializeSeed<'de>,
{
if self.len == 0 {
let bytes = self.bytes.get();
for _ in 0..self.excess {
Deserializer { bytes }.skip_next_item()?;
}
return Ok(None);
}
self.len -= 1;
Expand Down Expand Up @@ -429,32 +475,21 @@ impl<'a, 'de> EnumAccess<'de> for StructParser<'a> {
}

#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
struct Visitation(u8);
enum Visitation {
#[default]
Default,
BytesAsBytes,
MapAsSeq,
SeqAsTuple(usize),
}

impl Visitation {
const BYTES_AS_BYTES: Self = Self(1);
const MAP_AS_SEQ: Self = Self(2);

fn visit_bytes_as_bytes(self) -> bool {
self.0 & Self::BYTES_AS_BYTES.0 != 0
matches!(self, Self::BytesAsBytes)
}

fn visit_map_as_seq(self) -> bool {
self.0 & Self::MAP_AS_SEQ.0 != 0
}
}

impl BitOr for Visitation {
type Output = Self;

fn bitor(self, rhs: Self) -> Self::Output {
Self(self.0 | rhs.0)
}
}

impl BitOrAssign for Visitation {
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0;
matches!(self, Self::MapAsSeq)
}
}

Expand All @@ -477,6 +512,7 @@ impl<'a> SharedBytes<'a> {
}

#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("Not enough data to parse a bolt stream.")]
Empty,
Expand All @@ -487,6 +523,9 @@ pub enum Error {
#[error("The bytes do no contain valid UTF-8 to produce a string: {0}")]
InvalidUtf8(#[source] std::str::Utf8Error),

#[error("Invalid length: expected {expected}, actual {actual}")]
InvalidLength { expected: usize, actual: usize },

// TODO: copy DeError
#[error("Deserialization error: {0}")]
DeserializationError(String),
Expand Down
Loading

0 comments on commit 7294ef7

Please sign in to comment.