Skip to content

Commit 94bdb6b

Browse files
authored
feature: PyTraverse derive macro for traverse object's childrens(like CPython's tp_traverse) (RustPython#4872)
1 parent 2c90b12 commit 94bdb6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+940
-120
lines changed

derive-impl/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod pyclass;
1818
mod pymodule;
1919
mod pypayload;
2020
mod pystructseq;
21+
mod pytraverse;
2122

2223
use error::{extract_spans, Diagnostic};
2324
use proc_macro2::TokenStream;
@@ -77,3 +78,7 @@ pub fn py_freeze(input: TokenStream, compiler: &dyn Compiler) -> TokenStream {
7778
pub fn pypayload(input: DeriveInput) -> TokenStream {
7879
result_to_tokens(pypayload::impl_pypayload(input))
7980
}
81+
82+
pub fn pytraverse(item: DeriveInput) -> TokenStream {
83+
result_to_tokens(pytraverse::impl_pytraverse(item))
84+
}

derive-impl/src/pyclass.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,59 @@ pub(crate) fn impl_pyclass(attr: AttributeArgs, item: Item) -> Result<TokenStrea
413413
attrs,
414414
)?;
415415

416+
const ALLOWED_TRAVERSE_OPTS: &[&str] = &["manual"];
417+
// try to know if it have a `#[pyclass(trace)]` exist on this struct
418+
// TODO(discord9): rethink on auto detect `#[Derive(PyTrace)]`
419+
420+
// 1. no `traverse` at all: generate a dummy try_traverse
421+
// 2. `traverse = "manual"`: generate a try_traverse, but not #[derive(Traverse)]
422+
// 3. `traverse`: generate a try_traverse, and #[derive(Traverse)]
423+
let (maybe_trace_code, derive_trace) = {
424+
if class_meta.inner()._has_key("traverse")? {
425+
let maybe_trace_code = quote! {
426+
impl ::rustpython_vm::object::MaybeTraverse for #ident {
427+
const IS_TRACE: bool = true;
428+
fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
429+
::rustpython_vm::object::Traverse::traverse(self, tracer_fn);
430+
}
431+
}
432+
};
433+
// if the key `traverse` exist but not as key-value, _optional_str return Err(...)
434+
// so we need to check if it is Ok(Some(...))
435+
let value = class_meta.inner()._optional_str("traverse");
436+
let derive_trace = if let Ok(Some(s)) = value {
437+
if !ALLOWED_TRAVERSE_OPTS.contains(&s.as_str()) {
438+
bail_span!(
439+
item,
440+
"traverse attribute only accept {ALLOWED_TRAVERSE_OPTS:?} as value or no value at all",
441+
);
442+
}
443+
assert_eq!(s, "manual");
444+
quote! {}
445+
} else {
446+
quote! {#[derive(Traverse)]}
447+
};
448+
(maybe_trace_code, derive_trace)
449+
} else {
450+
(
451+
// a dummy impl, which do nothing
452+
// #attrs
453+
quote! {
454+
impl ::rustpython_vm::object::MaybeTraverse for #ident {
455+
fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
456+
// do nothing
457+
}
458+
}
459+
},
460+
quote! {},
461+
)
462+
}
463+
};
464+
416465
let ret = quote! {
466+
#derive_trace
417467
#item
468+
#maybe_trace_code
418469
#class_def
419470
};
420471
Ok(ret)

