Skip to content

Commit

Permalink
Call visit_item_fn_mut in a custom
Browse files Browse the repository at this point in the history
visit_item_fn_mut can cause wired behavior
because it will also visit inner functions.
A clean solution to #189 is to take the responsibility
of the two extraction separated.
Also we removed the inner visit_item_fn_mut for
all others extractions because are useless and can cause
also wired bugs
  • Loading branch information
la10736 committed Apr 8, 2023
1 parent d891ee3 commit 530a8c4
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 33 deletions.
4 changes: 2 additions & 2 deletions rstest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ default = ["async-timeout"]
[dependencies]
futures = {version = "0.3.21", optional = true}
futures-timer = {version = "3.0.2", optional = true}
rstest_macros = {version = "0.17.0", path = "../rstest_macros", default-features = false}
rstest_macros = {version = "0.18.0", path = "../rstest_macros", default-features = false}

[dev-dependencies]
actix-rt = "2.7.0"
async-std = {version = "1.12.0", features = ["attributes"]}
lazy_static = "1.4.0"
mytest = {package = "rstest", version = "0.16.0", default-features = false}
mytest = {package = "rstest", version = "0.17.0", default-features = false}
pretty_assertions = "1.2.1"
rstest_reuse = {version = "0.5.0", path = "../rstest_reuse"}
rstest_test = {version = "0.11.0", path = "../rstest_test"}
Expand Down
2 changes: 1 addition & 1 deletion rstest_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords = ["test", "fixture"]
license = "MIT/Apache-2.0"
name = "rstest_macros"
repository = "https://github.com/la10736/rstest"
version = "0.17.0"
version = "0.18.0"

