diff --git a/src/de.rs b/src/de.rs index 37c72ddf..7500bdec 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1149,6 +1149,18 @@ where } } +fn is_plain_or_tagged_literal_scalar( + expected: &str, + scalar: &Scalar, + tagged_already: bool, +) -> bool { + match (scalar.style, &scalar.tag, tagged_already) { + (ScalarStyle::Plain, _, _) => true, + (ScalarStyle::Literal, Some(tag), false) => tag == expected, + _ => false, + } +} + fn invalid_type(event: &Event, exp: &dyn Expected) -> Error { enum Void {} @@ -1250,11 +1262,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_bool(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::BOOL, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(boolean) = parse_bool(value) { break visitor.visit_bool(boolean); @@ -1293,11 +1308,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_i64(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::INT, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(int) = parse_signed_int(value, i64::from_str_radix) { break visitor.visit_i64(int); @@ -1315,11 +1333,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_i128(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::INT, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(int) = parse_signed_int(value, i128::from_str_radix) { break visitor.visit_i128(int); @@ -1358,11 +1379,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_u64(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::INT, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(int) = parse_unsigned_int(value, u64::from_str_radix) { break visitor.visit_u64(int); @@ -1380,11 +1404,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_u128(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::INT, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(int) = parse_unsigned_int(value, u128::from_str_radix) { break visitor.visit_u128(int); @@ -1409,11 +1436,14 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, where V: Visitor<'de>, { + let tagged_already = self.current_enum.is_some(); let (next, mark) = self.next_event_mark()?; loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_f64(visitor), - Event::Scalar(scalar) if scalar.style == ScalarStyle::Plain => { + Event::Scalar(scalar) + if is_plain_or_tagged_literal_scalar(Tag::FLOAT, scalar, tagged_already) => + { if let Ok(value) = str::from_utf8(&scalar.value) { if let Some(float) = parse_f64(value) { break visitor.visit_f64(float); diff --git a/tests/test_de.rs b/tests/test_de.rs index 0d16d5e7..fea81007 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -569,3 +569,28 @@ fn test_empty_scalar() { }; test_de(yaml, &expected); } + +#[test] +fn test_python_safe_dump() { + #[derive(Deserialize, PartialEq, Debug)] + struct Frob { + foo: u32, + } + + // This matches output produced by PyYAML's `yaml.safe_dump` when using the + // default_style parameter. + // + // >>> import yaml + // >>> d = {"foo": 7200} + // >>> print(yaml.safe_dump(d, default_style="|")) + // "foo": !!int |- + // 7200 + // + let yaml = indoc! {r#" + "foo": !!int |- + 7200 + "#}; + + let expected = Frob { foo: 7200 }; + test_de(yaml, &expected); +}