diff --git a/deno_webgpu/command_encoder.rs b/deno_webgpu/command_encoder.rs index 20dfe0db09..b82fba92ea 100644 --- a/deno_webgpu/command_encoder.rs +++ b/deno_webgpu/command_encoder.rs @@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass( None }; + let instance = state.borrow::(); + let command_encoder = &command_encoder_resource.1; let descriptor = wgpu_core::command::ComputePassDescriptor { label: Some(label), timestamp_writes: timestamp_writes.as_ref(), }; - let compute_pass = - wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor); + let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor)); let rid = state .resource_table diff --git a/deno_webgpu/compute_pass.rs b/deno_webgpu/compute_pass.rs index 2cdea2c8f2..fb499e7e06 100644 --- a/deno_webgpu/compute_pass.rs +++ b/deno_webgpu/compute_pass.rs @@ -10,7 +10,9 @@ use std::cell::RefCell; use super::error::WebGpuResult; -pub(crate) struct WebGpuComputePass(pub(crate) RefCell); +pub(crate) struct WebGpuComputePass( + pub(crate) RefCell>, +); impl Resource for WebGpuComputePass { fn name(&self) -> Cow { "webGPUComputePass".into() @@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline( - &mut compute_pass_resource.0.borrow_mut(), - compute_pipeline_resource.1, - ); + compute_pass_resource + .0 + .borrow_mut() + .set_pipeline(state.borrow(), compute_pipeline_resource.1)?; Ok(WebGpuResult::empty()) } @@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups( - &mut compute_pass_resource.0.borrow_mut(), - x, - y, - z, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups(state.borrow(), x, y, z); Ok(WebGpuResult::empty()) } @@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect( - &mut compute_pass_resource.0.borrow_mut(), - buffer_resource.1, - indirect_offset, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?; Ok(WebGpuResult::empty()) } @@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( #[serde] pub fn op_webgpu_compute_pass_end( state: &mut OpState, - #[smi] command_encoder_rid: ResourceId, #[smi] compute_pass_rid: ResourceId, ) -> Result { - let command_encoder_resource = - state - .resource_table - .get::(command_encoder_rid)?; - let command_encoder = command_encoder_resource.1; let compute_pass_resource = state .resource_table .take::(compute_pass_rid)?; - let compute_pass = &compute_pass_resource.0.borrow(); - let instance = state.borrow::(); - gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass( - command_encoder, - compute_pass - )) + compute_pass_resource.0.borrow_mut().run(state.borrow())?; + + Ok(WebGpuResult::empty()) } #[op2] @@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group( let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len]; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().set_bind_group( + state.borrow(), index, bind_group_resource.1, dynamic_offsets_data, - ); + )?; Ok(WebGpuResult::empty()) } @@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().push_debug_group( + state.borrow(), group_label, 0, // wgpu#975 ); @@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group( - &mut compute_pass_resource.0.borrow_mut(), - ); + compute_pass_resource + .0 + .borrow_mut() + .pop_debug_group(state.borrow()); Ok(WebGpuResult::empty()) } @@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().insert_debug_marker( + state.borrow(), marker_label, 0, // wgpu#975 ); diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs index 863c03b304..8a07b76201 100644 --- a/tests/tests/compute_pass_resource_ownership.rs +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -115,6 +115,10 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) { cpass.set_bind_group(0, &bind_group, &[]); cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); + // TODO: + // write_timestamp + // begin_pipeline_statistics_query + // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. drop(pipeline); drop(bind_group); diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index e6c4b7b978..6eadaa8b00 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -271,6 +271,14 @@ impl Global { ComputePass::new(parent_id, desc) } + pub fn command_encoder_create_compute_pass_dyn( + &self, + parent_id: id::CommandEncoderId, + desc: &ComputePassDescriptor, + ) -> Box { + Box::new(ComputePass::::new(parent_id, desc)) + } + pub fn command_encoder_run_compute_pass( &self, pass: &ComputePass, diff --git a/wgpu-core/src/command/dyn_compute_pass.rs b/wgpu-core/src/command/dyn_compute_pass.rs new file mode 100644 index 0000000000..b7ffea3d42 --- /dev/null +++ b/wgpu-core/src/command/dyn_compute_pass.rs @@ -0,0 +1,136 @@ +use wgt::WasmNotSendSync; + +use crate::{global, hal_api::HalApi, id}; + +use super::{ComputePass, ComputePassError}; + +/// Trait for type erasing ComputePass. +// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. +// Practically speaking this allows us merge gfx_select with type erasure: +// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. +pub trait DynComputePass: std::fmt::Debug + WasmNotSendSync { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError>; + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError>; + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError>; + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]); + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ); + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError>; + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32); + fn pop_debug_group(&mut self, context: &global::Global); + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32); + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn end_pipeline_statistics_query(&mut self, context: &global::Global); +} + +impl DynComputePass for ComputePass { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError> { + context.command_encoder_run_compute_pass(self) + } + + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError> { + context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) + } + + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError> { + context.compute_pass_set_pipeline(self, pipeline_id) + } + + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]) { + context.compute_pass_set_push_constant(self, offset, data) + } + + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ) { + context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) + } + + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError> { + context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) + } + + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_push_debug_group(self, label, color) + } + + fn pop_debug_group(&mut self, context: &global::Global) { + context.compute_pass_pop_debug_group(self) + } + + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_insert_debug_marker(self, label, color) + } + + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_write_timestamp(self, query_set_id, query_index) + } + + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) + } + + fn end_pipeline_statistics_query(&mut self, context: &global::Global) { + context.compute_pass_end_pipeline_statistics_query(self) + } +} diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 97eb455c7a..252fc89f7c 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -5,6 +5,7 @@ mod clear; mod compute; mod compute_command; mod draw; +mod dyn_compute_pass; mod memory_init; mod query; mod render; @@ -14,8 +15,8 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*, - render::*, transfer::*, + bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, + dyn_compute_pass::DynComputePass, query::*, render::*, transfer::*, }; pub(crate) use allocator::CommandAllocator; diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 19d12d3390..4779910055 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -301,17 +301,6 @@ macro_rules! gfx_select { other => panic!("Unexpected backend {:?}", other), } }; - - ($id:expr => $method:ident $params:tt) => { - match $id.backend() { - wgt::Backend::Vulkan => $crate::gfx_if_vulkan!($method::<$crate::api::Vulkan> $params), - wgt::Backend::Metal => $crate::gfx_if_metal!($method::<$crate::api::Metal> $params), - wgt::Backend::Dx12 => $crate::gfx_if_dx12!($method::<$crate::api::Dx12> $params), - wgt::Backend::Gl => $crate::gfx_if_gles!($method::<$crate::api::Gles> $params), - wgt::Backend::Empty => $crate::gfx_if_empty!($method::<$crate::api::Empty> $params), - other => panic!("Unexpected backend {:?}", other), - } - }; } #[cfg(feature = "api_log_info")] diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 55469ba859..efedee66cd 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -25,7 +25,6 @@ use std::{ use wgc::{ command::{bundle_ffi::*, render_commands::*}, gfx_select, - hal_api::HalApi, id::CommandEncoderId, }; use wgc::{device::DeviceLostClosure, id::TextureViewId}; @@ -473,141 +472,10 @@ impl Queue { #[derive(Debug)] pub struct ComputePass { - pass: Box, + pass: Box, error_sink: ErrorSink, } -/// Trait for type erasing ComputePass. -// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. -// Practically speaking this allows us merge gfx_select with type erasure: -// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. -trait DynComputePass: std::fmt::Debug + WasmNotSendSync { - fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError>; - fn set_bind_group( - &mut self, - context: &wgc::global::Global, - index: u32, - bind_group_id: wgc::id::BindGroupId, - offsets: &[wgt::DynamicOffset], - ) -> Result<(), wgc::command::ComputePassError>; - fn set_pipeline( - &mut self, - context: &wgc::global::Global, - pipeline_id: wgc::id::ComputePipelineId, - ) -> Result<(), wgc::command::ComputePassError>; - fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]); - fn dispatch_workgroups( - &mut self, - context: &wgc::global::Global, - groups_x: u32, - groups_y: u32, - groups_z: u32, - ); - fn dispatch_workgroups_indirect( - &mut self, - context: &wgc::global::Global, - buffer_id: wgc::id::BufferId, - offset: wgt::BufferAddress, - ) -> Result<(), wgc::command::ComputePassError>; - fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32); - fn pop_debug_group(&mut self, context: &wgc::global::Global); - fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32); - fn write_timestamp( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError>; - fn begin_pipeline_statistics_query( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError>; - fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global); -} - -impl DynComputePass for wgc::command::ComputePass { - fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError> { - context.command_encoder_run_compute_pass(self) - } - - fn set_bind_group( - &mut self, - context: &wgc::global::Global, - index: u32, - bind_group_id: wgc::id::BindGroupId, - offsets: &[wgt::DynamicOffset], - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) - } - - fn set_pipeline( - &mut self, - context: &wgc::global::Global, - pipeline_id: wgc::id::ComputePipelineId, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_set_pipeline(self, pipeline_id) - } - - fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]) { - context.compute_pass_set_push_constant(self, offset, data) - } - - fn dispatch_workgroups( - &mut self, - context: &wgc::global::Global, - groups_x: u32, - groups_y: u32, - groups_z: u32, - ) { - context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) - } - - fn dispatch_workgroups_indirect( - &mut self, - context: &wgc::global::Global, - buffer_id: wgc::id::BufferId, - offset: wgt::BufferAddress, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) - } - - fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32) { - context.compute_pass_push_debug_group(self, label, color) - } - - fn pop_debug_group(&mut self, context: &wgc::global::Global) { - context.compute_pass_pop_debug_group(self) - } - - fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32) { - context.compute_pass_insert_debug_marker(self, label, color) - } - - fn write_timestamp( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_write_timestamp(self, query_set_id, query_index) - } - - fn begin_pipeline_statistics_query( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) - } - - fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global) { - context.compute_pass_end_pipeline_statistics_query(self) - } -} - #[derive(Debug)] pub struct CommandEncoder { error_sink: ErrorSink, @@ -1967,22 +1835,10 @@ impl crate::Context for ContextWgpuCore { end_of_pass_write_index: tw.end_of_pass_write_index, }); - fn create_dyn_compute_pass( - context: &ContextWgpuCore, - encoder: wgc::id::CommandEncoderId, - desc: &wgc::command::ComputePassDescriptor<'_>, - ) -> Box { - Box::new( - context - .0 - .command_encoder_create_compute_pass::(encoder, desc), - ) - } - ( Unused, Self::ComputePassData { - pass: gfx_select!(encoder => create_dyn_compute_pass(self, *encoder, &wgc::command::ComputePassDescriptor { + pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass_dyn(*encoder, &wgc::command::ComputePassDescriptor { label: desc.label.map(Borrowed), timestamp_writes: timestamp_writes.as_ref(), })),