diff --git a/crates/bevy_macro_utils/src/attrs.rs b/crates/bevy_macro_utils/src/attrs.rs index 196c4868452dd..4f2ed491e7e0d 100644 --- a/crates/bevy_macro_utils/src/attrs.rs +++ b/crates/bevy_macro_utils/src/attrs.rs @@ -31,3 +31,17 @@ pub fn get_lit_str(attr_name: Symbol, lit: &syn::Lit) -> syn::Result<&syn::LitSt )) } } + +pub fn get_lit_bool(attr_name: Symbol, lit: &syn::Lit) -> syn::Result { + if let syn::Lit::Bool(lit) = lit { + Ok(lit.value()) + } else { + Err(syn::Error::new_spanned( + lit, + format!( + "expected {} attribute to be a bool value, `true` or `false`: `{} = ...`", + attr_name, attr_name + ), + )) + } +} diff --git a/crates/bevy_render/macros/src/as_bind_group.rs b/crates/bevy_render/macros/src/as_bind_group.rs index f0c196dfd308c..30888a5227055 100644 --- a/crates/bevy_render/macros/src/as_bind_group.rs +++ b/crates/bevy_render/macros/src/as_bind_group.rs @@ -1,17 +1,17 @@ -use bevy_macro_utils::BevyManifest; +use bevy_macro_utils::{get_lit_bool, get_lit_str, BevyManifest, Symbol}; use proc_macro::TokenStream; use proc_macro2::{Ident, Span}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{ - parse::ParseStream, parse_macro_input, token::Comma, Data, DataStruct, DeriveInput, Field, - Fields, LitInt, + parse::{Parse, ParseStream}, + punctuated::Punctuated, + Data, DataStruct, Error, Fields, LitInt, LitStr, NestedMeta, Result, Token, }; -const BINDING_ATTRIBUTE_NAME: &str = "binding"; -const UNIFORM_ATTRIBUTE_NAME: &str = "uniform"; -const TEXTURE_ATTRIBUTE_NAME: &str = "texture"; -const SAMPLER_ATTRIBUTE_NAME: &str = "sampler"; -const BIND_GROUP_DATA_ATTRIBUTE_NAME: &str = "bind_group_data"; +const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform"); +const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture"); +const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler"); +const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data"); #[derive(Copy, Clone, Debug)] enum BindingType { @@ -29,13 +29,11 @@ enum BindingState<'a> { }, OccupiedConvertedUniform, OccupiedMergableUniform { - uniform_fields: Vec<&'a Field>, + uniform_fields: Vec<&'a syn::Field>, }, } -pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { - let ast = parse_macro_input!(input as DeriveInput); - +pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result { let manifest = BevyManifest::default(); let render_path = manifest.get_path("bevy_render"); let asset_path = manifest.get_path("bevy_asset"); @@ -56,18 +54,7 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { attr_prepared_data_ident = Some(prepared_data_ident); } } else if attr_ident == UNIFORM_ATTRIBUTE_NAME { - let (binding_index, converted_shader_type) = attr - .parse_args_with(|input: ParseStream| { - let binding_index = input - .parse::() - .and_then(|i| i.base10_parse::())?; - input.parse::()?; - let converted_shader_type = input.parse::()?; - Ok((binding_index, converted_shader_type)) - }) - .unwrap_or_else(|_| { - panic!("struct-level uniform bindings must be in the format: uniform(BINDING_INDEX, ConvertedShaderType)"); - }); + let (binding_index, converted_shader_type) = get_uniform_binding_attr(attr)?; binding_impls.push(quote! {{ use #render_path::render_resource::AsBindGroupShaderType; @@ -118,7 +105,12 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { fields: Fields::Named(fields), .. }) => &fields.named, - _ => panic!("Expected a struct with named fields"), + _ => { + return Err(Error::new_spanned( + ast, + "Expected a struct with named fields", + )); + } }; // Read field-level attributes @@ -140,17 +132,7 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { continue; }; - let binding_index = attr - .parse_args_with(|input: ParseStream| { - let binding_index = input - .parse::() - .and_then(|i| i.base10_parse::()) - .expect("binding index was not a valid u32"); - Ok(binding_index) - }) - .unwrap_or_else(|_| { - panic!("Invalid `{}` attribute format", BINDING_ATTRIBUTE_NAME) - }); + let (binding_index, nested_meta_items) = get_binding_nested_attr(attr)?; let field_name = field.ident.as_ref().unwrap(); let required_len = binding_index as usize + 1; @@ -175,23 +157,36 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { } }); BindingState::Occupied { - binding_type, - ident: field_name, - }}, + binding_type, + ident: field_name, + } + } } - }, - BindingState::Occupied { binding_type, ident: occupied_ident} => panic!( - "The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by the field '{occupied_ident}' of type {binding_type:?}." - ), - BindingState::OccupiedConvertedUniform => panic!( - "The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by a struct-level uniform binding at the same index." - ), - BindingState::OccupiedMergableUniform { uniform_fields } => { - match binding_type { - BindingType::Uniform => { - uniform_fields.push(field); - }, - _ => {panic!("The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by a {:?}.", BindingType::Uniform)}, + } + BindingState::Occupied { + binding_type, + ident: occupied_ident, + } => { + return Err(Error::new_spanned( + attr, + format!("The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by the field '{occupied_ident}' of type {binding_type:?}.") + )); + } + BindingState::OccupiedConvertedUniform => { + return Err(Error::new_spanned( + attr, + format!("The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by a struct-level uniform binding at the same index.") + )); + } + BindingState::OccupiedMergableUniform { uniform_fields } => match binding_type { + BindingType::Uniform => { + uniform_fields.push(field); + } + _ => { + return Err(Error::new_spanned( + attr, + format!("The '{field_name}' field cannot be assigned to binding {binding_index} because it is already occupied by a {:?}.", BindingType::Uniform) + )); } }, } @@ -200,6 +195,16 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { BindingType::Uniform => { /* uniform codegen is deferred to account for combined uniform bindings */ } BindingType::Texture => { + let TextureAttrs { + dimension, + sample_type, + multisampled, + visibility, + } = get_texture_attrs(nested_meta_items)?; + + let visibility = + visibility.hygenic_quote("e! { #render_path::render_resource }); + binding_impls.push(quote! { #render_path::render_resource::OwnedBindingResource::TextureView({ let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into(); @@ -211,20 +216,28 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { }) }); - binding_layouts.push(quote!{ + binding_layouts.push(quote! { #render_path::render_resource::BindGroupLayoutEntry { binding: #binding_index, - visibility: #render_path::render_resource::ShaderStages::all(), + visibility: #visibility, ty: #render_path::render_resource::BindingType::Texture { - multisampled: false, - sample_type: #render_path::render_resource::TextureSampleType::Float { filterable: true }, - view_dimension: #render_path::render_resource::TextureViewDimension::D2, + multisampled: #multisampled, + sample_type: #render_path::render_resource::#sample_type, + view_dimension: #render_path::render_resource::#dimension, }, count: None, } }); } BindingType::Sampler => { + let SamplerAttrs { + sampler_binding_type, + visibility, + } = get_sampler_attrs(nested_meta_items)?; + + let visibility = + visibility.hygenic_quote("e! { #render_path::render_resource }); + binding_impls.push(quote! { #render_path::render_resource::OwnedBindingResource::Sampler({ let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into(); @@ -239,8 +252,8 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { binding_layouts.push(quote!{ #render_path::render_resource::BindGroupLayoutEntry { binding: #binding_index, - visibility: #render_path::render_resource::ShaderStages::all(), - ty: #render_path::render_resource::BindingType::Sampler(#render_path::render_resource::SamplerBindingType::Filtering), + visibility: #visibility, + ty: #render_path::render_resource::BindingType::Sampler(#render_path::render_resource::#sampler_binding_type), count: None, } }); @@ -297,6 +310,7 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { &format!("_{struct_name}AsBindGroupUniformStructBindGroup{binding_index}"), Span::call_site(), ); + let field_name = uniform_fields.iter().map(|f| f.ident.as_ref().unwrap()); let field_type = uniform_fields.iter().map(|f| &f.ty); field_struct_impls.push(quote! { @@ -348,7 +362,7 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { (prepared_data.clone(), prepared_data) }; - TokenStream::from(quote! { + Ok(TokenStream::from(quote! { #(#field_struct_impls)* impl #impl_generics #render_path::render_resource::AsBindGroup for #struct_name #ty_generics #where_clause { @@ -385,5 +399,467 @@ pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { }) } } + })) +} + +/// Represents the arguments for the `uniform` binding attribute. +/// +/// If parsed, represents an attribute +/// like `#[uniform(LitInt, Ident)]` +struct UniformBindingMeta { + lit_int: LitInt, + _comma: Token![,], + ident: Ident, +} + +/// Represents the arguments for any general binding attribute. +/// +/// If parsed, represents an attribute +/// like `#[foo(LitInt, ...)]` where the rest is optional `NestedMeta`. +enum BindingMeta { + IndexOnly(LitInt), + IndexWithOptions(BindingIndexOptions), +} + +/// Represents the arguments for an attribute with a list of arguments. +/// +/// This represents, for example, `#[texture(0, dimension = "2d_array")]`. +struct BindingIndexOptions { + lit_int: LitInt, + _comma: Token![,], + meta_list: Punctuated, +} + +impl Parse for BindingMeta { + fn parse(input: ParseStream) -> Result { + if input.peek2(Token![,]) { + input.parse().map(Self::IndexWithOptions) + } else { + input.parse().map(Self::IndexOnly) + } + } +} + +impl Parse for BindingIndexOptions { + fn parse(input: ParseStream) -> Result { + Ok(Self { + lit_int: input.parse()?, + _comma: input.parse()?, + meta_list: input.parse_terminated(NestedMeta::parse)?, + }) + } +} + +impl Parse for UniformBindingMeta { + fn parse(input: ParseStream) -> Result { + Ok(Self { + lit_int: input.parse()?, + _comma: input.parse()?, + ident: input.parse()?, + }) + } +} + +fn get_uniform_binding_attr(attr: &syn::Attribute) -> Result<(u32, Ident)> { + let uniform_binding_meta = attr.parse_args_with(UniformBindingMeta::parse)?; + + let binding_index = uniform_binding_meta.lit_int.base10_parse()?; + let ident = uniform_binding_meta.ident; + + Ok((binding_index, ident)) +} + +fn get_binding_nested_attr(attr: &syn::Attribute) -> Result<(u32, Vec)> { + let binding_meta = attr.parse_args_with(BindingMeta::parse)?; + + match binding_meta { + BindingMeta::IndexOnly(lit_int) => Ok((lit_int.base10_parse()?, Vec::new())), + BindingMeta::IndexWithOptions(BindingIndexOptions { + lit_int, + _comma: _, + meta_list, + }) => Ok((lit_int.base10_parse()?, meta_list.into_iter().collect())), + } +} + +#[derive(Default)] +enum ShaderStageVisibility { + #[default] + All, + None, + Flags(VisibilityFlags), +} + +#[derive(Default)] +struct VisibilityFlags { + vertex: bool, + fragment: bool, + compute: bool, +} + +impl ShaderStageVisibility { + fn vertex_fragment() -> Self { + Self::Flags(VisibilityFlags::vertex_fragment()) + } +} + +impl VisibilityFlags { + fn vertex_fragment() -> Self { + Self { + vertex: true, + fragment: true, + ..Default::default() + } + } +} + +impl ShaderStageVisibility { + fn hygenic_quote(&self, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream { + match self { + ShaderStageVisibility::All => quote! { #path::ShaderStages::all() }, + ShaderStageVisibility::None => quote! { #path::ShaderStages::NONE }, + ShaderStageVisibility::Flags(flags) => { + let mut quoted = Vec::new(); + + if flags.vertex { + quoted.push(quote! { #path::ShaderStages::VERTEX }); + } + if flags.fragment { + quoted.push(quote! { #path::ShaderStages::FRAGMENT }); + } + if flags.compute { + quoted.push(quote! { #path::ShaderStages::COMPUTE }); + } + + quote! { #(#quoted)|* } + } + } + } +} + +const VISIBILITY: Symbol = Symbol("visibility"); +const VISIBILITY_VERTEX: Symbol = Symbol("vertex"); +const VISIBILITY_FRAGMENT: Symbol = Symbol("fragment"); +const VISIBILITY_COMPUTE: Symbol = Symbol("compute"); +const VISIBILITY_ALL: Symbol = Symbol("all"); +const VISIBILITY_NONE: Symbol = Symbol("none"); + +fn get_visibility_flag_value( + nested_metas: &Punctuated, +) -> Result { + let mut visibility = VisibilityFlags::vertex_fragment(); + + for meta in nested_metas { + use syn::{Meta::Path, NestedMeta::Meta}; + match meta { + // Parse `visibility(all)]`. + Meta(Path(path)) if path == VISIBILITY_ALL => { + return Ok(ShaderStageVisibility::All) + } + // Parse `visibility(none)]`. + Meta(Path(path)) if path == VISIBILITY_NONE => { + return Ok(ShaderStageVisibility::None) + } + // Parse `visibility(vertex, ...)]`. + Meta(Path(path)) if path == VISIBILITY_VERTEX => { + visibility.vertex = true; + } + // Parse `visibility(fragment, ...)]`. + Meta(Path(path)) if path == VISIBILITY_FRAGMENT => { + visibility.fragment = true; + } + // Parse `visibility(compute, ...)]`. + Meta(Path(path)) if path == VISIBILITY_COMPUTE => { + visibility.compute = true; + } + Meta(Path(path)) => return Err(Error::new_spanned( + path, + "Not a valid visibility flag. Must be `all`, `none`, or a list-combination of `vertex`, `fragment` and/or `compute`." + )), + _ => return Err(Error::new_spanned( + meta, + "Invalid visibility format: `visibility(...)`.", + )), + } + } + + Ok(ShaderStageVisibility::Flags(visibility)) +} + +#[derive(Default)] +enum BindingTextureDimension { + D1, + #[default] + D2, + D2Array, + Cube, + CubeArray, + D3, +} + +enum BindingTextureSampleType { + Float { filterable: bool }, + Depth, + Sint, + Uint, +} + +impl ToTokens for BindingTextureDimension { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.extend(match self { + BindingTextureDimension::D1 => quote! { TextureViewDimension::D1 }, + BindingTextureDimension::D2 => quote! { TextureViewDimension::D2 }, + BindingTextureDimension::D2Array => quote! { TextureViewDimension::D2Array }, + BindingTextureDimension::Cube => quote! { TextureViewDimension::Cube }, + BindingTextureDimension::CubeArray => quote! { TextureViewDimension::CubeArray }, + BindingTextureDimension::D3 => quote! { TextureViewDimension::D3 }, + }); + } +} + +impl ToTokens for BindingTextureSampleType { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.extend(match self { + BindingTextureSampleType::Float { filterable } => { + quote! { TextureSampleType::Float { filterable: #filterable } } + } + BindingTextureSampleType::Depth => quote! { TextureSampleType::Depth }, + BindingTextureSampleType::Sint => quote! { TextureSampleType::Sint }, + BindingTextureSampleType::Uint => quote! { TextureSampleType::Uint }, + }); + } +} + +struct TextureAttrs { + dimension: BindingTextureDimension, + sample_type: BindingTextureSampleType, + multisampled: bool, + visibility: ShaderStageVisibility, +} + +impl Default for BindingTextureSampleType { + fn default() -> Self { + BindingTextureSampleType::Float { filterable: true } + } +} + +impl Default for TextureAttrs { + fn default() -> Self { + Self { + dimension: Default::default(), + sample_type: Default::default(), + multisampled: true, + visibility: Default::default(), + } + } +} + +const DIMENSION: Symbol = Symbol("dimension"); +const SAMPLE_TYPE: Symbol = Symbol("sample_type"); +const FILTERABLE: Symbol = Symbol("filterable"); +const MULTISAMPLED: Symbol = Symbol("multisampled"); + +// Values for `dimension` attribute. +const DIM_1D: &str = "1d"; +const DIM_2D: &str = "2d"; +const DIM_3D: &str = "3d"; +const DIM_2D_ARRAY: &str = "2d_array"; +const DIM_CUBE: &str = "cube"; +const DIM_CUBE_ARRAY: &str = "cube_array"; + +// Values for sample `type` attribute. +const FLOAT: &str = "float"; +const DEPTH: &str = "depth"; +const S_INT: &str = "s_int"; +const U_INT: &str = "u_int"; + +fn get_texture_attrs(metas: Vec) -> Result { + let mut dimension = Default::default(); + let mut sample_type = Default::default(); + let mut multisampled = Default::default(); + let mut filterable = None; + let mut filterable_ident = None; + + let mut visibility = ShaderStageVisibility::vertex_fragment(); + + for meta in metas { + use syn::{ + Meta::{List, NameValue}, + NestedMeta::Meta, + }; + match meta { + // Parse #[texture(0, dimension = "...")]. + Meta(NameValue(m)) if m.path == DIMENSION => { + let value = get_lit_str(DIMENSION, &m.lit)?; + dimension = get_texture_dimension_value(value)?; + } + // Parse #[texture(0, sample_type = "...")]. + Meta(NameValue(m)) if m.path == SAMPLE_TYPE => { + let value = get_lit_str(SAMPLE_TYPE, &m.lit)?; + sample_type = get_texture_sample_type_value(value)?; + } + // Parse #[texture(0, multisampled = "...")]. + Meta(NameValue(m)) if m.path == MULTISAMPLED => { + multisampled = get_lit_bool(MULTISAMPLED, &m.lit)?; + } + // Parse #[texture(0, filterable = "...")]. + Meta(NameValue(m)) if m.path == FILTERABLE => { + filterable = get_lit_bool(FILTERABLE, &m.lit)?.into(); + filterable_ident = m.path.into(); + } + // Parse #[texture(0, visibility(...))]. + Meta(List(m)) if m.path == VISIBILITY => { + visibility = get_visibility_flag_value(&m.nested)?; + } + Meta(NameValue(m)) => { + return Err(Error::new_spanned( + m.path, + "Not a valid name. Available attributes: `dimension`, `sample_type`, `multisampled`, or `filterable`." + )); + } + _ => { + return Err(Error::new_spanned( + meta, + "Not a name value pair: `foo = \"...\"`", + )); + } + } + } + + // Resolve `filterable` since the float + // sample type is the one that contains the value. + if let Some(filterable) = filterable { + let path = filterable_ident.unwrap(); + match sample_type { + BindingTextureSampleType::Float { filterable: _ } => { + sample_type = BindingTextureSampleType::Float { filterable } + } + _ => { + return Err(Error::new_spanned( + path, + "Type must be `float` to use the `filterable` attribute.", + )); + } + }; + } + + Ok(TextureAttrs { + dimension, + sample_type, + multisampled, + visibility, }) } + +fn get_texture_dimension_value(lit_str: &LitStr) -> Result { + match lit_str.value().as_str() { + DIM_1D => Ok(BindingTextureDimension::D1), + DIM_2D => Ok(BindingTextureDimension::D2), + DIM_2D_ARRAY => Ok(BindingTextureDimension::D2Array), + DIM_3D => Ok(BindingTextureDimension::D3), + DIM_CUBE => Ok(BindingTextureDimension::Cube), + DIM_CUBE_ARRAY => Ok(BindingTextureDimension::CubeArray), + + _ => Err(Error::new_spanned( + lit_str, + "Not a valid dimension. Must be `1d`, `2d`, `2d_array`, `3d`, `cube` or `cube_array`.", + )), + } +} + +fn get_texture_sample_type_value(lit_str: &LitStr) -> Result { + match lit_str.value().as_str() { + FLOAT => Ok(BindingTextureSampleType::Float { filterable: true }), + DEPTH => Ok(BindingTextureSampleType::Depth), + S_INT => Ok(BindingTextureSampleType::Sint), + U_INT => Ok(BindingTextureSampleType::Uint), + + _ => Err(Error::new_spanned( + lit_str, + "Not a valid sample type. Must be `float`, `depth`, `s_int` or `u_int`.", + )), + } +} + +#[derive(Default)] +struct SamplerAttrs { + sampler_binding_type: SamplerBindingType, + visibility: ShaderStageVisibility, +} + +#[derive(Default)] +enum SamplerBindingType { + #[default] + Filtering, + NonFiltering, + Comparison, +} + +impl ToTokens for SamplerBindingType { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.extend(match self { + SamplerBindingType::Filtering => quote! { SamplerBindingType::Filtering }, + SamplerBindingType::NonFiltering => quote! { SamplerBindingType::NonFiltering }, + SamplerBindingType::Comparison => quote! { SamplerBindingType::Comparison }, + }); + } +} + +const SAMPLER_TYPE: Symbol = Symbol("sampler_type"); + +const FILTERING: &str = "filtering"; +const NON_FILTERING: &str = "non_filtering"; +const COMPARISON: &str = "comparison"; + +fn get_sampler_attrs(metas: Vec) -> Result { + let mut sampler_binding_type = Default::default(); + let mut visibility = ShaderStageVisibility::vertex_fragment(); + + for meta in metas { + use syn::{ + Meta::{List, NameValue}, + NestedMeta::Meta, + }; + match meta { + // Parse #[sampler(0, sampler_type = "..."))]. + Meta(NameValue(m)) if m.path == SAMPLER_TYPE => { + let value = get_lit_str(DIMENSION, &m.lit)?; + sampler_binding_type = get_sampler_binding_type_value(value)?; + } + // Parse #[sampler(0, visibility(...))]. + Meta(List(m)) if m.path == VISIBILITY => { + visibility = get_visibility_flag_value(&m.nested)?; + } + Meta(NameValue(m)) => { + return Err(Error::new_spanned( + m.path, + "Not a valid name. Available attributes: `sampler_type`.", + )); + } + _ => { + return Err(Error::new_spanned( + meta, + "Not a name value pair: `foo = \"...\"`", + )); + } + } + } + + Ok(SamplerAttrs { + sampler_binding_type, + visibility, + }) +} + +fn get_sampler_binding_type_value(lit_str: &LitStr) -> Result { + match lit_str.value().as_str() { + FILTERING => Ok(SamplerBindingType::Filtering), + NON_FILTERING => Ok(SamplerBindingType::NonFiltering), + COMPARISON => Ok(SamplerBindingType::Comparison), + + _ => Err(Error::new_spanned( + lit_str, + "Not a valid dimension. Must be `filtering`, `non_filtering`, or `comparison`.", + )), + } +} diff --git a/crates/bevy_render/macros/src/lib.rs b/crates/bevy_render/macros/src/lib.rs index 5e0851a69a3dc..863b48baf5fa0 100644 --- a/crates/bevy_render/macros/src/lib.rs +++ b/crates/bevy_render/macros/src/lib.rs @@ -3,6 +3,7 @@ mod extract_resource; use bevy_macro_utils::BevyManifest; use proc_macro::TokenStream; +use syn::{parse_macro_input, DeriveInput}; pub(crate) fn bevy_render_path() -> syn::Path { BevyManifest::default() @@ -18,5 +19,7 @@ pub fn derive_extract_resource(input: TokenStream) -> TokenStream { #[proc_macro_derive(AsBindGroup, attributes(uniform, texture, sampler, bind_group_data))] pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { - as_bind_group::derive_as_bind_group(input) + let input = parse_macro_input!(input as DeriveInput); + + as_bind_group::derive_as_bind_group(input).unwrap_or_else(|err| err.to_compile_error().into()) } diff --git a/examples/shader/array_texture.rs b/examples/shader/array_texture.rs index 59873724765bd..2c288625d8fb9 100644 --- a/examples/shader/array_texture.rs +++ b/examples/shader/array_texture.rs @@ -2,17 +2,7 @@ use bevy::{ asset::LoadState, prelude::*, reflect::TypeUuid, - render::{ - render_asset::RenderAssets, - render_resource::{ - AsBindGroup, AsBindGroupError, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, - BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, - OwnedBindingResource, PreparedBindGroup, SamplerBindingType, ShaderRef, ShaderStages, - TextureSampleType, TextureViewDimension, - }, - renderer::RenderDevice, - texture::FallbackImage, - }, + render::render_resource::{AsBindGroup, ShaderRef}, }; /// This example illustrates how to create a texture for use with a `texture_2d_array` shader @@ -98,9 +88,11 @@ fn create_array_texture( } } -#[derive(Debug, Clone, TypeUuid)] +#[derive(AsBindGroup, Debug, Clone, TypeUuid)] #[uuid = "9c5a0ddf-1eaf-41b4-9832-ed736fd26af3"] struct ArrayTextureMaterial { + #[texture(0, dimension = "2d_array")] + #[sampler(1)] array_texture: Handle, } @@ -109,68 +101,3 @@ impl Material for ArrayTextureMaterial { "shaders/array_texture.wgsl".into() } } - -impl AsBindGroup for ArrayTextureMaterial { - type Data = (); - - fn as_bind_group( - &self, - layout: &BindGroupLayout, - render_device: &RenderDevice, - images: &RenderAssets, - _fallback_image: &FallbackImage, - ) -> Result, AsBindGroupError> { - let image = images - .get(&self.array_texture) - .ok_or(AsBindGroupError::RetryNextUpdate)?; - let bind_group = render_device.create_bind_group(&BindGroupDescriptor { - entries: &[ - BindGroupEntry { - binding: 0, - resource: BindingResource::TextureView(&image.texture_view), - }, - BindGroupEntry { - binding: 1, - resource: BindingResource::Sampler(&image.sampler), - }, - ], - label: Some("array_texture_material_bind_group"), - layout, - }); - - Ok(PreparedBindGroup { - bind_group, - bindings: vec![ - OwnedBindingResource::TextureView(image.texture_view.clone()), - OwnedBindingResource::Sampler(image.sampler.clone()), - ], - data: (), - }) - } - - fn bind_group_layout(render_device: &RenderDevice) -> BindGroupLayout { - render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { - entries: &[ - // Array Texture - BindGroupLayoutEntry { - binding: 0, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Texture { - multisampled: false, - sample_type: TextureSampleType::Float { filterable: true }, - view_dimension: TextureViewDimension::D2Array, - }, - count: None, - }, - // Array Texture Sampler - BindGroupLayoutEntry { - binding: 1, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Sampler(SamplerBindingType::Filtering), - count: None, - }, - ], - label: None, - }) - } -}