@@ -105,39 +105,114 @@ pub(crate) fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream
105105 Item :: Impl ( mut imp) => {
106106 extract_items_into_context ( & mut context, imp. items . iter_mut ( ) ) ;
107107
108- let ty = & imp. self_ty ;
108+ let ( impl_ty, payload_guess) = match imp. self_ty . as_ref ( ) {
109+ syn:: Type :: Path ( syn:: TypePath {
110+ path : syn:: Path { segments, .. } ,
111+ ..
112+ } ) if segments. len ( ) == 1 => {
113+ let segment = & segments[ 0 ] ;
114+ let payload_ty = if segment. ident == "Py" || segment. ident == "PyRef" {
115+ match & segment. arguments {
116+ syn:: PathArguments :: AngleBracketed (
117+ syn:: AngleBracketedGenericArguments { args, .. } ,
118+ ) if args. len ( ) == 1 => {
119+ let arg = & args[ 0 ] ;
120+ match arg {
121+ syn:: GenericArgument :: Type ( syn:: Type :: Path (
122+ syn:: TypePath {
123+ path : syn:: Path { segments, .. } ,
124+ ..
125+ } ,
126+ ) ) if segments. len ( ) == 1 => segments[ 0 ] . ident . clone ( ) ,
127+ _ => {
128+ return Err ( syn:: Error :: new_spanned (
129+ segment,
130+ "Py{Ref}<T> is expected but Py{Ref}<?> is found" ,
131+ ) )
132+ }
133+ }
134+ }
135+ _ => {
136+ return Err ( syn:: Error :: new_spanned (
137+ segment,
138+ "Py{Ref}<T> is expected but Py{Ref}? is found" ,
139+ ) )
140+ }
141+ }
142+ } else {
143+ if !matches ! ( segment. arguments, syn:: PathArguments :: None ) {
144+ return Err ( syn:: Error :: new_spanned (
145+ segment,
146+ "PyImpl can only be implemented for Py{Ref}<T> or T" ,
147+ ) ) ;
148+ }
149+ segment. ident . clone ( )
150+ } ;
151+ ( segment. ident . clone ( ) , payload_ty)
152+ }
153+ _ => {
154+ return Err ( syn:: Error :: new_spanned (
155+ imp. self_ty ,
156+ "PyImpl can only be implemented for Py{Ref}<T> or T" ,
157+ ) )
158+ }
159+ } ;
160+
109161 let ExtractedImplAttrs {
162+ payload : attr_payload,
110163 with_impl,
111164 flags,
112165 with_slots,
113- } = extract_impl_attrs ( attr, & Ident :: new ( & quote ! ( ty ) . to_string ( ) , ty . span ( ) ) ) ?;
114-
166+ } = extract_impl_attrs ( attr, & impl_ty ) ?;
167+ let payload_ty = attr_payload . unwrap_or ( payload_guess ) ;
115168 let getset_impl = & context. getset_items ;
116169 let member_impl = & context. member_items ;
117170 let extend_impl = context. impl_extend_items . validate ( ) ?;
118171 let slots_impl = context. extend_slots_items . validate ( ) ?;
119172 let class_extensions = & context. class_extensions ;
120- quote ! {
121- #imp
122- impl :: rustpython_vm:: class:: PyClassImpl for #ty {
123- const TP_FLAGS : :: rustpython_vm:: types:: PyTypeFlags = #flags;
124173
125- fn impl_extend_class(
174+ let extra_methods = iter_chain ! [
175+ parse_quote! {
176+ fn __extend_py_class(
126177 ctx: & :: rustpython_vm:: Context ,
127178 class: & ' static :: rustpython_vm:: Py <:: rustpython_vm:: builtins:: PyType >,
128179 ) {
129180 #getset_impl
130181 #member_impl
131182 #extend_impl
132- #with_impl
133183 #( #class_extensions) *
134184 }
135-
136- fn extend_slots ( slots : & mut :: rustpython_vm :: types :: PyTypeSlots ) {
137- #with_slots
185+ } ,
186+ parse_quote! {
187+ fn __extend_slots ( slots : & mut :: rustpython_vm :: types :: PyTypeSlots ) {
138188 #slots_impl
139189 }
190+ } ,
191+ ] ;
192+ imp. items . extend ( extra_methods) ;
193+ let is_main_impl = impl_ty == payload_ty;
194+ if is_main_impl {
195+ quote ! {
196+ #imp
197+ impl :: rustpython_vm:: class:: PyClassImpl for #payload_ty {
198+ const TP_FLAGS : :: rustpython_vm:: types:: PyTypeFlags = #flags;
199+
200+ fn impl_extend_class(
201+ ctx: & :: rustpython_vm:: Context ,
202+ class: & ' static :: rustpython_vm:: Py <:: rustpython_vm:: builtins:: PyType >,
203+ ) {
204+ #impl_ty:: __extend_py_class( ctx, class) ;
205+ #with_impl
206+ }
207+
208+ fn extend_slots( slots: & mut :: rustpython_vm:: types:: PyTypeSlots ) {
209+ #impl_ty:: __extend_slots( slots) ;
210+ #with_slots
211+ }
212+ }
140213 }
214+ } else {
215+ imp. into_token_stream ( )
141216 }
142217 }
143218 Item :: Trait ( mut trai) => {
@@ -1163,6 +1238,7 @@ impl MemberItemMeta {
11631238}
11641239
11651240struct ExtractedImplAttrs {
1241+ payload : Option < Ident > ,
11661242 with_impl : TokenStream ,
11671243 with_slots : TokenStream ,
11681244 flags : TokenStream ,
@@ -1182,6 +1258,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
11821258 }
11831259 }
11841260 } ] ;
1261+ let mut payload = None ;
11851262
11861263 for attr in attr {
11871264 match attr {
@@ -1191,18 +1268,19 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
11911268 let NestedMeta :: Meta ( Meta :: Path ( path) ) = meta else {
11921269 bail_span ! ( meta, "#[pyclass(with(...))] arguments should be paths" )
11931270 } ;
1194- let ( extend_class, extend_slots) = if path. is_ident ( "PyRef" ) {
1195- // special handling for PyRef
1196- (
1197- quote ! ( PyRef :: <Self >:: impl_extend_class) ,
1198- quote ! ( PyRef :: <Self >:: extend_slots) ,
1199- )
1200- } else {
1201- (
1202- quote ! ( <Self as #path>:: __extend_py_class) ,
1203- quote ! ( <Self as #path>:: __extend_slots) ,
1204- )
1205- } ;
1271+ let ( extend_class, extend_slots) =
1272+ if path. is_ident ( "PyRef" ) || path. is_ident ( "Py" ) {
1273+ // special handling for PyRef
1274+ (
1275+ quote ! ( #path:: <Self >:: __extend_py_class) ,
1276+ quote ! ( #path:: <Self >:: __extend_slots) ,
1277+ )
1278+ } else {
1279+ (
1280+ quote ! ( <Self as #path>:: __extend_py_class) ,
1281+ quote ! ( <Self as #path>:: __extend_slots) ,
1282+ )
1283+ } ;
12061284 let item_span = item. span ( ) . resolved_at ( Span :: call_site ( ) ) ;
12071285 withs. push ( quote_spanned ! { path. span( ) =>
12081286 #extend_class( ctx, class) ;
@@ -1227,11 +1305,23 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
12271305 bail_span ! ( path, "Unknown pyimpl attribute" )
12281306 }
12291307 }
1308+ NestedMeta :: Meta ( Meta :: NameValue ( syn:: MetaNameValue { path, lit, .. } ) ) => {
1309+ if path. is_ident ( "payload" ) {
1310+ if let syn:: Lit :: Str ( lit) = lit {
1311+ payload = Some ( Ident :: new ( & lit. value ( ) , lit. span ( ) ) ) ;
1312+ } else {
1313+ bail_span ! ( lit, "payload must be a string literal" )
1314+ }
1315+ } else {
1316+ bail_span ! ( path, "Unknown pyimpl attribute" )
1317+ }
1318+ }
12301319 attr => bail_span ! ( attr, "Unknown pyimpl attribute" ) ,
12311320 }
12321321 }
12331322
12341323 Ok ( ExtractedImplAttrs {
1324+ payload,
12351325 with_impl : quote ! {
12361326 #( #withs) *
12371327 } ,
0 commit comments