derive-impl/src/pytraverse.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
use proc_macro2::TokenStream;
2+
use quote::quote;
3+
use syn::{Attribute, DeriveInput, Field, Meta, MetaList, NestedMeta, Result};
4+
5+
struct TraverseAttr {
6+
/// set to `true` if the attribute is `#[pytraverse(skip)]`
7+
skip: bool,
8+
}
9+
10+
const ATTR_TRAVERSE: &str = "pytraverse";
11+
12+
/// get the `#[pytraverse(..)]` attribute from the struct
13+
fn valid_get_traverse_attr_from_meta_list(list: &MetaList) -> Result<TraverseAttr> {
14+
let find_skip_and_only_skip = || {
15+
let len = list.nested.len();
16+
if len != 1 {
17+
return None;
18+
}
19+
let mut iter = list.nested.iter();
20+
// we have checked the length, so unwrap is safe
21+
let first_arg = iter.next().unwrap();
22+
let skip = match first_arg {
23+
NestedMeta::Meta(Meta::Path(path)) => match path.is_ident("skip") {
24+
true => true,
25+
false => return None,
26+
},
27+
_ => return None,
28+
};
29+
Some(skip)
30+
};
31+
let skip = find_skip_and_only_skip().ok_or_else(|| {
32+
err_span!(
33+
list,
34+
"only support attr is #[pytraverse(skip)], got arguments: {:?}",
35+
list.nested
36+
)
37+
})?;
38+
Ok(TraverseAttr { skip })
39+
}
40+
41+
/// only accept `#[pytraverse(skip)]` for now
42+
fn pytraverse_arg(attr: &Attribute) -> Option<Result<TraverseAttr>> {
43+
if !attr.path.is_ident(ATTR_TRAVERSE) {
44+
return None;
45+
}
46+
let ret = || {
47+
let parsed = attr.parse_meta()?;
48+
if let Meta::List(list) = parsed {
49+
valid_get_traverse_attr_from_meta_list(&list)
50+
} else {
51+
bail_span!(attr, "pytraverse must be a list, like #[pytraverse(skip)]")
52+
}
53+
};
54+
Some(ret())
55+
}
56+
57+
fn field_to_traverse_code(field: &Field) -> Result<TokenStream> {
58+
let pytraverse_attrs = field
59+
.attrs
60+
.iter()
61+
.filter_map(pytraverse_arg)
62+
.collect::<std::result::Result<Vec<_>, _>>()?;
63+
let do_trace = if pytraverse_attrs.len() > 1 {
64+
bail_span!(
65+
field,
66+
"found multiple #[pytraverse] attributes on the same field, expect at most one"
67+
)
68+
} else if pytraverse_attrs.is_empty() {
69+
// default to always traverse every field
70+
true
71+
} else {
72+
!pytraverse_attrs[0].skip
73+
};
74+
let name = field.ident.as_ref().ok_or_else(|| {
75+
syn::Error::new_spanned(
76+
field.clone(),
77+
"Field should have a name in non-tuple struct",
78+
)
79+
})?;
80+
if do_trace {
81+
Ok(quote!(
82+
::rustpython_vm::object::Traverse::traverse(&self.#name, tracer_fn);
83+
))
84+
} else {
85+
Ok(quote!())
86+
}
87+
}
88+
89+
/// not trace corresponding field
90+
fn gen_trace_code(item: &mut DeriveInput) -> Result<TokenStream> {
91+
match &mut item.data {
92+
syn::Data::Struct(s) => {
93+
let fields = &mut s.fields;
94+
match fields {
95+
syn::Fields::Named(ref mut fields) => {
96+
let res: Vec<TokenStream> = fields
97+
.named
98+
.iter_mut()
99+
.map(|f| -> Result<TokenStream> { field_to_traverse_code(f) })
100+
.collect::<Result<_>>()?;
101+
let res = res.into_iter().collect::<TokenStream>();
102+
Ok(res)
103+
}
104+
syn::Fields::Unnamed(fields) => {
105+
let res: TokenStream = (0..fields.unnamed.len())
106+
.map(|i| {
107+
let i = syn::Index::from(i);
108+
quote!(
109+
::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn);
110+
)
111+
})
112+
.collect();
113+
Ok(res)
114+
}
115+
_ => Err(syn::Error::new_spanned(
116+
fields,
117+
"Only named and unnamed fields are supported",
118+
)),
119+
}
120+
}
121+
_ => Err(syn::Error::new_spanned(item, "Only structs are supported")),
122+
}
123+
}
124+
125+
pub(crate) fn impl_pytraverse(mut item: DeriveInput) -> Result<TokenStream> {
126+
let trace_code = gen_trace_code(&mut item)?;
127+
128+
let ty = &item.ident;
129+
130+
let ret = quote! {
131+
unsafe impl ::rustpython_vm::object::Traverse for #ty {
132+
fn traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
133+
#trace_code
134+
}
135+
}
136+
};
137+
Ok(ret)
138+
}

derive-impl/src/util.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ impl ItemMetaInner {
178178
Ok(value)
179179
}
180180

