From 43af1ebe6e4d2bb9ec90167616a3d88bf34c7c0b Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Sat, 9 Sep 2023 06:28:48 +0900 Subject: [PATCH] Handle bitflags bits method calls --- Cargo.toml | 2 +- src/bindgen/bitflags.rs | 82 +++++++++++++++++++++++++++++++++++++++-- tests/rust/bitflags.rs | 8 ++-- 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e654e58b8..1402d6ab3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ heck = "0.4" [dependencies.syn] version = "1.0.88" default-features = false -features = ["clone-impls", "extra-traits", "full", "parsing", "printing"] +features = ["clone-impls", "extra-traits", "fold", "full", "parsing", "printing"] [dev-dependencies] serial_test = "0.5.0" diff --git a/src/bindgen/bitflags.rs b/src/bindgen/bitflags.rs index 28412940e..6ebb1d912 100644 --- a/src/bindgen/bitflags.rs +++ b/src/bindgen/bitflags.rs @@ -3,6 +3,8 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use proc_macro2::TokenStream; +use std::collections::HashSet; +use syn::fold::Fold; use syn::parse::{Parse, ParseStream, Parser, Result as ParseResult}; // $(#[$outer:meta])* @@ -84,17 +86,86 @@ struct Flag { semicolon_token: Token![;], } +struct FlagValueFold<'a> { + struct_name: &'a syn::Ident, + flag_names: &'a HashSet, +} + +impl<'a> FlagValueFold<'a> { + fn is_self(&self, ident: &syn::Ident) -> bool { + ident == self.struct_name || ident == "Self" + } +} + +impl<'a> Fold for FlagValueFold<'a> { + fn fold_expr(&mut self, node: syn::Expr) -> syn::Expr { + // bitflags 2 doesn't expose `bits` publically anymore, and the documented way to + // combine flags is using the `bits` method, e.g. + // ``` + // bitflags! { + // struct Flags: u8 { + // const A = 1; + // const B = 1 << 1; + // const AB = Flags::A.bits() | Flags::B.bits(); + // } + // } + // ``` + // As we're transforming the struct definition into `struct StructName { bits: T }` + // as far as our bindings generation is concerned, `bits` is available as a field, + // so by replacing `StructName::FLAG.bits()` with `StructName::FLAG.bits`, we make + // e.g. `Flags::AB` available in the generated bindings. + match node { + syn::Expr::MethodCall(syn::ExprMethodCall { + attrs, + receiver, + dot_token, + method, + args, + .. + }) if method == "bits" + && args.is_empty() + && matches!(&*receiver, + syn::Expr::Path(syn::ExprPath { path, .. }) + if path.segments.len() == 2 + && self.is_self(&path.segments.first().unwrap().ident) + && self + .flag_names + .contains(&path.segments.last().unwrap().ident.to_string())) => + { + return syn::Expr::Field(syn::ExprField { + attrs, + base: receiver, + dot_token, + member: syn::Member::Named(method), + }); + } + _ => {} + } + syn::fold::fold_expr(self, node) + } +} + impl Flag { - fn expand(&self, struct_name: &syn::Ident, repr: &syn::Type) -> TokenStream { + fn expand( + &self, + struct_name: &syn::Ident, + repr: &syn::Type, + flag_names: &HashSet, + ) -> TokenStream { let Flag { ref attrs, ref name, ref value, .. } = *self; + let folded_value = FlagValueFold { + struct_name, + flag_names, + } + .fold_expr(value.clone()); quote! { #(#attrs)* - pub const #name : #struct_name = #struct_name { bits: (#value) as #repr }; + pub const #name : #struct_name = #struct_name { bits: (#folded_value) as #repr }; } } } @@ -130,8 +201,13 @@ impl Parse for Flags { impl Flags { fn expand(&self, struct_name: &syn::Ident, repr: &syn::Type) -> TokenStream { let mut ts = quote! {}; + let flag_names = self + .0 + .iter() + .map(|flag| flag.name.to_string()) + .collect::>(); for flag in &self.0 { - ts.extend(flag.expand(struct_name, repr)); + ts.extend(flag.expand(struct_name, repr, &flag_names)); } ts } diff --git a/tests/rust/bitflags.rs b/tests/rust/bitflags.rs index 7e78bd815..1fe1c6e3e 100644 --- a/tests/rust/bitflags.rs +++ b/tests/rust/bitflags.rs @@ -13,11 +13,11 @@ bitflags! { const START = 1 << 1; /// 'end' const END = 1 << 2; - const ALIAS = Self::END.bits; + const ALIAS = Self::END.bits(); /// 'flex-start' const FLEX_START = 1 << 3; - const MIXED = 1 << 4 | AlignFlags::FLEX_START.bits | AlignFlags::END.bits; - const MIXED_SELF = 1 << 5 | AlignFlags::FLEX_START.bits | AlignFlags::END.bits; + const MIXED = 1 << 4 | AlignFlags::FLEX_START.bits() | AlignFlags::END.bits(); + const MIXED_SELF = 1 << 5 | AlignFlags::FLEX_START.bits() | AlignFlags::END.bits(); } } @@ -34,7 +34,7 @@ bitflags! { pub struct LargeFlags: u64 { /// Flag with a very large shift that usually would be narrowed. const LARGE_SHIFT = 1u64 << 44; - const INVERTED = !Self::LARGE_SHIFT.bits; + const INVERTED = !Self::LARGE_SHIFT.bits(); } }