Skip to content

Commit

Permalink
Implement usage tracking in codegen for stream_position calls
Browse files Browse the repository at this point in the history
This should increase the accuracy of positioning information in
errors and ensures `stream_position` calls will only be emitted
when the result may actually be used.

There are still too many `stream_position` calls in part because
there is no API to coordinate between parent and child objects to
prevent children from doing their own position management when the
parent is already doing it. Improving this will probably require
some kind of rudimentary stack tracking, or at least to split the
API so that `read_options` is an entrypoint and then there is e.g.
`read_inner` that does not try to restore position on error.

In future, it should be possible for authors to configure trading
off worse error handling with performance by e.g. making error
position `Option<u64>` and emitting `None`s instead of a variable
reference, and giving some option to not leave the stream in a
known good state on failure.

Refs jam1garner#220.
  • Loading branch information
csnover committed Oct 25, 2023
1 parent 6b2173b commit 13ae99e
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 135 deletions.
2 changes: 1 addition & 1 deletion binrw/tests/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ fn no_seek_data_enum() {
(BacktraceFrame::Message(m), BacktraceFrame::Custom(e)) => {
assert_eq!(m, "rewinding after a failure");
match e.downcast_ref::<binrw::Error>() {
Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 0),
Some(binrw::Error::AssertFail { pos, .. }) => assert_eq!(*pos, 1),
e => panic!("unexpected error {:?}", e),
}
}
Expand Down
75 changes: 70 additions & 5 deletions binrw_derive/src/binrw/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ mod write_options;

use crate::{
binrw::parser::{
Assert, AssertionError, CondEndian, Imports, Input, ParseResult, PassedArgs, StructField,
Assert, AssertionError, CondEndian, Imports, Input, Map, ParseResult, PassedArgs,
StructField,
},
named_args::{arg_type_name, derive_from_imports},
util::{quote_spanned_any, IdentStr},
Expand All @@ -14,8 +15,8 @@ use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use sanitization::{
ARGS, ARGS_LIFETIME, ARGS_MACRO, ASSERT, ASSERT_ERROR_FN, BINREAD_TRAIT, BINWRITE_TRAIT,
BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, SEEK_TRAIT, TEMP, WRITER,
WRITE_TRAIT,
BIN_ERROR, BIN_RESULT, ENDIAN_ENUM, OPT, POS, READER, READ_TRAIT, RESTORE_POSITION, SEEK_TRAIT,
TEMP, WRITER, WRITE_TRAIT,
};
use syn::{spanned::Spanned, DeriveInput, Ident, Type};

Expand Down Expand Up @@ -194,11 +195,47 @@ fn generate_trait_impl<const WRITE: bool>(
}
}

#[derive(Clone)]
#[must_use]
struct PosEmitter {
used: core::cell::Cell<bool>,
stream_var: TokenStream,
}

impl PosEmitter {
fn new(stream_var: &TokenStream) -> Self {
Self {
used: core::cell::Cell::new(false),
stream_var: stream_var.clone(),
}
}

fn finish(self) -> TokenStream {
self.used
.get()
.then(|| {
let var = self.stream_var;
quote! {
let #POS = #SEEK_TRAIT::stream_position(#var)?;
}
})
.unwrap_or_default()
}

fn pos(&self) -> &'static IdentStr {
self.used.set(true);
&POS
}
}

fn get_args_lifetime(span: proc_macro2::Span) -> syn::Lifetime {
syn::Lifetime::new(&format!("'{ARGS_LIFETIME}"), span)
}

fn get_assertions(assertions: &[Assert]) -> impl Iterator<Item = TokenStream> + '_ {
fn get_assertions<'a>(
pos_emitter: &'a PosEmitter,
assertions: &'a [Assert],
) -> impl Iterator<Item = TokenStream> + 'a {
assertions.iter().map(
|Assert {
kw_span,
Expand All @@ -215,8 +252,10 @@ fn get_assertions(assertions: &[Assert]) -> impl Iterator<Item = TokenStream> +
}
};

let pos = pos_emitter.pos();

quote_spanned_any! {*kw_span=>
#ASSERT(#condition, #POS, #error_fn)?;
#ASSERT(#condition, #pos, #error_fn)?;
}
},
)
Expand Down Expand Up @@ -303,6 +342,32 @@ fn get_passed_args(field: &StructField, stream: IdentStr) -> Option<TokenStream>
}
}

