From 13ae99ea186949fe0b25e66f700f8630d3b2f113 Mon Sep 17 00:00:00 2001 From: Colin Snover Date: Mon, 23 Oct 2023 16:47:27 -0500 Subject: [PATCH] Implement usage tracking in codegen for `stream_position` calls This should increase the accuracy of positioning information in errors and ensures `stream_position` calls will only be emitted when the result may actually be used. There are still too many `stream_position` calls in part because there is no API to coordinate between parent and child objects to prevent children from doing their own position management when the parent is already doing it. Improving this will probably require some kind of rudimentary stack tracking, or at least to split the API so that `read_options` is an entrypoint and then there is e.g. `read_inner` that does not try to restore position on error. In future, it should be possible for authors to configure trading off worse error handling with performance by e.g. making error position `Option` and emitting `None`s instead of a variable reference, and giving some option to not leave the stream in a known good state on failure. Refs jam1garner/binrw#220. --- binrw/tests/error.rs | 2 +- binrw_derive/src/binrw/codegen/mod.rs | 75 ++++++++++++- .../src/binrw/codegen/read_options.rs | 55 ++++------ .../src/binrw/codegen/read_options/enum.rs | 101 ++++++++++++------ .../src/binrw/codegen/read_options/map.rs | 28 +++-- .../src/binrw/codegen/read_options/struct.rs | 83 ++++++++++---- .../src/binrw/codegen/write_options.rs | 25 +++-- .../src/binrw/codegen/write_options/enum.rs | 19 +++- .../src/binrw/codegen/write_options/struct.rs | 38 ++++--- .../codegen/write_options/struct_field.rs | 31 ++++-- 10 files changed, 322 insertions(+), 135 deletions(-) diff --git a/binrw/tests/error.rs b/binrw/tests/error.rs index e2593ac4..3f6b5446 100644 --- a/binrw/tests/error.rs +++ b/binrw/tests/error.rs @@ -241,7 +241,7 @@ fn no_seek_data_enum() { (BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => { assert_eq!(m, "rewinding after a failure"); match e.downcast_ref::() { - Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 0), + Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 1), e => panic!("unexpected error {:?}", e), } } diff --git a/binrw_derive/src/binrw/codegen/mod.rs b/binrw_derive/src/binrw/codegen/mod.rs index a3552e9e..d5a6af0c 100644 --- a/binrw_derive/src/binrw/codegen/mod.rs +++ b/binrw_derive/src/binrw/codegen/mod.rs @@ -5,7 +5,8 @@ mod write_options; use crate::{ binrw::parser::{ - Assert, AssertionError, CondEndian, Imports, Input, ParseResult, PassedArgs, StructField, + Assert, AssertionError, CondEndian, Imports, Input, Map, ParseResult, PassedArgs, + StructField, }, named_args::{arg_type_name, derive_from_imports}, util::{quote_spanned_any, IdentStr}, @@ -14,8 +15,8 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; use sanitization::{ ARGS, ARGS_LIFETIME, ARGS_MACRO, ASSERT, ASSERT_ERROR_FN, BINREAD_TRAIT, BINWRITE_TRAIT, - BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, SEEK_TRAIT, TEMP, WRITER, - WRITE_TRAIT, + BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, RESTORE_POSITION, SEEK_TRAIT, + TEMP, WRITER, WRITE_TRAIT, }; use syn::{spanned::Spanned, DeriveInput, Ident, Type}; @@ -194,11 +195,47 @@ fn generate_trait_impl( } } +#[derive(Clone)] +#[must_use] +struct PosEmitter { + used: core::cell::Cell, + stream_var: TokenStream, +} + +impl PosEmitter { + fn new(stream_var: &TokenStream) -> Self { + Self { + used: core::cell::Cell::new(false), + stream_var: stream_var.clone(), + } + } + + fn finish(self) -> TokenStream { + self.used + .get() + .then(|| { + let var = self.stream_var; + quote! { + let #POS = #SEEK_TRAIT::stream_position(#var)?; + } + }) + .unwrap_or_default() + } + + fn pos(&self) -> &'static IdentStr { + self.used.set(true); + &POS + } +} + fn get_args_lifetime(span: proc_macro2::Span) -> syn::Lifetime { syn::Lifetime::new(&format!("'{ARGS_LIFETIME}"), span) } -fn get_assertions(assertions: &[Assert]) -> impl Iterator + '_ { +fn get_assertions<'a>( + pos_emitter: &'a PosEmitter, + assertions: &'a [Assert], +) -> impl Iterator + 'a { assertions.iter().map( |Assert { kw_span, @@ -215,8 +252,10 @@ fn get_assertions(assertions: &[Assert]) -> impl Iterator + } }; + let pos = pos_emitter.pos(); + quote_spanned_any! {*kw_span=> - #ASSERT(#condition, #POS, #error_fn)?; + #ASSERT(#condition, #pos, #error_fn)?; } }, ) @@ -303,6 +342,32 @@ fn get_passed_args(field: &StructField, stream: IdentStr) -> Option } } +fn get_rewind( + input: &Input, + var: &TokenStream, + pos_emitter: PosEmitter, +) -> (TokenStream, TokenStream) { + let needs_rewind = match input.map() { + Map::None => match input { + Input::UnitStruct(_) | Input::Enum(_) => input.magic().is_some(), + Input::Struct(_) => true, + Input::UnitOnlyEnum(e) => e.map.as_repr().is_some(), + }, + Map::Try(_) | Map::Map(_) | Map::Repr(_) => true, + }; + + let rewind = needs_rewind + .then(|| { + let pos = pos_emitter.pos(); + quote! { + .or_else(#RESTORE_POSITION::(#var, #pos)) + } + }) + .unwrap_or_default(); + + (pos_emitter.finish(), rewind) +} + fn get_try_calc(pos: IdentStr, ty: &Type, calc: &TokenStream) -> TokenStream { let map_err = get_map_err(pos, calc.span()); quote_spanned! {ty.span()=> { diff --git a/binrw_derive/src/binrw/codegen/read_options.rs b/binrw_derive/src/binrw/codegen/read_options.rs index ac5e5624..226feb34 100644 --- a/binrw_derive/src/binrw/codegen/read_options.rs +++ b/binrw_derive/src/binrw/codegen/read_options.rs @@ -2,14 +2,13 @@ mod r#enum; mod map; mod r#struct; -use super::{get_assertions, get_destructured_imports}; +use super::{get_assertions, get_destructured_imports, PosEmitter}; use crate::{ binrw::{ codegen::{ get_endian, sanitization::{ - ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, RESTORE_POSITION, - SEEK_TRAIT, + ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, SEEK_TRAIT, }, }, parser::{Input, Magic, Map}, @@ -23,43 +22,33 @@ use r#struct::{generate_struct, generate_unit_struct}; use syn::{spanned::Spanned, Ident}; pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream { + let reader_var = input.stream_ident_or(READER); let name = Some(&derive_input.ident); - let (inner, needs_rewind) = match input.map() { + let pos_emitter = PosEmitter::new(&reader_var); + let inner = match input.map() { Map::None => match input { - Input::UnitStruct(_) => (generate_unit_struct(input, name, None), false), - Input::Struct(s) => (generate_struct(input, name, s), true), - Input::Enum(e) => (generate_data_enum(input, name, e), false), - Input::UnitOnlyEnum(e) => ( - generate_unit_enum(input, name, e), - e.map.as_repr().is_some(), - ), + Input::UnitStruct(_) => generate_unit_struct(input, name, None, &pos_emitter), + Input::Struct(s) => generate_struct(input, name, s, &pos_emitter), + Input::Enum(e) => generate_data_enum(input, name, e, &pos_emitter), + Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter), }, - Map::Try(map) => (map::generate_try_map(input, name, map), true), - Map::Map(map) => (map::generate_map(input, name, map), true), + Map::Try(map) => map::generate_try_map(input, name, map, &pos_emitter), + Map::Map(map) => map::generate_map(input, name, map, &pos_emitter), Map::Repr(ty) => match input { - Input::UnitOnlyEnum(e) => (generate_unit_enum(input, name, e), true), - _ => ( - map::generate_try_map( - input, - name, - "e! { <#ty as core::convert::TryInto<_>>::try_into }, - ), - true, + Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter), + _ => map::generate_try_map( + input, + name, + "e! { <#ty as core::convert::TryInto<_>>::try_into }, + &pos_emitter, ), }, }; - let reader_var = input.stream_ident_or(READER); - - let rewind = (needs_rewind || input.magic().is_some()).then(|| { - quote! { - .or_else(#RESTORE_POSITION::(#reader_var, #POS)) - } - }); - + let (set_pos, rewind) = super::get_rewind(input, &reader_var, pos_emitter); quote! { let #reader_var = #READER; - let #POS = #SEEK_TRAIT::stream_position(#reader_var)?; + #set_pos (|| { #inner })()#rewind @@ -70,15 +59,17 @@ struct PreludeGenerator<'input> { input: &'input Input, reader_var: TokenStream, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> PreludeGenerator<'input> { - fn new(input: &'input Input) -> Self { + fn new(input: &'input Input, pos_emitter: &'input PosEmitter) -> Self { let reader_var = input.stream_ident_or(READER); Self { input, reader_var, out: TokenStream::new(), + pos_emitter, } } @@ -111,7 +102,7 @@ impl<'input> PreludeGenerator<'input> { fn add_magic_pre_assertion(mut self) -> Self { let head = self.out; let magic = get_magic(self.input.magic(), &self.reader_var, OPT); - let pre_assertions = get_assertions(self.input.pre_assertions()); + let pre_assertions = get_assertions(self.pos_emitter, self.input.pre_assertions()); self.out = quote! { #head #magic diff --git a/binrw_derive/src/binrw/codegen/read_options/enum.rs b/binrw_derive/src/binrw/codegen/read_options/enum.rs index 010300ff..80eeb8f8 100644 --- a/binrw_derive/src/binrw/codegen/read_options/enum.rs +++ b/binrw_derive/src/binrw/codegen/read_options/enum.rs @@ -3,9 +3,12 @@ use super::{ PreludeGenerator, }; use crate::binrw::{ - codegen::sanitization::{ - BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD, - RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, + codegen::{ + sanitization::{ + BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, READER, READ_METHOD, + RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, + }, + PosEmitter, }, parser::{Enum, EnumErrorMode, EnumVariant, Input, UnitEnumField, UnitOnlyEnum}, }; @@ -17,16 +20,22 @@ pub(super) fn generate_unit_enum( input: &Input, name: Option<&Ident>, en: &UnitOnlyEnum, + pos_emitter: &PosEmitter, ) -> TokenStream { - let prelude = PreludeGenerator::new(input) + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let read = match en.map.as_repr() { - Some(repr) => generate_unit_enum_repr(&input.stream_ident_or(READER), repr, &en.fields), - None => generate_unit_enum_magic(&input.stream_ident_or(READER), &en.fields), + Some(repr) => generate_unit_enum_repr( + &input.stream_ident_or(READER), + repr, + &en.fields, + pos_emitter, + ), + None => generate_unit_enum_magic(&input.stream_ident_or(READER), &en.fields, pos_emitter), }; quote! { @@ -39,6 +48,7 @@ fn generate_unit_enum_repr( reader_var: &TokenStream, repr: &TokenStream, variants: &[UnitEnumField], + pos_emitter: &PosEmitter, ) -> TokenStream { let clauses = variants.iter().map(|variant| { let ident = &variant.ident; @@ -54,12 +64,14 @@ fn generate_unit_enum_repr( } }); + let pos = pos_emitter.pos(); + quote! { let #TEMP: #repr = #READ_METHOD(#reader_var, #OPT, ())?; #(#clauses else)* { Err(#WITH_CONTEXT( #BIN_ERROR::NoVariantMatch { - pos: #POS, + pos: #pos, }, #BACKTRACE_FRAME::Message({ extern crate alloc; @@ -70,7 +82,11 @@ fn generate_unit_enum_repr( } } -fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField]) -> TokenStream { +fn generate_unit_enum_magic( + reader_var: &TokenStream, + variants: &[UnitEnumField], + pos_emitter: &PosEmitter, +) -> TokenStream { // group fields by the type (Kind) of their magic value, preserve the order let group_by_magic_type = variants.iter().fold( Vec::new(), @@ -94,6 +110,8 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] }, ); + let pos = pos_emitter.pos(); + // for each type (Kind), read and try to match the magic of each field let try_each_magic_type = group_by_magic_type.into_iter().map(|(_kind, fields)| { let amp = fields[0].magic.as_ref().map(|magic| magic.add_ref()); @@ -120,7 +138,7 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] let body = quote! { match #amp #READ_METHOD(#reader_var, #OPT, ())? { #(#matches,)* - _ => Err(#BIN_ERROR::NoVariantMatch { pos: #POS }) + _ => Err(#BIN_ERROR::NoVariantMatch { pos: #pos }) } }; @@ -129,14 +147,14 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] #body })() { v @ Ok(_) => return v, - Err(#TEMP) => { #RESTORE_POSITION_VARIANT(#reader_var, #POS, #TEMP)?; } + Err(#TEMP) => { #RESTORE_POSITION_VARIANT(#reader_var, #pos, #TEMP)?; } } } }); let return_error = quote! { Err(#BIN_ERROR::NoVariantMatch { - pos: #POS + pos: #pos }) }; @@ -146,9 +164,23 @@ fn generate_unit_enum_magic(reader_var: &TokenStream, variants: &[UnitEnumField] } } -pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) -> TokenStream { +pub(super) fn generate_data_enum( + input: &Input, + name: Option<&Ident>, + en: &Enum, + pos_emitter: &PosEmitter, +) -> TokenStream { let return_all_errors = en.error_mode != EnumErrorMode::ReturnUnexpectedError; + let prelude = PreludeGenerator::new(input, pos_emitter) + .add_imports(name) + .add_endian() + .add_magic_pre_assertion() + .reset_position_after_magic() + .finish(); + + let pos = pos_emitter.pos(); + let (create_error_basket, return_error) = if return_all_errors { ( quote! { @@ -157,7 +189,7 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) }, quote! { Err(#BIN_ERROR::EnumErrors { - pos: #POS, + pos: #pos, variant_errors: #ERROR_BASKET }) }, @@ -167,23 +199,16 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) TokenStream::new(), quote! { Err(#BIN_ERROR::NoVariantMatch { - pos: #POS + pos: #pos, }) }, ) }; - let prelude = PreludeGenerator::new(input) - .add_imports(name) - .add_endian() - .add_magic_pre_assertion() - .reset_position_after_magic() - .finish(); - let reader_var = input.stream_ident_or(READER); let try_each_variant = en.variants.iter().map(|variant| { - let body = generate_variant_impl(en, variant); + let body = generate_variant_impl(en, variant, pos_emitter); let handle_error = if return_all_errors { let name = variant.ident().to_string(); @@ -200,7 +225,7 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) })() { ok @ Ok(_) => return ok, Err(error) => { - #RESTORE_POSITION_VARIANT(#reader_var, #POS, error).map(|#TEMP| { + #RESTORE_POSITION_VARIANT(#reader_var, #pos, error).map(|#TEMP| { #handle_error })?; } @@ -216,19 +241,27 @@ pub(super) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) } } -fn generate_variant_impl(en: &Enum, variant: &EnumVariant) -> TokenStream { +fn generate_variant_impl( + en: &Enum, + variant: &EnumVariant, + pos_emitter: &PosEmitter, +) -> TokenStream { let input = Input::Struct(variant.clone().into()); match variant { - EnumVariant::Variant { ident, options } => StructGenerator::new(&input, options) - .read_fields( - None, - Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)), - ) - .initialize_value_with_assertions(Some(ident), &en.assertions) - .return_value() - .finish(), - - EnumVariant::Unit(options) => generate_unit_struct(&input, None, Some(&options.ident)), + EnumVariant::Variant { ident, options } => { + StructGenerator::new(&input, options, pos_emitter) + .read_fields( + None, + Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)), + ) + .initialize_value_with_assertions(Some(ident), &en.assertions) + .return_value() + .finish() + } + + EnumVariant::Unit(options) => { + generate_unit_struct(&input, None, Some(&options.ident), pos_emitter) + } } } diff --git a/binrw_derive/src/binrw/codegen/read_options/map.rs b/binrw_derive/src/binrw/codegen/read_options/map.rs index b743be7b..ac0cf184 100644 --- a/binrw_derive/src/binrw/codegen/read_options/map.rs +++ b/binrw_derive/src/binrw/codegen/read_options/map.rs @@ -3,6 +3,7 @@ use crate::binrw::{ codegen::{ get_assertions, get_map_err, sanitization::{ARGS, OPT, POS, READER, READ_METHOD, THIS}, + PosEmitter, }, parser::Input, }; @@ -10,15 +11,21 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{spanned::Spanned, Ident}; -pub(crate) fn generate_map(input: &Input, name: Option<&Ident>, map: &TokenStream) -> TokenStream { - let prelude = PreludeGenerator::new(input) +pub(super) fn generate_map( + input: &Input, + name: Option<&Ident>, + map: &TokenStream, + pos_emitter: &PosEmitter, +) -> TokenStream { + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let destructure_ref = destructure_ref(input); - let assertions = field_asserts(input).chain(get_assertions(input.assertions())); + let assertions = + field_asserts(input, pos_emitter).chain(get_assertions(pos_emitter, input.assertions())); let reader_var = input.stream_ident_or(READER); // TODO: replace args with top-level arguments and only @@ -42,20 +49,22 @@ pub(crate) fn generate_map(input: &Input, name: Option<&Ident>, map: &TokenStrea } } -pub(crate) fn generate_try_map( +pub(super) fn generate_try_map( input: &Input, name: Option<&Ident>, map: &TokenStream, + pos_emitter: &PosEmitter, ) -> TokenStream { let map_err = get_map_err(POS, map.span()); - let prelude = PreludeGenerator::new(input) + let prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() .add_magic_pre_assertion() .finish(); let destructure_ref = destructure_ref(input); - let assertions = field_asserts(input).chain(get_assertions(input.assertions())); + let assertions = + field_asserts(input, pos_emitter).chain(get_assertions(pos_emitter, input.assertions())); let reader_var = input.stream_ident_or(READER); // TODO: replace args with top-level arguments and only @@ -100,13 +109,16 @@ fn destructure_ref(input: &Input) -> Option { } } -fn field_asserts(input: &Input) -> impl Iterator + '_ { +fn field_asserts<'a>( + input: &'a Input, + pos_emitter: &'a PosEmitter, +) -> impl Iterator + 'a { match input { Input::Struct(input) => either::Left( input .fields .iter() - .flat_map(|field| get_assertions(&field.assertions)), + .flat_map(|field| get_assertions(pos_emitter, &field.assertions)), ), _ => either::Right(core::iter::empty()), } diff --git a/binrw_derive/src/binrw/codegen/read_options/struct.rs b/binrw_derive/src/binrw/codegen/read_options/struct.rs index 10a3114d..b7194479 100644 --- a/binrw_derive/src/binrw/codegen/read_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/read_options/struct.rs @@ -1,6 +1,7 @@ use super::{get_magic, PreludeGenerator}; #[cfg(feature = "verbose-backtrace")] use crate::binrw::backtrace::BacktraceFrame; +use crate::binrw::codegen::PosEmitter; use crate::binrw::parser::Assert; use crate::{ binrw::{ @@ -26,8 +27,9 @@ pub(super) fn generate_unit_struct( input: &Input, name: Option<&Ident>, variant_ident: Option<&Ident>, + pos_emitter: &PosEmitter, ) -> TokenStream { - let prelude = get_prelude(input, name); + let prelude = get_prelude(input, name, false, pos_emitter); let return_type = get_return_type(variant_ident); quote! { #prelude @@ -35,8 +37,13 @@ pub(super) fn generate_unit_struct( } } -pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream { - StructGenerator::new(input, st) +pub(super) fn generate_struct( + input: &Input, + name: Option<&Ident>, + st: &Struct, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructGenerator::new(input, st, pos_emitter) .read_fields(name, None) .initialize_value_with_assertions(None, &[]) .return_value() @@ -47,14 +54,20 @@ pub(super) struct StructGenerator<'input> { input: &'input Input, st: &'input Struct, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> StructGenerator<'input> { - pub(super) fn new(input: &'input Input, st: &'input Struct) -> Self { + pub(super) fn new( + input: &'input Input, + st: &'input Struct, + pos_emitter: &'input PosEmitter, + ) -> Self { Self { input, st, out: TokenStream::new(), + pos_emitter, } } @@ -85,8 +98,8 @@ impl<'input> StructGenerator<'input> { } fn add_assertions(mut self, extra_assertions: &[Assert]) -> Self { - let assertions = - get_assertions(&self.st.assertions).chain(get_assertions(extra_assertions)); + let assertions = get_assertions(self.pos_emitter, &self.st.assertions) + .chain(get_assertions(self.pos_emitter, extra_assertions)); let head = self.out; self.out = quote! { #head @@ -97,12 +110,12 @@ impl<'input> StructGenerator<'input> { } pub(super) fn read_fields(mut self, name: Option<&Ident>, variant_name: Option<&str>) -> Self { - let prelude = get_prelude(self.input, name); - let read_fields = self - .st - .fields - .iter() - .map(|field| generate_field(self.input, field, name, variant_name)); + let prelude = get_prelude(self.input, name, variant_name.is_some(), self.pos_emitter); + let mut pos_emitter = Some(self.pos_emitter); + let read_fields = + self.st.fields.iter().map(|field| { + generate_field(self.input, field, name, variant_name, pos_emitter.take()) + }); self.out = quote! { #prelude #(#read_fields)* @@ -146,13 +159,14 @@ fn generate_field( field: &StructField, name: Option<&Ident>, variant_name: Option<&str>, + pos_emitter: Option<&PosEmitter>, ) -> TokenStream { // temp + ignore == just don't bother if field.is_temp(false) && matches!(field.field_mode, FieldMode::Default) { return TokenStream::new(); } - FieldGenerator::new(input, field) + FieldGenerator::new(input, field, pos_emitter) .read_value() .wrap_map_stream() .try_conversion(name, variant_name) @@ -177,12 +191,22 @@ struct FieldGenerator<'field> { reader_var: TokenStream, endian_var: TokenStream, args_var: Option, + pos_emitter: Cow<'field, PosEmitter>, } impl<'field> FieldGenerator<'field> { - fn new(input: &Input, field: &'field StructField) -> Self { + fn new( + input: &Input, + field: &'field StructField, + pos_emitter: Option<&'field PosEmitter>, + ) -> Self { let (reader_var, endian_var, args_var) = make_field_vars(input, field); + let pos_emitter = pos_emitter.map_or_else( + || Cow::Owned(PosEmitter::new(&input.stream_ident_or(READER))), + Cow::Borrowed, + ); + Self { field, out: TokenStream::new(), @@ -190,6 +214,7 @@ impl<'field> FieldGenerator<'field> { reader_var, endian_var, args_var, + pos_emitter, } } @@ -259,7 +284,7 @@ impl<'field> FieldGenerator<'field> { } fn append_assertions(mut self) -> Self { - let assertions = get_assertions(&self.field.assertions); + let assertions = get_assertions(&self.pos_emitter, &self.field.assertions); let head = self.out; self.out = quote! { #head @@ -279,7 +304,16 @@ impl<'field> FieldGenerator<'field> { } fn finish(self) -> TokenStream { - self.out + let rest = self.out; + if let Cow::Owned(pos_emitter) = self.pos_emitter { + let pos = pos_emitter.finish(); + quote! { + #pos + #rest + } + } else { + rest + } } fn map_value(mut self) -> Self { @@ -586,13 +620,20 @@ fn get_err_context( } } -fn get_prelude(input: &Input, name: Option<&Ident>) -> TokenStream { - PreludeGenerator::new(input) +fn get_prelude( + input: &Input, + name: Option<&Ident>, + is_enum_variant: bool, + pos_emitter: &PosEmitter, +) -> TokenStream { + let mut prelude = PreludeGenerator::new(input, pos_emitter) .add_imports(name) .add_endian() - .add_magic_pre_assertion() - .add_map_stream() - .finish() + .add_magic_pre_assertion(); + if is_enum_variant { + prelude = prelude.reset_position_after_magic(); + } + prelude.add_map_stream().finish() } fn generate_seek_after(reader_var: &TokenStream, field: &StructField) -> TokenStream { diff --git a/binrw_derive/src/binrw/codegen/write_options.rs b/binrw_derive/src/binrw/codegen/write_options.rs index abbe386e..b418b8be 100644 --- a/binrw_derive/src/binrw/codegen/write_options.rs +++ b/binrw_derive/src/binrw/codegen/write_options.rs @@ -5,7 +5,10 @@ mod struct_field; use super::get_map_err; use crate::binrw::{ - codegen::sanitization::{OPT, POS, SEEK_TRAIT, WRITER, WRITE_METHOD}, + codegen::{ + sanitization::{OPT, POS, WRITER, WRITE_METHOD}, + PosEmitter, + }, parser::{Input, Map}, }; use proc_macro2::TokenStream; @@ -16,10 +19,14 @@ use syn::{spanned::Spanned, Ident}; pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream { let name = Some(&derive_input.ident); + let writer_var = input.stream_ident_or(WRITER); + let pos_emitter = PosEmitter::new(&writer_var); let inner = match input.map() { Map::None => match input { - Input::UnitStruct(s) | Input::Struct(s) => generate_struct(input, name, s), - Input::Enum(e) => generate_data_enum(input, name, e), + Input::UnitStruct(s) | Input::Struct(s) => { + generate_struct(input, name, s, &pos_emitter) + } + Input::Enum(e) => generate_data_enum(input, name, e, &pos_emitter), Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e), }, Map::Try(map) | Map::Map(map) => generate_map(input, name, map), @@ -29,14 +36,14 @@ pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenS }, }; - let writer_var = input.stream_ident_or(WRITER); - + let (set_pos, rewind) = super::get_rewind(input, &writer_var, pos_emitter); quote! { let #writer_var = #WRITER; - let #POS = #SEEK_TRAIT::stream_position(#writer_var)?; - #inner - - Ok(()) + #set_pos + (|| { + #inner + Ok(()) + })()#rewind } } diff --git a/binrw_derive/src/binrw/codegen/write_options/enum.rs b/binrw_derive/src/binrw/codegen/write_options/enum.rs index 1d135a08..1d29fbbe 100644 --- a/binrw_derive/src/binrw/codegen/write_options/enum.rs +++ b/binrw_derive/src/binrw/codegen/write_options/enum.rs @@ -1,6 +1,9 @@ use super::{prelude::PreludeGenerator, r#struct::StructGenerator}; use crate::binrw::{ - codegen::sanitization::{OPT, WRITER, WRITE_METHOD}, + codegen::{ + sanitization::{OPT, WRITER, WRITE_METHOD}, + PosEmitter, + }, parser::{Enum, EnumVariant, Input, UnitEnumField, UnitOnlyEnum}, }; use proc_macro2::{Ident, TokenStream}; @@ -25,8 +28,13 @@ pub(crate) fn generate_unit_enum( .finish() } -pub(crate) fn generate_data_enum(input: &Input, name: Option<&Ident>, en: &Enum) -> TokenStream { - EnumGenerator::new(input, name, en, input.stream_ident_or(WRITER)) +pub(super) fn generate_data_enum( + input: &Input, + name: Option<&Ident>, + en: &Enum, + pos_emitter: &PosEmitter, +) -> TokenStream { + EnumGenerator::new(input, name, en, input.stream_ident_or(WRITER), pos_emitter) .write_variants() .prefix_prelude() .finish() @@ -38,6 +46,7 @@ struct EnumGenerator<'a> { name: Option<&'a Ident>, writer_var: TokenStream, out: TokenStream, + pos_emitter: &'a PosEmitter, } impl<'a> EnumGenerator<'a> { @@ -46,6 +55,7 @@ impl<'a> EnumGenerator<'a> { name: Option<&'a Ident>, en: &'a Enum, writer_var: TokenStream, + pos_emitter: &'a PosEmitter, ) -> Self { Self { input, @@ -53,6 +63,7 @@ impl<'a> EnumGenerator<'a> { en, writer_var, out: TokenStream::new(), + pos_emitter, } } @@ -67,7 +78,7 @@ impl<'a> EnumGenerator<'a> { let writer_var = &self.writer_var; let writing = match variant { EnumVariant::Variant { options, .. } => { - StructGenerator::new(None, options, None, &self.writer_var) + StructGenerator::new(None, options, None, &self.writer_var, self.pos_emitter) .write_fields() .prefix_prelude() .finish() diff --git a/binrw_derive/src/binrw/codegen/write_options/struct.rs b/binrw_derive/src/binrw/codegen/write_options/struct.rs index a51cd2da..06139709 100644 --- a/binrw_derive/src/binrw/codegen/write_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/write_options/struct.rs @@ -1,35 +1,48 @@ use super::{prelude::PreludeGenerator, struct_field::write_field}; use crate::binrw::{ - codegen::{get_assertions, sanitization::WRITER}, + codegen::{get_assertions, sanitization::WRITER, PosEmitter}, parser::{Input, Struct}, }; use proc_macro2::TokenStream; use quote::quote; use syn::Ident; -pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream { - StructGenerator::new(Some(input), st, name, &input.stream_ident_or(WRITER)) - .write_fields() - .prefix_assertions() - .prefix_prelude() - .prefix_borrow_fields() - .finish() +pub(super) fn generate_struct( + input: &Input, + name: Option<&Ident>, + st: &Struct, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructGenerator::new( + Some(input), + st, + name, + &input.stream_ident_or(WRITER), + pos_emitter, + ) + .write_fields() + .prefix_assertions() + .prefix_prelude() + .prefix_borrow_fields() + .finish() } -pub(crate) struct StructGenerator<'input> { +pub(super) struct StructGenerator<'input> { input: Option<&'input Input>, st: &'input Struct, name: Option<&'input Ident>, writer_var: &'input TokenStream, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'input> StructGenerator<'input> { - pub(crate) fn new( + pub(super) fn new( input: Option<&'input Input>, st: &'input Struct, name: Option<&'input Ident>, writer_var: &'input TokenStream, + pos_emitter: &'input PosEmitter, ) -> Self { Self { input, @@ -37,6 +50,7 @@ impl<'input> StructGenerator<'input> { name, writer_var, out: TokenStream::new(), + pos_emitter, } } @@ -52,7 +66,7 @@ impl<'input> StructGenerator<'input> { } fn prefix_assertions(mut self) -> Self { - let assertions = get_assertions(&self.st.assertions); + let assertions = get_assertions(self.pos_emitter, &self.st.assertions); let out = self.out; self.out = quote! { @@ -68,7 +82,7 @@ impl<'input> StructGenerator<'input> { .st .fields .iter() - .map(|field| write_field(self.writer_var, field)); + .map(|field| write_field(self.writer_var, field, self.pos_emitter)); self.out = quote! { #(#write_fields)* diff --git a/binrw_derive/src/binrw/codegen/write_options/struct_field.rs b/binrw_derive/src/binrw/codegen/write_options/struct_field.rs index 1e8a310b..8577cc9a 100644 --- a/binrw_derive/src/binrw/codegen/write_options/struct_field.rs +++ b/binrw_derive/src/binrw/codegen/write_options/struct_field.rs @@ -10,6 +10,7 @@ use crate::{ WRITE_MAP_ARGS_TYPE_HINT, WRITE_MAP_INPUT_TYPE_HINT, WRITE_METHOD, WRITE_TRY_MAP_ARGS_TYPE_HINT, WRITE_ZEROES, }, + PosEmitter, }, parser::{FieldMode, Map, StructField}, }, @@ -21,8 +22,12 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Ident}; -pub(crate) fn write_field(writer_var: &TokenStream, field: &StructField) -> TokenStream { - StructFieldGenerator::new(field, writer_var) +pub(super) fn write_field( + writer_var: &TokenStream, + field: &StructField, + pos_emitter: &PosEmitter, +) -> TokenStream { + StructFieldGenerator::new(field, writer_var, pos_emitter) .write_field() .wrap_map_stream() .prefix_map_value() @@ -42,24 +47,32 @@ struct StructFieldGenerator<'input> { outer_writer_var: &'input TokenStream, writer_var: Cow<'input, TokenStream>, out: TokenStream, + pos_emitter: &'input PosEmitter, } impl<'a> StructFieldGenerator<'a> { - fn new(field: &'a StructField, outer_writer_var: &'a TokenStream) -> Self { + fn new( + field: &'a StructField, + outer_writer_var: &'a TokenStream, + pos_emitter: &'a PosEmitter, + ) -> Self { + let writer_var = if field.map_stream.is_some() { + Cow::Owned(make_ident(&field.ident, "reader").into_token_stream()) + } else { + Cow::Borrowed(outer_writer_var) + }; + Self { field, outer_writer_var, - writer_var: if field.map_stream.is_some() { - Cow::Owned(make_ident(&field.ident, "reader").into_token_stream()) - } else { - Cow::Borrowed(outer_writer_var) - }, + writer_var, out: TokenStream::new(), + pos_emitter, } } fn prefix_assertions(mut self) -> Self { - let assertions = get_assertions(&self.field.assertions); + let assertions = get_assertions(self.pos_emitter, &self.field.assertions); let out = self.out; self.out = quote! {