From d58e13112b46cb564e6568b20461b3ab751feb77 Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 17:09:26 +0900 Subject: [PATCH 1/6] Propagate impl info with thrust_macros::context --- thrust-macros/Cargo.toml | 2 +- thrust-macros/src/lib.rs | 88 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/thrust-macros/Cargo.toml b/thrust-macros/Cargo.toml index 4858a70..419cc79 100644 --- a/thrust-macros/Cargo.toml +++ b/thrust-macros/Cargo.toml @@ -10,4 +10,4 @@ proc-macro = true [dependencies] proc-macro2 = "1" quote = "1" -syn = { version = "2", features = ["full"] } +syn = { version = "2", features = ["full", "visit"] } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 0a42979..9978d11 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -6,6 +6,25 @@ use syn::{ Type, TypeParamBound, WherePredicate, }; +#[proc_macro_attribute] +pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut impl_item = syn::parse_macro_input!(item as syn::ItemImpl); + let impl_header = { + let mut header = impl_item.clone(); + header.items.clear(); + header + }; + for item in &mut impl_item.items { + let syn::ImplItem::Fn(item) = item else { + continue; + }; + // TODO: why ::thrust_macros doesn't work here? + item.attrs + .push(syn::parse_quote!(#[thrust::_impl_context(#impl_header)])); + } + impl_item.into_token_stream().into() +} + #[proc_macro_attribute] pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { let expr = TokenStream2::from(attr); @@ -97,9 +116,43 @@ pub fn _requires_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { let req_expr = exprs.pop().unwrap().into_value(); let func = parse_macro_input!(item as ItemFn); - ExpandedTokens::new(func, req_expr, ens_expr) - .into_token_stream() - .into() + let impl_context = match extract_impl_context(&func) { + Ok(ctx) => ctx, + Err(e) => return e.to_compile_error().into(), + }; + if mentions_self(&func.sig) && impl_context.is_none() { + let err = syn::Error::new_spanned( + func.sig.ident.clone(), + "Wrap impl block with #[thrust_macros::context] to use requires/ensures on methods", + ) + .to_compile_error(); + return quote! { #err #func }.into(); + } + let mut tokens = ExpandedTokens::new(func, req_expr, ens_expr); + if let Some(ctx) = impl_context { + tokens = tokens.with_impl_context(ctx); + } + tokens.into_token_stream().into() +} + +fn extract_impl_context(func: &syn::ItemFn) -> syn::Result> { + let impl_context_path: syn::Path = syn::parse_quote!(thrust::_impl_context); + let mut impl_context = None; + for attr in &func.attrs { + if attr.path() != &impl_context_path { + continue; + } + + let item = attr.parse_args()?; + if impl_context.is_some() { + return Err(syn::Error::new_spanned( + attr, + "multiple _impl_context attributes found; expected at most one", + )); + } + impl_context = Some(item); + } + Ok(impl_context) } struct ExpandedTokens { @@ -116,6 +169,8 @@ struct ExpandedTokens { model_ty_params: TokenStream2, ret_model_ty: Type, + + impl_context: Option, } impl quote::ToTokens for ExpandedTokens { @@ -157,9 +212,15 @@ impl ExpandedTokens { extended_where, model_ty_params, ret_model_ty, + impl_context: None, } } + pub fn with_impl_context(mut self, impl_item: syn::ItemImpl) -> Self { + self.impl_context = Some(impl_item); + self + } + fn is_extern_spec_fn(&self) -> bool { let extern_spec_fn_path: syn::Path = syn::parse_quote!(thrust::extern_spec_fn); self.func @@ -277,6 +338,27 @@ impl ExpandedTokens { } } +fn mentions_self(sig: &syn::Signature) -> bool { + struct Visitor { + mentions_self: bool, + } + + impl syn::visit::Visit<'_> for Visitor { + fn visit_ident(&mut self, i: &syn::Ident) { + if i == "self" || i == "Self" { + self.mentions_self = true; + } + } + } + + let mut visitor = Visitor { + mentions_self: false, + }; + use syn::visit::Visit as _; + visitor.visit_signature(sig); + visitor.mentions_self +} + /// Returns `` — the generic param list for function definitions, /// without a where clause. fn generic_params_tokens(generics: &Generics) -> TokenStream2 { From 2a9bacf224d9f29c288ead39164f5f185f9650a1 Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 17:18:50 +0900 Subject: [PATCH 2/6] Take outer generics into consideration when building where clause --- thrust-macros/src/lib.rs | 115 +++++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 9978d11..8436467 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -165,7 +165,6 @@ struct ExpandedTokens { def_generics: TokenStream2, turbofish: TokenStream2, - extended_where: TokenStream2, model_ty_params: TokenStream2, ret_model_ty: Type, @@ -192,8 +191,6 @@ impl ExpandedTokens { let generics = &func.sig.generics; let def_generics = generic_params_tokens(generics); let turbofish = generic_turbofish(generics); - let model_preds = model_where_predicates(generics); - let extended_where = extended_where_clause(generics, &model_preds); let model_ty_params = fn_params_with_model_ty(&func.sig.inputs); let ret_model_ty: Type = match &func.sig.output { @@ -209,7 +206,6 @@ impl ExpandedTokens { ensures_name, def_generics, turbofish, - extended_where, model_ty_params, ret_model_ty, impl_context: None, @@ -221,6 +217,65 @@ impl ExpandedTokens { self } + /// Returns `T: thrust_models::Model` predicates for every type param that does not + /// already carry an `Fn`, `FnOnce`, or `FnMut` bound. + fn model_where_predicates(&self) -> Vec { + let mut generic_type_params: Vec<&syn::TypeParam> = Vec::new(); + for param in &self.func.sig.generics.params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp); + } + if let Some(impl_item) = &self.impl_context { + for param in &impl_item.generics.params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp); + } + } + + let mut predicates: Vec = Vec::new(); + for param in generic_type_params { + let has_fn_bound = param.bounds.iter().any(|b| { + let TypeParamBound::Trait(tb) = b else { + return false; + }; + tb.path.segments.last().map_or(false, |s| { + matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") + }) + }); + if has_fn_bound { + continue; + } + let ident = ¶m.ident; + predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); + predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); + } + predicates + } + + /// Builds `where , `. + /// Returns an empty token stream when both sets are empty. + fn extended_where_clause(&self) -> TokenStream2 { + let model_preds = self.model_where_predicates(); + let existing: Vec<&WherePredicate> = self + .func + .sig + .generics + .where_clause + .as_ref() + .map(|wc| wc.predicates.iter().collect()) + .unwrap_or_default(); + + if existing.is_empty() && model_preds.is_empty() { + return quote!(); + } + + quote! { where #(#existing,)* #(#model_preds),* } + } + fn is_extern_spec_fn(&self) -> bool { let extern_spec_fn_path: syn::Path = syn::parse_quote!(thrust::extern_spec_fn); self.func @@ -233,7 +288,7 @@ impl ExpandedTokens { let requires_name = &self.requires_name; let def_generics = &self.def_generics; let model_ty_params = &self.model_ty_params; - let extended_where = &self.extended_where; + let extended_where = self.extended_where_clause(); let req_expr = &self.req_expr; quote! { @@ -250,7 +305,7 @@ impl ExpandedTokens { let ensures_name = &self.ensures_name; let def_generics = &self.def_generics; let model_ty_params = &self.model_ty_params; - let extended_where = &self.extended_where; + let extended_where = self.extended_where_clause(); let ret_model_ty = &self.ret_model_ty; let ens_expr = &self.ens_expr; @@ -279,7 +334,7 @@ impl ExpandedTokens { let extern_spec_name = format_ident!("_thrust_extern_spec_{}", self.func.sig.ident); let def_generics = &self.def_generics; let orig_output = &self.func.sig.output; - let extended_where = &self.extended_where; + let extended_where = self.extended_where_clause(); let requires_name = &self.requires_name; let ensures_name = &self.ensures_name; @@ -386,52 +441,6 @@ fn generic_turbofish(generics: &Generics) -> TokenStream2 { quote!(::<#(#args),*>) } -/// Returns `T: thrust_models::Model` predicates for every type param that does not -/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. -fn model_where_predicates(generics: &Generics) -> Vec { - generics - .params - .iter() - .flat_map(|p| { - let GenericParam::Type(tp) = p else { - return vec![]; - }; - let has_fn_bound = tp.bounds.iter().any(|b| { - let TypeParamBound::Trait(tb) = b else { - return false; - }; - tb.path.segments.last().map_or(false, |s| { - matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") - }) - }); - if has_fn_bound { - return vec![]; - } - let ident = &tp.ident; - vec![ - syn::parse_quote!(#ident: thrust_models::Model), - syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq), - ] - }) - .collect() -} - -/// Builds `where , `. -/// Returns an empty token stream when both sets are empty. -fn extended_where_clause(generics: &Generics, model_preds: &[WherePredicate]) -> TokenStream2 { - let existing: Vec<&WherePredicate> = generics - .where_clause - .as_ref() - .map(|wc| wc.predicates.iter().collect()) - .unwrap_or_default(); - - if existing.is_empty() && model_preds.is_empty() { - return quote!(); - } - - quote! { where #(#existing,)* #(#model_preds),* } -} - /// Maps each typed function parameter `x: T` to `x: ::Ty`. /// Receiver (`self`) parameters are skipped. fn fn_params_with_model_ty( From e79965acdc8208667a3cdd592985be9884bdf495 Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 15:24:05 +0900 Subject: [PATCH 3/6] Enable to mention receivers in formula_fn --- thrust-macros/Cargo.toml | 2 +- thrust-macros/src/lib.rs | 58 ++++++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/thrust-macros/Cargo.toml b/thrust-macros/Cargo.toml index 419cc79..9fec73b 100644 --- a/thrust-macros/Cargo.toml +++ b/thrust-macros/Cargo.toml @@ -10,4 +10,4 @@ proc-macro = true [dependencies] proc-macro2 = "1" quote = "1" -syn = { version = "2", features = ["full", "visit"] } +syn = { version = "2", features = ["full", "visit", "visit-mut"] } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 8436467..0429c5d 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -183,7 +183,7 @@ impl quote::ToTokens for ExpandedTokens { } impl ExpandedTokens { - pub fn new(func: ItemFn, req_expr: syn::Expr, ens_expr: syn::Expr) -> Self { + pub fn new(func: ItemFn, mut req_expr: syn::Expr, mut ens_expr: syn::Expr) -> Self { let name = &func.sig.ident; let requires_name = format_ident!("_thrust_requires_{}", name); let ensures_name = format_ident!("_thrust_ensures_{}", name); @@ -198,6 +198,11 @@ impl ExpandedTokens { ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), }; + if func.sig.receiver().is_some() { + rewrite_self_in_expr(&mut req_expr); + rewrite_self_in_expr(&mut ens_expr); + } + Self { func, req_expr, @@ -414,6 +419,21 @@ fn mentions_self(sig: &syn::Signature) -> bool { visitor.mentions_self } +fn rewrite_self_in_expr(expr: &mut syn::Expr) { + struct Visitor; + + impl syn::visit_mut::VisitMut for Visitor { + fn visit_ident_mut(&mut self, ident: &mut syn::Ident) { + if ident == "self" { + *ident = format_ident!("self_"); + } + } + } + + use syn::visit_mut::VisitMut as _; + Visitor.visit_expr_mut(expr); +} + /// Returns `` — the generic param list for function definitions, /// without a where clause. fn generic_params_tokens(generics: &Generics) -> TokenStream2 { @@ -441,22 +461,26 @@ fn generic_turbofish(generics: &Generics) -> TokenStream2 { quote!(::<#(#args),*>) } -/// Maps each typed function parameter `x: T` to `x: ::Ty`. -/// Receiver (`self`) parameters are skipped. -fn fn_params_with_model_ty( - inputs: &syn::punctuated::Punctuated, -) -> TokenStream2 { - let params: Vec = inputs - .iter() - .filter_map(|arg| { - let FnArg::Typed(pt) = arg else { return None }; - let pat = &pt.pat; - let ty = &pt.ty; - let model_ty: Type = syn::parse_quote!(<#ty as thrust_models::Model>::Ty); - Some(quote!(#pat: #model_ty)) - }) - .collect(); - quote!(#(#params),*) +/// Maps each function parameter `x: T` to `x: ::Ty`. +fn fn_params_with_model_ty<'ast, I>(args: I) -> TokenStream2 +where + I: IntoIterator, +{ + let mut model_inputs: Vec = Vec::new(); + for arg in args { + match arg { + FnArg::Receiver(receiver) => { + let ty = &receiver.ty; + model_inputs.push(syn::parse_quote!(self_: <#ty as thrust_models::Model>::Ty)); + } + FnArg::Typed(pt) => { + let pat = &pt.pat; + let ty = &pt.ty; + model_inputs.push(syn::parse_quote!(#pat: <#ty as thrust_models::Model>::Ty)); + } + } + } + quote!(#(#model_inputs),*) } /// For the extern_spec wrapper: replaces every typed parameter with a fresh `_arg_N` ident, From b3dba1be13619b0a028418ce6c1a1fead3122c8f Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 17:49:03 +0900 Subject: [PATCH 4/6] Allow resolving path on requires_path and ensures_path --- src/analyze.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 523b56d..6d2bfae 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -610,16 +610,10 @@ impl<'tcx> Analyzer<'tcx> { ); continue; }; - let rustc_hir::QPath::Resolved(_, path) = qpath else { + let typeck = self.tcx.typeck(local_def_id); + let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else { self.tcx.dcx().span_err( expr.span, - "annotated path is expected to be a resolved path", - ); - continue; - }; - let rustc_hir::def::Res::Def(_, def_id) = path.res else { - self.tcx.dcx().span_err( - path.span, "annotated path is expected to refer to a definition", ); continue; From a6b2c664f1cbcb51b924bb60bec03807042ed3f5 Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 17:58:56 +0900 Subject: [PATCH 5/6] Enable to write extern_spec_fn inside impl --- src/analyze/local_def.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 09d0216..69d92c7 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -400,11 +400,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { /// semicolon) in the function body block. pub fn extern_spec_fn_target_def_id(&self) -> DefId { let node = self.tcx.hir_node_by_def_id(self.local_def_id); - let rustc_hir::Node::Item(item) = node else { - panic!("extern_spec_fn must be a function item"); - }; - let rustc_hir::ItemKind::Fn(_, _, body_id) = item.kind else { - panic!("extern_spec_fn must be a function"); + let body_id = match node { + rustc_hir::Node::Item(item) => { + let rustc_hir::ItemKind::Fn(_, _, body_id) = item.kind else { + panic!("extern_spec_fn must be a function"); + }; + body_id + } + rustc_hir::Node::ImplItem(impl_item) => { + let rustc_hir::ImplItemKind::Fn(_, body_id) = impl_item.kind else { + panic!("extern_spec_fn must be a function"); + }; + body_id + } + rustc_hir::Node::TraitItem(trait_item) => { + let rustc_hir::TraitItemKind::Fn(_, rustc_hir::TraitFn::Provided(body_id)) = + trait_item.kind + else { + panic!("extern_spec_fn must be a function with a body"); + }; + body_id + } + _ => panic!("extern_spec_fn must be a function item or impl item"), }; let body = self.tcx.hir().body(body_id); From 91f199ae50273241076a8ff421313f41b738cac6 Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 6 May 2026 17:59:13 +0900 Subject: [PATCH 6/6] Enable to write thrust_macros::{ensures,requires} inside impl --- tests/ui/fail/annot_struct_impl.rs | 53 ++++++++++++++++++++++++++++++ tests/ui/pass/annot_struct_impl.rs | 53 ++++++++++++++++++++++++++++++ thrust-macros/src/lib.rs | 17 +++++++--- 3 files changed, 118 insertions(+), 5 deletions(-) create mode 100644 tests/ui/fail/annot_struct_impl.rs create mode 100644 tests/ui/pass/annot_struct_impl.rs diff --git a/tests/ui/fail/annot_struct_impl.rs b/tests/ui/fail/annot_struct_impl.rs new file mode 100644 index 0000000..5c8719f --- /dev/null +++ b/tests/ui/fail/annot_struct_impl.rs @@ -0,0 +1,53 @@ +//@error-in-other-file: Unsat + +struct VecWrap { + inner: Vec +} + +impl thrust_models::Model for VecWrap where T: thrust_models::Model { + type Ty = ( + thrust_models::model::Array< + thrust_models::model::Int, + ::Ty + >, + thrust_models::model::Int, + ); +} + +#[thrust_macros::context] +impl VecWrap { + #[thrust::trusted] + #[thrust_macros::ensures(result.1 == 0)] + fn new() -> Self { + VecWrap { inner: Vec::new() } + } + + #[thrust::trusted] + #[thrust_macros::ensures((!self).0 == (*self).0.store((*self).1, elem))] + #[thrust_macros::ensures((!self).1 == (*self).1)] + fn push(&mut self, elem: T) { + self.inner.push(elem); + } + + #[thrust::trusted] + #[thrust_macros::ensures(result == (*self).1)] + fn len(&self) -> usize { + self.inner.len() + } + + #[thrust::trusted] + #[thrust_macros::requires(index < (*self).1)] + #[thrust_macros::ensures(*result == (*self).0[index])] + fn get(&self, index: usize) -> &T { + &self.inner[index] + } +} + +fn main() { + let mut v = VecWrap::new(); + v.push(10); + v.push(20); + assert!(v.len() == 2); + assert!(*v.get(0) == 10); + assert!(*v.get(1) == 20); +} diff --git a/tests/ui/pass/annot_struct_impl.rs b/tests/ui/pass/annot_struct_impl.rs new file mode 100644 index 0000000..6d51a16 --- /dev/null +++ b/tests/ui/pass/annot_struct_impl.rs @@ -0,0 +1,53 @@ +//@check-pass + +struct VecWrap { + inner: Vec +} + +impl thrust_models::Model for VecWrap where T: thrust_models::Model { + type Ty = ( + thrust_models::model::Array< + thrust_models::model::Int, + ::Ty + >, + thrust_models::model::Int, + ); +} + +#[thrust_macros::context] +impl VecWrap { + #[thrust::trusted] + #[thrust_macros::ensures(result.1 == 0)] + fn new() -> Self { + VecWrap { inner: Vec::new() } + } + + #[thrust::trusted] + #[thrust_macros::ensures((!self).0 == (*self).0.store((*self).1, elem))] + #[thrust_macros::ensures((!self).1 == (*self).1 + 1)] + fn push(&mut self, elem: T) { + self.inner.push(elem); + } + + #[thrust::trusted] + #[thrust_macros::ensures(result == (*self).1)] + fn len(&self) -> usize { + self.inner.len() + } + + #[thrust::trusted] + #[thrust_macros::requires(index < (*self).1)] + #[thrust_macros::ensures(*result == (*self).0[index])] + fn get(&self, index: usize) -> &T { + &self.inner[index] + } +} + +fn main() { + let mut v = VecWrap::new(); + v.push(10); + v.push(20); + assert!(v.len() == 2); + assert!(*v.get(0) == 10); + assert!(*v.get(1) == 20); +} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 0429c5d..c5ccf9d 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -324,6 +324,11 @@ impl ExpandedTokens { } } + fn path_prefix(&self) -> Option { + self.impl_context.as_ref()?; + Some(quote!(Self::)) + } + fn expand(&self) -> TokenStream2 { let mut func = self.func.clone(); let trusted_path: syn::Path = syn::parse_quote!(thrust::trusted); @@ -344,6 +349,7 @@ impl ExpandedTokens { let requires_name = &self.requires_name; let ensures_name = &self.ensures_name; let turbofish = &self.turbofish; + let path_prefix = self.path_prefix(); let name = &self.func.sig.ident; let (extern_spec_inputs, call_args) = rewrite_inputs_for_call(&self.func.sig.inputs); @@ -358,12 +364,12 @@ impl ExpandedTokens { #[allow(path_statements)] fn #extern_spec_name #def_generics(#extern_spec_inputs) #orig_output #extended_where { #[thrust::requires_path] - #requires_name #turbofish; + #path_prefix #requires_name #turbofish; #[thrust::ensures_path] - #ensures_name #turbofish; + #path_prefix #ensures_name #turbofish; - #name #turbofish(#call_args) + #path_prefix #name #turbofish(#call_args) } } } @@ -372,15 +378,16 @@ impl ExpandedTokens { let requires_name = &self.requires_name; let ensures_name = &self.ensures_name; let turbofish = &self.turbofish; + let path_prefix = self.path_prefix(); let mut func = self.func.clone(); let orig_stmts = func.block.stmts.clone(); func.block = syn::parse_quote!({ #[thrust::requires_path] - #requires_name #turbofish; + #path_prefix #requires_name #turbofish; #[thrust::ensures_path] - #ensures_name #turbofish; + #path_prefix #ensures_name #turbofish; #(#orig_stmts)* });