Skip to content

Commit

Permalink
Merge pull request #17 from ModProg/new_type_enum
Browse files Browse the repository at this point in the history
allow newtype enum variants containing structs to be structs
  • Loading branch information
ecton committed May 20, 2023
2 parents aab818f + b557fe2 commit aeb5e56
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 45 deletions.
145 changes: 100 additions & 45 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ use crate::tokenizer::{self, Integer};

pub struct Deserializer<'de> {
parser: BetterPeekable<Parser<'de>>,
newtype_state: Option<NewtypeState>,
}

#[derive(Clone, Copy, Eq, PartialEq)]
enum NewtypeState {
StructVariant,
TupleVariant,
}

impl<'de> Deserializer<'de> {
pub fn new(source: &'de str, config: Config) -> Self {
Self {
parser: BetterPeekable::new(Parser::new(source, config.include_comments(false))),
newtype_state: None,
}
}

Expand Down Expand Up @@ -86,8 +94,20 @@ impl<'de> Deserializer<'de> {
}
}
}

fn set_newtype_state(&mut self, state: NewtypeState) -> NewtypeStateModification {
let old_state = self.newtype_state.replace(state);
NewtypeStateModification(old_state)
}

fn finish_newtype(&mut self, modification: NewtypeStateModification) -> Option<NewtypeState> {
core::mem::replace(&mut self.newtype_state, modification.0)
}
}

#[must_use]
struct NewtypeStateModification(Option<NewtypeState>);

