diff --git a/Cargo.toml b/Cargo.toml index 47d8ea3c9fe93..d47dab1d6f7af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1385,11 +1385,21 @@ name = "post_processing" path = "examples/shader/post_processing.rs" [package.metadata.example.post_processing] -name = "Post Processing" +name = "Post Processing - Render To Texture" description = "A custom post processing effect, using two cameras, with one reusing the render texture of the first one" category = "Shaders" wasm = true +[[example]] +name = "post_process_pass" +path = "examples/shader/post_process_pass.rs" + +[package.metadata.example.post_process_pass] +name = "Post Processing - Custom Render Pass" +description = "A custom post processing effect, using a custom render pass that runs after the main pass" +category = "Shaders" +wasm = true + [[example]] name = "shader_defs" path = "examples/shader/shader_defs.rs" diff --git a/assets/shaders/post_process_pass.wgsl b/assets/shaders/post_process_pass.wgsl new file mode 100644 index 0000000000000..b25b5788cc8a6 --- /dev/null +++ b/assets/shaders/post_process_pass.wgsl @@ -0,0 +1,48 @@ +// This shader computes the chromatic aberration effect + +#import bevy_pbr::utils + +// Since post processing is a fullscreen effect, we use the fullscreen vertex shader provided by bevy. +// This will import a vertex shader that renders a single fullscreen triangle. +// +// A fullscreen triangle is a single triangle that covers the entire screen. +// The box in the top left in that diagram is the screen. The 4 x are the corner of the screen +// +// Y axis +// 1 | x-----x...... +// 0 | | s | . ´ +// -1 | x_____x´ +// -2 | : .´ +// -3 | :´ +// +--------------- X axis +// -1 0 1 2 3 +// +// As you can see, the triangle ends up bigger than the screen. +// +// You don't need to worry about this too much since bevy will compute the correct UVs for you. +#import bevy_core_pipeline::fullscreen_vertex_shader + +@group(0) @binding(0) +var screen_texture: texture_2d; +@group(0) @binding(1) +var texture_sampler: sampler; +struct PostProcessSettings { + intensity: f32, +} +@group(0) @binding(2) +var settings: PostProcessSettings; + +@fragment +fn fragment(in: FullscreenVertexOutput) -> @location(0) vec4 { + // Chromatic aberration strength + let offset_strength = settings.intensity; + + // Sample each color channel with an arbitrary shift + return vec4( + textureSample(screen_texture, texture_sampler, in.uv + vec2(offset_strength, -offset_strength)).r, + textureSample(screen_texture, texture_sampler, in.uv + vec2(-offset_strength, 0.0)).g, + textureSample(screen_texture, texture_sampler, in.uv + vec2(0.0, offset_strength)).b, + 1.0 + ); +} + diff --git a/crates/bevy_core_pipeline/src/bloom/mod.rs b/crates/bevy_core_pipeline/src/bloom/mod.rs index 5ff8cfba4e376..8b5c58d80210c 100644 --- a/crates/bevy_core_pipeline/src/bloom/mod.rs +++ b/crates/bevy_core_pipeline/src/bloom/mod.rs @@ -4,7 +4,7 @@ mod upsampling_pipeline; pub use settings::{BloomCompositeMode, BloomPrefilterSettings, BloomSettings}; -use crate::{core_2d, core_3d}; +use crate::{add_node, core_2d, core_3d}; use bevy_app::{App, Plugin}; use bevy_asset::{load_internal_asset, HandleUntyped}; use bevy_ecs::{ @@ -12,7 +12,7 @@ use bevy_ecs::{ query::{QueryState, With}, schedule::IntoSystemConfig, system::{Commands, Query, Res, ResMut}, - world::World, + world::{FromWorld, World}, }; use bevy_math::UVec2; use bevy_reflect::TypeUuid; @@ -22,7 +22,7 @@ use bevy_render::{ ComponentUniforms, DynamicUniformIndex, ExtractComponentPlugin, UniformComponentPlugin, }, prelude::Color, - render_graph::{Node, NodeRunError, RenderGraph, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_resource::*, renderer::{RenderContext, RenderDevice}, texture::{CachedTexture, TextureCache}, @@ -79,54 +79,28 @@ impl Plugin for BloomPlugin { )); // Add bloom to the 3d render graph - { - let bloom_node = BloomNode::new(&mut render_app.world); - let mut graph = render_app.world.resource_mut::(); - let draw_3d_graph = graph - .get_sub_graph_mut(crate::core_3d::graph::NAME) - .unwrap(); - draw_3d_graph.add_node(core_3d::graph::node::BLOOM, bloom_node); - draw_3d_graph.add_slot_edge( - draw_3d_graph.input_node().id, - crate::core_3d::graph::input::VIEW_ENTITY, - core_3d::graph::node::BLOOM, - BloomNode::IN_VIEW, - ); - // MAIN_PASS -> BLOOM -> TONEMAPPING - draw_3d_graph.add_node_edge( - crate::core_3d::graph::node::MAIN_PASS, - core_3d::graph::node::BLOOM, - ); - draw_3d_graph.add_node_edge( + add_node::( + render_app, + core_3d::graph::NAME, + core_3d::graph::node::BLOOM, + &[ + core_3d::graph::node::MAIN_PASS, core_3d::graph::node::BLOOM, - crate::core_3d::graph::node::TONEMAPPING, - ); - } + core_3d::graph::node::TONEMAPPING, + ], + ); // Add bloom to the 2d render graph - { - let bloom_node = BloomNode::new(&mut render_app.world); - let mut graph = render_app.world.resource_mut::(); - let draw_2d_graph = graph - .get_sub_graph_mut(crate::core_2d::graph::NAME) - .unwrap(); - draw_2d_graph.add_node(core_2d::graph::node::BLOOM, bloom_node); - draw_2d_graph.add_slot_edge( - draw_2d_graph.input_node().id, - crate::core_2d::graph::input::VIEW_ENTITY, + add_node::( + render_app, + core_2d::graph::NAME, + core_2d::graph::node::BLOOM, + &[ + core_2d::graph::node::MAIN_PASS, core_2d::graph::node::BLOOM, - BloomNode::IN_VIEW, - ); - // MAIN_PASS -> BLOOM -> TONEMAPPING - draw_2d_graph.add_node_edge( - crate::core_2d::graph::node::MAIN_PASS, - core_2d::graph::node::BLOOM, - ); - draw_2d_graph.add_node_edge( - core_2d::graph::node::BLOOM, - crate::core_2d::graph::node::TONEMAPPING, - ); - } + core_2d::graph::node::TONEMAPPING, + ], + ); } } @@ -143,10 +117,8 @@ pub struct BloomNode { )>, } -impl BloomNode { - pub const IN_VIEW: &'static str = "view"; - - pub fn new(world: &mut World) -> Self { +impl FromWorld for BloomNode { + fn from_world(world: &mut World) -> Self { Self { view_query: QueryState::new(world), } @@ -154,10 +126,6 @@ impl BloomNode { } impl Node for BloomNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(Self::IN_VIEW, SlotType::Entity)] - } - fn update(&mut self, world: &mut World) { self.view_query.update_archetypes(world); } @@ -177,7 +145,7 @@ impl Node for BloomNode { let downsampling_pipeline_res = world.resource::(); let pipeline_cache = world.resource::(); let uniforms = world.resource::>(); - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let Ok(( camera, view_target, diff --git a/crates/bevy_core_pipeline/src/core_2d/main_pass_2d_node.rs b/crates/bevy_core_pipeline/src/core_2d/main_pass_2d_node.rs index b5660c4c0aa58..f095cf237ffed 100644 --- a/crates/bevy_core_pipeline/src/core_2d/main_pass_2d_node.rs +++ b/crates/bevy_core_pipeline/src/core_2d/main_pass_2d_node.rs @@ -5,7 +5,7 @@ use crate::{ use bevy_ecs::prelude::*; use bevy_render::{ camera::ExtractedCamera, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_phase::RenderPhase, render_resource::{LoadOp, Operations, RenderPassDescriptor}, renderer::RenderContext, @@ -37,10 +37,6 @@ impl MainPass2dNode { } impl Node for MainPass2dNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(MainPass2dNode::IN_VIEW, SlotType::Entity)] - } - fn update(&mut self, world: &mut World) { self.query.update_archetypes(world); } @@ -51,7 +47,7 @@ impl Node for MainPass2dNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let (camera, transparent_phase, target, camera_2d) = if let Ok(result) = self.query.get_manual(world, view_entity) { result diff --git a/crates/bevy_core_pipeline/src/core_2d/mod.rs b/crates/bevy_core_pipeline/src/core_2d/mod.rs index a94da4d0b3651..6df5a9b32d258 100644 --- a/crates/bevy_core_pipeline/src/core_2d/mod.rs +++ b/crates/bevy_core_pipeline/src/core_2d/mod.rs @@ -25,7 +25,7 @@ use bevy_ecs::prelude::*; use bevy_render::{ camera::Camera, extract_component::ExtractComponentPlugin, - render_graph::{EmptyNode, RenderGraph, SlotInfo, SlotType}, + render_graph::{EmptyNode, RenderGraph}, render_phase::{ batch_phase_system, sort_phase_system, BatchedPhaseItem, CachedRenderPipelinePhaseItem, DrawFunctionId, DrawFunctions, PhaseItem, RenderPhase, @@ -61,46 +61,33 @@ impl Plugin for Core2dPlugin { )); let pass_node_2d = MainPass2dNode::new(&mut render_app.world); - let tonemapping = TonemappingNode::new(&mut render_app.world); - let upscaling = UpscalingNode::new(&mut render_app.world); + let tonemapping = TonemappingNode::from_world(&mut render_app.world); + let upscaling = UpscalingNode::from_world(&mut render_app.world); let mut graph = render_app.world.resource_mut::(); let mut draw_2d_graph = RenderGraph::default(); - draw_2d_graph.add_node(graph::node::MAIN_PASS, pass_node_2d); - draw_2d_graph.add_node(graph::node::TONEMAPPING, tonemapping); + + draw_2d_graph.add_node_with_edges(graph::node::MAIN_PASS, pass_node_2d, &[]); draw_2d_graph.add_node(graph::node::END_MAIN_PASS_POST_PROCESSING, EmptyNode); - draw_2d_graph.add_node(graph::node::UPSCALING, upscaling); - let input_node_id = draw_2d_graph.set_input(vec![SlotInfo::new( - graph::input::VIEW_ENTITY, - SlotType::Entity, - )]); - draw_2d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, - graph::node::MAIN_PASS, - MainPass2dNode::IN_VIEW, - ); - draw_2d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, - graph::node::TONEMAPPING, - TonemappingNode::IN_VIEW, - ); - draw_2d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, - graph::node::UPSCALING, - UpscalingNode::IN_VIEW, - ); - draw_2d_graph.add_node_edge(graph::node::MAIN_PASS, graph::node::TONEMAPPING); - draw_2d_graph.add_node_edge( + + draw_2d_graph.add_node_with_edges( graph::node::TONEMAPPING, - graph::node::END_MAIN_PASS_POST_PROCESSING, + tonemapping, + &[ + graph::node::MAIN_PASS, + graph::node::TONEMAPPING, + graph::node::END_MAIN_PASS_POST_PROCESSING, + ], ); - draw_2d_graph.add_node_edge( - graph::node::END_MAIN_PASS_POST_PROCESSING, + draw_2d_graph.add_node_with_edges( graph::node::UPSCALING, + upscaling, + &[ + graph::node::END_MAIN_PASS_POST_PROCESSING, + graph::node::UPSCALING, + ], ); + graph.add_sub_graph(graph::NAME, draw_2d_graph); } } diff --git a/crates/bevy_core_pipeline/src/core_3d/main_pass_3d_node.rs b/crates/bevy_core_pipeline/src/core_3d/main_pass_3d_node.rs index 5003fbfd538f1..8ea24327fa5ec 100644 --- a/crates/bevy_core_pipeline/src/core_3d/main_pass_3d_node.rs +++ b/crates/bevy_core_pipeline/src/core_3d/main_pass_3d_node.rs @@ -6,7 +6,7 @@ use crate::{ use bevy_ecs::prelude::*; use bevy_render::{ camera::ExtractedCamera, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_phase::RenderPhase, render_resource::{LoadOp, Operations, RenderPassDepthStencilAttachment, RenderPassDescriptor}, renderer::RenderContext, @@ -34,21 +34,19 @@ pub struct MainPass3dNode { >, } -impl MainPass3dNode { - pub const IN_VIEW: &'static str = "view"; - - pub fn new(world: &mut World) -> Self { +impl FromWorld for MainPass3dNode { + fn from_world(world: &mut World) -> Self { Self { query: world.query_filtered(), } } } -impl Node for MainPass3dNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(MainPass3dNode::IN_VIEW, SlotType::Entity)] - } +impl MainPass3dNode { + pub const IN_VIEW: &'static str = "view"; +} +impl Node for MainPass3dNode { fn update(&mut self, world: &mut World) { self.query.update_archetypes(world); } @@ -59,7 +57,7 @@ impl Node for MainPass3dNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let Ok(( camera, opaque_phase, @@ -71,6 +69,7 @@ impl Node for MainPass3dNode { depth_prepass, normal_prepass, )) = self.query.get_manual(world, view_entity) else { + bevy_utils::tracing::error!("no view_entity"); // No window return Ok(()); }; diff --git a/crates/bevy_core_pipeline/src/core_3d/mod.rs b/crates/bevy_core_pipeline/src/core_3d/mod.rs index 166f6765e0ea2..08421dc97e0fb 100644 --- a/crates/bevy_core_pipeline/src/core_3d/mod.rs +++ b/crates/bevy_core_pipeline/src/core_3d/mod.rs @@ -29,7 +29,7 @@ use bevy_render::{ camera::{Camera, ExtractedCamera}, extract_component::ExtractComponentPlugin, prelude::Msaa, - render_graph::{EmptyNode, RenderGraph, SlotInfo, SlotType}, + render_graph::{EmptyNode, RenderGraph}, render_phase::{ sort_phase_system, CachedRenderPipelinePhaseItem, DrawFunctionId, DrawFunctions, PhaseItem, RenderPhase, @@ -78,57 +78,40 @@ impl Plugin for Core3dPlugin { sort_phase_system::.in_set(RenderSet::PhaseSort), )); - let prepass_node = PrepassNode::new(&mut render_app.world); - let pass_node_3d = MainPass3dNode::new(&mut render_app.world); - let tonemapping = TonemappingNode::new(&mut render_app.world); - let upscaling = UpscalingNode::new(&mut render_app.world); - let mut graph = render_app.world.resource_mut::(); + let prepass_node = PrepassNode::from_world(&mut render_app.world); + let pass_node_3d = MainPass3dNode::from_world(&mut render_app.world); + let tonemapping = TonemappingNode::from_world(&mut render_app.world); + let upscaling = UpscalingNode::from_world(&mut render_app.world); + let mut graph = render_app.world.resource_mut::(); let mut draw_3d_graph = RenderGraph::default(); - draw_3d_graph.add_node(graph::node::PREPASS, prepass_node); - draw_3d_graph.add_node(graph::node::MAIN_PASS, pass_node_3d); - draw_3d_graph.add_node(graph::node::TONEMAPPING, tonemapping); + + draw_3d_graph.add_node_with_edges(graph::node::MAIN_PASS, pass_node_3d, &[]); draw_3d_graph.add_node(graph::node::END_MAIN_PASS_POST_PROCESSING, EmptyNode); - draw_3d_graph.add_node(graph::node::UPSCALING, upscaling); - - let input_node_id = draw_3d_graph.set_input(vec![SlotInfo::new( - graph::input::VIEW_ENTITY, - SlotType::Entity, - )]); - draw_3d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, + + draw_3d_graph.add_node_with_edges( graph::node::PREPASS, - PrepassNode::IN_VIEW, - ); - draw_3d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, - graph::node::MAIN_PASS, - MainPass3dNode::IN_VIEW, + prepass_node, + &[graph::node::PREPASS, graph::node::MAIN_PASS], ); - draw_3d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, + draw_3d_graph.add_node_with_edges( graph::node::TONEMAPPING, - TonemappingNode::IN_VIEW, + tonemapping, + &[ + graph::node::MAIN_PASS, + graph::node::TONEMAPPING, + graph::node::END_MAIN_PASS_POST_PROCESSING, + ], ); - draw_3d_graph.add_slot_edge( - input_node_id, - graph::input::VIEW_ENTITY, - graph::node::UPSCALING, - UpscalingNode::IN_VIEW, - ); - draw_3d_graph.add_node_edge(graph::node::PREPASS, graph::node::MAIN_PASS); - draw_3d_graph.add_node_edge(graph::node::MAIN_PASS, graph::node::TONEMAPPING); - draw_3d_graph.add_node_edge( - graph::node::TONEMAPPING, - graph::node::END_MAIN_PASS_POST_PROCESSING, - ); - draw_3d_graph.add_node_edge( - graph::node::END_MAIN_PASS_POST_PROCESSING, + draw_3d_graph.add_node_with_edges( graph::node::UPSCALING, + upscaling, + &[ + graph::node::END_MAIN_PASS_POST_PROCESSING, + graph::node::UPSCALING, + ], ); + graph.add_sub_graph(graph::NAME, draw_3d_graph); } } diff --git a/crates/bevy_core_pipeline/src/fxaa/mod.rs b/crates/bevy_core_pipeline/src/fxaa/mod.rs index 0892299f24fc8..03cc19e4094ec 100644 --- a/crates/bevy_core_pipeline/src/fxaa/mod.rs +++ b/crates/bevy_core_pipeline/src/fxaa/mod.rs @@ -1,4 +1,4 @@ -use crate::{core_2d, core_3d, fullscreen_vertex_shader::fullscreen_shader_vertex_state}; +use crate::{add_node, core_2d, core_3d, fullscreen_vertex_shader::fullscreen_shader_vertex_state}; use bevy_app::prelude::*; use bevy_asset::{load_internal_asset, HandleUntyped}; use bevy_derive::Deref; @@ -9,7 +9,6 @@ use bevy_reflect::{ use bevy_render::{ extract_component::{ExtractComponent, ExtractComponentPlugin}, prelude::Camera, - render_graph::RenderGraph, render_resource::*, renderer::RenderDevice, texture::BevyDefault, @@ -92,52 +91,27 @@ impl Plugin for FxaaPlugin { .init_resource::>() .add_system(prepare_fxaa_pipelines.in_set(RenderSet::Prepare)); - { - let fxaa_node = FxaaNode::new(&mut render_app.world); - let mut binding = render_app.world.resource_mut::(); - let graph = binding.get_sub_graph_mut(core_3d::graph::NAME).unwrap(); - - graph.add_node(core_3d::graph::node::FXAA, fxaa_node); - - graph.add_slot_edge( - graph.input_node().id, - core_3d::graph::input::VIEW_ENTITY, - core_3d::graph::node::FXAA, - FxaaNode::IN_VIEW, - ); - - graph.add_node_edge( + add_node::( + render_app, + core_3d::graph::NAME, + core_3d::graph::node::FXAA, + &[ core_3d::graph::node::TONEMAPPING, core_3d::graph::node::FXAA, - ); - graph.add_node_edge( - core_3d::graph::node::FXAA, core_3d::graph::node::END_MAIN_PASS_POST_PROCESSING, - ); - } - { - let fxaa_node = FxaaNode::new(&mut render_app.world); - let mut binding = render_app.world.resource_mut::(); - let graph = binding.get_sub_graph_mut(core_2d::graph::NAME).unwrap(); - - graph.add_node(core_2d::graph::node::FXAA, fxaa_node); - - graph.add_slot_edge( - graph.input_node().id, - core_2d::graph::input::VIEW_ENTITY, - core_2d::graph::node::FXAA, - FxaaNode::IN_VIEW, - ); + ], + ); - graph.add_node_edge( + add_node::( + render_app, + core_2d::graph::NAME, + core_2d::graph::node::FXAA, + &[ core_2d::graph::node::TONEMAPPING, core_2d::graph::node::FXAA, - ); - graph.add_node_edge( - core_2d::graph::node::FXAA, core_2d::graph::node::END_MAIN_PASS_POST_PROCESSING, - ); - } + ], + ); } } diff --git a/crates/bevy_core_pipeline/src/fxaa/node.rs b/crates/bevy_core_pipeline/src/fxaa/node.rs index 5050e3c4b3920..ef4a4b15975de 100644 --- a/crates/bevy_core_pipeline/src/fxaa/node.rs +++ b/crates/bevy_core_pipeline/src/fxaa/node.rs @@ -4,7 +4,7 @@ use crate::fxaa::{CameraFxaaPipeline, Fxaa, FxaaPipeline}; use bevy_ecs::prelude::*; use bevy_ecs::query::QueryState; use bevy_render::{ - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_resource::{ BindGroup, BindGroupDescriptor, BindGroupEntry, BindingResource, FilterMode, Operations, PipelineCache, RenderPassColorAttachment, RenderPassDescriptor, SamplerDescriptor, @@ -29,8 +29,10 @@ pub struct FxaaNode { impl FxaaNode { pub const IN_VIEW: &'static str = "view"; +} - pub fn new(world: &mut World) -> Self { +impl FromWorld for FxaaNode { + fn from_world(world: &mut World) -> Self { Self { query: QueryState::new(world), cached_texture_bind_group: Mutex::new(None), @@ -39,10 +41,6 @@ impl FxaaNode { } impl Node for FxaaNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(FxaaNode::IN_VIEW, SlotType::Entity)] - } - fn update(&mut self, world: &mut World) { self.query.update_archetypes(world); } @@ -53,7 +51,7 @@ impl Node for FxaaNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let pipeline_cache = world.resource::(); let fxaa_pipeline = world.resource::(); diff --git a/crates/bevy_core_pipeline/src/lib.rs b/crates/bevy_core_pipeline/src/lib.rs index 4faa1cd83f7f9..6fb89ac6093fc 100644 --- a/crates/bevy_core_pipeline/src/lib.rs +++ b/crates/bevy_core_pipeline/src/lib.rs @@ -34,7 +34,12 @@ use crate::{ }; use bevy_app::{App, Plugin}; use bevy_asset::load_internal_asset; -use bevy_render::{extract_resource::ExtractResourcePlugin, prelude::Shader}; +use bevy_ecs::world::FromWorld; +use bevy_render::{ + extract_resource::ExtractResourcePlugin, + prelude::Shader, + render_graph::{Node, RenderGraph}, +}; #[derive(Default)] pub struct CorePipelinePlugin; @@ -64,3 +69,20 @@ impl Plugin for CorePipelinePlugin { .add_plugin(FxaaPlugin); } } + +/// Utility function to add a [`Node`] to the [`RenderGraph`] +/// * Create the [`Node`] using the [`FromWorld`] implementation +/// * Add it to the graph +/// * Automatically add the required node edges based on the given ordering +pub fn add_node( + render_app: &mut App, + sub_graph_name: &'static str, + node_name: &'static str, + edges: &[&'static str], +) { + let node = T::from_world(&mut render_app.world); + let mut render_graph = render_app.world.resource_mut::(); + + let graph = render_graph.sub_graph_mut(sub_graph_name); + graph.add_node_with_edges(node_name, node, edges); +} diff --git a/crates/bevy_core_pipeline/src/msaa_writeback.rs b/crates/bevy_core_pipeline/src/msaa_writeback.rs index 2f8122d193d4c..9be7edd8d2fb8 100644 --- a/crates/bevy_core_pipeline/src/msaa_writeback.rs +++ b/crates/bevy_core_pipeline/src/msaa_writeback.rs @@ -3,7 +3,7 @@ use bevy_app::{App, Plugin}; use bevy_ecs::prelude::*; use bevy_render::{ camera::ExtractedCamera, - render_graph::{Node, NodeRunError, RenderGraph, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraph, RenderGraphContext}, renderer::RenderContext, view::{Msaa, ViewTarget}, RenderSet, @@ -25,7 +25,6 @@ impl Plugin for MsaaWritebackPlugin { let msaa_writeback_3d = MsaaWritebackNode::new(&mut render_app.world); let mut graph = render_app.world.resource_mut::(); if let Some(core_2d) = graph.get_sub_graph_mut(crate::core_2d::graph::NAME) { - let input_node = core_2d.input_node().id; core_2d.add_node( crate::core_2d::graph::node::MSAA_WRITEBACK, msaa_writeback_2d, @@ -34,16 +33,9 @@ impl Plugin for MsaaWritebackPlugin { crate::core_2d::graph::node::MSAA_WRITEBACK, crate::core_2d::graph::node::MAIN_PASS, ); - core_2d.add_slot_edge( - input_node, - crate::core_2d::graph::input::VIEW_ENTITY, - crate::core_2d::graph::node::MSAA_WRITEBACK, - MsaaWritebackNode::IN_VIEW, - ); } if let Some(core_3d) = graph.get_sub_graph_mut(crate::core_3d::graph::NAME) { - let input_node = core_3d.input_node().id; core_3d.add_node( crate::core_3d::graph::node::MSAA_WRITEBACK, msaa_writeback_3d, @@ -52,12 +44,6 @@ impl Plugin for MsaaWritebackPlugin { crate::core_3d::graph::node::MSAA_WRITEBACK, crate::core_3d::graph::node::MAIN_PASS, ); - core_3d.add_slot_edge( - input_node, - crate::core_3d::graph::input::VIEW_ENTITY, - crate::core_3d::graph::node::MSAA_WRITEBACK, - MsaaWritebackNode::IN_VIEW, - ); } } } @@ -77,9 +63,6 @@ impl MsaaWritebackNode { } impl Node for MsaaWritebackNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(Self::IN_VIEW, SlotType::Entity)] - } fn update(&mut self, world: &mut World) { self.cameras.update_archetypes(world); } @@ -89,7 +72,7 @@ impl Node for MsaaWritebackNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); if let Ok((target, blit_pipeline_id)) = self.cameras.get_manual(world, view_entity) { let blit_pipeline = world.resource::(); let pipeline_cache = world.resource::(); diff --git a/crates/bevy_core_pipeline/src/prepass/node.rs b/crates/bevy_core_pipeline/src/prepass/node.rs index 2687b925c00f5..86fd6726b560b 100644 --- a/crates/bevy_core_pipeline/src/prepass/node.rs +++ b/crates/bevy_core_pipeline/src/prepass/node.rs @@ -3,7 +3,7 @@ use bevy_ecs::query::QueryState; use bevy_render::{ camera::ExtractedCamera, prelude::Color, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_phase::RenderPhase, render_resource::{ LoadOp, Operations, RenderPassColorAttachment, RenderPassDepthStencilAttachment, @@ -33,21 +33,19 @@ pub struct PrepassNode { >, } -impl PrepassNode { - pub const IN_VIEW: &'static str = "view"; - - pub fn new(world: &mut World) -> Self { +impl FromWorld for PrepassNode { + fn from_world(world: &mut World) -> Self { Self { main_view_query: QueryState::new(world), } } } -impl Node for PrepassNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(Self::IN_VIEW, SlotType::Entity)] - } +impl PrepassNode { + pub const IN_VIEW: &'static str = "view"; +} +impl Node for PrepassNode { fn update(&mut self, world: &mut World) { self.main_view_query.update_archetypes(world); } @@ -58,7 +56,7 @@ impl Node for PrepassNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let Ok(( camera, opaque_prepass_phase, diff --git a/crates/bevy_core_pipeline/src/tonemapping/node.rs b/crates/bevy_core_pipeline/src/tonemapping/node.rs index 357822fc73c6e..62397a81df8cf 100644 --- a/crates/bevy_core_pipeline/src/tonemapping/node.rs +++ b/crates/bevy_core_pipeline/src/tonemapping/node.rs @@ -6,7 +6,7 @@ use bevy_ecs::prelude::*; use bevy_ecs::query::QueryState; use bevy_render::{ render_asset::RenderAssets, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_resource::{ BindGroup, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferId, LoadOp, Operations, PipelineCache, RenderPassColorAttachment, RenderPassDescriptor, @@ -33,10 +33,8 @@ pub struct TonemappingNode { last_tonemapping: Mutex>, } -impl TonemappingNode { - pub const IN_VIEW: &'static str = "view"; - - pub fn new(world: &mut World) -> Self { +impl FromWorld for TonemappingNode { + fn from_world(world: &mut World) -> Self { Self { query: QueryState::new(world), cached_bind_group: Mutex::new(None), @@ -45,11 +43,11 @@ impl TonemappingNode { } } -impl Node for TonemappingNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(TonemappingNode::IN_VIEW, SlotType::Entity)] - } +impl TonemappingNode { + pub const IN_VIEW: &'static str = "view"; +} +impl Node for TonemappingNode { fn update(&mut self, world: &mut World) { self.query.update_archetypes(world); } @@ -60,7 +58,7 @@ impl Node for TonemappingNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let pipeline_cache = world.resource::(); let tonemapping_pipeline = world.resource::(); let gpu_images = world.get_resource::>().unwrap(); diff --git a/crates/bevy_core_pipeline/src/upscaling/node.rs b/crates/bevy_core_pipeline/src/upscaling/node.rs index 8e66f1eb07f06..79eb784df53f1 100644 --- a/crates/bevy_core_pipeline/src/upscaling/node.rs +++ b/crates/bevy_core_pipeline/src/upscaling/node.rs @@ -3,7 +3,7 @@ use bevy_ecs::prelude::*; use bevy_ecs::query::QueryState; use bevy_render::{ camera::{CameraOutputMode, ExtractedCamera}, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_resource::{ BindGroup, BindGroupDescriptor, BindGroupEntry, BindingResource, LoadOp, Operations, PipelineCache, RenderPassColorAttachment, RenderPassDescriptor, SamplerDescriptor, @@ -26,10 +26,8 @@ pub struct UpscalingNode { cached_texture_bind_group: Mutex>, } -impl UpscalingNode { - pub const IN_VIEW: &'static str = "view"; - - pub fn new(world: &mut World) -> Self { +impl FromWorld for UpscalingNode { + fn from_world(world: &mut World) -> Self { Self { query: QueryState::new(world), cached_texture_bind_group: Mutex::new(None), @@ -37,11 +35,11 @@ impl UpscalingNode { } } -impl Node for UpscalingNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(UpscalingNode::IN_VIEW, SlotType::Entity)] - } +impl UpscalingNode { + pub const IN_VIEW: &'static str = "view"; +} +impl Node for UpscalingNode { fn update(&mut self, world: &mut World) { self.query.update_archetypes(world); } @@ -52,7 +50,7 @@ impl Node for UpscalingNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); let pipeline_cache = world.get_resource::().unwrap(); let blit_pipeline = world.get_resource::().unwrap(); diff --git a/crates/bevy_pbr/src/lib.rs b/crates/bevy_pbr/src/lib.rs index 544984d5c1104..4a6cf637a9aae 100644 --- a/crates/bevy_pbr/src/lib.rs +++ b/crates/bevy_pbr/src/lib.rs @@ -287,11 +287,5 @@ impl Plugin for PbrPlugin { draw_3d_graph::node::SHADOW_PASS, bevy_core_pipeline::core_3d::graph::node::MAIN_PASS, ); - draw_3d_graph.add_slot_edge( - draw_3d_graph.input_node().id, - bevy_core_pipeline::core_3d::graph::input::VIEW_ENTITY, - draw_3d_graph::node::SHADOW_PASS, - ShadowPassNode::IN_VIEW, - ); } } diff --git a/crates/bevy_pbr/src/render/light.rs b/crates/bevy_pbr/src/render/light.rs index cdf1e7ab1548e..42e985fbed42d 100644 --- a/crates/bevy_pbr/src/render/light.rs +++ b/crates/bevy_pbr/src/render/light.rs @@ -15,7 +15,7 @@ use bevy_render::{ color::Color, mesh::Mesh, render_asset::RenderAssets, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType}, + render_graph::{Node, NodeRunError, RenderGraphContext}, render_phase::{ CachedRenderPipelinePhaseItem, DrawFunctionId, DrawFunctions, PhaseItem, RenderPhase, }, @@ -1702,10 +1702,6 @@ impl ShadowPassNode { } impl Node for ShadowPassNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(ShadowPassNode::IN_VIEW, SlotType::Entity)] - } - fn update(&mut self, world: &mut World) { self.main_view_query.update_archetypes(world); self.view_light_query.update_archetypes(world); @@ -1717,7 +1713,7 @@ impl Node for ShadowPassNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let view_entity = graph.view_entity(); if let Ok(view_lights) = self.main_view_query.get_manual(world, view_entity) { for view_light_entity in view_lights.lights.iter().copied() { let (view_light, shadow_phase) = self diff --git a/crates/bevy_render/Cargo.toml b/crates/bevy_render/Cargo.toml index a42ef073b6231..6e262a35a2ebd 100644 --- a/crates/bevy_render/Cargo.toml +++ b/crates/bevy_render/Cargo.toml @@ -39,7 +39,9 @@ bevy_hierarchy = { path = "../bevy_hierarchy", version = "0.11.0-dev" } bevy_log = { path = "../bevy_log", version = "0.11.0-dev" } bevy_math = { path = "../bevy_math", version = "0.11.0-dev" } bevy_mikktspace = { path = "../bevy_mikktspace", version = "0.11.0-dev" } -bevy_reflect = { path = "../bevy_reflect", version = "0.11.0-dev", features = ["bevy"] } +bevy_reflect = { path = "../bevy_reflect", version = "0.11.0-dev", features = [ + "bevy", +] } bevy_render_macros = { path = "macros", version = "0.11.0-dev" } bevy_time = { path = "../bevy_time", version = "0.11.0-dev" } bevy_transform = { path = "../bevy_transform", version = "0.11.0-dev" } @@ -54,12 +56,17 @@ image = { version = "0.24", default-features = false } wgpu = { version = "0.15.0", features = ["spirv"] } wgpu-hal = "0.15.1" codespan-reporting = "0.11.0" -naga = { version = "0.11.0", features = ["glsl-in", "spv-in", "spv-out", "wgsl-in", "wgsl-out"] } +naga = { version = "0.11.0", features = [ + "glsl-in", + "spv-in", + "spv-out", + "wgsl-in", + "wgsl-out", +] } serde = { version = "1", features = ["derive"] } bitflags = "1.2.1" smallvec = { version = "1.6", features = ["union", "const_generics"] } once_cell = "1.4.1" # TODO: replace once_cell with std equivalent if/when this lands: https://github.com/rust-lang/rfcs/pull/2788 -downcast-rs = "1.2.0" thread_local = "1.1" thiserror = "1.0" futures-lite = "1.4.0" @@ -76,5 +83,7 @@ ruzstd = { version = "0.2.4", optional = true } basis-universal = { version = "0.2.0", optional = true } encase = { version = "0.5", features = ["glam"] } # For wgpu profiling using tracing. Use `RUST_LOG=info` to also capture the wgpu spans. -profiling = { version = "1", features = ["profile-with-tracing"], optional = true } +profiling = { version = "1", features = [ + "profile-with-tracing", +], optional = true } async-channel = "1.8" diff --git a/crates/bevy_render/src/camera/camera_driver_node.rs b/crates/bevy_render/src/camera/camera_driver_node.rs index dcee0cba45045..4bc07a1a45e14 100644 --- a/crates/bevy_render/src/camera/camera_driver_node.rs +++ b/crates/bevy_render/src/camera/camera_driver_node.rs @@ -1,6 +1,6 @@ use crate::{ camera::{ExtractedCamera, NormalizedRenderTarget, SortedCameras}, - render_graph::{Node, NodeRunError, RenderGraphContext, SlotValue}, + render_graph::{Node, NodeRunError, RenderGraphContext}, renderer::RenderContext, view::ExtractedWindows, }; @@ -24,6 +24,7 @@ impl Node for CameraDriverNode { fn update(&mut self, world: &mut World) { self.cameras.update_archetypes(world); } + fn run( &self, graph: &mut RenderGraphContext, @@ -37,10 +38,7 @@ impl Node for CameraDriverNode { if let Some(NormalizedRenderTarget::Window(window_ref)) = camera.target { camera_windows.insert(window_ref.entity()); } - graph.run_sub_graph( - camera.render_graph.clone(), - vec![SlotValue::Entity(sorted_camera.entity)], - )?; + graph.run_sub_graph(camera.render_graph.clone(), sorted_camera.entity)?; } } diff --git a/crates/bevy_render/src/render_graph/context.rs b/crates/bevy_render/src/render_graph/context.rs index 6a4ee01afa0ee..ea7d24e2a6d77 100644 --- a/crates/bevy_render/src/render_graph/context.rs +++ b/crates/bevy_render/src/render_graph/context.rs @@ -1,7 +1,3 @@ -use crate::{ - render_graph::{NodeState, RenderGraph, SlotInfos, SlotLabel, SlotType, SlotValue}, - render_resource::{Buffer, Sampler, TextureView}, -}; use bevy_ecs::entity::Entity; use std::borrow::Cow; use thiserror::Error; @@ -10,7 +6,7 @@ use thiserror::Error; /// with the specified `inputs` next. pub struct RunSubGraph { pub name: Cow<'static, str>, - pub inputs: Vec, + pub view_entity: Entity, } /// The context with all graph information required to run a [`Node`](super::Node). @@ -21,180 +17,40 @@ pub struct RunSubGraph { /// /// Sub graphs can be queued for running by adding a [`RunSubGraph`] command to the context. /// After the node has finished running the graph runner is responsible for executing the sub graphs. -pub struct RenderGraphContext<'a> { - graph: &'a RenderGraph, - node: &'a NodeState, - inputs: &'a [SlotValue], - outputs: &'a mut [Option], +pub struct RenderGraphContext { run_sub_graphs: Vec, + view_entity: Option, } -impl<'a> RenderGraphContext<'a> { +impl RenderGraphContext { /// Creates a new render graph context for the `node`. - pub fn new( - graph: &'a RenderGraph, - node: &'a NodeState, - inputs: &'a [SlotValue], - outputs: &'a mut [Option], - ) -> Self { + pub fn new() -> Self { Self { - graph, - node, - inputs, - outputs, run_sub_graphs: Vec::new(), + view_entity: None, } } - /// Returns the input slot values for the node. - #[inline] - pub fn inputs(&self) -> &[SlotValue] { - self.inputs + pub fn view_entity(&self) -> Entity { + self.view_entity.unwrap() } - /// Returns the [`SlotInfos`] of the inputs. - pub fn input_info(&self) -> &SlotInfos { - &self.node.input_slots + pub fn get_view_entity(&self) -> Option { + self.view_entity } - /// Returns the [`SlotInfos`] of the outputs. - pub fn output_info(&self) -> &SlotInfos { - &self.node.output_slots - } - - /// Retrieves the input slot value referenced by the `label`. - pub fn get_input(&self, label: impl Into) -> Result<&SlotValue, InputSlotError> { - let label = label.into(); - let index = self - .input_info() - .get_slot_index(label.clone()) - .ok_or(InputSlotError::InvalidSlot(label))?; - Ok(&self.inputs[index]) - } - - // TODO: should this return an Arc or a reference? - /// Retrieves the input slot value referenced by the `label` as a [`TextureView`]. - pub fn get_input_texture( - &self, - label: impl Into, - ) -> Result<&TextureView, InputSlotError> { - let label = label.into(); - match self.get_input(label.clone())? { - SlotValue::TextureView(value) => Ok(value), - value => Err(InputSlotError::MismatchedSlotType { - label, - actual: value.slot_type(), - expected: SlotType::TextureView, - }), - } - } - - /// Retrieves the input slot value referenced by the `label` as a [`Sampler`]. - pub fn get_input_sampler( - &self, - label: impl Into, - ) -> Result<&Sampler, InputSlotError> { - let label = label.into(); - match self.get_input(label.clone())? { - SlotValue::Sampler(value) => Ok(value), - value => Err(InputSlotError::MismatchedSlotType { - label, - actual: value.slot_type(), - expected: SlotType::Sampler, - }), - } - } - - /// Retrieves the input slot value referenced by the `label` as a [`Buffer`]. - pub fn get_input_buffer(&self, label: impl Into) -> Result<&Buffer, InputSlotError> { - let label = label.into(); - match self.get_input(label.clone())? { - SlotValue::Buffer(value) => Ok(value), - value => Err(InputSlotError::MismatchedSlotType { - label, - actual: value.slot_type(), - expected: SlotType::Buffer, - }), - } - } - - /// Retrieves the input slot value referenced by the `label` as an [`Entity`]. - pub fn get_input_entity(&self, label: impl Into) -> Result { - let label = label.into(); - match self.get_input(label.clone())? { - SlotValue::Entity(value) => Ok(*value), - value => Err(InputSlotError::MismatchedSlotType { - label, - actual: value.slot_type(), - expected: SlotType::Entity, - }), - } - } - - /// Sets the output slot value referenced by the `label`. - pub fn set_output( - &mut self, - label: impl Into, - value: impl Into, - ) -> Result<(), OutputSlotError> { - let label = label.into(); - let value = value.into(); - let slot_index = self - .output_info() - .get_slot_index(label.clone()) - .ok_or_else(|| OutputSlotError::InvalidSlot(label.clone()))?; - let slot = self - .output_info() - .get_slot(slot_index) - .expect("slot is valid"); - if value.slot_type() != slot.slot_type { - return Err(OutputSlotError::MismatchedSlotType { - label, - actual: slot.slot_type, - expected: value.slot_type(), - }); - } - self.outputs[slot_index] = Some(value); - Ok(()) + pub fn set_view_entity(&mut self, view_entity: Entity) { + self.view_entity = Some(view_entity); } /// Queues up a sub graph for execution after the node has finished running. pub fn run_sub_graph( &mut self, name: impl Into>, - inputs: Vec, + view_entity: Entity, ) -> Result<(), RunSubGraphError> { let name = name.into(); - let sub_graph = self - .graph - .get_sub_graph(&name) - .ok_or_else(|| RunSubGraphError::MissingSubGraph(name.clone()))?; - if let Some(input_node) = sub_graph.get_input_node() { - for (i, input_slot) in input_node.input_slots.iter().enumerate() { - if let Some(input_value) = inputs.get(i) { - if input_slot.slot_type != input_value.slot_type() { - return Err(RunSubGraphError::MismatchedInputSlotType { - graph_name: name, - slot_index: i, - actual: input_value.slot_type(), - expected: input_slot.slot_type, - label: input_slot.name.clone().into(), - }); - } - } else { - return Err(RunSubGraphError::MissingInput { - slot_index: i, - slot_name: input_slot.name.clone(), - graph_name: name, - }); - } - } - } else if !inputs.is_empty() { - return Err(RunSubGraphError::SubGraphHasNoInputs(name)); - } - - self.run_sub_graphs.push(RunSubGraph { name, inputs }); - + self.run_sub_graphs.push(RunSubGraph { name, view_entity }); Ok(()) } @@ -205,48 +61,14 @@ impl<'a> RenderGraphContext<'a> { } } +impl Default for RenderGraphContext { + fn default() -> Self { + Self::new() + } +} + #[derive(Error, Debug, Eq, PartialEq)] pub enum RunSubGraphError { #[error("attempted to run sub-graph `{0}`, but it does not exist")] MissingSubGraph(Cow<'static, str>), - #[error("attempted to pass inputs to sub-graph `{0}`, which has no input slots")] - SubGraphHasNoInputs(Cow<'static, str>), - #[error("sub graph (name: `{graph_name:?}`) could not be run because slot `{slot_name}` at index {slot_index} has no value")] - MissingInput { - slot_index: usize, - slot_name: Cow<'static, str>, - graph_name: Cow<'static, str>, - }, - #[error("attempted to use the wrong type for input slot")] - MismatchedInputSlotType { - graph_name: Cow<'static, str>, - slot_index: usize, - label: SlotLabel, - expected: SlotType, - actual: SlotType, - }, -} - -#[derive(Error, Debug, Eq, PartialEq)] -pub enum OutputSlotError { - #[error("output slot `{0:?}` does not exist")] - InvalidSlot(SlotLabel), - #[error("attempted to output a value of type `{actual}` to output slot `{label:?}`, which has type `{expected}`")] - MismatchedSlotType { - label: SlotLabel, - expected: SlotType, - actual: SlotType, - }, -} - -#[derive(Error, Debug, Eq, PartialEq)] -pub enum InputSlotError { - #[error("input slot `{0:?}` does not exist")] - InvalidSlot(SlotLabel), - #[error("attempted to retrieve a value of type `{actual}` from input slot `{label:?}`, which has type `{expected}`")] - MismatchedSlotType { - label: SlotLabel, - expected: SlotType, - actual: SlotType, - }, } diff --git a/crates/bevy_render/src/render_graph/edge.rs b/crates/bevy_render/src/render_graph/edge.rs index 88bfe24f9c913..8c80b74888948 100644 --- a/crates/bevy_render/src/render_graph/edge.rs +++ b/crates/bevy_render/src/render_graph/edge.rs @@ -14,37 +14,10 @@ use super::NodeId; /// with an input slot of the `input_node` to pass additional data along. /// For more information see [`SlotType`](super::SlotType). #[derive(Clone, Debug, Eq, PartialEq)] -pub enum Edge { - /// An edge describing to ordering of both nodes (`output_node` before `input_node`) - /// and connecting the output slot at the `output_index` of the output_node - /// with the slot at the `input_index` of the `input_node`. - SlotEdge { - input_node: NodeId, - input_index: usize, - output_node: NodeId, - output_index: usize, - }, +pub struct Edge { /// An edge describing to ordering of both nodes (`output_node` before `input_node`). - NodeEdge { - input_node: NodeId, - output_node: NodeId, - }, -} - -impl Edge { - /// Returns the id of the `input_node`. - pub fn get_input_node(&self) -> NodeId { - match self { - Edge::SlotEdge { input_node, .. } | Edge::NodeEdge { input_node, .. } => *input_node, - } - } - - /// Returns the id of the `output_node`. - pub fn get_output_node(&self) -> NodeId { - match self { - Edge::SlotEdge { output_node, .. } | Edge::NodeEdge { output_node, .. } => *output_node, - } - } + pub input_node: NodeId, + pub output_node: NodeId, } #[derive(PartialEq, Eq)] diff --git a/crates/bevy_render/src/render_graph/graph.rs b/crates/bevy_render/src/render_graph/graph.rs index e034f345f4e67..3966b2aabcf1b 100644 --- a/crates/bevy_render/src/render_graph/graph.rs +++ b/crates/bevy_render/src/render_graph/graph.rs @@ -1,7 +1,7 @@ use crate::{ render_graph::{ Edge, Node, NodeId, NodeLabel, NodeRunError, NodeState, RenderGraphContext, - RenderGraphError, SlotInfo, SlotLabel, + RenderGraphError, }, renderer::RenderContext, }; @@ -53,7 +53,6 @@ pub struct RenderGraph { nodes: HashMap, node_names: HashMap, NodeId>, sub_graphs: HashMap, RenderGraph>, - input_node: Option, } impl RenderGraph { @@ -71,39 +70,6 @@ impl RenderGraph { } } - /// Creates an [`GraphInputNode`] with the specified slots if not already present. - pub fn set_input(&mut self, inputs: Vec) -> NodeId { - assert!(self.input_node.is_none(), "Graph already has an input node"); - - let id = self.add_node("GraphInputNode", GraphInputNode { inputs }); - self.input_node = Some(id); - id - } - - /// Returns the [`NodeState`] of the input node of this graph. - /// - /// # See also - /// - /// - [`input_node`](Self::input_node) for an unchecked version. - #[inline] - pub fn get_input_node(&self) -> Option<&NodeState> { - self.input_node.and_then(|id| self.get_node_state(id).ok()) - } - - /// Returns the [`NodeState`] of the input node of this graph. - /// - /// # Panics - /// - /// Panics if there is no input node set. - /// - /// # See also - /// - /// - [`get_input_node`](Self::get_input_node) for a version which returns an [`Option`] instead. - #[inline] - pub fn input_node(&self) -> &NodeState { - self.get_input_node().unwrap() - } - /// Adds the `node` with the `name` to the graph. /// If the name is already present replaces it instead. pub fn add_node(&mut self, name: impl Into>, node: T) -> NodeId @@ -119,6 +85,26 @@ impl RenderGraph { id } + /// Adds the `node` with the `name` to the graph. + /// If the name is already present replaces it instead. + /// Also adds `node_edges` based on the order of the given `edges`. + pub fn add_node_with_edges( + &mut self, + name: impl Into>, + node: T, + edges: &[&'static str], + ) -> NodeId + where + T: Node, + { + let id = self.add_node(name, node); + for window in edges.windows(2) { + let [a, b] = window else { break; }; + self.add_node_edge(*a, *b); + } + id + } + /// Removes the `node` with the `name` from the graph. /// If the name is does not exist, nothing happens. pub fn remove_node( @@ -131,36 +117,15 @@ impl RenderGraph { // Remove all edges from other nodes to this one. Note that as we're removing this // node, we don't need to remove its input edges for input_edge in node_state.edges.input_edges().iter() { - match input_edge { - Edge::SlotEdge { output_node, .. } - | Edge::NodeEdge { - input_node: _, - output_node, - } => { - if let Ok(output_node) = self.get_node_state_mut(*output_node) { - output_node.edges.remove_output_edge(input_edge.clone())?; - } - } + if let Ok(output_node) = self.get_node_state_mut(input_edge.output_node) { + output_node.edges.remove_output_edge(input_edge.clone())?; } } // Remove all edges from this node to other nodes. Note that as we're removing this // node, we don't need to remove its output edges for output_edge in node_state.edges.output_edges().iter() { - match output_edge { - Edge::SlotEdge { - output_node: _, - output_index: _, - input_node, - input_index: _, - } - | Edge::NodeEdge { - output_node: _, - input_node, - } => { - if let Ok(input_node) = self.get_node_state_mut(*input_node) { - input_node.edges.remove_input_edge(output_edge.clone())?; - } - } + if let Ok(input_node) = self.get_node_state_mut(output_edge.input_node) { + input_node.edges.remove_input_edge(output_edge.clone())?; } } } @@ -206,140 +171,6 @@ impl RenderGraph { } } - /// Retrieves the [`Node`] referenced by the `label`. - pub fn get_node(&self, label: impl Into) -> Result<&T, RenderGraphError> - where - T: Node, - { - self.get_node_state(label).and_then(|n| n.node()) - } - - /// Retrieves the [`Node`] referenced by the `label` mutably. - pub fn get_node_mut( - &mut self, - label: impl Into, - ) -> Result<&mut T, RenderGraphError> - where - T: Node, - { - self.get_node_state_mut(label).and_then(|n| n.node_mut()) - } - - /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node` - /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`. - /// - /// Fails if any invalid [`NodeLabel`]s or [`SlotLabel`]s are given. - /// - /// # See also - /// - /// - [`add_slot_edge`](Self::add_slot_edge) for an infallible version. - pub fn try_add_slot_edge( - &mut self, - output_node: impl Into, - output_slot: impl Into, - input_node: impl Into, - input_slot: impl Into, - ) -> Result<(), RenderGraphError> { - let output_slot = output_slot.into(); - let input_slot = input_slot.into(); - let output_node_id = self.get_node_id(output_node)?; - let input_node_id = self.get_node_id(input_node)?; - - let output_index = self - .get_node_state(output_node_id)? - .output_slots - .get_slot_index(output_slot.clone()) - .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?; - let input_index = self - .get_node_state(input_node_id)? - .input_slots - .get_slot_index(input_slot.clone()) - .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?; - - let edge = Edge::SlotEdge { - output_node: output_node_id, - output_index, - input_node: input_node_id, - input_index, - }; - - self.validate_edge(&edge, EdgeExistence::DoesNotExist)?; - - { - let output_node = self.get_node_state_mut(output_node_id)?; - output_node.edges.add_output_edge(edge.clone())?; - } - let input_node = self.get_node_state_mut(input_node_id)?; - input_node.edges.add_input_edge(edge)?; - - Ok(()) - } - - /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node` - /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`. - /// - /// # Panics - /// - /// Any invalid [`NodeLabel`]s or [`SlotLabel`]s are given. - /// - /// # See also - /// - /// - [`try_add_slot_edge`](Self::try_add_slot_edge) for a fallible version. - pub fn add_slot_edge( - &mut self, - output_node: impl Into, - output_slot: impl Into, - input_node: impl Into, - input_slot: impl Into, - ) { - self.try_add_slot_edge(output_node, output_slot, input_node, input_slot) - .unwrap(); - } - - /// Removes the [`Edge::SlotEdge`] from the graph. If any nodes or slots do not exist then - /// nothing happens. - pub fn remove_slot_edge( - &mut self, - output_node: impl Into, - output_slot: impl Into, - input_node: impl Into, - input_slot: impl Into, - ) -> Result<(), RenderGraphError> { - let output_slot = output_slot.into(); - let input_slot = input_slot.into(); - let output_node_id = self.get_node_id(output_node)?; - let input_node_id = self.get_node_id(input_node)?; - - let output_index = self - .get_node_state(output_node_id)? - .output_slots - .get_slot_index(output_slot.clone()) - .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?; - let input_index = self - .get_node_state(input_node_id)? - .input_slots - .get_slot_index(input_slot.clone()) - .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?; - - let edge = Edge::SlotEdge { - output_node: output_node_id, - output_index, - input_node: input_node_id, - input_index, - }; - - self.validate_edge(&edge, EdgeExistence::Exists)?; - - { - let output_node = self.get_node_state_mut(output_node_id)?; - output_node.edges.remove_output_edge(edge.clone())?; - } - let input_node = self.get_node_state_mut(input_node_id)?; - input_node.edges.remove_input_edge(edge)?; - - Ok(()) - } - /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node` /// is run before the `input_node`. /// @@ -356,7 +187,7 @@ impl RenderGraph { let output_node_id = self.get_node_id(output_node)?; let input_node_id = self.get_node_id(input_node)?; - let edge = Edge::NodeEdge { + let edge = Edge { output_node: output_node_id, input_node: input_node_id, }; @@ -373,7 +204,7 @@ impl RenderGraph { Ok(()) } - /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node` + /// Adds the [`Edge`] to the graph. This guarantees that the `output_node` /// is run before the `input_node`. /// /// # Panics @@ -391,7 +222,7 @@ impl RenderGraph { self.try_add_node_edge(output_node, input_node).unwrap(); } - /// Removes the [`Edge::NodeEdge`] from the graph. If either node does not exist then nothing + /// Removes the [`Edge`] from the graph. If either node does not exist then nothing /// happens. pub fn remove_node_edge( &mut self, @@ -401,7 +232,7 @@ impl RenderGraph { let output_node_id = self.get_node_id(output_node)?; let input_node_id = self.get_node_id(input_node)?; - let edge = Edge::NodeEdge { + let edge = Edge { output_node: output_node_id, input_node: input_node_id, }; @@ -431,68 +262,13 @@ impl RenderGraph { return Err(RenderGraphError::EdgeAlreadyExists(edge.clone())); } - match *edge { - Edge::SlotEdge { - output_node, - output_index, - input_node, - input_index, - } => { - let output_node_state = self.get_node_state(output_node)?; - let input_node_state = self.get_node_state(input_node)?; - - let output_slot = output_node_state - .output_slots - .get_slot(output_index) - .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index( - output_index, - )))?; - let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or( - RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)), - )?; - - if let Some(Edge::SlotEdge { - output_node: current_output_node, - .. - }) = input_node_state.edges.input_edges().iter().find(|e| { - if let Edge::SlotEdge { - input_index: current_input_index, - .. - } = e - { - input_index == *current_input_index - } else { - false - } - }) { - if should_exist == EdgeExistence::DoesNotExist { - return Err(RenderGraphError::NodeInputSlotAlreadyOccupied { - node: input_node, - input_slot: input_index, - occupied_by_node: *current_output_node, - }); - } - } - - if output_slot.slot_type != input_slot.slot_type { - return Err(RenderGraphError::MismatchedNodeSlots { - output_node, - output_slot: output_index, - input_node, - input_slot: input_index, - }); - } - } - Edge::NodeEdge { .. } => { /* nothing to validate here */ } - } - Ok(()) } /// Checks whether the `edge` already exists in the graph. pub fn has_edge(&self, edge: &Edge) -> bool { - let output_node_state = self.get_node_state(edge.get_output_node()); - let input_node_state = self.get_node_state(edge.get_input_node()); + let output_node_state = self.get_node_state(edge.output_node); + let input_node_state = self.get_node_state(edge.input_node); if let Ok(output_node_state) = output_node_state { if output_node_state.edges.output_edges().contains(edge) { if let Ok(input_node_state) = input_node_state { @@ -541,7 +317,7 @@ impl RenderGraph { .edges .input_edges() .iter() - .map(|edge| (edge, edge.get_output_node())) + .map(|edge| (edge, edge.output_node)) .map(move |(edge, output_node_id)| { (edge, self.get_node_state(output_node_id).unwrap()) })) @@ -558,7 +334,7 @@ impl RenderGraph { .edges .output_edges() .iter() - .map(|edge| (edge, edge.get_input_node())) + .map(|edge| (edge, edge.input_node)) .map(move |(edge, input_node_id)| (edge, self.get_node_state(input_node_id).unwrap()))) } @@ -583,14 +359,22 @@ impl RenderGraph { pub fn get_sub_graph_mut(&mut self, name: impl AsRef) -> Option<&mut RenderGraph> { self.sub_graphs.get_mut(name.as_ref()) } + + /// Retrieves the sub graph corresponding to the `name`. + pub fn sub_graph(&self, name: impl AsRef) -> &RenderGraph { + self.sub_graphs.get(name.as_ref()).unwrap() + } + + /// Retrieves the sub graph corresponding to the `name` mutably. + pub fn sub_graph_mut(&mut self, name: impl AsRef) -> &mut RenderGraph { + self.sub_graphs.get_mut(name.as_ref()).unwrap() + } } impl Debug for RenderGraph { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for node in self.iter_nodes() { - writeln!(f, "{:?}", node.id)?; - writeln!(f, " in: {:?}", node.input_slots)?; - writeln!(f, " out: {:?}", node.output_slots)?; + writeln!(f, "{:?} {:?}", node.id, node.name)?; } Ok(()) @@ -599,29 +383,15 @@ impl Debug for RenderGraph { /// A [`Node`] which acts as an entry point for a [`RenderGraph`] with custom inputs. /// It has the same input and output slots and simply copies them over when run. -pub struct GraphInputNode { - inputs: Vec, -} +pub struct GraphInputNode {} impl Node for GraphInputNode { - fn input(&self) -> Vec { - self.inputs.clone() - } - - fn output(&self) -> Vec { - self.inputs.clone() - } - fn run( &self, - graph: &mut RenderGraphContext, + _graph: &mut RenderGraphContext, _render_context: &mut RenderContext, _world: &World, ) -> Result<(), NodeRunError> { - for i in 0..graph.inputs().len() { - let input = graph.inputs()[i].clone(); - graph.set_output(i, input)?; - } Ok(()) } } @@ -631,41 +401,16 @@ mod tests { use crate::{ render_graph::{ Edge, Node, NodeId, NodeRunError, RenderGraph, RenderGraphContext, RenderGraphError, - SlotInfo, SlotType, }, renderer::RenderContext, }; - use bevy_ecs::world::World; + use bevy_ecs::world::{FromWorld, World}; use bevy_utils::HashSet; #[derive(Debug)] - struct TestNode { - inputs: Vec, - outputs: Vec, - } - - impl TestNode { - pub fn new(inputs: usize, outputs: usize) -> Self { - TestNode { - inputs: (0..inputs) - .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView)) - .collect(), - outputs: (0..outputs) - .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView)) - .collect(), - } - } - } + struct TestNode {} impl Node for TestNode { - fn input(&self) -> Vec { - self.inputs.clone() - } - - fn output(&self) -> Vec { - self.outputs.clone() - } - fn run( &self, _: &mut RenderGraphContext, @@ -676,33 +421,33 @@ mod tests { } } + fn input_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { + graph + .iter_node_inputs(name) + .unwrap() + .map(|(_edge, node)| node.id) + .collect::>() + } + + fn output_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { + graph + .iter_node_outputs(name) + .unwrap() + .map(|(_edge, node)| node.id) + .collect::>() + } + #[test] fn test_graph_edges() { let mut graph = RenderGraph::default(); - let a_id = graph.add_node("A", TestNode::new(0, 1)); - let b_id = graph.add_node("B", TestNode::new(0, 1)); - let c_id = graph.add_node("C", TestNode::new(1, 1)); - let d_id = graph.add_node("D", TestNode::new(1, 0)); + let a_id = graph.add_node("A", TestNode {}); + let b_id = graph.add_node("B", TestNode {}); + let c_id = graph.add_node("C", TestNode {}); + let d_id = graph.add_node("D", TestNode {}); - graph.add_slot_edge("A", "out_0", "C", "in_0"); + graph.add_node_edge("A", "C"); graph.add_node_edge("B", "C"); - graph.add_slot_edge("C", 0, "D", 0); - - fn input_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { - graph - .iter_node_inputs(name) - .unwrap() - .map(|(_edge, node)| node.id) - .collect::>() - } - - fn output_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { - graph - .iter_node_outputs(name) - .unwrap() - .map(|(_edge, node)| node.id) - .collect::>() - } + graph.add_node_edge("C", "D"); assert!(input_nodes("A", &graph).is_empty(), "A has no inputs"); assert!( @@ -733,74 +478,64 @@ mod tests { } #[test] - fn test_get_node_typed() { - struct MyNode { - value: usize, - } + fn test_edge_already_exists() { + let mut graph = RenderGraph::default(); + + graph.add_node("A", TestNode {}); + graph.add_node("B", TestNode {}); + + graph.add_node_edge("A", "B"); + assert_eq!( + graph.try_add_node_edge("A", "B"), + Err(RenderGraphError::EdgeAlreadyExists(Edge { + output_node: graph.get_node_id("A").unwrap(), + input_node: graph.get_node_id("B").unwrap(), + })), + "Adding to a duplicate edge should return an error" + ); + } - impl Node for MyNode { + #[test] + fn test_add_node_with_edges() { + struct SimpleNode; + impl Node for SimpleNode { fn run( &self, - _: &mut RenderGraphContext, - _: &mut RenderContext, - _: &World, + _graph: &mut RenderGraphContext, + _render_context: &mut RenderContext, + _world: &World, ) -> Result<(), NodeRunError> { Ok(()) } } + impl FromWorld for SimpleNode { + fn from_world(_world: &mut World) -> Self { + Self + } + } let mut graph = RenderGraph::default(); + let a_id = graph.add_node("A", SimpleNode); + let c_id = graph.add_node("C", SimpleNode); - graph.add_node("A", MyNode { value: 42 }); + // A and C need to exist first + let b_id = graph.add_node_with_edges("B", SimpleNode, &["A", "B", "C"]); - let node: &MyNode = graph.get_node("A").unwrap(); - assert_eq!(node.value, 42, "node value matches"); - - let result: Result<&TestNode, RenderGraphError> = graph.get_node("A"); - assert_eq!( - result.unwrap_err(), - RenderGraphError::WrongNodeType, - "expect a wrong node type error" + assert!( + output_nodes("A", &graph) == HashSet::from_iter(vec![b_id]), + "A -> B" ); - } - - #[test] - fn test_slot_already_occupied() { - let mut graph = RenderGraph::default(); - - graph.add_node("A", TestNode::new(0, 1)); - graph.add_node("B", TestNode::new(0, 1)); - graph.add_node("C", TestNode::new(1, 1)); - - graph.add_slot_edge("A", 0, "C", 0); - assert_eq!( - graph.try_add_slot_edge("B", 0, "C", 0), - Err(RenderGraphError::NodeInputSlotAlreadyOccupied { - node: graph.get_node_id("C").unwrap(), - input_slot: 0, - occupied_by_node: graph.get_node_id("A").unwrap(), - }), - "Adding to a slot that is already occupied should return an error" + assert!( + input_nodes("B", &graph) == HashSet::from_iter(vec![a_id]), + "B -> C" ); - } - - #[test] - fn test_edge_already_exists() { - let mut graph = RenderGraph::default(); - - graph.add_node("A", TestNode::new(0, 1)); - graph.add_node("B", TestNode::new(1, 0)); - - graph.add_slot_edge("A", 0, "B", 0); - assert_eq!( - graph.try_add_slot_edge("A", 0, "B", 0), - Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge { - output_node: graph.get_node_id("A").unwrap(), - output_index: 0, - input_node: graph.get_node_id("B").unwrap(), - input_index: 0, - })), - "Adding to a duplicate edge should return an error" + assert!( + output_nodes("B", &graph) == HashSet::from_iter(vec![c_id]), + "B -> C" + ); + assert!( + input_nodes("C", &graph) == HashSet::from_iter(vec![b_id]), + "B -> C" ); } } diff --git a/crates/bevy_render/src/render_graph/mod.rs b/crates/bevy_render/src/render_graph/mod.rs index 9f5a5b7c712fb..964086a6d5ff9 100644 --- a/crates/bevy_render/src/render_graph/mod.rs +++ b/crates/bevy_render/src/render_graph/mod.rs @@ -2,13 +2,11 @@ mod context; mod edge; mod graph; mod node; -mod node_slot; pub use context::*; pub use edge::*; pub use graph::*; pub use node::*; -pub use node_slot::*; use thiserror::Error; @@ -16,31 +14,10 @@ use thiserror::Error; pub enum RenderGraphError { #[error("node does not exist")] InvalidNode(NodeLabel), - #[error("output node slot does not exist")] - InvalidOutputNodeSlot(SlotLabel), - #[error("input node slot does not exist")] - InvalidInputNodeSlot(SlotLabel), #[error("node does not match the given type")] WrongNodeType, - #[error("attempted to connect a node output slot to an incompatible input node slot")] - MismatchedNodeSlots { - output_node: NodeId, - output_slot: usize, - input_node: NodeId, - input_slot: usize, - }, #[error("attempted to add an edge that already exists")] EdgeAlreadyExists(Edge), #[error("attempted to remove an edge that does not exist")] EdgeDoesNotExist(Edge), - #[error("node has an unconnected input slot")] - UnconnectedNodeInputSlot { node: NodeId, input_slot: usize }, - #[error("node has an unconnected output slot")] - UnconnectedNodeOutputSlot { node: NodeId, output_slot: usize }, - #[error("node input slot already occupied")] - NodeInputSlotAlreadyOccupied { - node: NodeId, - input_slot: usize, - occupied_by_node: NodeId, - }, } diff --git a/crates/bevy_render/src/render_graph/node.rs b/crates/bevy_render/src/render_graph/node.rs index 11db4aca83ac8..f50a539252e19 100644 --- a/crates/bevy_render/src/render_graph/node.rs +++ b/crates/bevy_render/src/render_graph/node.rs @@ -1,13 +1,9 @@ use crate::{ define_atomic_id, - render_graph::{ - Edge, InputSlotError, OutputSlotError, RenderGraphContext, RenderGraphError, - RunSubGraphError, SlotInfo, SlotInfos, SlotType, SlotValue, - }, + render_graph::{Edge, RenderGraphContext, RenderGraphError, RunSubGraphError}, renderer::RenderContext, }; use bevy_ecs::world::World; -use downcast_rs::{impl_downcast, Downcast}; use std::{borrow::Cow, fmt::Debug}; use thiserror::Error; @@ -25,19 +21,7 @@ define_atomic_id!(NodeId); /// A node can produce outputs used as dependencies by other nodes. /// Those inputs and outputs are called slots and are the default way of passing render data /// inside the graph. For more information see [`SlotType`](super::SlotType). -pub trait Node: Downcast + Send + Sync + 'static { - /// Specifies the required input slots for this node. - /// They will then be available during the run method inside the [`RenderGraphContext`]. - fn input(&self) -> Vec { - Vec::new() - } - - /// Specifies the produced output slots for this node. - /// They can then be passed one inside [`RenderGraphContext`] during the run method. - fn output(&self) -> Vec { - Vec::new() - } - +pub trait Node: Send + Sync + 'static { /// Updates internal node state using the current render [`World`] prior to the run method. fn update(&mut self, _world: &mut World) {} @@ -52,14 +36,8 @@ pub trait Node: Downcast + Send + Sync + 'static { ) -> Result<(), NodeRunError>; } -impl_downcast!(Node); - #[derive(Error, Debug, Eq, PartialEq)] pub enum NodeRunError { - #[error("encountered an input slot error")] - InputSlotError(#[from] InputSlotError), - #[error("encountered an output slot error")] - OutputSlotError(#[from] OutputSlotError), #[error("encountered an error when running a sub-graph")] RunSubGraphError(#[from] RunSubGraphError), } @@ -138,42 +116,6 @@ impl Edges { pub fn has_output_edge(&self, edge: &Edge) -> bool { self.output_edges.contains(edge) } - - /// Searches the `input_edges` for a [`Edge::SlotEdge`], - /// which `input_index` matches the `index`; - pub fn get_input_slot_edge(&self, index: usize) -> Result<&Edge, RenderGraphError> { - self.input_edges - .iter() - .find(|e| { - if let Edge::SlotEdge { input_index, .. } = e { - *input_index == index - } else { - false - } - }) - .ok_or(RenderGraphError::UnconnectedNodeInputSlot { - input_slot: index, - node: self.id, - }) - } - - /// Searches the `output_edges` for a [`Edge::SlotEdge`], - /// which `output_index` matches the `index`; - pub fn get_output_slot_edge(&self, index: usize) -> Result<&Edge, RenderGraphError> { - self.output_edges - .iter() - .find(|e| { - if let Edge::SlotEdge { output_index, .. } = e { - *output_index == index - } else { - false - } - }) - .ok_or(RenderGraphError::UnconnectedNodeOutputSlot { - output_slot: index, - node: self.id, - }) - } } /// The internal representation of a [`Node`], with all data required @@ -186,8 +128,6 @@ pub struct NodeState { /// The name of the type that implements [`Node`]. pub type_name: &'static str, pub node: Box, - pub input_slots: SlotInfos, - pub output_slots: SlotInfos, pub edges: Edges, } @@ -207,8 +147,6 @@ impl NodeState { NodeState { id, name: None, - input_slots: node.input().into(), - output_slots: node.output().into(), node: Box::new(node), type_name: std::any::type_name::(), edges: Edges { @@ -218,44 +156,6 @@ impl NodeState { }, } } - - /// Retrieves the [`Node`]. - pub fn node(&self) -> Result<&T, RenderGraphError> - where - T: Node, - { - self.node - .downcast_ref::() - .ok_or(RenderGraphError::WrongNodeType) - } - - /// Retrieves the [`Node`] mutably. - pub fn node_mut(&mut self) -> Result<&mut T, RenderGraphError> - where - T: Node, - { - self.node - .downcast_mut::() - .ok_or(RenderGraphError::WrongNodeType) - } - - /// Validates that each input slot corresponds to an input edge. - pub fn validate_input_slots(&self) -> Result<(), RenderGraphError> { - for i in 0..self.input_slots.len() { - self.edges.get_input_slot_edge(i)?; - } - - Ok(()) - } - - /// Validates that each output slot corresponds to an output edge. - pub fn validate_output_slots(&self) -> Result<(), RenderGraphError> { - for i in 0..self.output_slots.len() { - self.edges.get_output_slot_edge(i)?; - } - - Ok(()) - } } /// A [`NodeLabel`] is used to reference a [`NodeState`] by either its name or [`NodeId`] @@ -322,20 +222,13 @@ impl RunGraphOnViewNode { } impl Node for RunGraphOnViewNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(Self::IN_VIEW, SlotType::Entity)] - } fn run( &self, graph: &mut RenderGraphContext, _render_context: &mut RenderContext, _world: &World, ) -> Result<(), NodeRunError> { - let view_entity = graph.get_input_entity(Self::IN_VIEW)?; - graph.run_sub_graph( - self.graph_name.clone(), - vec![SlotValue::Entity(view_entity)], - )?; + graph.run_sub_graph(self.graph_name.clone(), graph.view_entity())?; Ok(()) } } diff --git a/crates/bevy_render/src/render_graph/node_slot.rs b/crates/bevy_render/src/render_graph/node_slot.rs deleted file mode 100644 index fd5dc388d3812..0000000000000 --- a/crates/bevy_render/src/render_graph/node_slot.rs +++ /dev/null @@ -1,199 +0,0 @@ -use bevy_ecs::entity::Entity; -use std::{borrow::Cow, fmt}; - -use crate::render_resource::{Buffer, Sampler, TextureView}; - -/// A value passed between render [`Nodes`](super::Node). -/// Corresponds to the [`SlotType`] specified in the [`RenderGraph`](super::RenderGraph). -/// -/// Slots can have four different types of values: -/// [`Buffer`], [`TextureView`], [`Sampler`] and [`Entity`]. -/// -/// These values do not contain the actual render data, but only the ids to retrieve them. -#[derive(Debug, Clone)] -pub enum SlotValue { - /// A GPU-accessible [`Buffer`]. - Buffer(Buffer), - /// A [`TextureView`] describes a texture used in a pipeline. - TextureView(TextureView), - /// A texture [`Sampler`] defines how a pipeline will sample from a [`TextureView`]. - Sampler(Sampler), - /// An entity from the ECS. - Entity(Entity), -} - -impl SlotValue { - /// Returns the [`SlotType`] of this value. - pub fn slot_type(&self) -> SlotType { - match self { - SlotValue::Buffer(_) => SlotType::Buffer, - SlotValue::TextureView(_) => SlotType::TextureView, - SlotValue::Sampler(_) => SlotType::Sampler, - SlotValue::Entity(_) => SlotType::Entity, - } - } -} - -impl From for SlotValue { - fn from(value: Buffer) -> Self { - SlotValue::Buffer(value) - } -} - -impl From for SlotValue { - fn from(value: TextureView) -> Self { - SlotValue::TextureView(value) - } -} - -impl From for SlotValue { - fn from(value: Sampler) -> Self { - SlotValue::Sampler(value) - } -} - -impl From for SlotValue { - fn from(value: Entity) -> Self { - SlotValue::Entity(value) - } -} - -/// Describes the render resources created (output) or used (input) by -/// the render [`Nodes`](super::Node). -/// -/// This should not be confused with [`SlotValue`], which actually contains the passed data. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum SlotType { - /// A GPU-accessible [`Buffer`]. - Buffer, - /// A [`TextureView`] describes a texture used in a pipeline. - TextureView, - /// A texture [`Sampler`] defines how a pipeline will sample from a [`TextureView`]. - Sampler, - /// An entity from the ECS. - Entity, -} - -impl fmt::Display for SlotType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - SlotType::Buffer => "Buffer", - SlotType::TextureView => "TextureView", - SlotType::Sampler => "Sampler", - SlotType::Entity => "Entity", - }; - - f.write_str(s) - } -} - -/// A [`SlotLabel`] is used to reference a slot by either its name or index -/// inside the [`RenderGraph`](super::RenderGraph). -#[derive(Debug, Clone, Eq, PartialEq)] -pub enum SlotLabel { - Index(usize), - Name(Cow<'static, str>), -} - -impl From<&SlotLabel> for SlotLabel { - fn from(value: &SlotLabel) -> Self { - value.clone() - } -} - -impl From for SlotLabel { - fn from(value: String) -> Self { - SlotLabel::Name(value.into()) - } -} - -impl From<&'static str> for SlotLabel { - fn from(value: &'static str) -> Self { - SlotLabel::Name(value.into()) - } -} - -impl From> for SlotLabel { - fn from(value: Cow<'static, str>) -> Self { - SlotLabel::Name(value) - } -} - -impl From for SlotLabel { - fn from(value: usize) -> Self { - SlotLabel::Index(value) - } -} - -/// The internal representation of a slot, which specifies its [`SlotType`] and name. -#[derive(Clone, Debug)] -pub struct SlotInfo { - pub name: Cow<'static, str>, - pub slot_type: SlotType, -} - -impl SlotInfo { - pub fn new(name: impl Into>, slot_type: SlotType) -> Self { - SlotInfo { - name: name.into(), - slot_type, - } - } -} - -/// A collection of input or output [`SlotInfos`](SlotInfo) for -/// a [`NodeState`](super::NodeState). -#[derive(Default, Debug)] -pub struct SlotInfos { - slots: Vec, -} - -impl> From for SlotInfos { - fn from(slots: T) -> Self { - SlotInfos { - slots: slots.into_iter().collect(), - } - } -} - -impl SlotInfos { - /// Returns the count of slots. - #[inline] - pub fn len(&self) -> usize { - self.slots.len() - } - - /// Returns true if there are no slots. - #[inline] - pub fn is_empty(&self) -> bool { - self.slots.is_empty() - } - - /// Retrieves the [`SlotInfo`] for the provided label. - pub fn get_slot(&self, label: impl Into) -> Option<&SlotInfo> { - let label = label.into(); - let index = self.get_slot_index(label)?; - self.slots.get(index) - } - - /// Retrieves the [`SlotInfo`] for the provided label mutably. - pub fn get_slot_mut(&mut self, label: impl Into) -> Option<&mut SlotInfo> { - let label = label.into(); - let index = self.get_slot_index(label)?; - self.slots.get_mut(index) - } - - /// Retrieves the index (inside input or output slots) of the slot for the provided label. - pub fn get_slot_index(&self, label: impl Into) -> Option { - let label = label.into(); - match label { - SlotLabel::Index(index) => Some(index), - SlotLabel::Name(ref name) => self.slots.iter().position(|s| s.name == *name), - } - } - - /// Returns an iterator over the slot infos. - pub fn iter(&self) -> impl Iterator { - self.slots.iter() - } -} diff --git a/crates/bevy_render/src/renderer/graph_runner.rs b/crates/bevy_render/src/renderer/graph_runner.rs index e11c61f6ca064..5005de748f75a 100644 --- a/crates/bevy_render/src/renderer/graph_runner.rs +++ b/crates/bevy_render/src/renderer/graph_runner.rs @@ -1,18 +1,14 @@ -use bevy_ecs::world::World; +use bevy_ecs::{prelude::Entity, world::World}; #[cfg(feature = "trace")] use bevy_utils::tracing::info_span; -use bevy_utils::HashMap; -use smallvec::{smallvec, SmallVec}; +use bevy_utils::HashSet; #[cfg(feature = "trace")] use std::ops::Deref; use std::{borrow::Cow, collections::VecDeque}; use thiserror::Error; use crate::{ - render_graph::{ - Edge, NodeId, NodeRunError, NodeState, RenderGraph, RenderGraphContext, SlotLabel, - SlotType, SlotValue, - }, + render_graph::{NodeId, NodeRunError, NodeState, RenderGraph, RenderGraphContext}, renderer::{RenderContext, RenderDevice}, }; @@ -22,33 +18,6 @@ pub(crate) struct RenderGraphRunner; pub enum RenderGraphRunnerError { #[error(transparent)] NodeRunError(#[from] NodeRunError), - #[error("node output slot not set (index {slot_index}, name {slot_name})")] - EmptyNodeOutputSlot { - type_name: &'static str, - slot_index: usize, - slot_name: Cow<'static, str>, - }, - #[error("graph (name: '{graph_name:?}') could not be run because slot '{slot_name}' at index {slot_index} has no value")] - MissingInput { - slot_index: usize, - slot_name: Cow<'static, str>, - graph_name: Option>, - }, - #[error("attempted to use the wrong type for input slot")] - MismatchedInputSlotType { - slot_index: usize, - label: SlotLabel, - expected: SlotType, - actual: SlotType, - }, - #[error( - "node (name: '{node_name:?}') has {slot_count} input slots, but was provided {value_count} values" - )] - MismatchedInputCount { - node_name: Option>, - slot_count: usize, - value_count: usize, - }, } impl RenderGraphRunner { @@ -59,7 +28,7 @@ impl RenderGraphRunner { world: &World, ) -> Result<(), RenderGraphRunnerError> { let mut render_context = RenderContext::new(render_device); - Self::run_graph(graph, None, &mut render_context, world, &[])?; + Self::run_graph(graph, None, &mut render_context, world, None)?; { #[cfg(feature = "trace")] let _span = info_span!("submit_graph_commands").entered(); @@ -70,12 +39,13 @@ impl RenderGraphRunner { fn run_graph( graph: &RenderGraph, + #[allow(unused)] // This is only used in when trace is enabled graph_name: Option>, render_context: &mut RenderContext, world: &World, - inputs: &[SlotValue], + view_entity: Option, ) -> Result<(), RenderGraphRunnerError> { - let mut node_outputs: HashMap> = HashMap::default(); + let mut node_completed: HashSet = HashSet::default(); #[cfg(feature = "trace")] let span = if let Some(name) = &graph_name { info_span!("run_graph", name = name.deref()) @@ -85,100 +55,34 @@ impl RenderGraphRunner { #[cfg(feature = "trace")] let _guard = span.enter(); - // Queue up nodes without inputs, which can be run immediately - let mut node_queue: VecDeque<&NodeState> = graph - .iter_nodes() - .filter(|node| node.input_slots.is_empty()) - .collect(); - - // pass inputs into the graph - if let Some(input_node) = graph.get_input_node() { - let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new(); - for (i, input_slot) in input_node.input_slots.iter().enumerate() { - if let Some(input_value) = inputs.get(i) { - if input_slot.slot_type != input_value.slot_type() { - return Err(RenderGraphRunnerError::MismatchedInputSlotType { - slot_index: i, - actual: input_value.slot_type(), - expected: input_slot.slot_type, - label: input_slot.name.clone().into(), - }); - } - input_values.push(input_value.clone()); - } else { - return Err(RenderGraphRunnerError::MissingInput { - slot_index: i, - slot_name: input_slot.name.clone(), - graph_name, - }); - } - } - - node_outputs.insert(input_node.id, input_values); - - for (_, node_state) in graph.iter_node_outputs(input_node.id).expect("node exists") { - node_queue.push_front(node_state); - } - } + // Queue up nodes + let mut node_queue: VecDeque<&NodeState> = graph.iter_nodes().collect(); 'handle_node: while let Some(node_state) = node_queue.pop_back() { // skip nodes that are already processed - if node_outputs.contains_key(&node_state.id) { + if node_completed.contains(&node_state.id) { continue; } - let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new(); // check if all dependencies have finished running - for (edge, input_node) in graph + for (_edge, input_node) in graph .iter_node_inputs(node_state.id) .expect("node is in graph") { - match edge { - Edge::SlotEdge { - output_index, - input_index, - .. - } => { - if let Some(outputs) = node_outputs.get(&input_node.id) { - slot_indices_and_inputs - .push((*input_index, outputs[*output_index].clone())); - } else { - node_queue.push_front(node_state); - continue 'handle_node; - } - } - Edge::NodeEdge { .. } => { - if !node_outputs.contains_key(&input_node.id) { - node_queue.push_front(node_state); - continue 'handle_node; - } - } + if !node_completed.contains(&input_node.id) { + node_queue.push_front(node_state); + continue 'handle_node; } } - // construct final sorted input list - slot_indices_and_inputs.sort_by_key(|(index, _)| *index); - let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs - .into_iter() - .map(|(_, value)| value) - .collect(); - - if inputs.len() != node_state.input_slots.len() { - return Err(RenderGraphRunnerError::MismatchedInputCount { - node_name: node_state.name.clone(), - slot_count: node_state.input_slots.len(), - value_count: inputs.len(), - }); - } - - let mut outputs: SmallVec<[Option; 4]> = - smallvec![None; node_state.output_slots.len()]; { - let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs); + let mut context = RenderGraphContext::new(); + if let Some(view_entity) = view_entity { + context.set_view_entity(view_entity); + } { #[cfg(feature = "trace")] let _span = info_span!("node", name = node_state.type_name).entered(); - node_state.node.run(&mut context, render_context, world)?; } @@ -191,25 +95,12 @@ impl RenderGraphRunner { Some(run_sub_graph.name), render_context, world, - &run_sub_graph.inputs, + Some(run_sub_graph.view_entity), )?; } } - let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new(); - for (i, output) in outputs.into_iter().enumerate() { - if let Some(value) = output { - values.push(value); - } else { - let empty_slot = node_state.output_slots.get_slot(i).unwrap(); - return Err(RenderGraphRunnerError::EmptyNodeOutputSlot { - type_name: node_state.type_name, - slot_index: i, - slot_name: empty_slot.name.clone(), - }); - } - } - node_outputs.insert(node_state.id, values); + node_completed.insert(node_state.id); for (_, node_state) in graph.iter_node_outputs(node_state.id).expect("node exists") { node_queue.push_front(node_state); diff --git a/crates/bevy_ui/src/render/mod.rs b/crates/bevy_ui/src/render/mod.rs index 029fd5d8a1659..e32905e4b9098 100644 --- a/crates/bevy_ui/src/render/mod.rs +++ b/crates/bevy_ui/src/render/mod.rs @@ -19,7 +19,7 @@ use bevy_render::{ camera::Camera, color::Color, render_asset::RenderAssets, - render_graph::{RenderGraph, RunGraphOnViewNode, SlotInfo, SlotType}, + render_graph::{RenderGraph, RunGraphOnViewNode}, render_phase::{sort_phase_system, AddRenderCommand, DrawFunctions, RenderPhase}, render_resource::*, renderer::{RenderDevice, RenderQueue}, @@ -107,12 +107,6 @@ pub fn build_ui_render(app: &mut App) { bevy_core_pipeline::core_2d::graph::node::MAIN_PASS, draw_ui_graph::node::UI_PASS, ); - graph_2d.add_slot_edge( - graph_2d.input_node().id, - bevy_core_pipeline::core_2d::graph::input::VIEW_ENTITY, - draw_ui_graph::node::UI_PASS, - RunGraphOnViewNode::IN_VIEW, - ); graph_2d.add_node_edge( bevy_core_pipeline::core_2d::graph::node::END_MAIN_PASS_POST_PROCESSING, draw_ui_graph::node::UI_PASS, @@ -141,12 +135,6 @@ pub fn build_ui_render(app: &mut App) { draw_ui_graph::node::UI_PASS, bevy_core_pipeline::core_3d::graph::node::UPSCALING, ); - graph_3d.add_slot_edge( - graph_3d.input_node().id, - bevy_core_pipeline::core_3d::graph::input::VIEW_ENTITY, - draw_ui_graph::node::UI_PASS, - RunGraphOnViewNode::IN_VIEW, - ); } } @@ -154,16 +142,6 @@ fn get_ui_graph(render_app: &mut App) -> RenderGraph { let ui_pass_node = UiPassNode::new(&mut render_app.world); let mut ui_graph = RenderGraph::default(); ui_graph.add_node(draw_ui_graph::node::UI_PASS, ui_pass_node); - let input_node_id = ui_graph.set_input(vec![SlotInfo::new( - draw_ui_graph::input::VIEW_ENTITY, - SlotType::Entity, - )]); - ui_graph.add_slot_edge( - input_node_id, - draw_ui_graph::input::VIEW_ENTITY, - draw_ui_graph::node::UI_PASS, - UiPassNode::IN_VIEW, - ); ui_graph } diff --git a/crates/bevy_ui/src/render/render_pass.rs b/crates/bevy_ui/src/render/render_pass.rs index 63ff0a47f989c..db06cfc1e626f 100644 --- a/crates/bevy_ui/src/render/render_pass.rs +++ b/crates/bevy_ui/src/render/render_pass.rs @@ -37,10 +37,6 @@ impl UiPassNode { } impl Node for UiPassNode { - fn input(&self) -> Vec { - vec![SlotInfo::new(UiPassNode::IN_VIEW, SlotType::Entity)] - } - fn update(&mut self, world: &mut World) { self.ui_view_query.update_archetypes(world); self.default_camera_view_query.update_archetypes(world); @@ -52,7 +48,7 @@ impl Node for UiPassNode { render_context: &mut RenderContext, world: &World, ) -> Result<(), NodeRunError> { - let input_view_entity = graph.get_input_entity(Self::IN_VIEW)?; + let input_view_entity = graph.view_entity(); let Ok((transparent_phase, target, camera_ui)) = self.ui_view_query.get_manual(world, input_view_entity) diff --git a/examples/README.md b/examples/README.md index 8fd8f7c29b8d0..6b30aad3ede15 100644 --- a/examples/README.md +++ b/examples/README.md @@ -278,7 +278,8 @@ Example | Description [Material - GLSL](../examples/shader/shader_material_glsl.rs) | A shader that uses the GLSL shading language [Material - Screenspace Texture](../examples/shader/shader_material_screenspace_texture.rs) | A shader that samples a texture with view-independent UV coordinates [Material Prepass](../examples/shader/shader_prepass.rs) | A shader that uses the various textures generated by the prepass -[Post Processing](../examples/shader/post_processing.rs) | A custom post processing effect, using two cameras, with one reusing the render texture of the first one +[Post Processing - Custom Render Pass](../examples/shader/post_process_pass.rs) | A custom post processing effect, using a custom render pass that runs after the main pass +[Post Processing - Render To Texture](../examples/shader/post_processing.rs) | A custom post processing effect, using two cameras, with one reusing the render texture of the first one [Shader Defs](../examples/shader/shader_defs.rs) | A shader that uses "shaders defs" (a bevy tool to selectively toggle parts of a shader) [Texture Binding Array (Bindless Textures)](../examples/shader/texture_binding_array.rs) | A shader that shows how to bind and sample multiple textures as a binding array (a.k.a. bindless textures). diff --git a/examples/shader/post_process_pass.rs b/examples/shader/post_process_pass.rs new file mode 100644 index 0000000000000..19c5fbffc0328 --- /dev/null +++ b/examples/shader/post_process_pass.rs @@ -0,0 +1,395 @@ +//! This example shows how to create a custom render pass that runs after the main pass +//! and reads the texture generated by the main pass. +//! +//! The example shader is a very simple implementation of chromatic aberration. +//! +//! This is a fairly low level example and assumes some familiarity with rendering concepts and wgpu. + +use bevy::{ + core_pipeline::{ + add_node, clear_color::ClearColorConfig, core_3d, + fullscreen_vertex_shader::fullscreen_shader_vertex_state, + }, + prelude::*, + render::{ + extract_component::{ + ComponentUniforms, ExtractComponent, ExtractComponentPlugin, UniformComponentPlugin, + }, + render_graph::{Node, NodeRunError, RenderGraphContext}, + render_resource::{ + BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor, + BindGroupLayoutEntry, BindingResource, BindingType, CachedRenderPipelineId, + ColorTargetState, ColorWrites, FragmentState, MultisampleState, Operations, + PipelineCache, PrimitiveState, RenderPassColorAttachment, RenderPassDescriptor, + RenderPipelineDescriptor, Sampler, SamplerBindingType, SamplerDescriptor, ShaderStages, + ShaderType, TextureFormat, TextureSampleType, TextureViewDimension, + }, + renderer::{RenderContext, RenderDevice}, + texture::BevyDefault, + view::{ExtractedView, ViewTarget}, + RenderApp, + }, +}; + +fn main() { + App::new() + .add_plugins(DefaultPlugins.set(AssetPlugin { + // Hot reloading the shader works correctly + watch_for_changes: true, + ..default() + })) + .add_plugin(PostProcessPlugin) + .add_startup_system(setup) + .add_system(rotate) + .add_system(update_settings) + .run(); +} + +/// It is generally encouraged to set up post processing effects as a plugin +struct PostProcessPlugin; +impl Plugin for PostProcessPlugin { + fn build(&self, app: &mut App) { + app + // The settings will be a component that lives in the main world but will + // be extracted to the render world every frame. + // This makes it possible to control the effect from the main world. + // This plugin will take care of extracting it automatically. + // It's important to derive [`ExtractComponent`] on [`PostProcessingSettings`] + // for this plugin to work correctly. + .add_plugin(ExtractComponentPlugin::::default()) + // The settings will also be the data used in the shader. + // This plugin will prepare the component for the GPU by creating a uniform buffer + // and writing the data to that buffer every frame. + .add_plugin(UniformComponentPlugin::::default()); + + // We need to get the render app from the main app + let Ok(render_app) = app.get_sub_app_mut(RenderApp) else { + return; + }; + + // Initialize the pipeline + render_app.init_resource::(); + + // Bevy's renderer uses a render graph which is a collection of nodes in a directed acyclic graph. + // It currently runs on each view/camera and executes each node in the specified order. + // It will make sure that any node that needs a dependency from another node + // only runs when that dependency is done. + // + // Each node can execute arbitrary work, but it generally runs at least one render pass. + // A node only has access to the render world, so if you need data from the main world + // you need to extract it manually or with the plugin like above. + + // Utility function to add a Node to the RenderGraph + // * It creates the Node using the FromWorld implementation + // * Adds it to the graph + // * Automatically adds the required node edges and slot edges based on the given ordering + add_node::( + render_app, + core_3d::graph::NAME, + PostProcessNode::NAME, + &[ + core_3d::graph::node::MAIN_PASS, + PostProcessNode::NAME, + core_3d::graph::node::END_MAIN_PASS_POST_PROCESSING, + ], + ); + } +} + +/// The post process node used for the render graph +struct PostProcessNode { + // The node needs a query to gather data from the ECS in order to do its rendering, + // but it's not a normal system so we need to define it manually. + query: QueryState<&'static ViewTarget, With>, +} + +impl PostProcessNode { + pub const NAME: &str = "post_process"; +} + +impl FromWorld for PostProcessNode { + fn from_world(world: &mut World) -> Self { + Self { + query: QueryState::new(world), + } + } +} + +impl Node for PostProcessNode { + // This will run every frame before the run() method + // The important difference is that `self` is `mut` here + fn update(&mut self, world: &mut World) { + // Since this is not a system we need to update the query manually. + // This is mostly boilerplate. There are plans to remove this in the future. + // For now, you can just copy it. + self.query.update_archetypes(world); + } + + // Runs the node logic + // This is where you encode draw commands. + // + // This will run on every view on which the graph is running. If you don't want your effect to run on every camera, + // you'll need to make sure you have a marker component to identify which camera(s) should run the effect. + fn run( + &self, + graph_context: &mut RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), NodeRunError> { + // Get the entity of the view for the render graph where this node is running + let view_entity = graph_context.view_entity(); + + // We get the data we need from the world based on the view entity passed to the node. + // The data is the query that was defined earlier in the [`PostProcessNode`] + let Ok(view_target) = self.query.get_manual(world, view_entity) else { + return Ok(()); + }; + + // Get the pipeline resource that contains the global data we need to create the render pipeline + let post_process_pipeline = world.resource::(); + + // The pipeline cache is a cache of all previously created pipelines. + // It is required to avoid creating a new pipeline each frame, which is expensive due to shader compilation. + let pipeline_cache = world.resource::(); + + // Get the pipeline from the cache + let Some(pipeline) = pipeline_cache.get_render_pipeline(post_process_pipeline.pipeline_id) else { + return Ok(()); + }; + + // Get the settings uniform binding + let settings_uniforms = world.resource::>(); + let Some(settings_binding) = settings_uniforms.uniforms().binding() else { + return Ok(()); + }; + + // This will start a new "post process write", obtaining two texture + // views from the view target - a `source` and a `destination`. + // `source` is the "current" main texture and you _must_ write into + // `destination` because calling `post_process_write()` on the + // [`ViewTarget`] will internally flip the [`ViewTarget`]'s main + // texture to the `destination` texture. Failing to do so will cause + // the current main texture information to be lost. + let post_process = view_target.post_process_write(); + + // The bind_group gets created each frame. + // + // Normally, you would create a bind_group in the Queue stage, + // but this doesn't work with the post_process_write(). + // The reason it doesn't work is because each post_process_write will + // alternate the source/destination. + // The only way to have the correct source/destination for the bind_group + // is to make sure you get it during the node execution. + let bind_group = render_context + .render_device() + .create_bind_group(&BindGroupDescriptor { + label: Some("post_process_bind_group"), + layout: &post_process_pipeline.layout, + // It's important for this to match the BindGroupLayout defined in + // the PostProcessPipeline + entries: &[ + BindGroupEntry { + binding: 0, + // Make sure to use the source view + resource: BindingResource::TextureView(post_process.source), + }, + BindGroupEntry { + binding: 1, + // Use the sampler created for the pipeline + resource: BindingResource::Sampler(&post_process_pipeline.sampler), + }, + BindGroupEntry { + binding: 2, + // Set the settings binding + resource: settings_binding.clone(), + }, + ], + }); + + // Begin the render pass + let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor { + label: Some("post_process_pass"), + color_attachments: &[Some(RenderPassColorAttachment { + // We need to specify the post process destination view here + // to make sure we write to the appropriate texture. + view: post_process.destination, + resolve_target: None, + ops: Operations::default(), + })], + depth_stencil_attachment: None, + }); + + // This is mostly just wgpu boilerplate for drawing a fullscreen triangle, + // using the pipeline/bind_group created above + render_pass.set_render_pipeline(pipeline); + render_pass.set_bind_group(0, &bind_group, &[]); + render_pass.draw(0..3, 0..1); + + Ok(()) + } +} + +// This contains global data used by the render pipeline. This will be created once on startup. +#[derive(Resource)] +struct PostProcessPipeline { + layout: BindGroupLayout, + sampler: Sampler, + pipeline_id: CachedRenderPipelineId, +} + +impl FromWorld for PostProcessPipeline { + fn from_world(world: &mut World) -> Self { + let render_device = world.resource::(); + + // We need to define the bind group layout used for our pipeline + let layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("post_process_bind_group_layout"), + entries: &[ + // The screen texture + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Texture { + sample_type: TextureSampleType::Float { filterable: true }, + view_dimension: TextureViewDimension::D2, + multisampled: false, + }, + count: None, + }, + // The sampler that will be used to sample the screen texture + BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Sampler(SamplerBindingType::Filtering), + count: None, + }, + // The settings uniform that will control the effect + BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Buffer { + ty: bevy::render::render_resource::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + }); + + // We can create the sampler here since it won't change at runtime and doesn't depend on the view + let sampler = render_device.create_sampler(&SamplerDescriptor::default()); + + // Get the shader handle + let shader = world + .resource::() + .load("shaders/post_process_pass.wgsl"); + + let pipeline_id = world + .resource_mut::() + // This will add the pipeline to the cache and queue it's creation + .queue_render_pipeline(RenderPipelineDescriptor { + label: Some("post_process_pipeline".into()), + layout: vec![layout.clone()], + // This will setup a fullscreen triangle for the vertex state + vertex: fullscreen_shader_vertex_state(), + fragment: Some(FragmentState { + shader, + shader_defs: vec![], + // Make sure this matches the entry point of your shader. + // It can be anything as long as it matches here and in the shader. + entry_point: "fragment".into(), + targets: vec![Some(ColorTargetState { + format: TextureFormat::bevy_default(), + blend: None, + write_mask: ColorWrites::ALL, + })], + }), + // All of the following property are not important for this effect so just use the default values. + // This struct doesn't have the Default trai implemented because not all field can have a default value. + primitive: PrimitiveState::default(), + depth_stencil: None, + multisample: MultisampleState::default(), + push_constant_ranges: vec![], + }); + + Self { + layout, + sampler, + pipeline_id, + } + } +} + +// This is the component that will get passed to the shader +#[derive(Component, Default, Clone, Copy, ExtractComponent, ShaderType)] +struct PostProcessSettings { + intensity: f32, +} + +/// Set up a simple 3D scene +fn setup( + mut commands: Commands, + mut meshes: ResMut>, + mut materials: ResMut>, +) { + // camera + commands.spawn(( + Camera3dBundle { + transform: Transform::from_translation(Vec3::new(0.0, 0.0, 5.0)) + .looking_at(Vec3::default(), Vec3::Y), + camera_3d: Camera3d { + clear_color: ClearColorConfig::Custom(Color::WHITE), + ..default() + }, + ..default() + }, + // Add the setting to the camera. + // This component is also used to determine on which camera to + // run the post processing effect. + PostProcessSettings { intensity: 0.02 }, + )); + + // cube + commands.spawn(( + PbrBundle { + mesh: meshes.add(Mesh::from(shape::Cube { size: 1.0 })), + material: materials.add(Color::rgb(0.8, 0.7, 0.6).into()), + transform: Transform::from_xyz(0.0, 0.5, 0.0), + ..default() + }, + Rotates, + )); + // light + commands.spawn(PointLightBundle { + transform: Transform::from_translation(Vec3::new(0.0, 0.0, 10.0)), + ..default() + }); +} + +#[derive(Component)] +struct Rotates; + +/// Rotates any entity around the x and y axis +fn rotate(time: Res