diff --git a/binrw/tests/error.rs b/binrw/tests/error.rs index e2593ac4..3f6b5446 100644 --- a/binrw/tests/error.rs +++ b/binrw/tests/error.rs @@ -241,7 +241,7 @@ fn no_seek_data_enum() { (BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => { assert_eq!(m, "rewinding after a failure"); match e.downcast_ref::() { - Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 0), + Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 1), e => panic!("unexpected error {:?}", e), } } diff --git a/binrw_derive/src/binrw/codegen/mod.rs b/binrw_derive/src/binrw/codegen/mod.rs index a3552e9e..d5a6af0c 100644 --- a/binrw_derive/src/binrw/codegen/mod.rs +++ b/binrw_derive/src/binrw/codegen/mod.rs @@ -5,7 +5,8 @@ mod write_options; use crate::{ binrw::parser::{ - Assert, AssertionError, CondEndian, Imports, Input, ParseResult, PassedArgs, StructField, + Assert, AssertionError, CondEndian, Imports, Input, Map, ParseResult, PassedArgs, + StructField, }, named_args::{arg_type_name, derive_from_imports}, util::{quote_spanned_any, IdentStr}, @@ -14,8 +15,8 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; use sanitization::{ ARGS, ARGS_LIFETIME, ARGS_MACRO, ASSERT, ASSERT_ERROR_FN, BINREAD_TRAIT, BINWRITE_TRAIT, - BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, SEEK_TRAIT, TEMP, WRITER, - WRITE_TRAIT, + BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, RESTORE_POSITION, SEEK_TRAIT, + TEMP, WRITER, WRITE_TRAIT, }; use syn::{spanned::Spanned, DeriveInput, Ident, Type}; @@ -194,11 +195,47 @@ fn generate_trait_impl( } } +#[derive(Clone)] +#[must_use] +struct PosEmitter { + used: core::cell::Cell, + stream_var: TokenStream, +} + +impl PosEmitter { + fn new(stream_var: &TokenStream) -> Self { + Self { + used: core::cell::Cell::new(false), + stream_var: stream_var.clone(), + } + } + + fn finish(self) -> TokenStream { + self.used + .get() + .then(|| { + let var = self.stream_var; + quote! { + let #POS = #SEEK_TRAIT::stream_position(#var)?; + } + }) + .unwrap_or_default() + } + + fn pos(&self) -> &'static IdentStr { + self.used.set(true); + &POS + } +} + fn get_args_lifetime(span: proc_macro2::Span) -> syn::Lifetime { syn::Lifetime::new(&format!("'{ARGS_LIFETIME}"), span) } -fn get_assertions(assertions: &[Assert]) -> impl Iterator + '_ { +fn get_assertions<'a>( + pos_emitter: &'a PosEmitter, + assertions: &'a [Assert], +) -> impl Iterator + 'a { assertions.iter().map( |Assert { kw_span, @@ -215,8 +252,10 @@ fn get_assertions(assertions: &[Assert]) -> impl Iterator + } }; + let pos = pos_emitter.pos(); + quote_spanned_any! {*kw_span=> - #ASSERT(#condition, #POS, #error_fn)?; + #ASSERT(#condition, #pos, #error_fn)?; } }, ) @@ -303,6 +342,32 @@ fn get_passed_args(field: &StructField, stream: IdentStr) -> Option } } +fn get_rewind( + input: &Input, + var: &TokenStream, + pos_emitter: PosEmitter, +) -> (TokenStream, TokenStream) { + let needs_rewind = match input.map() { + Map::None => match input { + Input::UnitStruct(_) | Input::Enum(_) => input.magic().is_some(), + Input::Struct(_) => true, + Input::UnitOnlyEnum(e) => e.map.as_repr().is_some(), + }, + Map::Try(_) | Map::Map(_) | Map::Repr(_) => true, + }; + + let rewind = needs_rewind + .then(|| { + let pos = pos_emitter.pos(); + quote! { + .or_else(#RESTORE_POSITION::(#var, #pos)) + } + }) + .unwrap_or_default(); + + (pos_emitter.finish(), rewind) +} + fn get_try_calc(pos: IdentStr, ty: &Type, calc: &TokenStream) -> TokenStream { let map_err = get_map_err(pos, calc.span()); quote_spanned! {ty.span()=> { diff --git a/binrw_derive/src/binrw/codegen/read_options.rs b/binrw_derive/src/binrw/codegen/read_options.rs index ac5e5624..226feb34 100644 --- a/binrw_derive/src/binrw/codegen/read_options.rs +++ b/binrw_derive/src/binrw/codegen/read_options.rs @@ -2,14 +2,13 @@ mod r#enum; mod map; mod r#struct; -use super::{get_assertions, get_destructured_imports}; +use super::{get_assertions, get_destructured_imports, PosEmitter}; use crate::{ binrw::{ codegen::{ get_endian, sanitization::{ - ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, RESTORE_POSITION, - SEEK_TRAIT, + ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, SEEK_TRAIT, }, }, parser::{Input, Magic, Map}, @@ -23,43 +22,33 @@ use r#struct::{generate_struct, generate_unit_struct}; use syn::{spanned::Spanned, Ident}; pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream { + let reader_var = input.stream_ident_or(READER); let name = Some(&derive_input.ident); - let (inner, needs_rewind) = match input.map() { + let pos_emitter = PosEmitter::new(&reader_var); + let inner = match input.map() { Map::None => match input { - Input::UnitStruct(_) => (generate_unit_struct(input, name, None), false), - Input::Struct(s) => (generate_struct(input, name, s), true), - Input::Enum(e) => (generate_data_enum(input, name, e), false), - Input::UnitOnlyEnum(e) => ( - generate_unit_enum(input, name, e), - e.map.as_repr().is_some(), - ), + Input::UnitStruct(_) => generate_unit_struct(input, name, None, &pos_emitter), + Input::Struct(s) => generate_struct(input, name, s, &pos_emitter), + Input::Enum(e) => generate_data_enum(input, name, e, &pos_emitter), + Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter), }, - Map::Try(map) => (map::generate_try_map(input, name, map), true), - Map::Map(map) => (map::generate_map(input, name, map), true), + Map::Try(map) => map::generate_try_map(input, name, map, &pos_emitter), + Map::Map(map) => map::generate_map(input, name, map, &pos_emitter), Map::Repr(ty) => match input { - Input::UnitOnlyEnum(e) => (generate_unit_enum(input, name, e), true), - _ => ( - map::generate_try_map( - input, - name, - "e! { <#ty as core::convert::TryInto<_>>::try_into }, - ), - true, + Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter), + _ => map::generate_try_map( + input, + name, + "e! { <#ty as core::convert::TryInto<_>>::try_into }, + &pos_emitter, ), }, }; - let reader_var = input.stream_ident_or(READER); - - let rewind = (needs_rewind || input.magic().is_some()).then(|| { - quote! { - .or_else(#RESTORE_POSITION::(#reader_var, #POS)) - } - }); - + let (set_pos, rewind) = super::get_rewind(input, &reader_var, pos_emitter); quote! { let #reader_var = #READER; - let #POS = #SEEK_TRAIT::stream_position(#reader_var)?; + #set_pos (|| { #inner })()#rewind @@ -70,15 +59,17 @@ struct PreludeGenerator<'input> { input: &'input Input, reader_var: TokenStream, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> PreludeGenerator<'input> { - fn new(input: &'input Input) -> Self { + fn new(input: &'input Input, pos_emitter: &'input PosEmitter) -> Self { let reader_var = input.stream_ident_or(READER); Self { input, reader_var, out: TokenStream::new(), + pos_emitter, } } @@ -111,7 +102,7 @@ impl<'input> PreludeGenerator<'input> { fn add_magic_pre_assertion(mut self) -> Self { let head = self.out; let magic = get_magic(self.input.magic(), &self.reader_var, OPT); - let pre_assertions = get_assertions(self.input.pre_assertions()); + let pre_assertions = get_assertions(self.pos_emitter, self.input.pre_assertions()); self.out = quote! { #head #magic diff --git a/binrw_derive/src/binrw/codegen/read_options/enum.rs b/binrw_derive/src/binrw/codegen/read_options/enum.rs index 010300ff..80eeb8f8 100644 --- a/binrw_derive/src/binrw/codegen/read_options/enum.rs +++ b/binrw_derive/src/binrw/codegen/read_options/enum.rs @@ -3,9 +3,12 @@ use super::{ PreludeGenerator, }; use crate::binrw::{ - codegen::sanitization::{ - BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD, - RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, + codegen::{ + sanitization::{ + BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, READER, READ_METHOD, + RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, + }, + PosEmitter, }, parser::{Enum, EnumErrorMode, EnumVariant, Input, UnitEnumField, UnitOnlyEnum}, }; @@ -17,16 +20,22 @@ pub(super) fn generate_unit_enum( input: &Input, name: Option<&Ident>, en: &UnitOnlyEnum, + pos_emitter: &PosEmitter, ) -> TokenStream { - let prelude = PreludeGenerator::new(input) + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let read = match en.map.as_repr() { - Some(repr) => generate_unit_enum_repr(&input.stream_ident_or(READER), repr, &en.fields), - None => generate_unit_enum_magic(&input.stream_ident_or(READER), &en.fields), + Some(repr) => generate_unit_enum_repr( + &input.stream_ident_or(READER), + repr, + &en.fields, + pos_emitter, + ), + None => generate_unit_enum_magic(&input.stream_ident_or(READER), &en.fields, pos_emitter), }; quote! { @@ -39,6 +48,7 @@ fn generate_unit_enum_repr( reader_var: &TokenStream, repr: &TokenStream, variants: &[UnitEnumField], + pos_emitter: &PosEmitter, ) -> TokenStream { let clauses = variants.iter().map(|variant| { let ident = &variant.ident; @@ -54,12 +64,14 @@ fn generate_unit_enum_repr( } }); + let pos = pos_emitter.pos(); + quote! { let #TEMP: #repr = #READ_METHOD(#reader_var, #OPT, ())?; #(#clauses else)* { Err(#WITH_CONTEXT( #BIN_ERROR::NoVariantMatch { - pos: #POS, + pos: #pos, }, #BACKTRACE_FRAME::Message({ extern crate alloc; @@ -70,7 +82,11 @@ fn generate_unit_enum_repr( } } -fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField]) -> TokenStream { +fn generate_unit_enum_magic( + reader_var: &TokenStream, + variants: &[UnitEnumField], + pos_emitter: &PosEmitter, +) -> TokenStream { // group fields by the type (Kind) of their magic value, preserve the order let group_by_magic_type = variants.iter().fold( Vec::new(), @@ -94,6 +110,8 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] }, ); + let pos = pos_emitter.pos(); + // for each type (Kind), read and try to match the magic of each field let try_each_magic_type = group_by_magic_type.into_iter().map(|(_kind, fields)| { let amp = fields[0].magic.as_ref().map(|magic| magic.add_ref()); @@ -120,7 +138,7 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] let body = quote! { match #amp #READ_METHOD(#reader_var, #OPT, ())? { #(#matches,)* - _ => Err(#BIN_ERROR::NoVariantMatch { pos: #POS }) + _ => Err(#BIN_ERROR::NoVariantMatch { pos: #pos }) } }; @@ -129,14 +147,14 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] #body })() { v @ Ok(_) => return v, - Err(#TEMP) => { #RESTORE_POSITION_VARIANT(#reader_var, #POS, #TEMP)?; } + Err(#TEMP) => { #RESTORE_POSITION_VARIANT(#reader_var, #pos, #TEMP)?; } } } }); let return_error = quote! { Err(#BIN_ERROR::NoVariantMatch { - pos: #POS + pos: #pos }) }; @@ -146,9 +164,23 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] } } -pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) -> TokenStream { +pub(super) fn generate_data_enum( + input: &Input, + name: Option<&Ident>, + en: &Enum, + pos_emitter: &PosEmitter, +) -> TokenStream { let return_all_errors = en.error_mode != EnumErrorMode::ReturnUnexpectedError; + let prelude = PreludeGenerator::new(input, pos_emitter) + .add_imports(name) + .add_endian() + .add_magic_pre_assertion() + .reset_position_after_magic() + .finish(); + + let pos = pos_emitter.pos(); + let (create_error_basket, return_error) = if return_all_errors { ( quote! { @@ -157,7 +189,7 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) }, quote! { Err(#BIN_ERROR::EnumErrors { - pos: #POS, + pos: #pos, variant_errors: #ERROR_BASKET }) }, @@ -167,23 +199,16 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) TokenStream::new(), quote! { Err(#BIN_ERROR::NoVariantMatch { - pos: #POS + pos: #pos, }) }, ) }; - let prelude = PreludeGenerator::new(input) - .add_imports(name) - .add_endian() - .add_magic_pre_assertion() - .reset_position_after_magic() - .finish(); - let reader_var = input.stream_ident_or(READER); let try_each_variant = en.variants.iter().map(|variant| { - let body = generate_variant_impl(en, variant); + let body = generate_variant_impl(en, variant, pos_emitter); let handle_error = if return_all_errors { let name = variant.ident().to_string(); @@ -200,7 +225,7 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) })() { ok @ Ok(_) => return ok, Err(error) => { - #RESTORE_POSITION_VARIANT(#reader_var, #POS, error).map(|#TEMP| { + #RESTORE_POSITION_VARIANT(#reader_var, #pos, error).map(|#TEMP| { #handle_error })?; } @@ -216,19 +241,27 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) } } -fn generate_variant_impl(en: &Enum, variant: &EnumVariant) -> TokenStream { +fn generate_variant_impl( + en: &Enum, + variant: &EnumVariant, + pos_emitter: &PosEmitter, +) -> TokenStream { let input = Input::Struct(variant.clone().into()); match variant { - EnumVariant::Variant { ident, options } => StructGenerator::new(&input, options) - .read_fields( - None, - Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)), - ) - .initialize_value_with_assertions(Some(ident), &en.assertions) - .return_value() - .finish(), - - EnumVariant::Unit(options) => generate_unit_struct(&input, None, Some(&options.ident)), + EnumVariant::Variant { ident, options } => { + StructGenerator::new(&input, options, pos_emitter) + .read_fields( + None, + Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)), + ) + .initialize_value_with_assertions(Some(ident), &en.assertions) + .return_value() + .finish() + } + + EnumVariant::Unit(options) => { + generate_unit_struct(&input, None, Some(&options.ident), pos_emitter) + } } } diff --git a/binrw_derive/src/binrw/codegen/read_options/map.rs b/binrw_derive/src/binrw/codegen/read_options/map.rs index b743be7b..ac0cf184 100644 --- a/binrw_derive/src/binrw/codegen/read_options/map.rs +++ b/binrw_derive/src/binrw/codegen/read_options/map.rs @@ -3,6 +3,7 @@ use crate::binrw::{ codegen::{ get_assertions, get_map_err, sanitization::{ARGS, OPT, POS, READER, READ_METHOD, THIS}, + PosEmitter, }, parser::Input, }; @@ -10,15 +11,21 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{spanned::Spanned, Ident}; -pub(crate) fn generate_map(input: &Input, name: Option<&Ident>, map: &TokenStream) -> TokenStream { - let prelude = PreludeGenerator::new(input) +pub(super) fn generate_map( + input: &Input, + name: Option<&Ident>, + map: &TokenStream, + pos_emitter: &PosEmitter, +) -> TokenStream { + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let destructure_ref = destructure_ref(input); - let assertions = field_asserts(input).chain(get_assertions(input.assertions())); + let assertions = + field_asserts(input, pos_emitter).chain(get_assertions(pos_emitter, input.assertions())); let reader_var = input.stream_ident_or(READER); // TODO: replace args with top-level arguments and only @@ -42,20 +49,22 @@ pub(crate) fn generate_map(input: &Input, name: Option<&Ident>, map: &TokenStrea } } -pub(crate) fn generate_try_map( +pub(super) fn generate_try_map( input: &Input, name: Option<&Ident>, map: &TokenStream, + pos_emitter: &PosEmitter, ) -> TokenStream { let map_err = get_map_err(POS, map.span()); - let prelude = PreludeGenerator::new(input) + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let destructure_ref = destructure_ref(input); - let assertions = field_asserts(input).chain(get_assertions(input.assertions())); + let assertions = + field_asserts(input, pos_emitter).chain(get_assertions(pos_emitter, input.assertions())); let reader_var = input.stream_ident_or(READER); // TODO: replace args with top-level arguments and only @@ -100,13 +109,16 @@ fn destructure_ref(input: &Input) -> Option { } } -fn field_asserts(input: &Input) -> impl Iterator + '_ { +fn field_asserts<'a>( + input: &'a Input, + pos_emitter: &'a PosEmitter, +) -> impl Iterator + 'a { match input { Input::Struct(input) => either::Left( input .fields .iter() - .flat_map(|field| get_assertions(&field.assertions)), + .flat_map(|field| get_assertions(pos_emitter, &field.assertions)), ), _ => either::Right(core::iter::empty()), } diff --git a/binrw_derive/src/binrw/codegen/read_options/struct.rs b/binrw_derive/src/binrw/codegen/read_options/struct.rs index 10a3114d..b7194479 100644 --- a/binrw_derive/src/binrw/codegen/read_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/read_options/struct.rs @@ -1,6 +1,7 @@ use super::{get_magic, PreludeGenerator}; #[cfg(feature = "verbose-backtrace")] use crate::binrw::backtrace::BacktraceFrame; +use crate::binrw::codegen::PosEmitter; use crate::binrw::parser::Assert; use crate::{ binrw::{ @@ -26,8 +27,9 @@ pub(super) fn generate_unit_struct( input: &Input, name: Option<&Ident>, variant_ident: Option<&Ident>, + pos_emitter: &PosEmitter, ) -> TokenStream { - let prelude = get_prelude(input, name); + let prelude = get_prelude(input, name, false, pos_emitter); let return_type = get_return_type(variant_ident); quote! { #prelude @@ -35,8 +37,13 @@ pub(super) fn generate_unit_struct( } } -pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream { - StructGenerator::new(input, st) +pub(super) fn generate_struct( + input: &Input, + name: Option<&Ident>, + st: &Struct, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructGenerator::new(input, st, pos_emitter) .read_fields(name, None) .initialize_value_with_assertions(None, &[]) .return_value() @@ -47,14 +54,20 @@ pub(super) struct StructGenerator<'input> { input: &'input Input, st: &'input Struct, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> StructGenerator<'input> { - pub(super) fn new(input: &'input Input, st: &'input Struct) -> Self { + pub(super) fn new( + input: &'input Input, + st: &'input Struct, + pos_emitter: &'input PosEmitter, + ) -> Self { Self { input, st, out: TokenStream::new(), + pos_emitter, } } @@ -85,8 +98,8 @@ impl<'input> StructGenerator<'input> { } fn add_assertions(mut self, extra_assertions: &[Assert]) -> Self { - let assertions = - get_assertions(&self.st.assertions).chain(get_assertions(extra_assertions)); + let assertions = get_assertions(self.pos_emitter, &self.st.assertions) + .chain(get_assertions(self.pos_emitter, extra_assertions)); let head = self.out; self.out = quote! { #head @@ -97,12 +110,12 @@ impl<'input> StructGenerator<'input> { } pub(super) fn read_fields(mut self, name: Option<&Ident>, variant_name: Option<&str>) -> Self { - let prelude = get_prelude(self.input, name); - let read_fields = self - .st - .fields - .iter() - .map(|field| generate_field(self.input, field, name, variant_name)); + let prelude = get_prelude(self.input, name, variant_name.is_some(), self.pos_emitter); + let mut pos_emitter = Some(self.pos_emitter); + let read_fields = + self.st.fields.iter().map(|field| { + generate_field(self.input, field, name, variant_name, pos_emitter.take()) + }); self.out = quote! { #prelude #(#read_fields)* @@ -146,13 +159,14 @@ fn generate_field( field: &StructField, name: Option<&Ident>, variant_name: Option<&str>, + pos_emitter: Option<&PosEmitter>, ) -> TokenStream { // temp + ignore == just don't bother if field.is_temp(false) && matches!(field.field_mode, FieldMode::Default) { return TokenStream::new(); } - FieldGenerator::new(input, field) + FieldGenerator::new(input, field, pos_emitter) .read_value() .wrap_map_stream() .try_conversion(name, variant_name) @@ -177,12 +191,22 @@ struct FieldGenerator<'field> { reader_var: TokenStream, endian_var: TokenStream, args_var: Option, + pos_emitter: Cow<'field, PosEmitter>, } impl<'field> FieldGenerator<'field> { - fn new(input: &Input, field: &'field StructField) -> Self { + fn new( + input: &Input, + field: &'field StructField, + pos_emitter: Option<&'field PosEmitter>, + ) -> Self { let (reader_var, endian_var, args_var) = make_field_vars(input, field); + let pos_emitter = pos_emitter.map_or_else( + || Cow::Owned(PosEmitter::new(&input.stream_ident_or(READER))), + Cow::Borrowed, + ); + Self { field, out: TokenStream::new(), @@ -190,6 +214,7 @@ impl<'field> FieldGenerator<'field> { reader_var, endian_var, args_var, + pos_emitter, } } @@ -259,7 +284,7 @@ impl<'field> FieldGenerator<'field> { } fn append_assertions(mut self) -> Self { - let assertions = get_assertions(&self.field.assertions); + let assertions = get_assertions(&self.pos_emitter, &self.field.assertions); let head = self.out; self.out = quote! { #head @@ -279,7 +304,16 @@ impl<'field> FieldGenerator<'field> { } fn finish(self) -> TokenStream { - self.out + let rest = self.out; + if let Cow::Owned(pos_emitter) = self.pos_emitter { + let pos = pos_emitter.finish(); + quote! { + #pos + #rest + } + } else { + rest + } } fn map_value(mut self) -> Self { @@ -586,13 +620,20 @@ fn get_err_context( } } -fn get_prelude(input: &Input, name: Option<&Ident>) -> TokenStream { - PreludeGenerator::new(input) +fn get_prelude( + input: &Input, + name: Option<&Ident>, + is_enum_variant: bool, + pos_emitter: &PosEmitter, +) -> TokenStream { + let mut prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() - .add_magic_pre_assertion() - .add_map_stream() - .finish() + .add_magic_pre_assertion(); + if is_enum_variant { + prelude = prelude.reset_position_after_magic(); + } + prelude.add_map_stream().finish() } fn generate_seek_after(reader_var: &TokenStream, field: &StructField) -> TokenStream { diff --git a/binrw_derive/src/binrw/codegen/write_options.rs b/binrw_derive/src/binrw/codegen/write_options.rs index abbe386e..b418b8be 100644 --- a/binrw_derive/src/binrw/codegen/write_options.rs +++ b/binrw_derive/src/binrw/codegen/write_options.rs @@ -5,7 +5,10 @@ mod struct_field; use super::get_map_err; use crate::binrw::{ - codegen::sanitization::{OPT, POS, SEEK_TRAIT, WRITER, WRITE_METHOD}, + codegen::{ + sanitization::{OPT, POS, WRITER, WRITE_METHOD}, + PosEmitter, + }, parser::{Input, Map}, }; use proc_macro2::TokenStream; @@ -16,10 +19,14 @@ use syn::{spanned::Spanned, Ident}; pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream { let name = Some(&derive_input.ident); + let writer_var = input.stream_ident_or(WRITER); + let pos_emitter = PosEmitter::new(&writer_var); let inner = match input.map() { Map::None => match input { - Input::UnitStruct(s) | Input::Struct(s) => generate_struct(input, name, s), - Input::Enum(e) => generate_data_enum(input, name, e), + Input::UnitStruct(s) | Input::Struct(s) => { + generate_struct(input, name, s, &pos_emitter) + } + Input::Enum(e) => generate_data_enum(input, name, e, &pos_emitter), Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e), }, Map::Try(map) | Map::Map(map) => generate_map(input, name, map), @@ -29,14 +36,14 @@ pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenS }, }; - let writer_var = input.stream_ident_or(WRITER); - + let (set_pos, rewind) = super::get_rewind(input, &writer_var, pos_emitter); quote! { let #writer_var = #WRITER; - let #POS = #SEEK_TRAIT::stream_position(#writer_var)?; - #inner - - Ok(()) + #set_pos + (|| { + #inner + Ok(()) + })()#rewind } } diff --git a/binrw_derive/src/binrw/codegen/write_options/enum.rs b/binrw_derive/src/binrw/codegen/write_options/enum.rs index 1d135a08..1d29fbbe 100644 --- a/binrw_derive/src/binrw/codegen/write_options/enum.rs +++ b/binrw_derive/src/binrw/codegen/write_options/enum.rs @@ -1,6 +1,9 @@ use super::{prelude::PreludeGenerator, r#struct::StructGenerator}; use crate::binrw::{ - codegen::sanitization::{OPT, WRITER, WRITE_METHOD}, + codegen::{ + sanitization::{OPT, WRITER, WRITE_METHOD}, + PosEmitter, + }, parser::{Enum, EnumVariant, Input, UnitEnumField, UnitOnlyEnum}, }; use proc_macro2::{Ident, TokenStream}; @@ -25,8 +28,13 @@ pub(crate) fn generate_unit_enum( .finish() } -pub(crate) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) -> TokenStream { - EnumGenerator::new(input, name, en, input.stream_ident_or(WRITER)) +pub(super) fn generate_data_enum( + input: &Input, + name: Option<&Ident>, + en: &Enum, + pos_emitter: &PosEmitter, +) -> TokenStream { + EnumGenerator::new(input, name, en, input.stream_ident_or(WRITER), pos_emitter) .write_variants() .prefix_prelude() .finish() @@ -38,6 +46,7 @@ struct EnumGenerator<'a> { name: Option<&'a Ident>, writer_var: TokenStream, out: TokenStream, + pos_emitter: &'a PosEmitter, } impl<'a> EnumGenerator<'a> { @@ -46,6 +55,7 @@ impl<'a> EnumGenerator<'a> { name: Option<&'a Ident>, en: &'a Enum, writer_var: TokenStream, + pos_emitter: &'a PosEmitter, ) -> Self { Self { input, @@ -53,6 +63,7 @@ impl<'a> EnumGenerator<'a> { en, writer_var, out: TokenStream::new(), + pos_emitter, } } @@ -67,7 +78,7 @@ impl<'a> EnumGenerator<'a> { let writer_var = &self.writer_var; let writing = match variant { EnumVariant::Variant { options, .. } => { - StructGenerator::new(None, options, None, &self.writer_var) + StructGenerator::new(None, options, None, &self.writer_var, self.pos_emitter) .write_fields() .prefix_prelude() .finish() diff --git a/binrw_derive/src/binrw/codegen/write_options/struct.rs b/binrw_derive/src/binrw/codegen/write_options/struct.rs index a51cd2da..06139709 100644 --- a/binrw_derive/src/binrw/codegen/write_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/write_options/struct.rs @@ -1,35 +1,48 @@ use super::{prelude::PreludeGenerator, struct_field::write_field}; use crate::binrw::{ - codegen::{get_assertions, sanitization::WRITER}, + codegen::{get_assertions, sanitization::WRITER, PosEmitter}, parser::{Input, Struct}, }; use proc_macro2::TokenStream; use quote::quote; use syn::Ident; -pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream { - StructGenerator::new(Some(input), st, name, &input.stream_ident_or(WRITER)) - .write_fields() - .prefix_assertions() - .prefix_prelude() - .prefix_borrow_fields() - .finish() +pub(super) fn generate_struct( + input: &Input, + name: Option<&Ident>, + st: &Struct, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructGenerator::new( + Some(input), + st, + name, + &input.stream_ident_or(WRITER), + pos_emitter, + ) + .write_fields() + .prefix_assertions() + .prefix_prelude() + .prefix_borrow_fields() + .finish() } -pub(crate) struct StructGenerator<'input> { +pub(super) struct StructGenerator<'input> { input: Option<&'input Input>, st: &'input Struct, name: Option<&'input Ident>, writer_var: &'input TokenStream, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> StructGenerator<'input> { - pub(crate) fn new( + pub(super) fn new( input: Option<&'input Input>, st: &'input Struct, name: Option<&'input Ident>, writer_var: &'input TokenStream, + pos_emitter: &'input PosEmitter, ) -> Self { Self { input, @@ -37,6 +50,7 @@ impl<'input> StructGenerator<'input> { name, writer_var, out: TokenStream::new(), + pos_emitter, } } @@ -52,7 +66,7 @@ impl<'input> StructGenerator<'input> { } fn prefix_assertions(mut self) -> Self { - let assertions = get_assertions(&self.st.assertions); + let assertions = get_assertions(self.pos_emitter, &self.st.assertions); let out = self.out; self.out = quote! { @@ -68,7 +82,7 @@ impl<'input> StructGenerator<'input> { .st .fields .iter() - .map(|field| write_field(self.writer_var, field)); + .map(|field| write_field(self.writer_var, field, self.pos_emitter)); self.out = quote! { #(#write_fields)* diff --git a/binrw_derive/src/binrw/codegen/write_options/struct_field.rs b/binrw_derive/src/binrw/codegen/write_options/struct_field.rs index 1e8a310b..8577cc9a 100644 --- a/binrw_derive/src/binrw/codegen/write_options/struct_field.rs +++ b/binrw_derive/src/binrw/codegen/write_options/struct_field.rs @@ -10,6 +10,7 @@ use crate::{ WRITE_MAP_ARGS_TYPE_HINT, WRITE_MAP_INPUT_TYPE_HINT, WRITE_METHOD, WRITE_TRY_MAP_ARGS_TYPE_HINT, WRITE_ZEROES, }, + PosEmitter, }, parser::{FieldMode, Map, StructField}, }, @@ -21,8 +22,12 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Ident}; -pub(crate) fn write_field(writer_var: &TokenStream, field: &StructField) -> TokenStream { - StructFieldGenerator::new(field, writer_var) +pub(super) fn write_field( + writer_var: &TokenStream, + field: &StructField, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructFieldGenerator::new(field, writer_var, pos_emitter) .write_field() .wrap_map_stream() .prefix_map_value() @@ -42,24 +47,32 @@ struct StructFieldGenerator<'input> { outer_writer_var: &'input TokenStream, writer_var: Cow<'input, TokenStream>, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'a> StructFieldGenerator<'a> { - fn new(field: &'a StructField, outer_writer_var: &'a TokenStream) -> Self { + fn new( + field: &'a StructField, + outer_writer_var: &'a TokenStream, + pos_emitter: &'a PosEmitter, + ) -> Self { + let writer_var = if field.map_stream.is_some() { + Cow::Owned(make_ident(&field.ident, "reader").into_token_stream()) + } else { + Cow::Borrowed(outer_writer_var) + }; + Self { field, outer_writer_var, - writer_var: if field.map_stream.is_some() { - Cow::Owned(make_ident(&field.ident, "reader").into_token_stream()) - } else { - Cow::Borrowed(outer_writer_var) - }, + writer_var, out: TokenStream::new(), + pos_emitter, } } fn prefix_assertions(mut self) -> Self { - let assertions = get_assertions(&self.field.assertions); + let assertions = get_assertions(self.pos_emitter, &self.field.assertions); let out = self.out; self.out = quote! {