fn get_rewind(
input: &Input,
var: &TokenStream,
pos_emitter: PosEmitter,
) -> (TokenStream, TokenStream) {
let needs_rewind = match input.map() {
Map::None => match input {
Input::UnitStruct(_) | Input::Enum(_) => input.magic().is_some(),
Input::Struct(_) => true,
Input::UnitOnlyEnum(e) => e.map.as_repr().is_some(),
},
Map::Try(_) | Map::Map(_) | Map::Repr(_) => true,
};

let rewind = needs_rewind
.then(|| {
let pos = pos_emitter.pos();
quote! {
.or_else(#RESTORE_POSITION::<binrw::Error, _, _>(#var, #pos))
}
})
.unwrap_or_default();

(pos_emitter.finish(), rewind)
}

fn get_try_calc(pos: IdentStr, ty: &Type, calc: &TokenStream) -> TokenStream {
let map_err = get_map_err(pos, calc.span());
quote_spanned! {ty.span()=> {
Expand Down
55 changes: 23 additions & 32 deletions binrw_derive/src/binrw/codegen/read_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ mod r#enum;
mod map;
mod r#struct;

use super::{get_assertions, get_destructured_imports};
use super::{get_assertions, get_destructured_imports, PosEmitter};
use crate::{
binrw::{
codegen::{
get_endian,
sanitization::{
ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, RESTORE_POSITION,
SEEK_TRAIT,
ARGS, ASSERT_MAGIC, MAP_READER_TYPE_HINT, OPT, POS, READER, SEEK_TRAIT,
},
},
parser::{Input, Magic, Map},
Expand All @@ -23,43 +22,33 @@ use r#struct::{generate_struct, generate_unit_struct};
use syn::{spanned::Spanned, Ident};

pub(crate) fn generate(input: &Input, derive_input: &syn::DeriveInput) -> TokenStream {
let reader_var = input.stream_ident_or(READER);
let name = Some(&derive_input.ident);
let (inner, needs_rewind) = match input.map() {
let pos_emitter = PosEmitter::new(&reader_var);
let inner = match input.map() {
Map::None => match input {
Input::UnitStruct(_) => (generate_unit_struct(input, name, None), false),
Input::Struct(s) => (generate_struct(input, name, s), true),
Input::Enum(e) => (generate_data_enum(input, name, e), false),
Input::UnitOnlyEnum(e) => (
generate_unit_enum(input, name, e),
e.map.as_repr().is_some(),
),
Input::UnitStruct(_) => generate_unit_struct(input, name, None, &pos_emitter),
Input::Struct(s) => generate_struct(input, name, s, &pos_emitter),
Input::Enum(e) => generate_data_enum(input, name, e, &pos_emitter),
Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter),
},
Map::Try(map) => (map::generate_try_map(input, name, map), true),
Map::Map(map) => (map::generate_map(input, name, map), true),
Map::Try(map) => map::generate_try_map(input, name, map, &pos_emitter),
Map::Map(map) => map::generate_map(input, name, map, &pos_emitter),
Map::Repr(ty) => match input {
Input::UnitOnlyEnum(e) => (generate_unit_enum(input, name, e), true),
_ => (
map::generate_try_map(
input,
name,
&quote! { <#ty as core::convert::TryInto<_>>::try_into },
),
true,
Input::UnitOnlyEnum(e) => generate_unit_enum(input, name, e, &pos_emitter),
_ => map::generate_try_map(
input,
name,
&quote! { <#ty as core::convert::TryInto<_>>::try_into },
&pos_emitter,
),
},
};

let reader_var = input.stream_ident_or(READER);

let rewind = (needs_rewind || input.magic().is_some()).then(|| {
quote! {
.or_else(#RESTORE_POSITION::<binrw::Error, _, _>(#reader_var, #POS))
}
});

let (set_pos, rewind) = super::get_rewind(input, &reader_var, pos_emitter);
quote! {
let #reader_var = #READER;
let #POS = #SEEK_TRAIT::stream_position(#reader_var)?;
#set_pos
(|| {
#inner
})()#rewind
Expand All @@ -70,15 +59,17 @@ struct PreludeGenerator<'input> {
input: &'input Input,
reader_var: TokenStream,
out: TokenStream,
pos_emitter: &'input PosEmitter,
}

impl<'input> PreludeGenerator<'input> {
fn new(input: &'input Input) -> Self {
fn new(input: &'input Input, pos_emitter: &'input PosEmitter) -> Self {
let reader_var = input.stream_ident_or(READER);
Self {
input,
reader_var,
out: TokenStream::new(),
pos_emitter,
}
}

Expand Down Expand Up @@ -111,7 +102,7 @@ impl<'input> PreludeGenerator<'input> {
fn add_magic_pre_assertion(mut self) -> Self {
let head = self.out;
let magic = get_magic(self.input.magic(), &self.reader_var, OPT);
let pre_assertions = get_assertions(self.input.pre_assertions());
let pre_assertions = get_assertions(self.pos_emitter, self.input.pre_assertions());
self.out = quote! {
#head
#magic
Expand Down

0 comments on commit 13ae99e

Please sign in to comment.