Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make sure that structs are serialized correctly #34

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 58 additions & 34 deletions src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl<'a, W: enc::Write> serde::Serializer for &'a mut Serializer<W> {
type SerializeTupleStruct = BoundedCollect<'a, W>;
type SerializeTupleVariant = BoundedCollect<'a, W>;
type SerializeMap = CollectMap<'a, W>;
type SerializeStruct = BoundedCollect<'a, W>;
type SerializeStructVariant = BoundedCollect<'a, W>;
type SerializeStruct = CollectMap<'a, W>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you educate me on BoundedCollet vs CollectMap?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names might not be great, but the idea now is that BoundedCollect does take the input in the same order as it comes in without further checks, the CollectMap does the whole sorting thing.

Those structs then implement specific traits so that they can be used with those type definitions.

type SerializeStructVariant = CollectMap<'a, W>;

#[inline]
fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
Expand Down Expand Up @@ -265,7 +265,7 @@ impl<'a, W: enc::Write> serde::Serializer for &'a mut Serializer<W> {
len: usize,
) -> Result<Self::SerializeStruct, Self::Error> {
enc::MapStartBounded(len).encode(&mut self.writer)?;
Ok(BoundedCollect { ser: self })
Ok(CollectMap::new(self))
}

#[inline]
Expand All @@ -279,7 +279,7 @@ impl<'a, W: enc::Write> serde::Serializer for &'a mut Serializer<W> {
enc::MapStartBounded(1).encode(&mut self.writer)?;
variant.encode(&mut self.writer)?;
enc::MapStartBounded(len).encode(&mut self.writer)?;
Ok(BoundedCollect { ser: self })
Ok(CollectMap::new(self))
}

#[inline]
Expand Down Expand Up @@ -429,14 +429,50 @@ pub struct CollectMap<'a, W> {
ser: &'a mut Serializer<W>,
}

impl<'a, W> CollectMap<'a, W> {
impl<'a, W> CollectMap<'a, W>
where
W: enc::Write,
{
fn new(ser: &'a mut Serializer<W>) -> Self {
Self {
buffer: BufWriter::new(Vec::new()),
entries: Vec::new(),
ser,
}
}

fn serialize<T: Serialize + ?Sized>(
&mut self,
maybe_key: Option<&'static str>,
value: &T,
) -> Result<(), EncodeError<W::Error>> {
// Instantiate a new serializer, so that the buffer can be re-used.
let mut mem_serializer = Serializer::new(&mut self.buffer);
if let Some(key) = maybe_key {
key.serialize(&mut mem_serializer)
.map_err(|_| EncodeError::Msg("Struct key cannot be serialized.".to_string()))?;
}
value
.serialize(&mut mem_serializer)
.map_err(|_| EncodeError::Msg("Struct value cannot be serialized.".to_string()))?;

self.entries.push(self.buffer.buffer().to_vec());
self.buffer.clear();

Ok(())
}

fn end(mut self) -> Result<(), EncodeError<W::Error>> {
// This sorting step makes sure we have the expected order of the keys. Byte-wise
// comparison gives us the right order as keys in DAG-CBOR are always (text) strings, hence
// have the same CBOR major type 3. The length of the string is encoded in the following
// bits. This means that a shorter string sorts before a longer string.
vmx marked this conversation as resolved.
Show resolved Hide resolved
self.entries.sort_unstable();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is enough. We're stuck with the 7049 rules of length-first sorting. https://ipld.io/specs/codecs/dag-cbor/spec/#strictness

A good extension to your tests would be to add a new field to the struct "a1" which should come after "b".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out and proposing better tests. Though I think it'ss correct. In a previous PR Steb asked me to clarify this a bit in a comment, but it still doesn't seem to be clear. When you've an idea to word it in a better way, please let me know.

As CBOR prefixes the strings with the length, shorter strings are sorted first. I even think this is the reason they came up with this (counter-intuitive) ordering. So that you can take the full encoded value (not just the string itself) and memcmp it for sorting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importantly, even though the length itself is variable length, the length itself is either:

  1. 0-23: inlined into the 5-bit additional data field.
  2. 24-27: 1-8 byte length.

Which means that lexicographical sorting actually works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, I missed that this was sorting over encoded values; and yeah, neat that it actually works given the flexible prefix!

for entry in self.entries {
self.ser.writer.push(&entry)?;
}
Ok(())
}
}

impl<W> serde::ser::SerializeMap for CollectMap<'_, W>
Expand All @@ -448,7 +484,8 @@ where

#[inline]
fn serialize_key<T: Serialize + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error> {
// Instantiate a new serializer, so that the buffer can be re-used.
// The key needs to be add to the buffer without any further operations. Serializing the
// value will then do the necessary flushing etc.
let mut mem_serializer = Serializer::new(&mut self.buffer);
key.serialize(&mut mem_serializer)
.map_err(|_| EncodeError::Msg("Map key cannot be serialized.".to_string()))?;
Expand All @@ -457,34 +494,20 @@ where

#[inline]
fn serialize_value<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Self::Error> {
// Instantiate a new serializer, so that the buffer can be re-used.
let mut mem_serializer = Serializer::new(&mut self.buffer);
value
.serialize(&mut mem_serializer)
.map_err(|_| EncodeError::Msg("Map value cannot be serialized.".to_string()))?;

self.entries.push(self.buffer.buffer().to_vec());
self.buffer.clear();

Ok(())
self.serialize(None, value)
}

#[inline]
fn end(mut self) -> Result<Self::Ok, Self::Error> {
fn end(self) -> Result<Self::Ok, Self::Error> {
enc::MapStartBounded(self.entries.len()).encode(&mut self.ser.writer)?;
// This sorting step makes sure we have the expected order of the keys. Byte-wise
// comparison gives us the right order as keys in DAG-CBOR are always (text) strings, hence
// have the same CBOR major type 3. The length of the string is encoded in the following
// bits. This means that a shorter string sorts before a longer string.
self.entries.sort_unstable();
for entry in self.entries {
self.ser.writer.push(&entry)?;
}
Ok(())
self.end()
}
}

impl<W: enc::Write> serde::ser::SerializeStruct for BoundedCollect<'_, W> {
impl<W> serde::ser::SerializeStruct for CollectMap<'_, W>
where
W: enc::Write,
{
type Ok = ();
type Error = EncodeError<W::Error>;

Expand All @@ -494,17 +517,19 @@ impl<W: enc::Write> serde::ser::SerializeStruct for BoundedCollect<'_, W> {
key: &'static str,
value: &T,
) -> Result<(), Self::Error> {
key.serialize(&mut *self.ser)?;
value.serialize(&mut *self.ser)
self.serialize(Some(key), value)
}

#[inline]
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
self.end()
}
}