macro_rules! deserialize_int_impl {
($de_name:ident, $visit_name:ident, $conv_name:ident) => {
fn $de_name<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -503,37 +523,59 @@ impl<'de> serde::de::Deserializer<'de> for &mut Deserializer<'de> {
where
V: serde::de::Visitor<'de>,
{
let is_parsing_newtype_tuple =
matches!(self.newtype_state, Some(NewtypeState::TupleVariant));
let next_token_is_nested_tuple = matches!(
self.parser.peek(),
Some(Ok(Event {
kind: EventKind::BeginNested {
kind: Nested::Tuple,
..
},
..
}))
);
self.with_error_context(|de| {
match de.parser.next().transpose()? {
Some(Event {
kind: EventKind::BeginNested { name, kind },
location,
}) => {
if name.map_or(false, |name| name != struct_name) {
return Err(DeserializerError::new(
location,
ErrorKind::NameMismatch(struct_name),
));
}
if is_parsing_newtype_tuple {
if next_token_is_nested_tuple {
// We have a multi-nested newtype situation here, and to enable
// parsing the `)` easily, we need to "take over" by erasing the
// current newtype state.
de.parser.next();
return visitor.visit_seq(SequenceDeserializer::new(de));
}
} else {
match de.parser.next().transpose()? {
Some(Event {
kind: EventKind::BeginNested { name, kind },
location,
}) => {
if name.map_or(false, |name| name != struct_name) {
return Err(DeserializerError::new(
location,
ErrorKind::NameMismatch(struct_name),
));
}

if kind != Nested::Tuple {
if kind != Nested::Tuple {
return Err(DeserializerError::new(
location,
ErrorKind::ExpectedTupleStruct,
));
}
}
Some(other) => {
return Err(DeserializerError::new(
location,
other.location,
ErrorKind::ExpectedTupleStruct,
));
}
}
Some(other) => {
return Err(DeserializerError::new(
other.location,
ErrorKind::ExpectedTupleStruct,
));
}
None => {
return Err(DeserializerError::new(
None,
parser::ErrorKind::UnexpectedEof,
))
None => {
return Err(DeserializerError::new(
None,
parser::ErrorKind::UnexpectedEof,
))
}
}
}

Expand Down Expand Up @@ -582,7 +624,9 @@ impl<'de> serde::de::Deserializer<'de> for &mut Deserializer<'de> {
kind: EventKind::BeginNested { name, kind },
location,
}) => {
if name.map_or(false, |name| name != struct_name) {
if name.map_or(false, |name| name != struct_name)
&& !matches!(de.newtype_state, Some(NewtypeState::StructVariant))
{
return Err(DeserializerError::new(
location,
ErrorKind::NameMismatch(struct_name),
Expand Down Expand Up @@ -1073,26 +1117,37 @@ impl<'a, 'de> VariantAccess<'de> for EnumVariantAccessor<'a, 'de> {
T: serde::de::DeserializeSeed<'de>,
{
if let EnumVariantAccessor::Nested(deserializer) = self {
let nested_event = deserializer
.parser
.next()
.expect("variant access matched Nested")?;
deserializer.with_error_start(nested_event.location.start, |de| {
let result = seed.deserialize(&mut *de)?;
loop {
if let Event {
kind: EventKind::EndNested,
..
} = de
.parser
.next()
.transpose()?
.expect("eof handled by parser")
{
return Ok(result);
}
let modification = match deserializer.parser.peek() {
Some(Ok(Event {
kind:
EventKind::BeginNested {
kind: Nested::Tuple,
..
},
..
})) => {
let _begin = deserializer.parser.next();
Some(deserializer.set_newtype_state(NewtypeState::TupleVariant))
}
})
Some(Ok(Event {
kind:
EventKind::BeginNested {
kind: Nested::Map, ..
},
..
})) => Some(deserializer.set_newtype_state(NewtypeState::StructVariant)),
_ => None,
};
let result = deserializer.with_error_context(|de| seed.deserialize(&mut *de))?;
if let Some(modification) = modification {
if deserializer.finish_newtype(modification) == Some(NewtypeState::TupleVariant) {
// SequenceDeserializer has a loop in its drop to eat the
// remaining events until the end
drop(SequenceDeserializer::new(&mut *deserializer));
}
}

Ok(result)
} else {
Err(DeserializerError::new(None, ErrorKind::ExpectedTupleStruct))
}
Expand Down
34 changes: 34 additions & 0 deletions src/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,24 @@ enum UntaggedEnum {
Unit(UnitStruct),
}

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
enum TaggedEnum {
Tuple(bool, bool),
Struct { a: u64 },
NewtypeStruct(SimpleStruct),
NewtypeTuple(SimpleTuple),
NewtypeBool(bool),
Unit,
}

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct SimpleStruct {
a: u64,
}

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct SimpleTuple(u64, bool);

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct NewtypeBool(bool);

Expand Down Expand Up @@ -171,6 +184,27 @@ fn deserialize_any() {
assert_eq!(untagged, Some(UntaggedEnum::Unit(UnitStruct)));
}

#[test]
fn deserialize_tagged() {
let tagged: TaggedEnum = crate::from_str("Tuple (true, false)").unwrap();
assert_eq!(tagged, TaggedEnum::Tuple(true, false));
let tagged: TaggedEnum = crate::from_str("Struct {a: 1}").unwrap();
assert_eq!(tagged, TaggedEnum::Struct { a: 1 });

let tagged: TaggedEnum = crate::from_str("NewtypeStruct {a: 1}").unwrap();
assert_eq!(tagged, TaggedEnum::NewtypeStruct(SimpleStruct { a: 1 }));
let tagged: TaggedEnum = crate::from_str("NewtypeStruct({a: 1})").unwrap();
assert_eq!(tagged, TaggedEnum::NewtypeStruct(SimpleStruct { a: 1 }));
let tagged: TaggedEnum = crate::from_str("NewtypeTuple(1, false)").unwrap();
assert_eq!(tagged, TaggedEnum::NewtypeTuple(SimpleTuple(1, false)));
let tagged: TaggedEnum = crate::from_str("NewtypeTuple((1, false))").unwrap();
assert_eq!(tagged, TaggedEnum::NewtypeTuple(SimpleTuple(1, false)));
let tagged: TaggedEnum = crate::from_str("NewtypeBool(true)").unwrap();
assert_eq!(tagged, TaggedEnum::NewtypeBool(true));
let tagged: TaggedEnum = crate::from_str("Unit").unwrap();
assert_eq!(tagged, TaggedEnum::Unit);
}

#[test]
fn value_from_serialize() {
let original = StructOfEverything::default();
Expand Down

0 comments on commit aeb5e56

Please sign in to comment.