Skip to content

Commit 0f24d66

Browse files
committed
pyclass macro to recognize Py/PyRef pattern
1 parent fd02825 commit 0f24d66

File tree

5 files changed

+120
-46
lines changed

5 files changed

+120
-46
lines changed

derive-impl/src/pyclass.rs

Lines changed: 114 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,39 +105,114 @@ pub(crate) fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream
105105
Item::Impl(mut imp) => {
106106
extract_items_into_context(&mut context, imp.items.iter_mut());
107107

108-
let ty = &imp.self_ty;
108+
let (impl_ty, payload_guess) = match imp.self_ty.as_ref() {
109+
syn::Type::Path(syn::TypePath {
110+
path: syn::Path { segments, .. },
111+
..
112+
}) if segments.len() == 1 => {
113+
let segment = &segments[0];
114+
let payload_ty = if segment.ident == "Py" || segment.ident == "PyRef" {
115+
match &segment.arguments {
116+
syn::PathArguments::AngleBracketed(
117+
syn::AngleBracketedGenericArguments { args, .. },
118+
) if args.len() == 1 => {
119+
let arg = &args[0];
120+
match arg {
121+
syn::GenericArgument::Type(syn::Type::Path(
122+
syn::TypePath {
123+
path: syn::Path { segments, .. },
124+
..
125+
},
126+
)) if segments.len() == 1 => segments[0].ident.clone(),
127+
_ => {
128+
return Err(syn::Error::new_spanned(
129+
segment,
130+
"Py{Ref}<T> is expected but Py{Ref}<?> is found",
131+
))
132+
}
133+
}
134+
}
135+
_ => {
136+
return Err(syn::Error::new_spanned(
137+
segment,
138+
"Py{Ref}<T> is expected but Py{Ref}? is found",
139+
))
140+
}
141+
}
142+
} else {
143+
if !matches!(segment.arguments, syn::PathArguments::None) {
144+
return Err(syn::Error::new_spanned(
145+
segment,
146+
"PyImpl can only be implemented for Py{Ref}<T> or T",
147+
));
148+
}
149+
segment.ident.clone()
150+
};
151+
(segment.ident.clone(), payload_ty)
152+
}
153+
_ => {
154+
return Err(syn::Error::new_spanned(
155+
imp.self_ty,
156+
"PyImpl can only be implemented for Py{Ref}<T> or T",
157+
))
158+
}
159+
};
160+
109161
let ExtractedImplAttrs {
162+
payload: attr_payload,
110163
with_impl,
111164
flags,
112165
with_slots,
113-
} = extract_impl_attrs(attr, &Ident::new(&quote!(ty).to_string(), ty.span()))?;
114-
166+
} = extract_impl_attrs(attr, &impl_ty)?;
167+
let payload_ty = attr_payload.unwrap_or(payload_guess);
115168
let getset_impl = &context.getset_items;
116169
let member_impl = &context.member_items;
117170
let extend_impl = context.impl_extend_items.validate()?;
118171
let slots_impl = context.extend_slots_items.validate()?;
119172
let class_extensions = &context.class_extensions;
120-
quote! {
121-
#imp
122-
impl ::rustpython_vm::class::PyClassImpl for #ty {
123-
const TP_FLAGS: ::rustpython_vm::types::PyTypeFlags = #flags;
124173

125-
fn impl_extend_class(
174+
let extra_methods = iter_chain![
175+
parse_quote! {
176+
fn __extend_py_class(
126177
ctx: &::rustpython_vm::Context,
127178
class: &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType>,
128179
) {
129180
#getset_impl
130181
#member_impl
131182
#extend_impl
132-
#with_impl
133183
#(#class_extensions)*
134184
}
135-
136-
fn extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
137-
#with_slots
185+
},
186+
parse_quote! {
187+
fn __extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
138188
#slots_impl
139189
}
190+
},
191+
];
192+
imp.items.extend(extra_methods);
193+
let is_main_impl = impl_ty == payload_ty;
194+
if is_main_impl {
195+
quote! {
196+
#imp
197+
impl ::rustpython_vm::class::PyClassImpl for #payload_ty {
198+
const TP_FLAGS: ::rustpython_vm::types::PyTypeFlags = #flags;
199+
200+
fn impl_extend_class(
201+
ctx: &::rustpython_vm::Context,
202+
class: &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType>,
203+
) {
204+
#impl_ty::__extend_py_class(ctx, class);
205+
#with_impl
206+
}
207+
208+
fn extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
209+
#impl_ty::__extend_slots(slots);
210+
#with_slots
211+
}
212+
}
140213
}
214+
} else {
215+
imp.into_token_stream()
141216
}
142217
}
143218
Item::Trait(mut trai) => {
@@ -1163,6 +1238,7 @@ impl MemberItemMeta {
11631238
}
11641239

11651240
struct ExtractedImplAttrs {
1241+
payload: Option<Ident>,
11661242
with_impl: TokenStream,
11671243
with_slots: TokenStream,
11681244
flags: TokenStream,
@@ -1182,6 +1258,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
11821258
}
11831259
}
11841260
}];
1261+
let mut payload = None;
11851262