181+
pub fn _has_key(&self, key: &str) -> Result<bool> {
182+
Ok(matches!(self.meta_map.get(key), Some((_, _))))
183+
}
184+
181185
pub fn _bool(&self, key: &str) -> Result<bool> {
182186
let value = if let Some((_, meta)) = self.meta_map.get(key) {
183187
match meta {
@@ -263,8 +267,14 @@ impl ItemMeta for AttrItemMeta {
263267
pub(crate) struct ClassItemMeta(ItemMetaInner);
264268

265269
impl ItemMeta for ClassItemMeta {
266-
const ALLOWED_NAMES: &'static [&'static str] =
267-
&["module", "name", "base", "metaclass", "unhashable"];
270+
const ALLOWED_NAMES: &'static [&'static str] = &[
271+
"module",
272+
"name",
273+
"base",
274+
"metaclass",
275+
"unhashable",
276+
"traverse",
277+
];
268278

269279
fn from_inner(inner: ItemMetaInner) -> Self {
270280
Self(inner)

derive/src/lib.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,27 @@ pub fn pypayload(input: TokenStream) -> TokenStream {
9191
let input = parse_macro_input!(input);
9292
derive_impl::pypayload(input).into()
9393
}
94+
95+
/// use on struct with named fields like `struct A{x:PyRef<B>, y:PyRef<C>}` to impl `Traverse` for datatype.
96+
///
97+
/// use `#[pytraverse(skip)]` on fields you wish not to trace
98+
///
99+
/// add `trace` attr to `#[pyclass]` to make it impl `MaybeTraverse` that will call `Traverse`'s `traverse` method so make it
100+
/// traceable(Even from type-erased PyObject)(i.e. write `#[pyclass(trace)]`).
101+
/// # Example
102+
/// ```rust, ignore
103+
/// #[pyclass(module = false, traverse)]
104+
/// #[derive(Default, Traverse)]
105+
/// pub struct PyList {
106+
/// elements: PyRwLock<Vec<PyObjectRef>>,
107+
/// #[pytraverse(skip)]
108+
/// len: AtomicCell<usize>,
109+
/// }
110+
/// ```
111+
/// This create both `MaybeTraverse` that call `Traverse`'s `traverse` method and `Traverse` that impl `Traverse`
112+
/// for `PyList` which call elements' `traverse` method and ignore `len` field.
113+
#[proc_macro_derive(Traverse, attributes(pytraverse))]
114+
pub fn pytraverse(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
115+
let item = parse_macro_input!(item);
116+
derive_impl::pytraverse(item).into()
117+
}

stdlib/src/array.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,7 @@ mod array {
13991399
}
14001400

14011401
#[pyattr]
1402-
#[pyclass(name = "arrayiterator")]
1402+
#[pyclass(name = "arrayiterator", traverse)]
14031403
#[derive(Debug, PyPayload)]
14041404
pub struct PyArrayIter {
14051405
internal: PyMutex<PositionIterInternal<PyArrayRef>>,
@@ -1434,12 +1434,13 @@ mod array {
14341434
}
14351435
}
14361436

1437-
#[derive(FromArgs)]
1437+
#[derive(FromArgs, Traverse)]
14381438
struct ReconstructorArgs {
14391439
#[pyarg(positional)]
14401440
arraytype: PyTypeRef,
14411441
#[pyarg(positional)]
14421442
typecode: PyStrRef,
1443+
#[pytraverse(skip)]
14431444
#[pyarg(positional)]
14441445
mformat_code: MachineFormatCode,
14451446
#[pyarg(positional)]

stdlib/src/bisect.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod _bisect {
88
PyObjectRef, PyResult, VirtualMachine,
99
};
1010

11-
#[derive(FromArgs)]
11+
#[derive(FromArgs, Traverse)]
1212
struct BisectArgs {
1313
a: PyObjectRef,
1414
x: PyObjectRef,

stdlib/src/contextvars.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ mod _contextvars {
8080
}
8181

8282
#[pyattr]
83-
#[pyclass(name)]
83+
#[pyclass(name, traverse)]
8484
#[derive(Debug, PyPayload)]
8585
struct ContextVar {
86+
#[pytraverse(skip)]
8687
#[allow(dead_code)] // TODO: RUSTPYTHON
8788
name: String,
8889
#[allow(dead_code)] // TODO: RUSTPYTHON
@@ -161,7 +162,7 @@ mod _contextvars {
161162
#[derive(Debug, PyPayload)]
162163
struct ContextToken {}
163164

164-
#[derive(FromArgs)]
165+
#[derive(FromArgs, Traverse)]
165166
struct ContextTokenOptions {
166167
#[pyarg(positional)]
167168
#[allow(dead_code)] // TODO: RUSTPYTHON

stdlib/src/csv.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,11 @@ mod _csv {
152152
reader: csv_core::Reader,
153153
}
154154

155-
#[pyclass(no_attr, module = "_csv", name = "reader")]
155+
#[pyclass(no_attr, module = "_csv", name = "reader", traverse)]
156156
#[derive(PyPayload)]
157157
pub(super) struct Reader {
158158
iter: PyIter,
159+
#[pytraverse(skip)]
159160
state: PyMutex<ReadState>,
160161
}
161162

@@ -242,10 +243,11 @@ mod _csv {
242243
writer: csv_core::Writer,
243244
}
244245

245-
#[pyclass(no_attr, module = "_csv", name = "writer")]
246+
#[pyclass(no_attr, module = "_csv", name = "writer", traverse)]
246247
#[derive(PyPayload)]
247248
pub(super) struct Writer {
248249
write: PyObjectRef,
250+
#[pytraverse(skip)]
249251
state: PyMutex<WriteState>,
250252
}
251253

stdlib/src/grp.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ mod grp {
1313
use std::ptr::NonNull;
1414

1515
#[pyattr]
16-
#[pyclass(module = "grp", name = "struct_group")]
16+
#[pyclass(module = "grp", name = "struct_group", traverse)]
1717
#[derive(PyStructSequence)]
1818
struct Group {
19+
#[pytraverse(skip)]
1920
gr_name: String,
21+
#[pytraverse(skip)]
2022
gr_passwd: String,
23+
#[pytraverse(skip)]
2124
gr_gid: u32,
2225
gr_mem: PyListRef,
2326
}

0 commit comments

Comments
 (0)