diff --git a/src/generation.rs b/src/generation.rs index 072a7a1..ca27055 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -386,6 +386,17 @@ impl DeserializationCode { self.content } } + +impl From for DeserializationCode { + fn from(content: BlocksOrLines) -> Self { + Self { + content, + read_len_used: false, + throws: false, + } + } +} + /// Context as to how to generate deserialization code. /// formats as {before}{}{after} in a line within the body param, allowing freedom e.g.: /// * {let x = }{}{;} - creation of variables @@ -425,7 +436,7 @@ impl<'a> DeserializeBeforeAfter<'a> { // Result -> T (false, true) => format!("?{}", self.after), // T -> - (true, false) => format!("){}", self.before), + (true, false) => format!("){}", self.after), // expected == found, nothing to be done (false, false) | (true, true) => self.after.to_owned(), } @@ -1669,7 +1680,7 @@ impl GenerationScope { format!( "{}.iter().map(|e| {}).sum()", config.expr, - ty.definite_info("e", types, cli) + ty.definite_info("e", true, types, cli) ) } } @@ -2361,6 +2372,7 @@ impl GenerationScope { types, variant, variant_final_exprs.is_empty(), + None, &mut deser_code.content, cli, ); @@ -3089,6 +3101,11 @@ impl GenerationScope { tag: Option, cli: &Cli, ) { + // I don't believe this is even possible (wouldn't be a single CBOR value + nowhere to embed) + // Just sanity checking since it's not handled in the wrapper code here + assert!(variants + .iter() + .all(|v| !matches!(v.data, EnumVariantData::Inlined(_)))); // Rust only generate_enum(self, types, name, variants, None, true, tag, cli); if cli.wasm { @@ -3923,7 +3940,7 @@ fn create_deserialize_impls( ident: &RustIdent, rep: Option, tag: Option, - len_info: RustStructCBORLen, + len_info: Option, generate_deserialize_embedded: bool, store_encoding: Option<&str>, deser_body: &mut dyn CodeBlock, @@ -3960,7 +3977,9 @@ fn create_deserialize_impls( )); } } - add_deserialize_initial_len_check(deser_body, len_info); + if let Some(len_info) = len_info { + add_deserialize_initial_len_check(deser_body, len_info); + } if generate_deserialize_embedded { deser_body.line( "let ret = Self::deserialize_as_embedded_group(raw, &mut read_len, len);", @@ -3980,7 +3999,9 @@ fn create_deserialize_impls( )); } } - add_deserialize_initial_len_check(deser_body, len_info); + if let Some(len_info) = len_info { + add_deserialize_initial_len_check(deser_body, len_info); + } if generate_deserialize_embedded { deser_body.line( "let ret = Self::deserialize_as_embedded_group(raw, &mut read_len, len);", @@ -3993,7 +4014,9 @@ fn create_deserialize_impls( //deser_body.line("Self::deserialize_as_embedded_group(serializer)"); } let deser_embedded_impl = if generate_deserialize_embedded { - add_deserialize_final_len_check(deser_body, rep, len_info, cli); + if let Some(len_info) = len_info { + add_deserialize_final_len_check(deser_body, rep, len_info, cli); + } deser_body.line("ret"); let mut embedded_impl = codegen::Impl::new(name); embedded_impl.impl_trait("DeserializeEmbeddedGroup"); @@ -5097,7 +5120,7 @@ fn codegen_struct( name, Some(record.rep), tag, - &record.definite_info(types, cli), + &record.definite_info("self", false, types, cli), len_encoding_var .map(|var| { format!("self.encodings.as_ref().map(|encs| encs.{var}).unwrap_or_default()") @@ -5475,7 +5498,7 @@ fn codegen_struct( ser_func.line(format!( "let deser_order = self.encodings.as_ref().filter(|encs| {}encs.orig_deser_order.len() == {}).map(|encs| encs.orig_deser_order.clone()).unwrap_or_else(|| {});", check_canonical, - record.definite_info(types, cli), + record.definite_info("self", false, types, cli), serialization_order)); let mut ser_loop = Block::new("for field_index in deser_order"); let mut ser_loop_match = Block::new("match field_index"); @@ -5700,7 +5723,7 @@ fn codegen_struct( name, Some(record.rep), tag, - record.cbor_len_info(types), + Some(record.cbor_len_info(types)), types.is_plain_group(name), len_encoding_var, &mut deser_scaffolding, @@ -5826,9 +5849,10 @@ fn codegen_group_choices( }; match fields { Some(fields) => { + let inlined = matches!(&variant.data, EnumVariantData::Inlined(_)); let ctor_fields: Vec<&RustField> = fields .iter() - .filter(|f| !f.optional && !f.rust_type.is_fixed_value()) + .filter(|f| (!f.optional || inlined) && !f.rust_type.is_fixed_value()) .collect(); let can_fail = ctor_fields .iter() @@ -5865,10 +5889,12 @@ fn codegen_group_choices( } else { output_comma = true; } - new_func.arg(&field.name, field.rust_type.for_wasm_param(types)); + // always okay - if not inlined this field would be skipped earlier + assert!(!field.optional || inlined); + let wasm_param_type = field.to_embedded_rust_type(); + new_func.arg(&field.name, wasm_param_type.for_wasm_param(types)); ctor.push_str(&ToWasmBoundaryOperations::format( - field - .rust_type + wasm_param_type .from_wasm_boundary_clone(types, &field.name, false) .into_iter(), )); @@ -5957,21 +5983,48 @@ fn add_wasm_enum_getters( let mut add_variant_functions = |ty: &RustType| { let enum_gen_info = EnumVariantInRust::new(types, variant, rep, cli); let mut as_variant = codegen::Function::new(format!("as_{}", variant.name_as_var())); - as_variant - .arg_ref_self() - .vis("pub") - .ret(&format!("Option<{}>", ty.for_wasm_return(types))); + as_variant.arg_ref_self().vis("pub"); let mut variant_match = Block::new("match &self.0"); - variant_match.line(format!( - "{}::{}{} => Some({}),", - rust_crate_struct_from_wasm(types, name, cli), - variant.name, - enum_gen_info.capture_ignore_encodings(), - ty.to_wasm_boundary(types, &enum_gen_info.names[0], true) - )); - variant_match.line("_ => None,"); - as_variant.push_block(variant_match); - s_impl.push_fn(as_variant); + // unfortunately wasm_bindgen doesn't support nested options so we must flatten + // this is a bit ambiguous but it's better than nothing + let supported = if let ConceptualRustType::Optional(inner) = ty.resolve_alias_shallow() + { + if let ConceptualRustType::Optional(_) = inner.resolve_alias_shallow() { + // don't bother - it's triple nested (optional nullable field?). + // this seems to be unable to parseas ? (T / null) + // but we'll keep this here as it makes it easy to make this + // the behavior for skipping vs condensing on double nested ones (optional fields) + println!("skipping {}::as_{}() due to triple nested Options unsupported by wasm_bindgen", name, variant.name_as_var()); + false + } else { + as_variant + .ret(ty.for_wasm_return(types)) + .doc(format!("Returns None if not {} variant OR it is but it's set to None\nThis is to get around wasm_bindgen not supporting Option>", variant.name)); + variant_match.line(format!( + "{}::{}{} => {},", + rust_crate_struct_from_wasm(types, name, cli), + variant.name, + enum_gen_info.capture_ignore_encodings(), + ty.to_wasm_boundary(types, &enum_gen_info.names[0], true) + )); + true + } + } else { + as_variant.ret(&format!("Option<{}>", ty.for_wasm_return(types))); + variant_match.line(format!( + "{}::{}{} => Some({}),", + rust_crate_struct_from_wasm(types, name, cli), + variant.name, + enum_gen_info.capture_ignore_encodings(), + ty.to_wasm_boundary(types, &enum_gen_info.names[0], true) + )); + true + }; + if supported { + variant_match.line("_ => None,"); + as_variant.push_block(variant_match); + s_impl.push_fn(as_variant); + } }; match &variant.data { EnumVariantData::RustType(ty) => { @@ -5992,7 +6045,7 @@ fn add_wasm_enum_getters( "multiple non-fixed not allowed right now for embedding into enums" ); if let Some(&field) = non_fixed_types.first() { - add_variant_functions(&field.rust_type); + add_variant_functions(field.to_embedded_rust_type().as_ref()); } } } @@ -6096,7 +6149,11 @@ impl EnumVariantInRust { for field in record.fields.iter() { if !field.rust_type.is_fixed_value() { names.push(field.name.clone()); - enum_types.push(field.rust_type.for_rust_member(types, false, cli)); + enum_types.push( + field + .to_embedded_rust_type() + .for_rust_member(types, false, cli), + ); } } for enc_field in &enc_fields { @@ -6262,27 +6319,41 @@ fn make_enum_variant_return_if_deserialized( types: &IntermediateTypes, variant: &EnumVariant, no_enum_types: bool, + len_check: Option<(RustStructCBORLen, Representation)>, deser_body: &mut dyn CodeBlock, cli: &Cli, ) -> Block { + let (before, after) = if len_check.is_some() && !no_enum_types { + ("let ret = ", ";") + } else { + ("", "") + }; let variant_deser_code = if no_enum_types { let mut code = gen_scope.generate_deserialize( types, (variant.rust_type()).into(), - DeserializeBeforeAfter::new("", "", false), + DeserializeBeforeAfter::new(before, after, false), DeserializeConfig::new(&variant.name_as_var()), cli, ); + if let Some((len_info, rep)) = len_check { + code = surround_in_len_checks(code, len_info, rep, cli); + } code.content.line("Ok(())"); code } else { - gen_scope.generate_deserialize( + let mut code = gen_scope.generate_deserialize( types, (variant.rust_type()).into(), - DeserializeBeforeAfter::new("", "", true), + DeserializeBeforeAfter::new(before, after, true), DeserializeConfig::new(&variant.name_as_var()), cli, - ) + ); + if let Some((len_info, rep)) = len_check { + code = surround_in_len_checks(code, len_info, rep, cli); + code.content.line("ret"); + } + code }; match variant_deser_code.content.as_single_line() { Some(single_line) if !variant_deser_code.throws => { @@ -6305,6 +6376,20 @@ fn make_enum_variant_return_if_deserialized( } } +fn surround_in_len_checks( + mut main_deser_code: DeserializationCode, + len_info: RustStructCBORLen, + rep: Representation, + cli: &Cli, +) -> DeserializationCode { + let mut len_check_before = DeserializationCode::default(); + add_deserialize_initial_len_check(&mut len_check_before.content, len_info); + main_deser_code.add_to_code(&mut len_check_before); + main_deser_code = len_check_before; + add_deserialize_final_len_check(&mut main_deser_code.content, Some(rep), len_info, cli); + main_deser_code +} + fn make_inline_deser_code( gen_scope: &mut GenerationScope, types: &IntermediateTypes, @@ -6317,12 +6402,6 @@ fn make_inline_deser_code( let mut variant_deser_code = generate_array_struct_deserialization( gen_scope, types, name, record, tag, false, false, cli, ); - add_deserialize_final_len_check( - &mut variant_deser_code.deser_code.content, - Some(record.rep), - record.cbor_len_info(types), - cli, - ); // generate_constructor zips the expressions with the names in the enum_gen_info // so just make sure we're in the same order as returned above assert_eq!( @@ -6340,6 +6419,12 @@ fn make_inline_deser_code( expr }) .collect(); + variant_deser_code.deser_code = surround_in_len_checks( + variant_deser_code.deser_code, + record.cbor_len_info(types), + record.rep, + cli, + ); enum_gen_info.generate_constructor( &mut variant_deser_code.deser_code.content, "Ok(", @@ -6391,9 +6476,6 @@ fn generate_enum( ser_func.line(format!("serializer.write_tag({tag}u64)?;")); } let mut ser_array_match_block = Block::new("match self"); - // we use Dynamic to avoid having any length checks here since we don't know what they are yet without testing the variants - // and it's not worth looking into and complicating things on the off chance that all variants are the same length. - let len_info = RustStructCBORLen::Dynamic; let mut deser_func = make_deserialization_function("deserialize"); let mut error_annotator = make_err_annotate_block(name.as_ref(), "", ""); let deser_body: &mut dyn CodeBlock = if cli.annotate_fields { @@ -6423,7 +6505,7 @@ fn generate_enum( name, rep, tag, - len_info, + None, false, outer_encoding_var, deser_body, @@ -6434,11 +6516,13 @@ fn generate_enum( // We avoid checking ALL variants if we can figure it out by instead checking the type. // This only works when the variants don't have first types in common. let mut non_overlapping_types_match = { - // uses to_byte() instead of directly since Ord not implemented for cbor_event::Type let mut first_types = BTreeSet::new(); let mut duplicates = false; for variant in variants.iter() { for first_type in variant.cbor_types(types) { + // to_byte(0) is used since cbor_event::Type doesn't implement + // Ord or Hash so we can't put it in a set. Since we fix the lenth + // to always 0 this still remains a 1-to-1 mapping to Type. if !first_types.insert(first_type.to_byte(0)) { duplicates = true; } @@ -6567,7 +6651,10 @@ fn generate_enum( .iter() .filter(|field| !field.rust_type.is_fixed_value()) .map(|field| { - new_func.arg(&field.name, field.rust_type.for_rust_move(types, cli)); + new_func.arg( + &field.name, + field.to_embedded_rust_type().for_rust_move(types, cli), + ); field.name.clone() }) .collect(); @@ -6717,7 +6804,7 @@ fn generate_enum( rep.expect("can't inline in type choices"), "serializer", "len_encoding", - &record.definite_info(types, cli), + &record.definite_info("", true, types, cli), cli, ); generate_array_struct_serialization( @@ -6818,12 +6905,23 @@ fn generate_enum( } None => { let mut return_if_deserialized = match &variant.data { - EnumVariantData::RustType(_) => { + EnumVariantData::RustType(ty) => { let mut return_if_deserialized = make_enum_variant_return_if_deserialized( gen_scope, types, variant, enum_gen_info.types.is_empty(), + rep.map(|r| { + let len_info = match ty.conceptual_type.resolve_alias_shallow() { + ConceptualRustType::Rust(ident) + if types.is_plain_group(ident) => + { + types.rust_struct(ident).unwrap().cbor_len_info(types) + } + _ => RustStructCBORLen::Fixed(1), + }; + (len_info, r) + }), deser_body, cli, ); @@ -6882,16 +6980,6 @@ fn generate_enum( } ser_func.push_block(ser_array_match_block); ser_impl.push_fn(ser_func); - // TODO: we pass in a dummy Fixed to avoid the check since we don't support optional fields for plain groups - // which causes an issue with group choices of plain groups where if we generate_deserialize() with - // optional_field = true then we hit asserts (not supported) and - // optional_field = false causes this final check to fail since no elements were read - // A possible workaround for this could be to read it beforehand if possible but - // that gets complicated for optional fields inside those plain groups so we'll - // just avoid this check instead for this one case. - // This can cause issues when there are overlapping (CBOR field-wise) variants inlined here. - // Issue: https://github.com/dcSpark/cddl-codegen/issues/175 - add_deserialize_final_len_check(deser_body, rep, RustStructCBORLen::Fixed(0), cli); match non_overlapping_types_match { Some((mut deser_type_match, deser_covers_all_types)) => { if !deser_covers_all_types { diff --git a/src/intermediate.rs b/src/intermediate.rs index acf9d84..1c30060 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -1,6 +1,7 @@ use cbor_event::Special as CBORSpecial; use cbor_event::{Special, Type as CBORType}; use cddl::ast::parent::ParentVisitor; +use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet}; // TODO: move all of these generation specifics into generation.rs @@ -1222,23 +1223,31 @@ impl RustType { } } - // See comment in RustStruct::definite_info(), this is the same, returns a string expression - // which evaluates to the length. - // self_expr is an expression that evaluates to this RustType (e.g. member, etc) at the point where - // the return of this function will be used. - pub fn definite_info(&self, self_expr: &str, types: &IntermediateTypes, cli: &Cli) -> String { + /// See comment in RustStruct::definite_info(), this is the same, returns a string expression + /// which evaluates to the length. + /// self_expr is an expression that evaluates to this RustType (e.g. member, etc) at the point where + /// the return of this function will be used. + /// self_is_ref whether the above expression is by-ref + pub fn definite_info( + &self, + self_expr: &str, + self_is_ref: bool, + types: &IntermediateTypes, + cli: &Cli, + ) -> String { match self.expanded_field_count(types) { Some(count) => count.to_string(), None => match self.conceptual_type.resolve_alias_shallow() { ConceptualRustType::Optional(ty) => format!( - "match {} {{ Some(x) => {}, None => 1 }}", + "match {}{} {{ Some(x) => {}, None => 1 }}", + if self_is_ref { "" } else { "&" }, self_expr, - ty.definite_info("x", types, cli) + ty.definite_info("x", true, types, cli) ), ConceptualRustType::Rust(ident) => { if types.is_plain_group(ident) { match types.rust_structs.get(ident) { - Some(rs) => rs.definite_info(types, cli), + Some(rs) => rs.definite_info(self_expr, self_is_ref, types, cli), None => panic!( "rust struct {} not found but referenced by {:?}", ident, self @@ -2104,6 +2113,16 @@ impl RustField { key, } } + + pub fn to_embedded_rust_type(&self) -> Cow { + if self.optional { + Cow::Owned(RustType::new(ConceptualRustType::Optional(Box::new( + self.rust_type.clone(), + )))) + } else { + Cow::Borrowed(&self.rust_type) + } + } } #[derive(Clone, Debug, Copy)] @@ -2310,14 +2329,25 @@ impl RustStruct { } } - // Even if fixed_field_count() == None, this will return an expression for - // a definite length, e.g. with optional field checks in the expression - // This is useful for definite-length serialization - pub fn definite_info(&self, types: &IntermediateTypes, cli: &Cli) -> String { + /// Even if fixed_field_count() == None, this will return an expression for + /// a definite length, e.g. with optional field checks in the expression + /// This is useful for definite-length serialization + /// self_expr is an expression that evaluates to this struct (e.g. "self") at the point where + /// the return of this function will be used. + /// self_is_ref whether the above expression is by-ref + pub fn definite_info( + &self, + self_expr: &str, + self_is_ref: bool, + types: &IntermediateTypes, + cli: &Cli, + ) -> String { match &self.variant { - RustStructType::Record(record) => record.definite_info(types, cli), - RustStructType::Table { .. } => String::from("self.0.len() as u64"), - RustStructType::Array { .. } => String::from("self.0.len() as u64"), + RustStructType::Record(record) => { + record.definite_info(self_expr, self_is_ref, types, cli) + } + RustStructType::Table { .. } => format!("{self_expr}.0.len() as u64"), + RustStructType::Array { .. } => format!("{self_expr}.0.len() as u64"), RustStructType::TypeChoice { .. } => { unreachable!("I don't think type choices should be using length?") } @@ -2357,7 +2387,7 @@ impl RustStruct { } } - fn _cbor_len_info(&self, types: &IntermediateTypes) -> RustStructCBORLen { + pub fn cbor_len_info(&self, types: &IntermediateTypes) -> RustStructCBORLen { match &self.variant { RustStructType::Record(record) => record.cbor_len_info(types), RustStructType::Table { .. } => RustStructCBORLen::Dynamic, @@ -2451,8 +2481,19 @@ impl RustRecord { Some(count) } - // This is guaranteed - pub fn definite_info(&self, types: &IntermediateTypes, cli: &Cli) -> String { + /// This is guaranteed + /// If inlined_enum is set, assumes the field is accessible via a local reference e.g. match branch + /// Otherwise assumes it's a field e.g. self.name + /// self_expr is an expression that evaluates to this struct (e.g. "self") at the point where + /// the return of this function will be used. + /// self_is_ref whether the above expression is by-ref + pub fn definite_info( + &self, + self_expr: &str, + self_is_ref: bool, + types: &IntermediateTypes, + cli: &Cli, + ) -> String { match self.fixed_field_count(types) { Some(count) => count.to_string(), None => { @@ -2467,33 +2508,53 @@ impl RustRecord { if !conditional_field_expr.is_empty() { conditional_field_expr.push_str(" + "); } - let (field_expr, field_contribution) = match self.rep { - Representation::Array => { - ("x", field.rust_type.definite_info("x", types, cli)) - } - // maps are defined by their keys instead (although they shouldn't have multi-length values either...) - Representation::Map => ("_", String::from("1")), + let self_field_expr = if self_expr.is_empty() { + Cow::Borrowed(&field.name) + } else { + Cow::Owned(format!("{}.{}", self_expr, field.name)) }; if let Some(default_value) = &field.rust_type.config.default { + let field_contribution = match self.rep { + Representation::Array => Cow::Owned(field.rust_type.definite_info( + &self_field_expr, + true, + types, + cli, + )), + // maps are defined by their keys instead (although they shouldn't have multi-length values either...) + Representation::Map => Cow::Borrowed("1"), + }; if cli.preserve_encodings { conditional_field_expr.push_str(&format!( - "if self.{} != {} || self.encodings.as_ref().map(|encs| encs.{}_default_present).unwrap_or(false) {{ {} }} else {{ 0 }}", + "if {}.{} != {} || self.encodings.as_ref().map(|encs| encs.{}_default_present).unwrap_or(false) {{ {} }} else {{ 0 }}", + self_expr, field.name, default_value.to_primitive_str_compare(), field.name, field_contribution)); } else { conditional_field_expr.push_str(&format!( - "if self.{} != {} {{ {} }} else {{ 0 }}", + "if {}.{} != {} {{ {} }} else {{ 0 }}", + self_expr, field.name, default_value.to_primitive_str_compare(), field_contribution )); } } else { + let (field_expr, field_contribution) = match self.rep { + Representation::Array => { + ("x", field.rust_type.definite_info("x", true, types, cli)) + } + // maps are defined by their keys instead (although they shouldn't have multi-length values either...) + Representation::Map => ("_", String::from("1")), + }; conditional_field_expr.push_str(&format!( - "match &self.{} {{ Some({}) => {}, None => 0 }}", - field.name, field_expr, field_contribution + "match {}{} {{ Some({}) => {}, None => 0 }}", + if self_is_ref { "" } else { "&" }, + self_field_expr, + field_expr, + field_contribution )); } } else { @@ -2509,6 +2570,7 @@ impl RustRecord { } let field_len_expr = field.rust_type.definite_info( &format!("self.{}", field.name), + false, types, cli, ); diff --git a/tests/core/input.cddl b/tests/core/input.cddl index f16560f..eb9ed93 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -171,4 +171,23 @@ bounds_group_choice = [ hash // ; @name c 1, x: hash, y: hash -] \ No newline at end of file +] + +enum_opt_embed_fields = [ + ; @name ea + 1 // + ; @name eb + 1, ?text, 5 // + ; @name ec + 1, uint, 7 // +; doesn't parse but would result in triple nesting so worth testing if we can ever parse it +; 1, ? (text / null), #6.9(9) + ; @name ed + 1, uint, ?text // + ; @name ee + 1, uint, ?bytes, uint // + ; @name ef + 1, ? non_overlapping_type_choice_some, #6.11(11) // + ; @name eg + 1, ? overlapping_inlined, #6.13(13) +] diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 904ab76..256dbae 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -215,13 +215,12 @@ mod tests { #[test] fn overlapping_inlined() { - // this test won't work until https://github.com/dcSpark/cddl-codegen/issues/175 is resolved. let overlap0 = OverlappingInlined::new_one(); deser_test(&overlap0); let overlap1 = OverlappingInlined::new_two(9); - //deser_test(&overlap1); + deser_test(&overlap1); let overlap2 = OverlappingInlined::new_three(5, "overlapping".into()); - //deser_test(&overlap2); + deser_test(&overlap2); } #[test] @@ -392,4 +391,40 @@ mod tests { let mut set_group_choice: std::collections::BTreeSet = std::collections::BTreeSet::new(); set_group_choice.insert(GroupChoice::GroupChoice1(37)); } + + #[test] + fn enum_opt_embed_fields() { + let a = EnumOptEmbedFields::new_ea(); + deser_test(&a); + let b1 = EnumOptEmbedFields::new_eb(Some("Hello".to_owned())); + deser_test(&b1); + let b2 = EnumOptEmbedFields::new_eb(None); + deser_test(&b2); + let c = EnumOptEmbedFields::new_ec(100); + deser_test(&c); + let mut d1 = EnumOptEmbedFields::new_ed(1); + match &mut d1 { + EnumOptEmbedFields::Ed(ed) => ed.index_2 = Some("Goodbye".to_owned()), + _ => panic!(), + } + deser_test(&d1); + let d2 = EnumOptEmbedFields::new_ed(2); + deser_test(&d2); + let mut e1 = EnumOptEmbedFields::new_ee(0, 0); + match &mut e1 { + EnumOptEmbedFields::Ee(ee) => ee.index_2 = Some(vec![0xBA, 0xAD, 0xF0, 0x0D]), + _ => panic!(), + } + deser_test(&e1); + let e2 = EnumOptEmbedFields::new_ee(u64::MAX, u64::MAX); + deser_test(&e2); + let f1 = EnumOptEmbedFields::new_ef(Some(NonOverlappingTypeChoiceSome::U64(5))); + deser_test(&f1); + let f2 = EnumOptEmbedFields::new_ef(None); + deser_test(&f2); + let g1 = EnumOptEmbedFields::new_eg(Some(OverlappingInlined::new_two(0))); + deser_test(&g1); + let g2 = EnumOptEmbedFields::new_eg(None); + deser_test(&g2); + } } diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index 7dd2065..4d16d53 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -111,4 +111,32 @@ bounds_group_choice = [ hash // ; @name c 1, x: hash, y: hash +] + +overlapping_inlined = [ + ; @name one + 0 // + ; @name two + 0, uint // + ; @name three + 0, uint, text +] + +enum_opt_embed_fields = [ + ; @name ea + 1 // + ; @name eb + 1, ?text, 5 // + ; @name ec + 1, uint, 7 // +; doesn't parse but would result in triple nesting so worth testing if we can ever parse it +; 1, ? (text / null), #6.9(9) + ; @name ed + 1, uint, ?text // + ; @name ee + 1, uint, ?bytes, uint // + ; @name ef + 1, ? non_overlapping_type_choice_some, #6.11(11) // + ; @name eg + 1, ? overlapping_inlined, #6.13(13) ] \ No newline at end of file diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index da2b780..fb74116 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -366,6 +366,49 @@ mod tests { } } + #[test] + fn overlapping_inlined() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_11_encodings = vec![ + StringLenSz::Len(Sz::One), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(5, Sz::Two), (6, Sz::One)]), + StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]), + ]; + for def_enc in &def_encodings { + // one + let irregular_bytes_one = vec![ + arr_sz(1, *def_enc), + cbor_int(0, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_one = OverlappingInlined::from_cbor_bytes(&irregular_bytes_one).unwrap(); + assert_eq!(irregular_bytes_one, irregular_one.to_cbor_bytes()); + assert!(matches!(irregular_one, OverlappingInlined::One { .. })); + // two + let irregular_bytes_two = vec![ + vec![ARR_INDEF], + cbor_int(0, *def_enc), + cbor_int(u64::MAX as i128, Sz::Eight), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_two = OverlappingInlined::from_cbor_bytes(&irregular_bytes_two).unwrap(); + assert_eq!(irregular_bytes_two, irregular_two.to_cbor_bytes()); + assert!(matches!(irregular_two, OverlappingInlined::Two { .. })); + for str_enc in &str_11_encodings { + // three + let irregular_bytes_three = vec![ + arr_sz(3, *def_enc), + cbor_int(0, *def_enc), + cbor_int(0, *def_enc), + cbor_str_sz("overlapping", str_enc.clone()), + ].into_iter().flatten().clone().collect::>(); + let irregular_three = OverlappingInlined::from_cbor_bytes(&irregular_bytes_three).unwrap(); + assert_eq!(irregular_bytes_three, irregular_three.to_cbor_bytes()); + assert!(matches!(irregular_three, OverlappingInlined::Three { .. })); + } + } + } + #[test] fn non_overlapping_type_choice_some() { let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; @@ -808,4 +851,150 @@ mod tests { let mut set_non_overlap: std::collections::HashSet = std::collections::HashSet::new(); set_non_overlap.insert(NonOverlappingTypeChoiceSome::new_uint(0)); } + + #[test] + fn enum_opt_embed_fields() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_3_encodings = vec![ + StringLenSz::Len(Sz::Eight), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(1, Sz::Two), (2, Sz::One)]), + StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (1, Sz::Four)]), + ]; + for str_enc in &str_3_encodings { + for def_enc in &def_encodings { + for opt_present in [false, true] { + // a + let irregular_bytes_a = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_a = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_a).unwrap(); + assert_eq!(irregular_bytes_a, irregular_a.to_cbor_bytes()); + assert!(matches!(irregular_a, EnumOptEmbedFields::Ea { .. })); + // b (Some) + let irregular_bytes_b1 = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + cbor_str_sz("foo", str_enc.clone()), + cbor_int(5, *def_enc), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_b1 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_b1).unwrap(); + assert_eq!(irregular_bytes_b1, irregular_b1.to_cbor_bytes()); + assert!(matches!(irregular_b1, EnumOptEmbedFields::Eb { .. })); + // b (None) + let irregular_bytes_b2 = vec![ + arr_sz(2, *def_enc), + cbor_int(1, *def_enc), + cbor_int(5, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_b2 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_b2).unwrap(); + assert_eq!(irregular_bytes_b2, irregular_b2.to_cbor_bytes()); + assert!(matches!(irregular_b2, EnumOptEmbedFields::Eb { .. })); + // c + let irregular_bytes_c = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + cbor_int(u64::MAX as i128, Sz::Eight), + cbor_int(7, *def_enc), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_c = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_c).unwrap(); + assert_eq!(irregular_bytes_c, irregular_c.to_cbor_bytes()); + assert!(matches!(irregular_c, EnumOptEmbedFields::Ec { .. })); + // d (Some) + let irregular_bytes_d1 = vec![ + arr_sz(3, *def_enc), + cbor_int(1, *def_enc), + cbor_int(0, *def_enc), + cbor_str_sz("bar", str_enc.clone()), + ].into_iter().flatten().clone().collect::>(); + let irregular_d1 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_d1).unwrap(); + assert_eq!(irregular_bytes_d1, irregular_d1.to_cbor_bytes()); + assert!(matches!(irregular_d1, EnumOptEmbedFields::Ed { .. })); + // d (None) + let irregular_bytes_d2 = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + cbor_int(u64::MAX as i128, Sz::Eight), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_d2 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_d2).unwrap(); + assert_eq!(irregular_bytes_d2, irregular_d2.to_cbor_bytes()); + assert!(matches!(irregular_d2, EnumOptEmbedFields::Ed { .. })); + // e (Some) + let irregular_bytes_e1 = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + cbor_int(0, *def_enc), + cbor_bytes_sz(vec![0x00, 0x01, 0x02], str_enc.clone()), + cbor_int(u64::MAX as i128, Sz::Eight), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_e1 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_e1).unwrap(); + assert_eq!(irregular_bytes_e1, irregular_e1.to_cbor_bytes()); + assert!(matches!(irregular_e1, EnumOptEmbedFields::Ee { .. })); + // e (None) + let irregular_bytes_e2 = vec![ + arr_sz(3, *def_enc), + cbor_int(1, *def_enc), + cbor_int(u64::MAX as i128, Sz::Eight), + cbor_int(0, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_e2 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_e2).unwrap(); + assert_eq!(irregular_bytes_e2, irregular_e2.to_cbor_bytes()); + assert!(matches!(irregular_e2, EnumOptEmbedFields::Ee { .. })); + // f (Some) + let irregular_bytes_f1 = vec![ + arr_sz(3, *def_enc), + cbor_int(1, *def_enc), + cbor_int(u64::MAX as i128, Sz::Eight), + cbor_tag_sz(11, *def_enc), + cbor_int(11, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_f1 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_f1).unwrap(); + assert_eq!(irregular_bytes_f1, irregular_f1.to_cbor_bytes()); + assert!(matches!(irregular_f1, EnumOptEmbedFields::Ef { .. })); + // f (None) + let irregular_bytes_f2 = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + cbor_tag_sz(11, *def_enc), + cbor_int(11, *def_enc), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_f2 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_f2).unwrap(); + assert_eq!(irregular_bytes_f2, irregular_f2.to_cbor_bytes()); + assert!(matches!(irregular_f2, EnumOptEmbedFields::Ef { .. })); + // g (Some) + let irregular_bytes_g1 = vec![ + vec![ARR_INDEF], + cbor_int(1, *def_enc), + arr_sz(3, *def_enc), + cbor_int(0, *def_enc), + cbor_int(3, *def_enc), + cbor_str_sz("xyz", str_enc.clone()), + cbor_tag_sz(13, *def_enc), + cbor_int(13, *def_enc), + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular_g1 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_g1).unwrap(); + assert_eq!(irregular_bytes_g1, irregular_g1.to_cbor_bytes()); + assert!(matches!(irregular_g1, EnumOptEmbedFields::Eg { .. })); + // g (None) + let irregular_bytes_g2 = vec![ + arr_sz(2, *def_enc), + cbor_int(1, *def_enc), + cbor_tag_sz(13, *def_enc), + cbor_int(13, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_g2 = EnumOptEmbedFields::from_cbor_bytes(&irregular_bytes_g2).unwrap(); + assert_eq!(irregular_bytes_g2, irregular_g2.to_cbor_bytes()); + assert!(matches!(irregular_g2, EnumOptEmbedFields::Eg { .. })); + } + } + } + } }