impl<W: enc::Write> serde::ser::SerializeStructVariant for BoundedCollect<'_, W> {
impl<W> serde::ser::SerializeStructVariant for CollectMap<'_, W>
where
W: enc::Write,
{
type Ok = ();
type Error = EncodeError<W::Error>;

Expand All @@ -514,13 +539,12 @@ impl<W: enc::Write> serde::ser::SerializeStructVariant for BoundedCollect<'_, W>
key: &'static str,
value: &T,
) -> Result<(), Self::Error> {
key.serialize(&mut *self.ser)?;
value.serialize(&mut *self.ser)
self.serialize(Some(key), value)
}

#[inline]
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
self.end()
}
}

Expand Down
44 changes: 44 additions & 0 deletions tests/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{collections::BTreeMap, iter};

use serde::de::value::{self, MapDeserializer, SeqDeserializer};
use serde_bytes::{ByteBuf, Bytes};
use serde_derive::Serialize;
use serde_ipld_dagcbor::{
from_slice,
ser::{BufWriter, Serializer},
Expand Down Expand Up @@ -190,3 +191,46 @@ fn test_non_unbound_list() {
let result = serializer.into_inner().into_inner();
assert_eq!(result, expected);
}

#[test]
fn test_struct_canonical() {
#[derive(Serialize)]
struct First {
a: u32,
b: u32,
}
#[derive(Serialize)]
struct Second {
b: u32,
a: u32,
}

let first = First { a: 1, b: 2 };
let second = Second { a: 1, b: 2 };

let first_bytes = serde_ipld_dagcbor::to_vec(&first).unwrap();
let second_bytes = serde_ipld_dagcbor::to_vec(&second).unwrap();

assert_eq!(first_bytes, second_bytes);
}

#[test]
fn test_struct_variant_canonical() {
#[derive(Serialize)]
enum First {
Data { a: u8, b: u8, c: u8 },
}

#[derive(Serialize)]
enum Second {
Data { b: u8, c: u8, a: u8 },
}

let first = First::Data { a: 1, b: 2, c: 3 };
let second = Second::Data { a: 1, b: 2, c: 3 };

let first_bytes = serde_ipld_dagcbor::to_vec(&first).unwrap();
let second_bytes = serde_ipld_dagcbor::to_vec(&second).unwrap();

assert_eq!(first_bytes, second_bytes);
}
2 changes: 1 addition & 1 deletion tests/std_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ testcase!(test_person_struct,
year_of_birth: 1906,
profession: Some("computer scientist".to_string()),
},
"a3646e616d656c477261636520486f707065726d796561725f6f665f62697274681907726a70726f66657373696f6e72636f6d707574657220736369656e74697374");
"a3646e616d656c477261636520486f707065726a70726f66657373696f6e72636f6d707574657220736369656e746973746d796561725f6f665f6269727468190772");

#[derive(Debug, PartialEq, Deserialize, Serialize)]
struct OptionalPerson {
Expand Down