[lib]
proc-macro = true
Expand Down
9 changes: 5 additions & 4 deletions rstest_macros/src/parse/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use syn::{

use super::{
arguments::ArgumentsInfo, extract_argument_attrs, extract_default_return_type,
extract_defaults, extract_fixtures, extract_partials_return_type, future::extract_futures,
extract_defaults, extract_fixtures, extract_partials_return_type, future::{extract_futures, extract_global_awt},
parse_vector_trailing_till_double_comma, Attributes, ExtendWithFunctionAttrs, Fixture,
};
use crate::{error::ErrorsVec, parse::extract_once, refident::RefIdent, utils::attr_is};
Expand Down Expand Up @@ -57,14 +57,16 @@ impl ExtendWithFunctionAttrs for FixtureInfo {
default_return_type,
partials_return_type,
once,
futures
futures,
global_awt
) = merge_errors!(
extract_fixtures(item_fn),
extract_defaults(item_fn),
extract_default_return_type(item_fn),
extract_partials_return_type(item_fn),
extract_once(item_fn),
extract_futures(item_fn)
extract_futures(item_fn),
extract_global_awt(item_fn)
)?;
self.data.items.extend(
fixtures
Expand All @@ -79,7 +81,6 @@ impl ExtendWithFunctionAttrs for FixtureInfo {
self.attributes.set_partial_return_type(id, return_type);
}
self.arguments.set_once(once);
let (futures, global_awt) = futures;
self.arguments.set_global_await(global_awt);
self.arguments.set_futures(futures.into_iter());
Ok(())
Expand Down
55 changes: 41 additions & 14 deletions rstest_macros/src/parse/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ use super::{arguments::FutureArg, extract_argument_attrs};

pub(crate) fn extract_futures(
item_fn: &mut ItemFn,
) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> {
) -> Result<Vec<(Ident, FutureArg)>, ErrorsVec> {
let mut extractor = FutureFunctionExtractor::default();
extractor.visit_item_fn_mut(item_fn);
extractor.take()
}

pub(crate) fn extract_global_awt(item_fn: &mut ItemFn) -> Result<bool, ErrorsVec> {
let mut extractor = GlobalAwtExtractor::default();
extractor.visit_item_fn_mut(item_fn);
extractor.take()
}

pub(crate) trait MaybeFutureImplType {
fn as_future_impl_type(&self) -> Option<&Type>;

Expand Down Expand Up @@ -50,26 +56,24 @@ fn can_impl_future(ty: &Type) -> bool {
)
}

/// Simple struct used to visit function attributes and extract future args to
/// implement the boilerplate.
/// Simple struct used to visit function attributes and extract global awt.
#[derive(Default)]
struct FutureFunctionExtractor {
futures: Vec<(Ident, FutureArg)>,
struct GlobalAwtExtractor {
awt: bool,
errors: Vec<syn::Error>,
}

impl FutureFunctionExtractor {
pub(crate) fn take(self) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> {
impl GlobalAwtExtractor {
pub(crate) fn take(self) -> Result<bool, ErrorsVec> {
if self.errors.is_empty() {
Ok((self.futures, self.awt))
Ok(self.awt)
} else {
Err(self.errors.into())
}
}
}

impl VisitMut for FutureFunctionExtractor {
impl VisitMut for GlobalAwtExtractor {
fn visit_item_fn_mut(&mut self, node: &mut ItemFn) {
let attrs = std::mem::take(&mut node.attrs);
let (awts, remain): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|a| attr_is(a, "awt"));
Expand All @@ -87,9 +91,28 @@ impl VisitMut for FutureFunctionExtractor {
std::cmp::Ordering::Less => false,
};
node.attrs = remain;
syn::visit_mut::visit_item_fn_mut(self, node);
}
}

/// Simple struct used to visit function attributes and extract future args to
/// implement the boilerplate.
#[derive(Default)]
struct FutureFunctionExtractor {
futures: Vec<(Ident, FutureArg)>,
errors: Vec<syn::Error>,
}

impl FutureFunctionExtractor {
pub(crate) fn take(self) -> Result<Vec<(Ident, FutureArg)>, ErrorsVec> {
if self.errors.is_empty() {
Ok(self.futures)
} else {
Err(self.errors.into())
}
}
}

impl VisitMut for FutureFunctionExtractor {
fn visit_fn_arg_mut(&mut self, node: &mut FnArg) {
if matches!(node, FnArg::Receiver(_)) {
return;
Expand Down Expand Up @@ -158,7 +181,8 @@ mod should {
let mut item_fn: ItemFn = item_fn.ast();
let orig = item_fn.clone();

let (futures, awt) = extract_futures(&mut item_fn).unwrap();
let composed_tuple!(futures, awt) =
merge_errors!(extract_futures(&mut item_fn), extract_global_awt(&mut item_fn)).unwrap();

assert_eq!(orig, item_fn);
assert!(futures.is_empty());
Expand All @@ -168,6 +192,7 @@ mod should {
#[rstest]
#[case::simple("fn f(#[future] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Define)], false)]
#[case::global_awt("#[awt] fn f(a: u32) {}", "fn f(a: u32) {}", &[], true)]
#[case::global_awt_with_inner_function("#[awt] fn f(a: u32) { fn g(){} }", "fn f(a: u32) { fn g(){} }", &[], true)]
#[case::simple_awaited("fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], false)]
#[case::simple_awaited_and_global("#[awt] fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], true)]
#[case::more_than_one(
Expand Down Expand Up @@ -199,7 +224,8 @@ mod should {
let mut item_fn: ItemFn = item_fn.ast();
let expected: ItemFn = expected.ast();

let (futures, awt) = extract_futures(&mut item_fn).unwrap();
let composed_tuple!(futures, awt) =
merge_errors!(extract_futures(&mut item_fn), extract_global_awt(&mut item_fn)).unwrap();

assert_eq!(expected, item_fn);
assert_eq!(
Expand Down Expand Up @@ -239,7 +265,7 @@ mod should {
let mut item_fn: ItemFn = item_fn.ast();
let expected: ItemFn = expected.ast();

let _ = extract_futures(&mut item_fn);
let _ = extract_global_awt(&mut item_fn);

assert_eq!(item_fn, expected);
}
Expand All @@ -253,7 +279,8 @@ mod should {
fn raise_error(#[case] item_fn: &str, #[case] message: &str) {
let mut item_fn: ItemFn = item_fn.ast();

let err = extract_futures(&mut item_fn).unwrap_err();
let err =
merge_errors!(extract_futures(&mut item_fn), extract_global_awt(&mut item_fn)).unwrap_err();

assert_in!(format!("{:?}", err), message);
}
Expand Down
11 changes: 3 additions & 8 deletions rstest_macros/src/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ impl VisitMut for DefaultTypeFunctionExtractor {
} else {
Ok(data)
};

syn::visit_mut::visit_item_fn_mut(self, node);
}
}

Expand Down Expand Up @@ -354,8 +352,6 @@ impl VisitMut for PartialsTypeFunctionExtractor {
} else {
Ok(data)
};

syn::visit_mut::visit_item_fn_mut(self, node);
}
}

Expand Down Expand Up @@ -392,7 +388,6 @@ impl VisitMut for IsOnceAttributeFunctionExtractor {
.collect::<Vec<_>>()
.into()),
};
syn::visit_mut::visit_item_fn_mut(self, node);
}
}

Expand Down Expand Up @@ -439,7 +434,8 @@ impl VisitMut for CasesFunctionExtractor {
if attr_starts_with(&attr, &case) {
match attr.parse_args::<Expressions>() {
Ok(expressions) => {
let description = attr.path().segments.iter().nth(1).map(|p| p.ident.clone());
let description =
attr.path().segments.iter().nth(1).map(|p| p.ident.clone());
self.0.push(TestCase {
args: expressions.into(),
attrs: std::mem::take(&mut attrs_buffer),
Expand All @@ -453,7 +449,6 @@ impl VisitMut for CasesFunctionExtractor {
}
}
node.attrs = std::mem::take(&mut attrs_buffer);
syn::visit_mut::visit_item_fn_mut(self, node);
}
}

Expand Down Expand Up @@ -674,7 +669,7 @@ pub(crate) mod arguments {
pub(crate) struct ArgumentsInfo {
args: HashMap<Ident, ArgumentInfo>,
is_global_await: bool,
once: Option<Ident>
once: Option<Ident>,
}

impl ArgumentsInfo {
Expand Down
8 changes: 4 additions & 4 deletions rstest_macros/src/parse/rstest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use syn::{
Ident, ItemFn, Token,
};

use super::testcase::TestCase;
use super::{testcase::TestCase, future::extract_global_awt};
use super::{
arguments::ArgumentsInfo, check_timeout_attrs, extract_case_args, extract_cases,
extract_excluded_trace, extract_fixtures, extract_value_list, future::extract_futures,
Expand Down Expand Up @@ -44,13 +44,13 @@ impl Parse for RsTestInfo {

impl ExtendWithFunctionAttrs for RsTestInfo {
fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> {
let composed_tuple!(_inner, excluded, _timeout, futures) = merge_errors!(
let composed_tuple!(_inner, excluded, _timeout, futures, global_awt) = merge_errors!(
self.data.extend_with_function_attrs(item_fn),
extract_excluded_trace(item_fn),
check_timeout_attrs(item_fn),
extract_futures(item_fn)
extract_futures(item_fn),
extract_global_awt(item_fn)
)?;
let (futures, global_awt) = futures;
self.attributes.add_notraces(excluded);
self.arguments.set_global_await(global_awt);
self.arguments.set_futures(futures.into_iter());
Expand Down

0 comments on commit 530a8c4

Please sign in to comment.