11861263
for attr in attr {
11871264
match attr {
@@ -1191,18 +1268,19 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
11911268
let NestedMeta::Meta(Meta::Path(path)) = meta else {
11921269
bail_span!(meta, "#[pyclass(with(...))] arguments should be paths")
11931270
};
1194-
let (extend_class, extend_slots) = if path.is_ident("PyRef") {
1195-
// special handling for PyRef
1196-
(
1197-
quote!(PyRef::<Self>::impl_extend_class),
1198-
quote!(PyRef::<Self>::extend_slots),
1199-
)
1200-
} else {
1201-
(
1202-
quote!(<Self as #path>::__extend_py_class),
1203-
quote!(<Self as #path>::__extend_slots),
1204-
)
1205-
};
1271+
let (extend_class, extend_slots) =
1272+
if path.is_ident("PyRef") || path.is_ident("Py") {
1273+
// special handling for PyRef
1274+
(
1275+
quote!(#path::<Self>::__extend_py_class),
1276+
quote!(#path::<Self>::__extend_slots),
1277+
)
1278+
} else {
1279+
(
1280+
quote!(<Self as #path>::__extend_py_class),
1281+
quote!(<Self as #path>::__extend_slots),
1282+
)
1283+
};
12061284
let item_span = item.span().resolved_at(Span::call_site());
12071285
withs.push(quote_spanned! { path.span() =>
12081286
#extend_class(ctx, class);
@@ -1227,11 +1305,23 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
12271305
bail_span!(path, "Unknown pyimpl attribute")
12281306
}
12291307
}
1308+
NestedMeta::Meta(Meta::NameValue(syn::MetaNameValue { path, lit, .. })) => {
1309+
if path.is_ident("payload") {
1310+
if let syn::Lit::Str(lit) = lit {
1311+
payload = Some(Ident::new(&lit.value(), lit.span()));
1312+
} else {
1313+
bail_span!(lit, "payload must be a string literal")
1314+
}
1315+
} else {
1316+
bail_span!(path, "Unknown pyimpl attribute")
1317+
}
1318+
}
12301319
attr => bail_span!(attr, "Unknown pyimpl attribute"),
12311320
}
12321321
}
12331322

12341323
Ok(ExtractedImplAttrs {
1324+
payload,
12351325
with_impl: quote! {
12361326
#(#withs)*
12371327
},

vm/src/builtins/code.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,5 +440,5 @@ impl ToPyObject for bytecode::CodeObject {
440440
}
441441

442442
pub fn init(ctx: &Context) {
443-
PyRef::<PyCode>::extend_class(ctx, ctx.types.code_type);
443+
PyCode::extend_class(ctx, ctx.types.code_type);
444444
}

vm/src/builtins/frame.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
use num_traits::Zero;
1414

1515
pub fn init(context: &Context) {
16-
FrameRef::extend_class(context, context.types.frame_type);
16+
Frame::extend_class(context, context.types.frame_type);
1717
}
1818

1919
#[pyclass(with(Constructor, PyRef, Representable))]
@@ -35,7 +35,7 @@ impl Representable for Frame {
3535
}
3636

3737
#[pyclass]
38-
impl FrameRef {
38+
impl PyRef<Frame> {
3939
#[pymethod]
4040
fn clear(self) {
4141
// TODO

vm/src/class.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::{
44
builtins::{PyBaseObject, PyBoundMethod, PyType, PyTypeRef},
55
identifier,
6-
object::{Py, PyObjectPayload, PyObjectRef, PyRef},
6+
object::{Py, PyObjectRef},
77
types::{hash_not_implemented, PyTypeFlags, PyTypeSlots},
88
vm::Context,
99
};
@@ -63,18 +63,6 @@ pub trait PyClassDef {
6363
const UNHASHABLE: bool = false;
6464
}
6565

66-
impl<T> PyClassDef for PyRef<T>
67-
where
68-
T: PyObjectPayload + PyClassDef,
69-
{
70-
const NAME: &'static str = T::NAME;
71-
const MODULE_NAME: Option<&'static str> = T::MODULE_NAME;
72-
const TP_NAME: &'static str = T::TP_NAME;
73-
const DOC: Option<&'static str> = T::DOC;
74-
const BASICSIZE: usize = T::BASICSIZE;
75-
const UNHASHABLE: bool = T::UNHASHABLE;
76-
}
77-
7866
pub trait PyClassImpl: PyClassDef {
7967
const TP_FLAGS: PyTypeFlags = PyTypeFlags::DEFAULT;
8068

vm/src/stdlib/io.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3065,8 +3065,6 @@ mod _io {
30653065
closed: AtomicCell<bool>,
30663066
}
30673067

3068-
type StringIORef = PyRef<StringIO>;
3069-
30703068
#[derive(FromArgs)]
30713069
struct StringIONewArgs {
30723070
#[pyarg(positional, optional)]
@@ -3135,7 +3133,7 @@ mod _io {
31353133
}
31363134

31373135
#[pyclass]
3138-
impl StringIORef {
3136+
impl PyRef<StringIO> {
31393137
//write string to underlying vector
31403138
#[pymethod]
31413139
fn write(self, data: PyStrRef, vm: &VirtualMachine) -> PyResult {
@@ -3216,8 +3214,6 @@ mod _io {
32163214
exports: AtomicCell<usize>,
32173215
}
32183216

3219-
type BytesIORef = PyRef<BytesIO>;
3220-
32213217
impl Constructor for BytesIO {
32223218
type Args = OptionalArg<Option<PyBytesRef>>;
32233219

@@ -3261,7 +3257,7 @@ mod _io {
32613257
}
32623258

32633259
#[pyclass]
3264-
impl BytesIORef {
3260+
impl PyRef<BytesIO> {
32653261
#[pymethod]
32663262
fn write(self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<u64> {
32673263
let mut buffer = self.try_resizable(vm)?;

0 commit comments

Comments
 (0)