Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 138 additions & 18 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ struct Deserializer<'de> {
is_untyped: bool,
config: DecoderConfig,
recursion_depth: crate::utils::RecursionDepth,
primitive_vec_fast_path: Option<PrimitiveType>,
}

impl<'de> Deserializer<'de> {
Expand All @@ -320,6 +321,7 @@ impl<'de> Deserializer<'de> {
is_untyped: false,
config: config.clone(),
recursion_depth: crate::utils::RecursionDepth::new(),
primitive_vec_fast_path: None,
})
}
fn dump_state(&self) -> String {
Expand Down Expand Up @@ -616,11 +618,48 @@ impl<'de> Deserializer<'de> {
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum PrimitiveType {
Bool,
Int8,
Int16,
Int32,
Int64,
Nat8,
Nat16,
Nat32,
Nat64,
Float32,
Float64,
}

fn exact_primitive_type(expect: &Type, wire: &Type) -> Option<PrimitiveType> {
match (expect.as_ref(), wire.as_ref()) {
(TypeInner::Bool, TypeInner::Bool) => Some(PrimitiveType::Bool),
(TypeInner::Int8, TypeInner::Int8) => Some(PrimitiveType::Int8),
(TypeInner::Int16, TypeInner::Int16) => Some(PrimitiveType::Int16),
(TypeInner::Int32, TypeInner::Int32) => Some(PrimitiveType::Int32),
(TypeInner::Int64, TypeInner::Int64) => Some(PrimitiveType::Int64),
(TypeInner::Nat8, TypeInner::Nat8) => Some(PrimitiveType::Nat8),
(TypeInner::Nat16, TypeInner::Nat16) => Some(PrimitiveType::Nat16),
(TypeInner::Nat32, TypeInner::Nat32) => Some(PrimitiveType::Nat32),
(TypeInner::Nat64, TypeInner::Nat64) => Some(PrimitiveType::Nat64),
(TypeInner::Float32, TypeInner::Float32) => Some(PrimitiveType::Float32),
(TypeInner::Float64, TypeInner::Float64) => Some(PrimitiveType::Float64),
_ => None,
}
}

macro_rules! primitive_impl {
($ty:ident, $type:expr, $cost:literal, $($value:tt)*) => {
($ty:ident, $type:expr, $fast:expr, $cost:literal, $($value:tt)*) => {
paste::item! {
fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
where V: Visitor<'de> {
if self.primitive_vec_fast_path == Some($fast) {
self.add_cost($cost)?;
let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
return visitor.[<visit_ $ty>](val);
}
self.unroll_type()?;
check!(*self.expect_type == $type && *self.wire_type == $type, stringify!($type));
self.add_cost($cost)?;
Expand Down Expand Up @@ -697,16 +736,64 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
v
}

primitive_impl!(i8, TypeInner::Int8, 1, read_i8);
primitive_impl!(i16, TypeInner::Int16, 2, read_i16::<LittleEndian>);
primitive_impl!(i32, TypeInner::Int32, 4, read_i32::<LittleEndian>);
primitive_impl!(i64, TypeInner::Int64, 8, read_i64::<LittleEndian>);
primitive_impl!(u8, TypeInner::Nat8, 1, read_u8);
primitive_impl!(u16, TypeInner::Nat16, 2, read_u16::<LittleEndian>);
primitive_impl!(u32, TypeInner::Nat32, 4, read_u32::<LittleEndian>);
primitive_impl!(u64, TypeInner::Nat64, 8, read_u64::<LittleEndian>);
primitive_impl!(f32, TypeInner::Float32, 4, read_f32::<LittleEndian>);
primitive_impl!(f64, TypeInner::Float64, 8, read_f64::<LittleEndian>);
primitive_impl!(i8, TypeInner::Int8, PrimitiveType::Int8, 1, read_i8);
primitive_impl!(
i16,
TypeInner::Int16,
PrimitiveType::Int16,
2,
read_i16::<LittleEndian>
);
primitive_impl!(
i32,
TypeInner::Int32,
PrimitiveType::Int32,
4,
read_i32::<LittleEndian>
);
primitive_impl!(
i64,
TypeInner::Int64,
PrimitiveType::Int64,
8,
read_i64::<LittleEndian>
);
primitive_impl!(u8, TypeInner::Nat8, PrimitiveType::Nat8, 1, read_u8);
primitive_impl!(
u16,
TypeInner::Nat16,
PrimitiveType::Nat16,
2,
read_u16::<LittleEndian>
);
primitive_impl!(
u32,
TypeInner::Nat32,
PrimitiveType::Nat32,
4,
read_u32::<LittleEndian>
);
primitive_impl!(
u64,
TypeInner::Nat64,
PrimitiveType::Nat64,
8,
read_u64::<LittleEndian>
);
primitive_impl!(
f32,
TypeInner::Float32,
PrimitiveType::Float32,
4,
read_f32::<LittleEndian>
);
primitive_impl!(
f64,
TypeInner::Float64,
PrimitiveType::Float64,
8,
read_f64::<LittleEndian>
);

fn is_human_readable(&self) -> bool {
false
Expand Down Expand Up @@ -752,10 +839,17 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
self.add_cost(1)?;
visitor.visit_unit()
}
// Bool is handled separately from `primitive_impl!` because its wire encoding
// uses `BoolValue::read` rather than a plain numeric read.
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.primitive_vec_fast_path == Some(PrimitiveType::Bool) {
self.add_cost(1)?;
let res = BoolValue::read(&mut self.input)?;
return visitor.visit_bool(res.0);
}
self.unroll_type()?;
check!(
*self.expect_type == TypeInner::Bool && *self.wire_type == TypeInner::Bool,
Expand Down Expand Up @@ -849,7 +943,16 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
let expect = e.clone();
let wire = self.table.trace_type_with_depth(w, &self.recursion_depth)?;
let len = Len::read(&mut self.input)?.0;
visitor.visit_seq(Compound::new(self, Style::Vector { len, expect, wire }))
let exact_primitive = exact_primitive_type(&expect, &wire);
visitor.visit_seq(Compound::new(
self,
Style::Vector {
len,
expect,
wire,
exact_primitive,
},
))
}
(TypeInner::Record(_), TypeInner::Record(_)) => {
let expect = self.expect_type.clone();
Expand Down Expand Up @@ -1051,6 +1154,7 @@ enum Style {
len: usize,
expect: Type,
wire: Type,
exact_primitive: Option<PrimitiveType>,
},
Struct {
expect: Type,
Expand Down Expand Up @@ -1113,14 +1217,22 @@ impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> {
ref mut len,
ref expect,
ref wire,
exact_primitive,
} => {
if *len == 0 {
return Ok(None);
}
*len -= 1;
self.de.expect_type = expect.clone();
self.de.wire_type = wire.clone();
seed.deserialize(&mut *self.de).map(Some)
let old_fast_path = self.de.primitive_vec_fast_path;
if let Some(exact_primitive) = exact_primitive {
self.de.primitive_vec_fast_path = Some(exact_primitive);
} else {
self.de.expect_type = expect.clone();
self.de.wire_type = wire.clone();
}
let result = seed.deserialize(&mut *self.de).map(Some);
self.de.primitive_vec_fast_path = old_fast_path;
result
}
Style::Struct {
ref expect,
Expand Down Expand Up @@ -1167,6 +1279,14 @@ impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> {
}
}

impl Drop for Compound<'_, '_> {
fn drop(&mut self) {
// Reset fast-path state so it cannot leak if this Compound is dropped
// before all elements are consumed (e.g., on an error path).
self.de.primitive_vec_fast_path = None;
}
}

impl<'de> de::MapAccess<'de> for Compound<'_, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
Expand Down Expand Up @@ -1337,22 +1457,22 @@ impl<'de> de::VariantAccess<'de> for Compound<'_, 'de> {
T: de::DeserializeSeed<'de>,
{
self.de.add_cost(1)?;
seed.deserialize(self.de)
seed.deserialize(&mut *self.de)
}

fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.de.add_cost(1)?;
de::Deserializer::deserialize_tuple(self.de, len, visitor)
de::Deserializer::deserialize_tuple(&mut *self.de, len, visitor)
}

fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.de.add_cost(1)?;
de::Deserializer::deserialize_struct(self.de, "_", fields, visitor)
de::Deserializer::deserialize_struct(&mut *self.de, "_", fields, visitor)
}
}
50 changes: 50 additions & 0 deletions rust/candid/tests/compatibility_vectors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use candid::{CandidType, Decode, Deserialize, Encode};

#[test]
fn primitive_vector_decode_stays_compatible_with_extra_args() {
let values = vec![-1i16, 0, 1, 42, i16::MIN, i16::MAX];
let bytes = Encode!(&values, &123u8).unwrap();

let decoded_values = Decode!(&bytes, Vec<i16>).unwrap();
assert_eq!(decoded_values, values);

let (decoded_values, trailing) = Decode!(&bytes, Vec<i16>, u8).unwrap();
assert_eq!(decoded_values, values);
assert_eq!(trailing, 123);
}

#[test]
fn nested_primitive_vector_decode() {
// Outer vec is non-primitive so exact_primitive is None; inner vecs use the fast path.
let values: Vec<Vec<i16>> = vec![vec![1, 2], vec![], vec![3, i16::MIN, i16::MAX]];
let bytes = Encode!(&values).unwrap();
let decoded: Vec<Vec<i16>> = Decode!(&bytes, Vec<Vec<i16>>).unwrap();
assert_eq!(decoded, values);
}

#[test]
fn struct_with_primitive_vector_field() {
// Ensures primitive_vec_fast_path is correctly restored when a vec<primitive>
// appears as a struct field alongside other fields.
#[derive(CandidType, Deserialize, PartialEq, Debug)]
struct S {
xs: Vec<i32>,
y: u8,
}
let s = S {
xs: vec![1, -1, i32::MAX],
y: 42,
};
let bytes = Encode!(&s).unwrap();
let decoded = Decode!(&bytes, S).unwrap();
assert_eq!(decoded, s);
}

#[test]
fn mismatched_rust_type_does_not_use_fast_path() {
// Wire type is vec nat16 but Rust target is Vec<u32>: expect and wire differ,
// so exact_primitive is None and the normal type-checking path rejects it.
let values: Vec<u16> = vec![1, 2, 3];
let bytes = Encode!(&values).unwrap();
assert!(Decode!(&bytes, Vec<u32>).is_err());
}
Loading